{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE TypeOperators #-}

module Unison.LSP.HandlerUtils where

import Control.Lens
import Control.Monad.Reader
import Data.Map qualified as Map
import Language.LSP.Protocol.Lens as LSP
import Language.LSP.Protocol.Message qualified as Msg
import Language.LSP.Protocol.Types
import Unison.Debug qualified as Debug
import Unison.LSP.Types
import Unison.Prelude
import UnliftIO (race_)
import UnliftIO.Concurrent (forkIO)
import UnliftIO.Exception (finally)
import UnliftIO.MVar
import UnliftIO.STM
import UnliftIO.Timeout (timeout)

-- | Cancels an in-flight request
cancelRequest :: (Int32 |? Text) -> Lsp ()
cancelRequest :: (Int32 |? Text) -> Lsp ()
cancelRequest Int32 |? Text
lspId = do
  TVar (Map (Int32 |? Text) (IO ()))
cancelMapVar <- (Env -> TVar (Map (Int32 |? Text) (IO ())))
-> Lsp (TVar (Map (Int32 |? Text) (IO ())))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env -> TVar (Map (Int32 |? Text) (IO ()))
cancellationMapVar
  IO ()
cancel <- STM (IO ()) -> Lsp (IO ())
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM (IO ()) -> Lsp (IO ())) -> STM (IO ()) -> Lsp (IO ())
forall a b. (a -> b) -> a -> b
$ do
    Map (Int32 |? Text) (IO ())
cancellers <- TVar (Map (Int32 |? Text) (IO ()))
-> STM (Map (Int32 |? Text) (IO ()))
forall a. TVar a -> STM a
readTVar TVar (Map (Int32 |? Text) (IO ()))
cancelMapVar
    let (Maybe (IO ())
mayCancel, Map (Int32 |? Text) (IO ())
newMap) = ((Int32 |? Text) -> IO () -> Maybe (IO ()))
-> (Int32 |? Text)
-> Map (Int32 |? Text) (IO ())
-> (Maybe (IO ()), Map (Int32 |? Text) (IO ()))
forall k a.
Ord k =>
(k -> a -> Maybe a) -> k -> Map k a -> (Maybe a, Map k a)
Map.updateLookupWithKey (\Int32 |? Text
_k IO ()
_io -> Maybe (IO ())
forall a. Maybe a
Nothing) Int32 |? Text
lspId Map (Int32 |? Text) (IO ())
cancellers
    case Maybe (IO ())
mayCancel of
      Maybe (IO ())
Nothing -> IO () -> STM (IO ())
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (() -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ())
      Just IO ()
cancel -> do
        TVar (Map (Int32 |? Text) (IO ()))
-> Map (Int32 |? Text) (IO ()) -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar (Map (Int32 |? Text) (IO ()))
cancelMapVar Map (Int32 |? Text) (IO ())
newMap
        pure IO ()
cancel
  IO () -> Lsp ()
forall a. IO a -> Lsp a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO ()
cancel

withDebugging ::
  (Show (Msg.TRequestMessage message), Show (Msg.MessageResult message)) =>
  (Msg.TRequestMessage message -> (Either Msg.ResponseError (Msg.MessageResult message) -> Lsp ()) -> Lsp ()) ->
  Msg.TRequestMessage message ->
  (Either Msg.ResponseError (Msg.MessageResult message) -> Lsp ()) ->
  Lsp ()
withDebugging :: forall {f :: MessageDirection} (message :: Method f 'Request).
(Show (TRequestMessage message), Show (MessageResult message)) =>
(TRequestMessage message
 -> (Either ResponseError (MessageResult message) -> Lsp ())
 -> Lsp ())
-> TRequestMessage message
-> (Either ResponseError (MessageResult message) -> Lsp ())
-> Lsp ()
withDebugging TRequestMessage message
-> (Either ResponseError (MessageResult message) -> Lsp ())
-> Lsp ()
handler TRequestMessage message
message Either ResponseError (MessageResult message) -> Lsp ()
respond = do
  DebugFlag -> String -> TRequestMessage message -> Lsp ()
forall a (m :: * -> *).
(Show a, Monad m) =>
DebugFlag -> String -> a -> m ()
Debug.debugM DebugFlag
Debug.LSP String
"Request" TRequestMessage message
message
  TRequestMessage message
-> (Either ResponseError (MessageResult message) -> Lsp ())
-> Lsp ()
handler TRequestMessage message
message \Either ResponseError (MessageResult message)
response -> do
    DebugFlag
-> String -> Either ResponseError (MessageResult message) -> Lsp ()
forall a (m :: * -> *).
(Show a, Monad m) =>
DebugFlag -> String -> a -> m ()
Debug.debugM DebugFlag
Debug.LSP String
"Response" Either ResponseError (MessageResult message)
response
    Either ResponseError (MessageResult message) -> Lsp ()
respond Either ResponseError (MessageResult message)
response

-- | Handler middleware to add the ability for the client to cancel long-running in-flight requests.
withCancellation ::
  forall message.
  Maybe Int ->
  (Msg.TRequestMessage message -> (Either Msg.ResponseError (Msg.MessageResult message) -> Lsp ()) -> Lsp ()) ->
  Msg.TRequestMessage message ->
  (Either Msg.ResponseError (Msg.MessageResult message) -> Lsp ()) ->
  Lsp ()
withCancellation :: forall {f :: MessageDirection} (message :: Method f 'Request).
Maybe Int
-> (TRequestMessage message
    -> (Either ResponseError (MessageResult message) -> Lsp ())
    -> Lsp ())
-> TRequestMessage message
-> (Either ResponseError (MessageResult message) -> Lsp ())
-> Lsp ()
withCancellation Maybe Int
mayTimeoutMillis TRequestMessage message
-> (Either ResponseError (MessageResult message) -> Lsp ())
-> Lsp ()
handler TRequestMessage message
message Either ResponseError (MessageResult message) -> Lsp ()
respond = do
  let reqId :: Int32 |? Text
reqId = case TRequestMessage message
message TRequestMessage message
-> Getting
     (LspId message) (TRequestMessage message) (LspId message)
-> LspId message
forall s a. s -> Getting a s a -> a
^. Getting (LspId message) (TRequestMessage message) (LspId message)
forall s a. HasId s a => Lens' s a
Lens' (TRequestMessage message) (LspId message)
LSP.id of
        Msg.IdInt Int32
i -> Int32 -> Int32 |? Text
forall a b. a -> a |? b
InL Int32
i
        Msg.IdString Text
s -> Text -> Int32 |? Text
forall a b. b -> a |? b
InR Text
s
  -- The server itself seems to be single-threaded, so we need to fork in order to be able to
  -- process cancellation requests while still computing some other response
  Lsp ThreadId -> Lsp ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (Lsp ThreadId -> Lsp ())
-> (Lsp () -> Lsp ThreadId) -> Lsp () -> Lsp ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lsp () -> Lsp ThreadId
forall (m :: * -> *). MonadUnliftIO m => m () -> m ThreadId
forkIO (Lsp () -> Lsp ()) -> Lsp () -> Lsp ()
forall a b. (a -> b) -> a -> b
$ (Lsp () -> Lsp () -> Lsp ()) -> Lsp () -> Lsp () -> Lsp ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip Lsp () -> Lsp () -> Lsp ()
forall (m :: * -> *) a b. MonadUnliftIO m => m a -> m b -> m a
finally ((Int32 |? Text) -> Lsp ()
forall {m :: * -> *}.
(MonadReader Env m, MonadIO m) =>
(Int32 |? Text) -> m ()
removeFromMap Int32 |? Text
reqId) do
    Lsp () -> Lsp ()
withTimeout (Lsp () -> Lsp ()) -> Lsp () -> Lsp ()
forall a b. (a -> b) -> a -> b
$ Lsp () -> Lsp () -> Lsp ()
forall (m :: * -> *) a b. MonadUnliftIO m => m a -> m b -> m ()
race_ ((Int32 |? Text) -> Lsp ()
waitForCancel Int32 |? Text
reqId) (TRequestMessage message
-> (Either ResponseError (MessageResult message) -> Lsp ())
-> Lsp ()
handler TRequestMessage message
message Either ResponseError (MessageResult message) -> Lsp ()
respond)
  where
    removeFromMap :: (Int32 |? Text) -> m ()
removeFromMap Int32 |? Text
reqId = do
      TVar (Map (Int32 |? Text) (IO ()))
cancelMapVar <- (Env -> TVar (Map (Int32 |? Text) (IO ())))
-> m (TVar (Map (Int32 |? Text) (IO ())))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env -> TVar (Map (Int32 |? Text) (IO ()))
cancellationMapVar
      STM () -> m ()
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM () -> m ()) -> STM () -> m ()
forall a b. (a -> b) -> a -> b
$ TVar (Map (Int32 |? Text) (IO ()))
-> (Map (Int32 |? Text) (IO ()) -> Map (Int32 |? Text) (IO ()))
-> STM ()
forall a. TVar a -> (a -> a) -> STM ()
modifyTVar TVar (Map (Int32 |? Text) (IO ()))
cancelMapVar ((Map (Int32 |? Text) (IO ()) -> Map (Int32 |? Text) (IO ()))
 -> STM ())
-> (Map (Int32 |? Text) (IO ()) -> Map (Int32 |? Text) (IO ()))
-> STM ()
forall a b. (a -> b) -> a -> b
$ (Int32 |? Text)
-> Map (Int32 |? Text) (IO ()) -> Map (Int32 |? Text) (IO ())
forall k a. Ord k => k -> Map k a -> Map k a
Map.delete Int32 |? Text
reqId
    withTimeout :: Lsp () -> Lsp ()
    withTimeout :: Lsp () -> Lsp ()
withTimeout Lsp ()
action =
      case Maybe Int
mayTimeoutMillis of
        Maybe Int
Nothing -> Lsp ()
action
        Just Int
t -> do
          (Int -> Lsp () -> Lsp (Maybe ())
forall (m :: * -> *) a.
MonadUnliftIO m =>
Int -> m a -> m (Maybe a)
timeout (Int
t Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
1000) Lsp ()
action) Lsp (Maybe ()) -> (Maybe () -> Lsp ()) -> Lsp ()
forall a b. Lsp a -> (a -> Lsp b) -> Lsp b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
            Maybe ()
Nothing -> Either ResponseError (MessageResult message) -> Lsp ()
respond (Either ResponseError (MessageResult message) -> Lsp ())
-> Either ResponseError (MessageResult message) -> Lsp ()
forall a b. (a -> b) -> a -> b
$ Text -> Either ResponseError (MessageResult message)
forall b. Text -> Either ResponseError b
serverCancelErr Text
"Timeout"
            Just () -> () -> Lsp ()
forall a. a -> Lsp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    clientCancelErr :: Text -> Either Msg.ResponseError b
    clientCancelErr :: forall b. Text -> Either ResponseError b
clientCancelErr Text
msg = ResponseError -> Either ResponseError b
forall a b. a -> Either a b
Left (ResponseError -> Either ResponseError b)
-> ResponseError -> Either ResponseError b
forall a b. (a -> b) -> a -> b
$ (LSPErrorCodes |? ErrorCodes)
-> Text -> Maybe Value -> ResponseError
Msg.ResponseError (LSPErrorCodes -> LSPErrorCodes |? ErrorCodes
forall a b. a -> a |? b
InL LSPErrorCodes
LSPErrorCodes_RequestCancelled) Text
msg Maybe Value
forall a. Maybe a
Nothing
    serverCancelErr :: Text -> Either Msg.ResponseError b
    serverCancelErr :: forall b. Text -> Either ResponseError b
serverCancelErr Text
msg = ResponseError -> Either ResponseError b
forall a b. a -> Either a b
Left (ResponseError -> Either ResponseError b)
-> ResponseError -> Either ResponseError b
forall a b. (a -> b) -> a -> b
$ (LSPErrorCodes |? ErrorCodes)
-> Text -> Maybe Value -> ResponseError
Msg.ResponseError (LSPErrorCodes -> LSPErrorCodes |? ErrorCodes
forall a b. a -> a |? b
InL LSPErrorCodes
LSPErrorCodes_ServerCancelled) Text
msg Maybe Value
forall a. Maybe a
Nothing
    -- I intentionally defer adding the canceller until after we've started the request,
    -- No matter what it's possible for a message to be cancelled before the
    -- canceller has been added, but this means we're not blocking the request waiting for
    -- contention on the cancellation map on every request.
    -- The majority of requests should be fast enough to complete "instantly" anyways.
    waitForCancel :: (Int32 |? Text) -> Lsp ()
    waitForCancel :: (Int32 |? Text) -> Lsp ()
waitForCancel Int32 |? Text
reqId = do
      MVar ()
barrier <- Lsp (MVar ())
forall (m :: * -> *) a. MonadIO m => m (MVar a)
newEmptyMVar
      let canceller :: IO ()
canceller = IO Bool -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Bool -> IO ()) -> IO Bool -> IO ()
forall a b. (a -> b) -> a -> b
$ MVar () -> () -> IO Bool
forall (m :: * -> *) a. MonadIO m => MVar a -> a -> m Bool
tryPutMVar MVar ()
barrier ()
      TVar (Map (Int32 |? Text) (IO ()))
cancelMapVar <- (Env -> TVar (Map (Int32 |? Text) (IO ())))
-> Lsp (TVar (Map (Int32 |? Text) (IO ())))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env -> TVar (Map (Int32 |? Text) (IO ()))
cancellationMapVar
      STM () -> Lsp ()
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically do
        TVar (Map (Int32 |? Text) (IO ()))
-> (Map (Int32 |? Text) (IO ()) -> Map (Int32 |? Text) (IO ()))
-> STM ()
forall a. TVar a -> (a -> a) -> STM ()
modifyTVar' TVar (Map (Int32 |? Text) (IO ()))
cancelMapVar ((Int32 |? Text)
-> IO ()
-> Map (Int32 |? Text) (IO ())
-> Map (Int32 |? Text) (IO ())
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert Int32 |? Text
reqId IO ()
canceller)
      MVar () -> Lsp ()
forall (m :: * -> *) a. MonadIO m => MVar a -> m a
readMVar MVar ()
barrier
      let msg :: String
msg = String
"Request Cancelled by client"
      DebugFlag -> String -> Lsp ()
forall (m :: * -> *). Monad m => DebugFlag -> String -> m ()
Debug.debugLogM DebugFlag
Debug.LSP String
msg
      Either ResponseError (MessageResult message) -> Lsp ()
respond (Text -> Either ResponseError (MessageResult message)
forall b. Text -> Either ResponseError b
clientCancelErr Text
"Request cancelled by client")