{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE TypeOperators #-}

module Unison.LSP.HandlerUtils where

import Control.Lens
import Control.Monad.Reader
import Data.Map qualified as Map
import Language.LSP.Protocol.Lens as LSP
import Language.LSP.Protocol.Message qualified as Msg
import Language.LSP.Protocol.Types
import Unison.Debug qualified as Debug
import Unison.LSP.Types
import Unison.Prelude
import UnliftIO (race_)
import UnliftIO.Concurrent (forkIO)
import UnliftIO.Exception (finally)
import UnliftIO.MVar
import UnliftIO.STM
import UnliftIO.Timeout (timeout)

-- | Cancels an in-flight request
cancelRequest :: (Int32 |? Text) -> Lsp ()
cancelRequest :: (Int32 |? Text) -> Lsp ()
cancelRequest Int32 |? Text
lspId = do
  cancelMapVar <- (Env -> TVar (Map (Int32 |? Text) (IO ())))
-> Lsp (TVar (Map (Int32 |? Text) (IO ())))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env -> TVar (Map (Int32 |? Text) (IO ()))
cancellationMapVar
  cancel <- atomically $ do
    cancellers <- readTVar cancelMapVar
    let (mayCancel, newMap) = Map.updateLookupWithKey (\Int32 |? Text
_k IO ()
_io -> Maybe (IO ())
forall a. Maybe a
Nothing) lspId cancellers
    case mayCancel of
      Maybe (IO ())
Nothing -> IO () -> STM (IO ())
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (() -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ())
      Just IO ()
cancel -> do
        TVar (Map (Int32 |? Text) (IO ()))
-> Map (Int32 |? Text) (IO ()) -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar (Map (Int32 |? Text) (IO ()))
cancelMapVar Map (Int32 |? Text) (IO ())
newMap
        IO () -> STM (IO ())
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure IO ()
cancel
  liftIO cancel

withDebugging ::
  (Show (Msg.TRequestMessage message), Show (Msg.ErrorData message), Show (Msg.MessageResult message)) =>
  (Msg.TRequestMessage message -> (Either (Msg.TResponseError message) (Msg.MessageResult message) -> Lsp ()) -> Lsp ()) ->
  Msg.TRequestMessage message ->
  (Either (Msg.TResponseError message) (Msg.MessageResult message) -> Lsp ()) ->
  Lsp ()
withDebugging :: forall {f :: MessageDirection} (message :: Method f 'Request).
(Show (TRequestMessage message), Show (ErrorData message),
 Show (MessageResult message)) =>
(TRequestMessage message
 -> (Either (TResponseError message) (MessageResult message)
     -> Lsp ())
 -> Lsp ())
-> TRequestMessage message
-> (Either (TResponseError message) (MessageResult message)
    -> Lsp ())
-> Lsp ()
withDebugging TRequestMessage message
-> (Either (TResponseError message) (MessageResult message)
    -> Lsp ())
-> Lsp ()
handler TRequestMessage message
message Either (TResponseError message) (MessageResult message) -> Lsp ()
respond = do
  DebugFlag -> String -> TRequestMessage message -> Lsp ()
forall a (m :: * -> *).
(Show a, Monad m) =>
DebugFlag -> String -> a -> m ()
Debug.debugM DebugFlag
Debug.LSP String
"Request" TRequestMessage message
message
  TRequestMessage message
-> (Either (TResponseError message) (MessageResult message)
    -> Lsp ())
-> Lsp ()
handler TRequestMessage message
message \Either (TResponseError message) (MessageResult message)
response -> do
    DebugFlag
-> String
-> Either (TResponseError message) (MessageResult message)
-> Lsp ()
forall a (m :: * -> *).
(Show a, Monad m) =>
DebugFlag -> String -> a -> m ()
Debug.debugM DebugFlag
Debug.LSP String
"Response" Either (TResponseError message) (MessageResult message)
response
    Either (TResponseError message) (MessageResult message) -> Lsp ()
respond Either (TResponseError message) (MessageResult message)
response

-- | Handler middleware to add the ability for the client to cancel long-running in-flight requests.
withCancellation ::
  forall message.
  Maybe Int ->
  (Msg.TRequestMessage message -> (Either (Msg.TResponseError message) (Msg.MessageResult message) -> Lsp ()) -> Lsp ()) ->
  Msg.TRequestMessage message ->
  (Either (Msg.TResponseError message) (Msg.MessageResult message) -> Lsp ()) ->
  Lsp ()
withCancellation :: forall {f :: MessageDirection} (message :: Method f 'Request).
Maybe Int
-> (TRequestMessage message
    -> (Either (TResponseError message) (MessageResult message)
        -> Lsp ())
    -> Lsp ())
-> TRequestMessage message
-> (Either (TResponseError message) (MessageResult message)
    -> Lsp ())
-> Lsp ()
withCancellation Maybe Int
mayTimeoutMillis TRequestMessage message
-> (Either (TResponseError message) (MessageResult message)
    -> Lsp ())
-> Lsp ()
handler TRequestMessage message
message Either (TResponseError message) (MessageResult message) -> Lsp ()
respond = do
  let reqId :: Int32 |? Text
reqId = case TRequestMessage message
message TRequestMessage message
-> Getting
     (LspId message) (TRequestMessage message) (LspId message)
-> LspId message
forall s a. s -> Getting a s a -> a
^. Getting (LspId message) (TRequestMessage message) (LspId message)
forall s a. HasId s a => Lens' s a
Lens' (TRequestMessage message) (LspId message)
LSP.id of
        Msg.IdInt Int32
i -> Int32 -> Int32 |? Text
forall a b. a -> a |? b
InL Int32
i
        Msg.IdString Text
s -> Text -> Int32 |? Text
forall a b. b -> a |? b
InR Text
s
  -- The server itself seems to be single-threaded, so we need to fork in order to be able to
  -- process cancellation requests while still computing some other response
  Lsp ThreadId -> Lsp ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (Lsp ThreadId -> Lsp ())
-> (Lsp () -> Lsp ThreadId) -> Lsp () -> Lsp ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lsp () -> Lsp ThreadId
forall (m :: * -> *). MonadUnliftIO m => m () -> m ThreadId
forkIO (Lsp () -> Lsp ()) -> Lsp () -> Lsp ()
forall a b. (a -> b) -> a -> b
$ (Lsp () -> Lsp () -> Lsp ()) -> Lsp () -> Lsp () -> Lsp ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip Lsp () -> Lsp () -> Lsp ()
forall (m :: * -> *) a b. MonadUnliftIO m => m a -> m b -> m a
finally ((Int32 |? Text) -> Lsp ()
forall {m :: * -> *}.
(MonadReader Env m, MonadIO m) =>
(Int32 |? Text) -> m ()
removeFromMap Int32 |? Text
reqId) do
    Lsp () -> Lsp ()
withTimeout (Lsp () -> Lsp ()) -> Lsp () -> Lsp ()
forall a b. (a -> b) -> a -> b
$ Lsp () -> Lsp () -> Lsp ()
forall (m :: * -> *) a b. MonadUnliftIO m => m a -> m b -> m ()
race_ ((Int32 |? Text) -> Lsp ()
waitForCancel Int32 |? Text
reqId) (TRequestMessage message
-> (Either (TResponseError message) (MessageResult message)
    -> Lsp ())
-> Lsp ()
handler TRequestMessage message
message Either (TResponseError message) (MessageResult message) -> Lsp ()
respond)
  where
    removeFromMap :: (Int32 |? Text) -> m ()
removeFromMap Int32 |? Text
reqId = do
      cancelMapVar <- (Env -> TVar (Map (Int32 |? Text) (IO ())))
-> m (TVar (Map (Int32 |? Text) (IO ())))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env -> TVar (Map (Int32 |? Text) (IO ()))
cancellationMapVar
      atomically $ modifyTVar' cancelMapVar $ Map.delete reqId
    withTimeout :: Lsp () -> Lsp ()
    withTimeout :: Lsp () -> Lsp ()
withTimeout Lsp ()
action =
      case Maybe Int
mayTimeoutMillis of
        Maybe Int
Nothing -> Lsp ()
action
        Just Int
t -> do
          (Int -> Lsp () -> Lsp (Maybe ())
forall (m :: * -> *) a.
MonadUnliftIO m =>
Int -> m a -> m (Maybe a)
timeout (Int
t Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
1000) Lsp ()
action) Lsp (Maybe ()) -> (Maybe () -> Lsp ()) -> Lsp ()
forall a b. Lsp a -> (a -> Lsp b) -> Lsp b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
            Maybe ()
Nothing -> Either (TResponseError message) (MessageResult message) -> Lsp ()
respond (Either (TResponseError message) (MessageResult message) -> Lsp ())
-> Either (TResponseError message) (MessageResult message)
-> Lsp ()
forall a b. (a -> b) -> a -> b
$ Text -> Either (TResponseError message) (MessageResult message)
forall b. Text -> Either (TResponseError message) b
serverCancelErr Text
"Timeout"
            Just () -> () -> Lsp ()
forall a. a -> Lsp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    clientCancelErr :: Text -> Either (Msg.TResponseError message) b
    clientCancelErr :: forall b. Text -> Either (TResponseError message) b
clientCancelErr Text
msg = TResponseError message -> Either (TResponseError message) b
forall a b. a -> Either a b
Left (TResponseError message -> Either (TResponseError message) b)
-> TResponseError message -> Either (TResponseError message) b
forall a b. (a -> b) -> a -> b
$ (LSPErrorCodes |? ErrorCodes)
-> Text -> Maybe (ErrorData message) -> TResponseError message
forall (f :: MessageDirection) (m :: Method f 'Request).
(LSPErrorCodes |? ErrorCodes)
-> Text -> Maybe (ErrorData m) -> TResponseError m
Msg.TResponseError (LSPErrorCodes -> LSPErrorCodes |? ErrorCodes
forall a b. a -> a |? b
InL LSPErrorCodes
LSPErrorCodes_RequestCancelled) Text
msg Maybe (ErrorData message)
forall a. Maybe a
Nothing
    serverCancelErr :: Text -> Either (Msg.TResponseError message) b
    serverCancelErr :: forall b. Text -> Either (TResponseError message) b
serverCancelErr Text
msg = TResponseError message -> Either (TResponseError message) b
forall a b. a -> Either a b
Left (TResponseError message -> Either (TResponseError message) b)
-> TResponseError message -> Either (TResponseError message) b
forall a b. (a -> b) -> a -> b
$ (LSPErrorCodes |? ErrorCodes)
-> Text -> Maybe (ErrorData message) -> TResponseError message
forall (f :: MessageDirection) (m :: Method f 'Request).
(LSPErrorCodes |? ErrorCodes)
-> Text -> Maybe (ErrorData m) -> TResponseError m
Msg.TResponseError (LSPErrorCodes -> LSPErrorCodes |? ErrorCodes
forall a b. a -> a |? b
InL LSPErrorCodes
LSPErrorCodes_ServerCancelled) Text
msg Maybe (ErrorData message)
forall a. Maybe a
Nothing
    -- I intentionally defer adding the canceller until after we've started the request,
    -- No matter what it's possible for a message to be cancelled before the
    -- canceller has been added, but this means we're not blocking the request waiting for
    -- contention on the cancellation map on every request.
    -- The majority of requests should be fast enough to complete "instantly" anyways.
    waitForCancel :: (Int32 |? Text) -> Lsp ()
    waitForCancel :: (Int32 |? Text) -> Lsp ()
waitForCancel Int32 |? Text
reqId = do
      barrier <- Lsp (MVar ())
forall (m :: * -> *) a. MonadIO m => m (MVar a)
newEmptyMVar
      let canceller = IO Bool -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Bool -> IO ()) -> IO Bool -> IO ()
forall a b. (a -> b) -> a -> b
$ MVar () -> () -> IO Bool
forall (m :: * -> *) a. MonadIO m => MVar a -> a -> m Bool
tryPutMVar MVar ()
barrier ()
      cancelMapVar <- asks cancellationMapVar
      atomically do
        modifyTVar' cancelMapVar (Map.insert reqId canceller)
      readMVar barrier
      let msg = String
"Request Cancelled by client"
      Debug.debugLogM Debug.LSP msg
      respond (clientCancelErr "Request cancelled by client")