{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE TypeFamilies #-}

-- | Wrapper to provide safer interface in constructing an MCP server.
module Unison.MCP.Wrapper
  ( Tool (..),
    Prompt (..),
    HasInputSchema (..),
    mkServer,
    CallToolResult (..),
    PromptArgument (..),
    StaticResources,
    Server,
    MCP.ServerCapabilities (..),
    MCP.ToolAnnotations (..),
    MCP.Implementation (..),
    MCP.ResourcesCapability (..),
    MCP.ToolsCapability (..),
    MCP.PromptsCapability (..),
    MCP.PromptContentType (..),
    errorToolResult,
    textToolResult,
    jsonToolResult,
  )
where

import Data.Aeson (FromJSON)
import Data.Aeson qualified as Aeson
import Data.ByteString.Lazy.Char8 qualified as BL
import Data.Data (Proxy)
import Data.Map qualified as Map
import Data.Text qualified as Text
import Network.MCP.Server
import Network.MCP.Types (CallToolResult (CallToolResult))
import Network.MCP.Types qualified as MCP
import Unison.Prelude
import UnliftIO qualified
import UnliftIO.Environment (lookupEnv)

type StaticResources = Map Text (MCP.Resource, MCP.ResourceContent)

class HasInputSchema arg where
  toInputSchema :: Proxy arg -> Aeson.Value

instance HasInputSchema () where
  toInputSchema :: Proxy () -> Value
toInputSchema Proxy ()
_ =
    [Pair] -> Value
Aeson.object
      [ (Key
"type", Text -> Value
Aeson.String Text
"object"),
        (Key
"properties", [Pair] -> Value
Aeson.object []),
        (Key
"required", Array -> Value
Aeson.Array Array
forall a. Monoid a => a
mempty)
      ]

data Tool m
  = forall arg.
  (FromJSON arg, HasInputSchema arg) =>
  Tool
  { forall (m :: * -> *). Tool m -> Text
toolName :: Text,
    forall (m :: * -> *). Tool m -> Text
toolDescription :: Text,
    forall (m :: * -> *). Tool m -> ToolAnnotations
toolAnnotations :: MCP.ToolAnnotations,
    ()
toolArgType :: Proxy arg,
    ()
toolHandler :: arg -> m MCP.CallToolResult
  }

data Prompt m = Prompt
  { forall (m :: * -> *). Prompt m -> Text
promptName :: Text,
    forall (m :: * -> *). Prompt m -> Text
promptDescription :: Text,
    forall (m :: * -> *). Prompt m -> Map Text PromptArgument
promptArgs :: Map Text PromptArgument,
    forall (m :: * -> *).
Prompt m -> Map Text Text -> m GetPromptResult
promptHandler :: Map Text Text -> m MCP.GetPromptResult
  }

data PromptArgument = PromptArgument
  { PromptArgument -> Text
promptArgumentDescription :: Text,
    -- | Whether the argument is required
    PromptArgument -> Bool
promptArgumentRequired :: Bool
  }

mkServer :: (MonadUnliftIO m) => MCP.ServerInfo -> Text -> StaticResources -> [Tool m] -> [Prompt m] -> m Server
mkServer :: forall (m :: * -> *).
MonadUnliftIO m =>
ServerInfo
-> Text -> StaticResources -> [Tool m] -> [Prompt m] -> m Server
mkServer ServerInfo
serverInfo Text
serverDescription StaticResources
staticResources [Tool m]
tools [Prompt m]
prompts = do
  let serverCapabilities :: ServerCapabilities
serverCapabilities =
        MCP.ServerCapabilities
          { resourcesCapability :: Maybe ResourcesCapability
resourcesCapability = ResourcesCapability -> Maybe ResourcesCapability
forall a. a -> Maybe a
Just (ResourcesCapability -> Maybe ResourcesCapability)
-> ResourcesCapability -> Maybe ResourcesCapability
forall a b. (a -> b) -> a -> b
$ Bool -> ResourcesCapability
MCP.ResourcesCapability (Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ StaticResources -> Bool
forall k a. Map k a -> Bool
Map.null StaticResources
staticResources),
            toolsCapability :: Maybe ToolsCapability
toolsCapability = ToolsCapability -> Maybe ToolsCapability
forall a. a -> Maybe a
Just (ToolsCapability -> Maybe ToolsCapability)
-> ToolsCapability -> Maybe ToolsCapability
forall a b. (a -> b) -> a -> b
$ Bool -> ToolsCapability
MCP.ToolsCapability (Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Tool m] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Tool m]
tools),
            promptsCapability :: Maybe PromptsCapability
promptsCapability = PromptsCapability -> Maybe PromptsCapability
forall a. a -> Maybe a
Just (PromptsCapability -> Maybe PromptsCapability)
-> PromptsCapability -> Maybe PromptsCapability
forall a b. (a -> b) -> a -> b
$ Bool -> PromptsCapability
MCP.PromptsCapability (Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Prompt m] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Prompt m]
prompts)
          }
  server <- IO Server -> m Server
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Server -> m Server) -> IO Server -> m Server
forall a b. (a -> b) -> a -> b
$ ServerInfo -> ServerCapabilities -> Text -> IO Server
createServer ServerInfo
serverInfo ServerCapabilities
serverCapabilities Text
serverDescription

  doResources server staticResources
  doTools server tools
  doPrompts server prompts

  pure server

doResources :: (MonadUnliftIO m) => Server -> StaticResources -> m ()
doResources :: forall (m :: * -> *).
MonadUnliftIO m =>
Server -> StaticResources -> m ()
doResources Server
server StaticResources
staticResources = do
  IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Server -> [Resource] -> IO ()
registerResources Server
server ((Resource, ResourceContent) -> Resource
forall a b. (a, b) -> a
fst ((Resource, ResourceContent) -> Resource)
-> [(Resource, ResourceContent)] -> [Resource]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> StaticResources -> [(Resource, ResourceContent)]
forall k a. Map k a -> [a]
Map.elems StaticResources
staticResources)

  -- Register resource read handler
  IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Server -> ResourceReadHandler -> IO ()
registerResourceReadHandler Server
server (ResourceReadHandler -> IO ()) -> ResourceReadHandler -> IO ()
forall a b. (a -> b) -> a -> b
$ \(MCP.ReadResourceRequest {Text
resourceReadUri :: Text
resourceReadUri :: ReadResourceRequest -> Text
resourceReadUri}) -> do
    case Text -> StaticResources -> Maybe (Resource, ResourceContent)
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Text
resourceReadUri StaticResources
staticResources of
      Just (Resource
_, ResourceContent
content) ->
        ReadResourceResult -> IO ReadResourceResult
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ReadResourceResult -> IO ReadResourceResult)
-> ([ResourceContent] -> ReadResourceResult)
-> [ResourceContent]
-> IO ReadResourceResult
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ResourceContent] -> ReadResourceResult
MCP.ReadResourceResult ([ResourceContent] -> IO ReadResourceResult)
-> [ResourceContent] -> IO ReadResourceResult
forall a b. (a -> b) -> a -> b
$ [ResourceContent
content]
      Maybe (Resource, ResourceContent)
_ -> ReadResourceResult -> IO ReadResourceResult
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ReadResourceResult -> IO ReadResourceResult)
-> ReadResourceResult -> IO ReadResourceResult
forall a b. (a -> b) -> a -> b
$ [ResourceContent] -> ReadResourceResult
MCP.ReadResourceResult []

-- | Default timeout for MCP tool calls in seconds
defaultMcpTimeoutSeconds :: Int
defaultMcpTimeoutSeconds :: Int
defaultMcpTimeoutSeconds = Int
60

-- | Get the MCP tool timeout in microseconds from UNISON_MCP_TIMEOUT env var.
-- The env var is specified in seconds for user convenience.
-- Defaults to 60 seconds if not set or invalid.
getMcpTimeoutMicroseconds :: (MonadIO m) => m Int
getMcpTimeoutMicroseconds :: forall (m :: * -> *). MonadIO m => m Int
getMcpTimeoutMicroseconds = IO Int -> m Int
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Int -> m Int) -> IO Int -> m Int
forall a b. (a -> b) -> a -> b
$ do
  String -> IO (Maybe String)
forall (m :: * -> *). MonadIO m => String -> m (Maybe String)
lookupEnv String
"UNISON_MCP_TIMEOUT" IO (Maybe String) -> (Maybe String -> Int) -> IO Int
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> \case
    Just String
str -> Int -> (Int -> Int) -> Maybe Int -> Int
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (Int
defaultMcpTimeoutSeconds Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
1_000_000) (Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
1_000_000) (String -> Maybe Int
forall a. Read a => String -> Maybe a
readMaybe String
str)
    Maybe String
Nothing -> Int
defaultMcpTimeoutSeconds Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
1_000_000

doTools :: (MonadUnliftIO m) => Server -> [Tool m] -> m ()
doTools :: forall (m :: * -> *). MonadUnliftIO m => Server -> [Tool m] -> m ()
doTools Server
server [Tool m]
tools = do
  runInIO <- m (m CallToolResult -> IO CallToolResult)
forall (m :: * -> *) a. MonadUnliftIO m => m (m a -> IO a)
askRunInIO
  timeoutMicros <- getMcpTimeoutMicroseconds
  let timeoutSeconds = Int
timeoutMicros Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
1_000_000
  let toolMap = [(Text, Tool m)] -> Map Text (Tool m)
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList ([Tool m]
tools [Tool m] -> (Tool m -> (Text, Tool m)) -> [(Text, Tool m)]
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (\Tool m
tool -> (Tool m -> Text
forall (m :: * -> *). Tool m -> Text
toolName Tool m
tool, Tool m
tool)))
  let mcpTools =
        [Tool m]
tools [Tool m] -> (Tool m -> Tool) -> [Tool]
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> \(Tool {Text
toolName :: forall (m :: * -> *). Tool m -> Text
toolName :: Text
toolName, Text
toolDescription :: forall (m :: * -> *). Tool m -> Text
toolDescription :: Text
toolDescription, ToolAnnotations
toolAnnotations :: forall (m :: * -> *). Tool m -> ToolAnnotations
toolAnnotations :: ToolAnnotations
toolAnnotations, Proxy arg
toolArgType :: ()
toolArgType :: Proxy arg
toolArgType}) ->
          MCP.Tool
            { Text
toolName :: Text
toolName :: Text
MCP.toolName,
              toolDescription :: Maybe Text
MCP.toolDescription = Text -> Maybe Text
forall a. a -> Maybe a
Just Text
toolDescription,
              toolInputSchema :: Value
MCP.toolInputSchema = Proxy arg -> Value
forall {k} (arg :: k). HasInputSchema arg => Proxy arg -> Value
toInputSchema Proxy arg
toolArgType,
              toolAnnotations :: Maybe ToolAnnotations
MCP.toolAnnotations = ToolAnnotations -> Maybe ToolAnnotations
forall a. a -> Maybe a
Just ToolAnnotations
toolAnnotations
            }
  liftIO $ registerTools server mcpTools
  liftIO $ registerToolCallHandler server \(MCP.CallToolRequest {Text
callToolName :: Text
callToolName :: CallToolRequest -> Text
callToolName, Value
callToolArguments :: Value
callToolArguments :: CallToolRequest -> Value
callToolArguments}) -> m CallToolResult -> IO CallToolResult
runInIO (m CallToolResult -> IO CallToolResult)
-> m CallToolResult -> IO CallToolResult
forall a b. (a -> b) -> a -> b
$ do
    case Text -> Map Text (Tool m) -> Maybe (Tool m)
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Text
callToolName Map Text (Tool m)
toolMap of
      Just Tool {arg -> m CallToolResult
toolHandler :: ()
toolHandler :: arg -> m CallToolResult
toolHandler} -> do
        case Value -> Result arg
forall a. FromJSON a => Value -> Result a
Aeson.fromJSON Value
callToolArguments of
          Aeson.Success arg
arg ->
            Int -> m CallToolResult -> m (Maybe CallToolResult)
forall (m :: * -> *) a.
MonadUnliftIO m =>
Int -> m a -> m (Maybe a)
UnliftIO.timeout Int
timeoutMicros (arg -> m CallToolResult
toolHandler arg
arg) m (Maybe CallToolResult)
-> (Maybe CallToolResult -> m CallToolResult) -> m CallToolResult
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
              Maybe CallToolResult
Nothing -> CallToolResult -> m CallToolResult
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (CallToolResult -> m CallToolResult)
-> CallToolResult -> m CallToolResult
forall a b. (a -> b) -> a -> b
$ Text -> CallToolResult
errorToolResult (Text -> CallToolResult) -> Text -> CallToolResult
forall a b. (a -> b) -> a -> b
$ Text
"Tool '" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
callToolName Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"' timed out after " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
Text.pack (Int -> String
forall a. Show a => a -> String
show Int
timeoutSeconds) Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" seconds."
              Just CallToolResult
result -> CallToolResult -> m CallToolResult
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure CallToolResult
result
          Aeson.Error String
err -> CallToolResult -> m CallToolResult
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (CallToolResult -> m CallToolResult)
-> CallToolResult -> m CallToolResult
forall a b. (a -> b) -> a -> b
$ Text -> CallToolResult
errorToolResult (Text -> CallToolResult) -> Text -> CallToolResult
forall a b. (a -> b) -> a -> b
$ Text
"Failed to parse arguments for tool '" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
callToolName Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"': " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
Text.pack String
err
      Maybe (Tool m)
Nothing -> CallToolResult -> m CallToolResult
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (CallToolResult -> m CallToolResult)
-> CallToolResult -> m CallToolResult
forall a b. (a -> b) -> a -> b
$ Text -> CallToolResult
errorToolResult (Text -> CallToolResult) -> Text -> CallToolResult
forall a b. (a -> b) -> a -> b
$ Text
"Tool '" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
callToolName Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"' not found."

errorToolResult :: Text -> MCP.CallToolResult
errorToolResult :: Text -> CallToolResult
errorToolResult Text
errMsg =
  MCP.CallToolResult
    { callToolContent :: [ToolContent]
MCP.callToolContent = [ToolContentType -> Maybe Text -> ToolContent
MCP.ToolContent ToolContentType
MCP.TextualContent (Maybe Text -> ToolContent) -> Maybe Text -> ToolContent
forall a b. (a -> b) -> a -> b
$ Text -> Maybe Text
forall a. a -> Maybe a
Just Text
errMsg],
      callToolIsError :: Bool
MCP.callToolIsError = Bool
True
    }

doPrompts :: (MonadUnliftIO m) => Server -> [Prompt m] -> m ()
doPrompts :: forall (m :: * -> *).
MonadUnliftIO m =>
Server -> [Prompt m] -> m ()
doPrompts Server
server [Prompt m]
prompts = do
  let mcpPrompts :: [Prompt]
mcpPrompts =
        [Prompt m]
prompts [Prompt m] -> (Prompt m -> Prompt) -> [Prompt]
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> \(Prompt {Text
promptName :: forall (m :: * -> *). Prompt m -> Text
promptName :: Text
promptName, Text
promptDescription :: forall (m :: * -> *). Prompt m -> Text
promptDescription :: Text
promptDescription, Map Text PromptArgument
promptArgs :: forall (m :: * -> *). Prompt m -> Map Text PromptArgument
promptArgs :: Map Text PromptArgument
promptArgs}) ->
          MCP.Prompt
            { Text
promptName :: Text
promptName :: Text
MCP.promptName,
              promptDescription :: Maybe Text
MCP.promptDescription = Text -> Maybe Text
forall a. a -> Maybe a
Just Text
promptDescription,
              promptArguments :: [PromptArgument]
MCP.promptArguments =
                Map Text PromptArgument
promptArgs
                  Map Text PromptArgument
-> (Map Text PromptArgument -> [(Text, PromptArgument)])
-> [(Text, PromptArgument)]
forall a b. a -> (a -> b) -> b
& Map Text PromptArgument -> [(Text, PromptArgument)]
forall k a. Map k a -> [(k, a)]
Map.toList
                  [(Text, PromptArgument)]
-> ((Text, PromptArgument) -> PromptArgument) -> [PromptArgument]
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> \(Text
argName, PromptArgument {Text
promptArgumentDescription :: PromptArgument -> Text
promptArgumentDescription :: Text
promptArgumentDescription, Bool
promptArgumentRequired :: PromptArgument -> Bool
promptArgumentRequired :: Bool
promptArgumentRequired}) ->
                    MCP.PromptArgument
                      { promptArgumentName :: Text
MCP.promptArgumentName = Text
argName,
                        promptArgumentDescription :: Maybe Text
MCP.promptArgumentDescription = Text -> Maybe Text
forall a. a -> Maybe a
Just Text
promptArgumentDescription,
                        promptArgumentRequired :: Bool
MCP.promptArgumentRequired = Bool
promptArgumentRequired
                      }
            }
  let promptsMap :: Map Text (Prompt m)
promptsMap = [(Text, Prompt m)] -> Map Text (Prompt m)
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList ([(Text, Prompt m)] -> Map Text (Prompt m))
-> [(Text, Prompt m)] -> Map Text (Prompt m)
forall a b. (a -> b) -> a -> b
$ [Prompt m]
prompts [Prompt m] -> (Prompt m -> (Text, Prompt m)) -> [(Text, Prompt m)]
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (\Prompt m
p -> (Prompt m -> Text
forall (m :: * -> *). Prompt m -> Text
promptName Prompt m
p, Prompt m
p))
  IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Server -> [Prompt] -> IO ()
registerPrompts Server
server [Prompt]
mcpPrompts
  runInIO <- m (m GetPromptResult -> IO GetPromptResult)
forall (m :: * -> *) a. MonadUnliftIO m => m (m a -> IO a)
askRunInIO
  liftIO $ registerPromptHandler server $ \(MCP.GetPromptRequest {Text
getPromptName :: Text
getPromptName :: GetPromptRequest -> Text
getPromptName, Map Text Text
getPromptArguments :: Map Text Text
getPromptArguments :: GetPromptRequest -> Map Text Text
getPromptArguments}) -> m GetPromptResult -> IO GetPromptResult
runInIO do
    case Text -> Map Text (Prompt m) -> Maybe (Prompt m)
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Text
getPromptName Map Text (Prompt m)
promptsMap of
      Maybe (Prompt m)
Nothing -> String -> m GetPromptResult
forall a. HasCallStack => String -> a
error (String -> m GetPromptResult) -> String -> m GetPromptResult
forall a b. (a -> b) -> a -> b
$ String
"Prompt '" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Text -> String
Text.unpack Text
getPromptName String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"' not found."
      Just (Prompt {Map Text Text -> m GetPromptResult
promptHandler :: forall (m :: * -> *).
Prompt m -> Map Text Text -> m GetPromptResult
promptHandler :: Map Text Text -> m GetPromptResult
promptHandler}) -> do
        Map Text Text -> m GetPromptResult
promptHandler Map Text Text
getPromptArguments

textToolResult :: Text -> MCP.CallToolResult
textToolResult :: Text -> CallToolResult
textToolResult Text
msg =
  MCP.CallToolResult
    { callToolContent :: [ToolContent]
MCP.callToolContent = [ToolContentType -> Maybe Text -> ToolContent
MCP.ToolContent ToolContentType
MCP.TextualContent (Maybe Text -> ToolContent) -> Maybe Text -> ToolContent
forall a b. (a -> b) -> a -> b
$ Text -> Maybe Text
forall a. a -> Maybe a
Just Text
msg],
      callToolIsError :: Bool
MCP.callToolIsError = Bool
False
    }

jsonToolResult :: (Aeson.ToJSON a) => a -> MCP.CallToolResult
jsonToolResult :: forall a. ToJSON a => a -> CallToolResult
jsonToolResult a
msg = Text -> CallToolResult
textToolResult (Text -> CallToolResult) -> Text -> CallToolResult
forall a b. (a -> b) -> a -> b
$ String -> Text
Text.pack (String -> Text) -> String -> Text
forall a b. (a -> b) -> a -> b
$ ByteString -> String
BL.unpack (ByteString -> String) -> ByteString -> String
forall a b. (a -> b) -> a -> b
$ a -> ByteString
forall a. ToJSON a => a -> ByteString
Aeson.encode a
msg