module Unison.MCP.Cli
( handleInputMCP,
ppForProjectContext,
cliToMCP,
virtualSourceName,
)
where
import Control.Monad.Except (ExceptT (..), throwError)
import Control.Monad.Reader
import Crypto.Random qualified as Random
import Data.Aeson
import Data.IORef
import Data.Sequence qualified as Seq
import Data.Text.IO qualified as Text
import GHC.IO.Handle (hDuplicate, hDuplicateTo)
import U.Codebase.Sqlite.Queries qualified as Queries
import Unison.Auth.CredentialManager qualified as AuthN
import Unison.Auth.HTTPClient qualified as AuthN
import Unison.Auth.Tokens qualified as AuthN
import Unison.Cli.Monad qualified as Cli
import Unison.Codebase qualified as Codebase
import Unison.Codebase.Editor.HandleInput qualified as HandleInput
import Unison.Codebase.Editor.Input (Event, Input)
import Unison.Codebase.Editor.Output qualified as Output
import Unison.Codebase.Path qualified as Path
import Unison.Codebase.ProjectPath qualified as PP
import Unison.CommandLine (defaultLoadSourceFile, defaultWriteSourceFile)
import Unison.CommandLine.OutputMessages qualified as Output
import Unison.MCP.Types
import Unison.MCP.Types qualified as MCP
import Unison.Prelude
import Unison.Sqlite (Transaction)
import Unison.Syntax.Parser qualified as Parser
import Unison.Util.Pretty qualified as Pretty
import UnliftIO qualified
import UnliftIO.IO qualified as IO
import UnliftIO.STM
import UnliftIO.Temporary (withSystemTempFile)
import Prelude hiding (readFile, writeFile)
virtualSourceName :: Text
virtualSourceName :: Text
virtualSourceName = Text
"<mcp-virtual-source>"
data CliOutput = CliOutput
{ CliOutput -> [Text]
sourceCodeUpdates :: [Text],
CliOutput -> Text
stdout :: Text,
CliOutput -> Text
stderr :: Text,
CliOutput -> [Text]
outputMessages :: [Text],
CliOutput -> [Text]
errorMessages :: [Text]
}
deriving (CliOutput -> CliOutput -> Bool
(CliOutput -> CliOutput -> Bool)
-> (CliOutput -> CliOutput -> Bool) -> Eq CliOutput
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: CliOutput -> CliOutput -> Bool
== :: CliOutput -> CliOutput -> Bool
$c/= :: CliOutput -> CliOutput -> Bool
/= :: CliOutput -> CliOutput -> Bool
Eq, Int -> CliOutput -> ShowS
[CliOutput] -> ShowS
CliOutput -> String
(Int -> CliOutput -> ShowS)
-> (CliOutput -> String)
-> ([CliOutput] -> ShowS)
-> Show CliOutput
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> CliOutput -> ShowS
showsPrec :: Int -> CliOutput -> ShowS
$cshow :: CliOutput -> String
show :: CliOutput -> String
$cshowList :: [CliOutput] -> ShowS
showList :: [CliOutput] -> ShowS
Show)
instance Semigroup CliOutput where
(CliOutput [Text]
scu1 Text
stdOut1 Text
stdErr1 [Text]
outMsgs1 [Text]
errMsgs1) <> :: CliOutput -> CliOutput -> CliOutput
<> (CliOutput [Text]
scu2 Text
stdOut2 Text
stdErr2 [Text]
outMsgs2 [Text]
errMsgs2) =
[Text] -> Text -> Text -> [Text] -> [Text] -> CliOutput
CliOutput
([Text]
scu1 [Text] -> [Text] -> [Text]
forall a. Semigroup a => a -> a -> a
<> [Text]
scu2)
(Text
stdOut1 Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
stdOut2)
(Text
stdErr1 Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
stdErr2)
([Text]
outMsgs1 [Text] -> [Text] -> [Text]
forall a. Semigroup a => a -> a -> a
<> [Text]
outMsgs2)
([Text]
errMsgs1 [Text] -> [Text] -> [Text]
forall a. Semigroup a => a -> a -> a
<> [Text]
errMsgs2)
instance Monoid CliOutput where
mempty :: CliOutput
mempty = [Text] -> Text -> Text -> [Text] -> [Text] -> CliOutput
CliOutput [Text]
forall a. Monoid a => a
mempty Text
forall a. Monoid a => a
mempty Text
forall a. Monoid a => a
mempty [Text]
forall a. Monoid a => a
mempty [Text]
forall a. Monoid a => a
mempty
instance ToJSON CliOutput where
toJSON :: CliOutput -> Value
toJSON CliOutput {[Text]
sourceCodeUpdates :: CliOutput -> [Text]
sourceCodeUpdates :: [Text]
sourceCodeUpdates, [Text]
outputMessages :: CliOutput -> [Text]
outputMessages :: [Text]
outputMessages, [Text]
errorMessages :: CliOutput -> [Text]
errorMessages :: [Text]
errorMessages, Text
stdout :: CliOutput -> Text
stdout :: Text
stdout, Text
stderr :: CliOutput -> Text
stderr :: Text
stderr} =
[Pair] -> Value
object
[ Key
"sourceCodeUpdates" Key -> [Text] -> Pair
forall v. ToJSON v => Key -> v -> Pair
forall e kv v. (KeyValue e kv, ToJSON v) => Key -> v -> kv
.= [Text]
sourceCodeUpdates,
Key
"stdout" Key -> Text -> Pair
forall v. ToJSON v => Key -> v -> Pair
forall e kv v. (KeyValue e kv, ToJSON v) => Key -> v -> kv
.= Text
stdout,
Key
"stderr" Key -> Text -> Pair
forall v. ToJSON v => Key -> v -> Pair
forall e kv v. (KeyValue e kv, ToJSON v) => Key -> v -> kv
.= Text
stderr,
Key
"outputMessages" Key -> [Text] -> Pair
forall v. ToJSON v => Key -> v -> Pair
forall e kv v. (KeyValue e kv, ToJSON v) => Key -> v -> kv
.= [Text]
outputMessages,
Key
"errorMessages" Key -> [Text] -> Pair
forall v. ToJSON v => Key -> v -> Pair
forall e kv v. (KeyValue e kv, ToJSON v) => Key -> v -> kv
.= [Text]
errorMessages
]
ppForProjectContext :: ProjectContext -> ExceptT Text Transaction PP.ProjectPath
ppForProjectContext :: ProjectContext -> ExceptT Text Transaction ProjectPath
ppForProjectContext ProjectContext {ProjectName
projectName :: ProjectName
projectName :: ProjectContext -> ProjectName
projectName, ProjectBranchName
branchName :: ProjectBranchName
branchName :: ProjectContext -> ProjectBranchName
branchName} = do
project <-
Transaction (Maybe Project)
-> ExceptT Text Transaction (Maybe Project)
forall (m :: * -> *) a. Monad m => m a -> ExceptT Text m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ProjectName -> Transaction (Maybe Project)
Queries.loadProjectByName ProjectName
projectName) ExceptT Text Transaction (Maybe Project)
-> (ExceptT Text Transaction (Maybe Project)
-> ExceptT Text Transaction Project)
-> ExceptT Text Transaction Project
forall a b. a -> (a -> b) -> b
& ExceptT Text Transaction Project
-> ExceptT Text Transaction (Maybe Project)
-> ExceptT Text Transaction Project
forall (m :: * -> *) a. Monad m => m a -> m (Maybe a) -> m a
onNothingM do
Text -> ExceptT Text Transaction Project
forall a. Text -> ExceptT Text Transaction a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Text -> ExceptT Text Transaction Project)
-> Text -> ExceptT Text Transaction Project
forall a b. (a -> b) -> a -> b
$ Text
"Project not found: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> forall target source. From source target => source -> target
into @Text ProjectName
projectName
branch <-
lift (Queries.loadProjectBranchByName project.projectId branchName) >>= \case
Maybe ProjectBranch
Nothing -> Text -> ExceptT Text Transaction ProjectBranch
forall a. Text -> ExceptT Text Transaction a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Text -> ExceptT Text Transaction ProjectBranch)
-> Text -> ExceptT Text Transaction ProjectBranch
forall a b. (a -> b) -> a -> b
$ Text
"Branch not found: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> forall target source. From source target => source -> target
into @Text ProjectBranchName
branchName
Just ProjectBranch
projectBranch -> ProjectBranch -> ExceptT Text Transaction ProjectBranch
forall a. a -> ExceptT Text Transaction a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ProjectBranch
projectBranch
pure $ PP.fromProjectAndBranch (PP.ProjectAndBranch project branch) Path.Root
handleInputMCP :: ProjectContext -> [Either Event Input] -> ExceptT Text MCP CliOutput
handleInputMCP :: ProjectContext
-> [Either Event Input] -> ExceptT Text MCP CliOutput
handleInputMCP ProjectContext
projectContext [Either Event Input]
input = do
hasErroredVar <- Bool -> ExceptT Text MCP (TVar Bool)
forall (m :: * -> *) a. MonadIO m => a -> m (TVar a)
newTVarIO Bool
False
let onErr Text
_errMsg = STM () -> IO ()
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TVar Bool -> Bool -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar Bool
hasErroredVar Bool
True
result <- cliToMCP projectContext onErr do
Cli.labelE \forall void. Text -> Cli void
fail' -> do
[Either Event Input] -> (Either Event Input -> Cli ()) -> Cli ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ [Either Event Input]
input \Either Event Input
inp -> do
Either Event Input -> Cli ()
HandleInput.loop Either Event Input
inp
TVar Bool -> Cli Bool
forall (m :: * -> *) a. MonadIO m => TVar a -> m a
readTVarIO TVar Bool
hasErroredVar Cli Bool -> (Bool -> Cli ()) -> Cli ()
forall a b. Cli a -> (a -> Cli b) -> Cli b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Bool
False -> () -> Cli ()
forall a. a -> Cli a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
Bool
True -> Text -> Cli ()
forall void. Text -> Cli void
fail' Text
"An error occurred during input handling."
case result of
(Maybe (Either Text ())
Nothing, CliOutput
cliOut) -> CliOutput -> ExceptT Text MCP CliOutput
forall a. a -> ExceptT Text MCP a
forall (f :: * -> *) a. Applicative f => a -> f a
pure CliOutput
cliOut
(Just (Left Text
err), CliOutput
cliOutput) ->
CliOutput -> ExceptT Text MCP CliOutput
forall a. a -> ExceptT Text MCP a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (CliOutput -> ExceptT Text MCP CliOutput)
-> CliOutput -> ExceptT Text MCP CliOutput
forall a b. (a -> b) -> a -> b
$ CliOutput
cliOutput CliOutput -> CliOutput -> CliOutput
forall a. Semigroup a => a -> a -> a
<> CliOutput
forall a. Monoid a => a
mempty {errorMessages = [err]}
(Just (Right ()), CliOutput
cliOutput) ->
CliOutput -> ExceptT Text MCP CliOutput
forall a. a -> ExceptT Text MCP a
forall (f :: * -> *) a. Applicative f => a -> f a
pure CliOutput
cliOutput
cliToMCP :: ProjectContext -> (Text -> IO ()) -> Cli.Cli a -> ExceptT Text MCP (Maybe a, CliOutput)
cliToMCP :: forall a.
ProjectContext
-> (Text -> IO ())
-> Cli a
-> ExceptT Text MCP (Maybe a, CliOutput)
cliToMCP ProjectContext
projCtx Text -> IO ()
onError Cli a
cli = do
MCP.Env {ucmVersion, codebase, runtime, workDir} <- ExceptT Text MCP Env
forall r (m :: * -> *). MonadReader r m => m r
ask
initialPP <- ExceptT . liftIO $ Codebase.runTransactionExceptT codebase $ do
ppForProjectContext projCtx
let credMan = CredentialManager
AuthN.globalCredentialManager
let tokenProvider :: AuthN.TokenProvider
tokenProvider = CredentialManager -> TokenProvider
AuthN.newTokenProvider CredentialManager
credMan
authenticatedHTTPClient <- AuthN.newAuthenticatedHTTPClient tokenProvider ucmVersion
outputVar <- newTVarIO Seq.empty
errorsVar <- newTVarIO Seq.empty
sourceCodeUpdatesVar <- newTVarIO Seq.empty
let notify Output
output = do
pretty <- Maybe String -> (Word -> IO Pretty) -> Output -> IO Pretty
Output.notifyUser Maybe String
workDir Word -> IO Pretty
Output.fetchIssueFromGitHub Output
output
atomically $ modifyTVar' outputVar (<> Seq.singleton pretty)
when (Output.isFailure output) do
atomically $ modifyTVar' errorsVar (<> Seq.singleton pretty)
liftIO $ onError (Pretty.toPlain 0 pretty)
let notifyNumbered NumberedOutput
output = do
let (Pretty
pretty, NumberedArgs
nargs) = NumberedOutput -> (Pretty, NumberedArgs)
Output.notifyNumbered NumberedOutput
output
STM () -> IO ()
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TVar (Seq Pretty) -> (Seq Pretty -> Seq Pretty) -> STM ()
forall a. TVar a -> (a -> a) -> STM ()
modifyTVar' TVar (Seq Pretty)
outputVar (Seq Pretty -> Seq Pretty -> Seq Pretty
forall a. Semigroup a => a -> a -> a
<> Pretty -> Seq Pretty
forall a. a -> Seq a
Seq.singleton Pretty
pretty)
NumberedArgs -> IO NumberedArgs
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure NumberedArgs
nargs
let writeSource Text
sourceName Text
content Bool
replace = do
if Text
sourceName Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== Text
virtualSourceName
then
if Bool
replace
then do
STM () -> IO ()
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TVar (Seq Text) -> Seq Text -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar (Seq Text)
sourceCodeUpdatesVar (Text -> Seq Text
forall a. a -> Seq a
Seq.singleton Text
content)
else do
STM () -> IO ()
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TVar (Seq Text) -> (Seq Text -> Seq Text) -> STM ()
forall a. TVar a -> (a -> a) -> STM ()
modifyTVar' TVar (Seq Text)
sourceCodeUpdatesVar (Seq Text -> Seq Text -> Seq Text
forall a. Semigroup a => a -> a -> a
<> Text -> Seq Text
forall a. a -> Seq a
Seq.singleton Text
content)
else do
Text -> Text -> Bool -> IO ()
defaultWriteSourceFile Text
sourceName Text
content Bool
replace
seedRef <- liftIO $ newIORef (0 :: Int)
let cliEnv =
Cli.Env
{ authHTTPClient :: AuthenticatedHttpClient
authHTTPClient = AuthenticatedHttpClient
authenticatedHTTPClient,
Codebase IO Symbol Ann
codebase :: Codebase IO Symbol Ann
codebase :: Codebase IO Symbol Ann
codebase,
credentialManager :: CredentialManager
credentialManager = CredentialManager
credMan,
generateUniqueName :: IO UniqueName
generateUniqueName = do
i <- IORef Int -> (Int -> (Int, Int)) -> IO Int
forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef' IORef Int
seedRef \Int
i -> let !i' :: Int
i' = Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 in (Int
i', Int
i)
pure (Parser.uniqueBase32Namegen (Random.drgNewSeed (Random.seedFromInteger (fromIntegral i)))),
loadSource :: Text -> IO LoadSourceResult
loadSource = Text -> IO LoadSourceResult
defaultLoadSourceFile,
lspCheckForChanges :: ProjectPathIds -> IO ()
lspCheckForChanges = \ProjectPathIds
_ -> () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (),
Text -> Text -> Bool -> IO ()
writeSource :: Text -> Text -> Bool -> IO ()
writeSource :: Text -> Text -> Bool -> IO ()
writeSource,
Output -> IO ()
notify :: Output -> IO ()
notify :: Output -> IO ()
notify,
NumberedOutput -> IO NumberedArgs
notifyNumbered :: NumberedOutput -> IO NumberedArgs
notifyNumbered :: NumberedOutput -> IO NumberedArgs
notifyNumbered,
Runtime Symbol
runtime :: Runtime Symbol
runtime :: Runtime Symbol
runtime,
sandboxedRuntime :: Runtime Symbol
sandboxedRuntime = String -> Runtime Symbol
forall a. HasCallStack => String -> a
error String
"Sandboxed runtime not implemented in MCP Server",
serverBaseUrl :: Maybe BaseUrl
serverBaseUrl = Maybe BaseUrl
forall a. Maybe a
Nothing,
Text
ucmVersion :: Text
ucmVersion :: Text
ucmVersion,
isTranscriptTest :: Bool
isTranscriptTest = Bool
False,
watchState :: Maybe WatchState
watchState = Maybe WatchState
forall a. Maybe a
Nothing
}
let startState = (ProjectPathIds -> LoopState
Cli.loopState0 (ProjectPath -> ProjectPathIds
PP.toIds ProjectPath
initialPP))
(stdout, stderr, (cliResult, _loopState)) <- liftIO $ do
captureHandles (Cli.runCli cliEnv startState cli)
cliOut <- atomically $ do
msgs <- readTVar outputVar
errs <- readTVar errorsVar
sourceCodeUpdates <- toList <$> readTVar sourceCodeUpdatesVar
let outputMessages =
Seq Pretty
msgs
Seq Pretty -> (Seq Pretty -> Seq Text) -> Seq Text
forall a b. a -> (a -> b) -> b
& (Pretty -> Text) -> Seq Pretty -> Seq Text
forall a b. (a -> b) -> Seq a -> Seq b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Width -> Pretty -> Text
Pretty.toPlain Width
0)
Seq Text -> (Seq Text -> [Text]) -> [Text]
forall a b. a -> (a -> b) -> b
& Seq Text -> [Text]
forall a. Seq a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList
let errorMessages =
Seq Pretty
errs
Seq Pretty -> (Seq Pretty -> Seq Text) -> Seq Text
forall a b. a -> (a -> b) -> b
& (Pretty -> Text) -> Seq Pretty -> Seq Text
forall a b. (a -> b) -> Seq a -> Seq b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Width -> Pretty -> Text
Pretty.toPlain Width
0)
Seq Text -> (Seq Text -> [Text]) -> [Text]
forall a b. a -> (a -> b) -> b
& Seq Text -> [Text]
forall a. Seq a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList
pure $
( CliOutput
{ sourceCodeUpdates,
stdout,
stderr,
outputMessages,
errorMessages
}
)
case cliResult of
ReturnType a
Cli.Continue -> (Maybe a, CliOutput) -> ExceptT Text MCP (Maybe a, CliOutput)
forall a. a -> ExceptT Text MCP a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe a
forall a. Maybe a
Nothing, CliOutput
cliOut)
ReturnType a
Cli.HaltRepl -> (Maybe a, CliOutput) -> ExceptT Text MCP (Maybe a, CliOutput)
forall a. a -> ExceptT Text MCP a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe a
forall a. Maybe a
Nothing, CliOutput
cliOut)
Cli.Success a
a -> (Maybe a, CliOutput) -> ExceptT Text MCP (Maybe a, CliOutput)
forall a. a -> ExceptT Text MCP a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a -> Maybe a
forall a. a -> Maybe a
Just a
a, CliOutput
cliOut)
captureHandles :: IO a -> IO (Text, Text, a)
captureHandles :: forall a. IO a -> IO (Text, Text, a)
captureHandles IO a
action = do
String
-> (String -> Handle -> IO (Text, Text, a)) -> IO (Text, Text, a)
forall (m :: * -> *) a.
MonadUnliftIO m =>
String -> (String -> Handle -> m a) -> m a
withSystemTempFile String
"stdout.txt" ((String -> Handle -> IO (Text, Text, a)) -> IO (Text, Text, a))
-> (String -> Handle -> IO (Text, Text, a)) -> IO (Text, Text, a)
forall a b. (a -> b) -> a -> b
$ \String
stdoutPath Handle
stdoutHandle -> do
String
-> (String -> Handle -> IO (Text, Text, a)) -> IO (Text, Text, a)
forall (m :: * -> *) a.
MonadUnliftIO m =>
String -> (String -> Handle -> m a) -> m a
withSystemTempFile String
"stderr.txt" ((String -> Handle -> IO (Text, Text, a)) -> IO (Text, Text, a))
-> (String -> Handle -> IO (Text, Text, a)) -> IO (Text, Text, a)
forall a b. (a -> b) -> a -> b
$ \String
stderrPath Handle
stderrHandle -> do
a <-
IO (Handle, Handle)
-> ((Handle, Handle) -> IO ())
-> ((Handle, Handle) -> IO a)
-> IO a
forall (m :: * -> *) a b c.
MonadUnliftIO m =>
m a -> (a -> m b) -> (a -> m c) -> m c
UnliftIO.bracket
( do
oldStdout <- Handle -> IO Handle
hDuplicate Handle
IO.stdout
hDuplicateTo stdoutHandle IO.stdout
oldStderr <- hDuplicate IO.stderr
hDuplicateTo stderrHandle IO.stderr
pure (oldStdout, oldStderr)
)
( \(Handle
oldStdout, Handle
oldStderr) -> do
Handle -> Handle -> IO ()
hDuplicateTo Handle
oldStdout Handle
IO.stdout
Handle -> IO ()
forall (m :: * -> *). MonadIO m => Handle -> m ()
IO.hClose Handle
oldStdout
Handle -> Handle -> IO ()
hDuplicateTo Handle
oldStderr Handle
IO.stderr
Handle -> IO ()
forall (m :: * -> *). MonadIO m => Handle -> m ()
IO.hClose Handle
oldStderr
)
( \(Handle, Handle)
_ -> do
action
)
IO.hClose stdoutHandle
output <- Text.readFile stdoutPath
IO.hClose stderrHandle
errOutput <- Text.readFile stderrPath
pure (output, errOutput, a)