{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Network.MCP.Transport.StdIO
  ( STDIOTransport (..),
    newSTDIOTransport,
  )
where

import Control.Exception (SomeException, handle)
import Data.Aeson
import Data.ByteString.Char8 qualified as BS8
import Data.ByteString.Lazy qualified as BL
import Network.MCP.Transport.Types
import System.IO

-- | STDIO implementation of the Transport interface
data STDIOTransport = STDIOTransport
  { STDIOTransport -> Handle
stdinHandle :: Handle,
    STDIOTransport -> Handle
stdoutHandle :: Handle,
    STDIOTransport -> Handle
stderrHandle :: Handle
  }

-- | Create a new STDIO transport with the given message handler
newSTDIOTransport ::
  IO STDIOTransport
newSTDIOTransport :: IO STDIOTransport
newSTDIOTransport = do
  -- Configure handles for better performance
  Handle -> BufferMode -> IO ()
hSetBuffering Handle
stdin BufferMode
LineBuffering
  Handle -> BufferMode -> IO ()
hSetBuffering Handle
stdout BufferMode
LineBuffering
  Handle -> TextEncoding -> IO ()
hSetEncoding Handle
stdin TextEncoding
utf8
  Handle -> TextEncoding -> IO ()
hSetEncoding Handle
stdout TextEncoding
utf8

  return $
    STDIOTransport
      { $sel:stdinHandle:STDIOTransport :: Handle
stdinHandle = Handle
stdin,
        $sel:stdoutHandle:STDIOTransport :: Handle
stdoutHandle = Handle
stdout,
        $sel:stderrHandle:STDIOTransport :: Handle
stderrHandle = Handle
stderr
      }

-- | A simple, synchronous transport implementation using standard input/output.
instance Transport STDIOTransport where
  handleMessages :: STDIOTransport -> (Message -> IO (Maybe Message)) -> IO ()
handleMessages (STDIOTransport {Handle
$sel:stdinHandle:STDIOTransport :: STDIOTransport -> Handle
stdinHandle :: Handle
stdinHandle, Handle
$sel:stderrHandle:STDIOTransport :: STDIOTransport -> Handle
stderrHandle :: Handle
stderrHandle, Handle
$sel:stdoutHandle:STDIOTransport :: STDIOTransport -> Handle
stdoutHandle :: Handle
stdoutHandle}) Message -> IO (Maybe Message)
handler = do
    Handle -> BufferMode -> IO ()
hSetBuffering Handle
stdinHandle BufferMode
LineBuffering
    Handle -> BufferMode -> IO ()
hSetBuffering Handle
stdoutHandle BufferMode
LineBuffering
    Handle -> TextEncoding -> IO ()
hSetEncoding Handle
stdinHandle TextEncoding
utf8
    Handle -> TextEncoding -> IO ()
hSetEncoding Handle
stdoutHandle TextEncoding
utf8
    IO ()
loop
    where
      loop :: IO ()
loop = do
        (SomeException -> IO ()) -> IO () -> IO ()
forall e a. Exception e => (e -> IO a) -> IO a -> IO a
handle SomeException -> IO ()
handleErr (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
          IO (Maybe Message)
readMessage IO (Maybe Message) -> (Maybe Message -> IO ()) -> IO ()
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
            Maybe Message
Nothing ->
              -- Transport is closed.
              () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
            Just Message
msg -> do
              Message -> IO (Maybe Message)
handler Message
msg IO (Maybe Message) -> (Maybe Message -> IO ()) -> IO ()
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
                Maybe Message
Nothing -> IO ()
loop
                Just Message
response -> do
                  Message -> IO ()
sendMessage Message
response
                  IO ()
loop
      readMessage :: IO (Maybe Message)
      readMessage :: IO (Maybe Message)
readMessage = do
        Handle -> IO Bool
hIsEOF Handle
stdinHandle IO Bool -> (Bool -> IO (Maybe Message)) -> IO (Maybe Message)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
          Bool
True -> do
            -- If EOF is reached, return Nothing to signal termination
            Maybe Message -> IO (Maybe Message)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe Message
forall a. Maybe a
Nothing
          Bool
False -> do
            ByteString
line <- Handle -> IO ByteString
BS8.hGetLine Handle
stdinHandle
            case ByteString -> Either [Char] Message
forall a. FromJSON a => ByteString -> Either [Char] a
eitherDecode (ByteString -> ByteString
BS8.fromStrict ByteString
line) of
              Left [Char]
err -> do
                -- On parse error, log and try again with a default error message
                Handle -> [Char] -> IO ()
hPutStrLn Handle
stderrHandle ([Char]
"JSON decode error: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
err)
                -- Try again on the next line.
                IO (Maybe Message)
readMessage
              Right Message
msg -> Maybe Message -> IO (Maybe Message)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Message -> IO (Maybe Message))
-> (Message -> Maybe Message) -> Message -> IO (Maybe Message)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Message -> Maybe Message
forall a. a -> Maybe a
Just (Message -> IO (Maybe Message)) -> Message -> IO (Maybe Message)
forall a b. (a -> b) -> a -> b
$ Message
msg

      -- Send a message through the transport
      sendMessage :: Message -> IO ()
sendMessage Message
msg = do
        Handle -> ByteString -> IO ()
BL.hPut Handle
stdoutHandle (Message -> ByteString
forall a. ToJSON a => a -> ByteString
encode Message
msg)
        Handle -> ByteString -> IO ()
BL.hPut Handle
stdoutHandle ByteString
"\n"
        Handle -> IO ()
hFlush Handle
stdoutHandle

      handleErr :: SomeException -> IO ()
      handleErr :: SomeException -> IO ()
handleErr SomeException
err = do
        Handle -> [Char] -> IO ()
hPutStrLn Handle
stderrHandle ([Char] -> IO ()) -> [Char] -> IO ()
forall a b. (a -> b) -> a -> b
$ [Char]
"Error reading message: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ SomeException -> [Char]
forall a. Show a => a -> [Char]
show SomeException
err
        IO ()
loop