{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE TypeFamilies #-}

module Unison.Util.Websockets
  ( withQueues,
    Queues (..),
    MsgOrError (..),
    withCodeserverWebsocket,
  )
where

import Codec.Serialise qualified as CBOR
import Control.Applicative
import Control.Concurrent.STM.TBMQueue
import Control.Lens (Profunctor (..))
import Control.Monad
import Data.Text qualified as Text
import Data.Text.Encoding qualified as Text
import Ki.Unlifted qualified as Ki
import Network.Socket
import Network.WebSockets
import Network.WebSockets qualified as WS
import Unison.Debug qualified as Debug
import Unison.Prelude
import Unison.Share.Types
import UnliftIO
import Wuss qualified

-- | Allows interfacing with a websocket as a pair of bounded queues.
data Queues i o = Queues
  { -- Receive from the other side. Returns Nothing if the connection is closed.
    forall i o. Queues i o -> STM (Maybe o)
receive :: STM (Maybe o),
    -- Send to the other side. Returns False if the connection is closed.
    forall i o. Queues i o -> i -> STM Bool
send :: i -> STM Bool
  }

instance Profunctor Queues where
  dimap :: forall a b c d. (a -> b) -> (c -> d) -> Queues b c -> Queues a d
dimap a -> b
f c -> d
g (Queues {STM (Maybe c)
receive :: forall i o. Queues i o -> STM (Maybe o)
receive :: STM (Maybe c)
receive, b -> STM Bool
send :: forall i o. Queues i o -> i -> STM Bool
send :: b -> STM Bool
send}) =
    Queues
      { receive :: STM (Maybe d)
receive = (c -> d) -> Maybe c -> Maybe d
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap c -> d
g (Maybe c -> Maybe d) -> STM (Maybe c) -> STM (Maybe d)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> STM (Maybe c)
receive,
        send :: a -> STM Bool
send = b -> STM Bool
send (b -> STM Bool) -> (a -> b) -> a -> STM Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> b
f
      }

withQueues :: forall i o m a. (MonadUnliftIO m, WebSocketsData i, WebSocketsData o) => Int -> Int -> Connection -> (Queues i o -> m a) -> m (Either ConnectionException (a, [o {- Any leftover messages received from the other side after we've indicated we want to shut down. -}]))
withQueues :: forall i o (m :: * -> *) a.
(MonadUnliftIO m, WebSocketsData i, WebSocketsData o) =>
Int
-> Int
-> Connection
-> (Queues i o -> m a)
-> m (Either ConnectionException (a, [o]))
withQueues Int
inputBuffer Int
outputBuffer Connection
conn Queues i o -> m a
action = (Scope -> m (Either ConnectionException (a, [o])))
-> m (Either ConnectionException (a, [o]))
forall a (m :: * -> *). MonadUnliftIO m => (Scope -> m a) -> m a
Ki.scoped ((Scope -> m (Either ConnectionException (a, [o])))
 -> m (Either ConnectionException (a, [o])))
-> (Scope -> m (Either ConnectionException (a, [o])))
-> m (Either ConnectionException (a, [o]))
forall a b. (a -> b) -> a -> b
$ \Scope
scope -> do
  receiveQ <- IO (TBMQueue o) -> m (TBMQueue o)
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (TBMQueue o) -> m (TBMQueue o))
-> IO (TBMQueue o) -> m (TBMQueue o)
forall a b. (a -> b) -> a -> b
$ Int -> IO (TBMQueue o)
forall a. Int -> IO (TBMQueue a)
newTBMQueueIO Int
inputBuffer
  sendQ <- liftIO $ newTBMQueueIO outputBuffer
  connectionClosedMVar <- liftIO $ newEmptyTMVarIO
  let receive = do TBMQueue o -> STM (Maybe o)
forall a. TBMQueue a -> STM (Maybe a)
readTBMQueue TBMQueue o
receiveQ
  let send i
msg = do
        TBMQueue i -> i -> STM ()
forall a. TBMQueue a -> a -> STM ()
writeTBMQueue TBMQueue i
sendQ i
msg
        Bool -> Bool
not (Bool -> Bool) -> STM Bool -> STM Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TBMQueue i -> STM Bool
forall a. TBMQueue a -> STM Bool
isClosedTBMQueue TBMQueue i
sendQ
  let queues = Queues {STM (Maybe o)
receive :: STM (Maybe o)
receive :: STM (Maybe o)
receive, i -> STM Bool
send :: i -> STM Bool
send :: i -> STM Bool
send}

  _ <- Ki.fork scope $ recvWorker connectionClosedMVar receiveQ
  sendWorkerThread <- Ki.fork scope $ sendWorker sendQ
  let waitConnectionError = STM ConnectionException -> m ConnectionException
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically do
        TMVar ConnectionException -> STM ConnectionException
forall a. TMVar a -> STM a
readTMVar TMVar ConnectionException
connectionClosedMVar
  race waitConnectionError (action queues) >>= \case
    Left ConnectionException
err -> do
      -- An error occurred, return it.
      Either ConnectionException (a, [o])
-> m (Either ConnectionException (a, [o]))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ConnectionException -> Either ConnectionException (a, [o])
forall a b. a -> Either a b
Left ConnectionException
err)
    Right a
result -> do
      -- The action completed, we need to close the connection gracefully
      -- and drain any remaining messages.
      STM () -> m ()
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM () -> m ()) -> STM () -> m ()
forall a b. (a -> b) -> a -> b
$ do
        -- Close the send queue, then wait for all messages to be sent before we close.
        TBMQueue i -> STM ()
forall a. TBMQueue a -> STM ()
closeTBMQueue TBMQueue i
sendQ
      STM () -> m ()
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM () -> m ()) -> STM () -> m ()
forall a b. (a -> b) -> a -> b
$ Thread () -> STM ()
forall a. Thread a -> STM a
Ki.await Thread ()
sendWorkerThread
      -- Now we can close and drain any remaining messages.
      msgs <- TBMQueue o -> m [o]
selfClose TBMQueue o
receiveQ
      pure $ Right (result, msgs)
  where
    -- Shut down the connection gracefully, returning any remaining messages.
    selfClose :: (TBMQueue o) -> m [o]
    selfClose :: TBMQueue o -> m [o]
selfClose TBMQueue o
receiveQ = do
      -- We've requested to close the connection.
      IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Connection -> Text -> IO ()
forall a. WebSocketsData a => Connection -> a -> IO ()
sendClose Connection
conn (Text
"Done" :: Text)
      let drainMessages :: m [o]
          drainMessages :: m [o]
drainMessages = do
            -- Read messages until the queue is closed, which indicates the other side has also closed their connection.
            STM (Maybe o) -> m (Maybe o)
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (TBMQueue o -> STM (Maybe o)
forall a. TBMQueue a -> STM (Maybe a)
readTBMQueue TBMQueue o
receiveQ) m (Maybe o) -> (Maybe o -> m [o]) -> m [o]
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
              Maybe o
Nothing -> [o] -> m [o]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
              Just o
msg -> do
                rest <- m [o]
drainMessages
                pure (msg : rest)
      m [o]
drainMessages

    recvWorker :: (TMVar ConnectionException) -> TBMQueue o -> m ()
    recvWorker :: TMVar ConnectionException -> TBMQueue o -> m ()
recvWorker TMVar ConnectionException
errMVar TBMQueue o
q = do
      closed <- (ConnectionException -> m Bool) -> m Bool -> m Bool
forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
(e -> m a) -> m a -> m a
UnliftIO.handle ConnectionException -> m Bool
handler (m Bool -> m Bool) -> m Bool -> m Bool
forall a b. (a -> b) -> a -> b
$ do
        msg <- IO o -> m o
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO o -> m o) -> IO o -> m o
forall a b. (a -> b) -> a -> b
$ Connection -> IO o
forall a. WebSocketsData a => Connection -> IO a
receiveData Connection
conn
        atomically $ writeTBMQueue q msg
        pure False
      when (not closed) $ recvWorker errMVar q
      where
        handler :: ConnectionException -> m Bool
        handler :: ConnectionException -> m Bool
handler = \case
          CloseRequest {} -> do
            -- The other side requested a close, we close the recv channel to indicate
            -- we won't receive any more messages.
            STM () -> m ()
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM () -> m ()) -> STM () -> m ()
forall a b. (a -> b) -> a -> b
$ do
              TBMQueue o -> STM ()
forall a. TBMQueue a -> STM ()
closeTBMQueue TBMQueue o
q
            Bool -> m Bool
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True

          -- Other cases are exceptional, set the error var
          ConnectionException
err -> do
            STM () -> m ()
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM () -> m ()) -> STM () -> m ()
forall a b. (a -> b) -> a -> b
$ do
              STM Bool -> STM ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (STM Bool -> STM ()) -> STM Bool -> STM ()
forall a b. (a -> b) -> a -> b
$ TMVar ConnectionException -> ConnectionException -> STM Bool
forall a. TMVar a -> a -> STM Bool
tryPutTMVar TMVar ConnectionException
errMVar ConnectionException
err
            Bool -> m Bool
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True

    sendWorker :: TBMQueue i -> m ()
    sendWorker :: TBMQueue i -> m ()
sendWorker TBMQueue i
q = do
      let flushQ :: STM ([i], Bool)
          flushQ :: STM ([i], Bool)
flushQ = do
            STM (Maybe i) -> STM (Maybe (Maybe i))
forall (f :: * -> *) a. Alternative f => f a -> f (Maybe a)
optional (TBMQueue i -> STM (Maybe i)
forall a. TBMQueue a -> STM (Maybe a)
readTBMQueue TBMQueue i
q) STM (Maybe (Maybe i))
-> (Maybe (Maybe i) -> STM ([i], Bool)) -> STM ([i], Bool)
forall a b. STM a -> (a -> STM b) -> STM b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
              -- No messages, but queue is still open
              Maybe (Maybe i)
Nothing -> STM ([i], Bool)
forall a. STM a
forall (f :: * -> *) a. Alternative f => f a
empty
              -- Queue is closed
              Just Maybe i
Nothing -> ([i], Bool) -> STM ([i], Bool)
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([], Bool
True)
              -- Got a message, keep flushing
              Just (Just i
outMsg) -> do
                (outMsgs, isClosed) <- STM ([i], Bool)
flushQ STM ([i], Bool) -> STM ([i], Bool) -> STM ([i], Bool)
forall a. STM a -> STM a -> STM a
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> ([i], Bool) -> STM ([i], Bool)
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([], Bool
False)
                pure (outMsg : outMsgs, isClosed)
      (outMsgs, isClosed) <- STM ([i], Bool) -> m ([i], Bool)
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM ([i], Bool) -> m ([i], Bool))
-> STM ([i], Bool) -> m ([i], Bool)
forall a b. (a -> b) -> a -> b
$ STM ([i], Bool)
flushQ
      liftIO $ sendBinaryDatas conn outMsgs
      when (not isClosed) $ sendWorker q

-- | Connect a websocket to the codeserver at the given URI.
-- The action will be called with a 'Queues' to send and receive messages,
-- when the action completes, the websocket connection will be closed.
withCodeserverWebsocket :: forall m i o r e. (MonadUnliftIO m, WebSocketsData i, WebSocketsData o) => Int -> CodeserverURI -> (CodeserverId -> IO (Either e Text)) -> String -> (Queues i o -> m r) -> m (Either ConnectionException (r, [o {- Any leftover messages received from the server after we've indicated we want to shut down. -}]))
withCodeserverWebsocket :: forall (m :: * -> *) i o r e.
(MonadUnliftIO m, WebSocketsData i, WebSocketsData o) =>
Int
-> CodeserverURI
-> (CodeserverId -> IO (Either e Text))
-> String
-> (Queues i o -> m r)
-> m (Either ConnectionException (r, [o]))
withCodeserverWebsocket Int
msgBufferSize CodeserverURI
codeserver CodeserverId -> IO (Either e Text)
tokenProvider String
codeserverPath Queues i o -> m r
action = do
  let host :: String
host = CodeserverURI -> String
codeserverRegName CodeserverURI
codeserver
  let connectionOptions :: ConnectionOptions
connectionOptions = ConnectionOptions
WS.defaultConnectionOptions {WS.connectionCompressionOptions = WS.PermessageDeflateCompression WS.defaultPermessageDeflate}
  headers <-
    (IO (Either e Text) -> m (Either e Text)
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (CodeserverId -> IO (Either e Text)
tokenProvider (CodeserverURI -> CodeserverId
codeserverIdFromCodeserverURI CodeserverURI
codeserver))) m (Either e Text)
-> (Either e Text -> [(CI ByteString, ByteString)])
-> m [(CI ByteString, ByteString)]
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> \case
      Left {} -> []
      Right Text
token -> [(CI ByteString
"Authorization", ByteString
"Bearer " ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Text -> ByteString
Text.encodeUtf8 Text
token)]

  let wsRunner String
path ConnectionOptions
opts [(CI ByteString, ByteString)]
headers ClientApp (Either ConnectionException (r, [o]))
action = case CodeserverURI -> Scheme
codeserverScheme CodeserverURI
codeserver of
        Scheme
Https ->
          let tlsPort :: PortNumber
tlsPort = PortNumber
443
              port :: PortNumber
port = PortNumber -> (Int -> PortNumber) -> Maybe Int -> PortNumber
forall b a. b -> (a -> b) -> Maybe a -> b
maybe PortNumber
tlsPort Int -> PortNumber
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Maybe Int -> PortNumber) -> Maybe Int -> PortNumber
forall a b. (a -> b) -> a -> b
$ (CodeserverURI -> Maybe Int
codeserverPort) CodeserverURI
codeserver
           in do
                DebugFlag -> String -> IO ()
forall (m :: * -> *). Monad m => DebugFlag -> String -> m ()
Debug.debugLogM DebugFlag
Debug.Websockets (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"Connecting to codeserver via WSS: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> (String, PortNumber, String, [(CI ByteString, ByteString)])
-> String
forall a. Show a => a -> String
show (String
host, PortNumber
port, String
codeserverPath, [(CI ByteString, ByteString)]
headers)
                String
-> PortNumber
-> String
-> ConnectionOptions
-> [(CI ByteString, ByteString)]
-> ClientApp (Either ConnectionException (r, [o]))
-> IO (Either ConnectionException (r, [o]))
forall (m :: * -> *) a.
(MonadIO m, MonadMask m) =>
String
-> PortNumber
-> String
-> ConnectionOptions
-> [(CI ByteString, ByteString)]
-> ClientApp a
-> m a
Wuss.runSecureClientWith String
host PortNumber
port String
path ConnectionOptions
opts [(CI ByteString, ByteString)]
headers ClientApp (Either ConnectionException (r, [o]))
action
        Scheme
Http ->
          let defaultPort :: Int
defaultPort = Int
80 :: Int
              port :: Int
port = Int -> (Int -> Int) -> Maybe Int -> Int
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Int
defaultPort Int -> Int
forall a. a -> a
id (Maybe Int -> Int) -> Maybe Int -> Int
forall a b. (a -> b) -> a -> b
$ CodeserverURI -> Maybe Int
codeserverPort CodeserverURI
codeserver
              fixedHost :: String
fixedHost = case String
host of
                -- The haskell ws client has issues with "localhost"
                String
"localhost" -> String
"127.0.0.1"
                String
_ -> String
host
           in do
                DebugFlag -> String -> IO ()
forall (m :: * -> *). Monad m => DebugFlag -> String -> m ()
Debug.debugLogM DebugFlag
Debug.Websockets (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"Connecting to codeserver via WS: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> (String, Int, String, [(CI ByteString, ByteString)]) -> String
forall a. Show a => a -> String
show (String
fixedHost, Int
port, String
codeserverPath, [(CI ByteString, ByteString)]
headers)
                String
-> Int
-> String
-> ConnectionOptions
-> [(CI ByteString, ByteString)]
-> ClientApp (Either ConnectionException (r, [o]))
-> IO (Either ConnectionException (r, [o]))
forall a.
String
-> Int
-> String
-> ConnectionOptions
-> [(CI ByteString, ByteString)]
-> ClientApp a
-> IO a
WS.runClientWith String
fixedHost Int
port String
path ConnectionOptions
opts [(CI ByteString, ByteString)]
headers ClientApp (Either ConnectionException (r, [o]))
action
  toIO <- askRunInIO
  liftIO $ withSocketsDo $ (wsRunner codeserverPath connectionOptions headers) \Connection
conn -> do
    Int
-> Int
-> Connection
-> (Queues i o -> IO r)
-> IO (Either ConnectionException (r, [o]))
forall i o (m :: * -> *) a.
(MonadUnliftIO m, WebSocketsData i, WebSocketsData o) =>
Int
-> Int
-> Connection
-> (Queues i o -> m a)
-> m (Either ConnectionException (a, [o]))
withQueues Int
msgBufferSize Int
msgBufferSize Connection
conn ((Queues i o -> IO r) -> IO (Either ConnectionException (r, [o])))
-> (Queues i o -> IO r) -> IO (Either ConnectionException (r, [o]))
forall a b. (a -> b) -> a -> b
$ \Queues i o
queues -> do
      m r -> IO r
toIO (m r -> IO r) -> m r -> IO r
forall a b. (a -> b) -> a -> b
$ Queues i o -> m r
action Queues i o
queues

-- | Type used for websocket messages that can either be a message or an error.
data MsgOrError err a
  = Msg !a
  | UserErr !err
  | DeserialiseFailure !Text
  deriving (Int -> MsgOrError err a -> String -> String
[MsgOrError err a] -> String -> String
MsgOrError err a -> String
(Int -> MsgOrError err a -> String -> String)
-> (MsgOrError err a -> String)
-> ([MsgOrError err a] -> String -> String)
-> Show (MsgOrError err a)
forall a.
(Int -> a -> String -> String)
-> (a -> String) -> ([a] -> String -> String) -> Show a
forall err a.
(Show a, Show err) =>
Int -> MsgOrError err a -> String -> String
forall err a.
(Show a, Show err) =>
[MsgOrError err a] -> String -> String
forall err a. (Show a, Show err) => MsgOrError err a -> String
$cshowsPrec :: forall err a.
(Show a, Show err) =>
Int -> MsgOrError err a -> String -> String
showsPrec :: Int -> MsgOrError err a -> String -> String
$cshow :: forall err a. (Show a, Show err) => MsgOrError err a -> String
show :: MsgOrError err a -> String
$cshowList :: forall err a.
(Show a, Show err) =>
[MsgOrError err a] -> String -> String
showList :: [MsgOrError err a] -> String -> String
Show, MsgOrError err a -> MsgOrError err a -> Bool
(MsgOrError err a -> MsgOrError err a -> Bool)
-> (MsgOrError err a -> MsgOrError err a -> Bool)
-> Eq (MsgOrError err a)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall err a.
(Eq a, Eq err) =>
MsgOrError err a -> MsgOrError err a -> Bool
$c== :: forall err a.
(Eq a, Eq err) =>
MsgOrError err a -> MsgOrError err a -> Bool
== :: MsgOrError err a -> MsgOrError err a -> Bool
$c/= :: forall err a.
(Eq a, Eq err) =>
MsgOrError err a -> MsgOrError err a -> Bool
/= :: MsgOrError err a -> MsgOrError err a -> Bool
Eq, Eq (MsgOrError err a)
Eq (MsgOrError err a) =>
(MsgOrError err a -> MsgOrError err a -> Ordering)
-> (MsgOrError err a -> MsgOrError err a -> Bool)
-> (MsgOrError err a -> MsgOrError err a -> Bool)
-> (MsgOrError err a -> MsgOrError err a -> Bool)
-> (MsgOrError err a -> MsgOrError err a -> Bool)
-> (MsgOrError err a -> MsgOrError err a -> MsgOrError err a)
-> (MsgOrError err a -> MsgOrError err a -> MsgOrError err a)
-> Ord (MsgOrError err a)
MsgOrError err a -> MsgOrError err a -> Bool
MsgOrError err a -> MsgOrError err a -> Ordering
MsgOrError err a -> MsgOrError err a -> MsgOrError err a
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall err a. (Ord a, Ord err) => Eq (MsgOrError err a)
forall err a.
(Ord a, Ord err) =>
MsgOrError err a -> MsgOrError err a -> Bool
forall err a.
(Ord a, Ord err) =>
MsgOrError err a -> MsgOrError err a -> Ordering
forall err a.
(Ord a, Ord err) =>
MsgOrError err a -> MsgOrError err a -> MsgOrError err a
$ccompare :: forall err a.
(Ord a, Ord err) =>
MsgOrError err a -> MsgOrError err a -> Ordering
compare :: MsgOrError err a -> MsgOrError err a -> Ordering
$c< :: forall err a.
(Ord a, Ord err) =>
MsgOrError err a -> MsgOrError err a -> Bool
< :: MsgOrError err a -> MsgOrError err a -> Bool
$c<= :: forall err a.
(Ord a, Ord err) =>
MsgOrError err a -> MsgOrError err a -> Bool
<= :: MsgOrError err a -> MsgOrError err a -> Bool
$c> :: forall err a.
(Ord a, Ord err) =>
MsgOrError err a -> MsgOrError err a -> Bool
> :: MsgOrError err a -> MsgOrError err a -> Bool
$c>= :: forall err a.
(Ord a, Ord err) =>
MsgOrError err a -> MsgOrError err a -> Bool
>= :: MsgOrError err a -> MsgOrError err a -> Bool
$cmax :: forall err a.
(Ord a, Ord err) =>
MsgOrError err a -> MsgOrError err a -> MsgOrError err a
max :: MsgOrError err a -> MsgOrError err a -> MsgOrError err a
$cmin :: forall err a.
(Ord a, Ord err) =>
MsgOrError err a -> MsgOrError err a -> MsgOrError err a
min :: MsgOrError err a -> MsgOrError err a -> MsgOrError err a
Ord)

-- | Roundtrip test:
-- >>> import qualified Codec.Serialise as CBOR
-- >>> CBOR.deserialise (CBOR.serialise (Msg "test" :: MsgOrError Text Text)) == Msg "test"
-- True
-- >>> CBOR.deserialise (CBOR.serialise (Err "error" :: MsgOrError Text Text)) == Err "error"
-- True
instance (CBOR.Serialise a, CBOR.Serialise err) => CBOR.Serialise (MsgOrError err a) where
  encode :: MsgOrError err a -> Encoding
encode = \case
    Msg a
a -> Int -> Encoding
forall a. Serialise a => a -> Encoding
CBOR.encode (Int
0 :: Int) Encoding -> Encoding -> Encoding
forall a. Semigroup a => a -> a -> a
<> a -> Encoding
forall a. Serialise a => a -> Encoding
CBOR.encode a
a
    UserErr err
e -> Int -> Encoding
forall a. Serialise a => a -> Encoding
CBOR.encode (Int
1 :: Int) Encoding -> Encoding -> Encoding
forall a. Semigroup a => a -> a -> a
<> err -> Encoding
forall a. Serialise a => a -> Encoding
CBOR.encode err
e
    DeserialiseFailure Text
msg -> Int -> Encoding
forall a. Serialise a => a -> Encoding
CBOR.encode (Int
2 :: Int) Encoding -> Encoding -> Encoding
forall a. Semigroup a => a -> a -> a
<> Text -> Encoding
forall a. Serialise a => a -> Encoding
CBOR.encode Text
msg

  decode :: forall s. Decoder s (MsgOrError err a)
decode = do
    tag <- forall a s. Serialise a => Decoder s a
CBOR.decode @Int
    case tag of
      Int
0 -> a -> MsgOrError err a
forall err a. a -> MsgOrError err a
Msg (a -> MsgOrError err a)
-> Decoder s a -> Decoder s (MsgOrError err a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Decoder s a
forall s. Decoder s a
forall a s. Serialise a => Decoder s a
CBOR.decode
      Int
1 -> err -> MsgOrError err a
forall err a. err -> MsgOrError err a
UserErr (err -> MsgOrError err a)
-> Decoder s err -> Decoder s (MsgOrError err a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Decoder s err
forall s. Decoder s err
forall a s. Serialise a => Decoder s a
CBOR.decode
      Int
2 -> Text -> MsgOrError err a
forall err a. Text -> MsgOrError err a
DeserialiseFailure (Text -> MsgOrError err a)
-> Decoder s Text -> Decoder s (MsgOrError err a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Decoder s Text
forall s. Decoder s Text
forall a s. Serialise a => Decoder s a
CBOR.decode
      Int
_ -> String -> Decoder s (MsgOrError err a)
forall a. String -> Decoder s a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Decoder s (MsgOrError err a))
-> String -> Decoder s (MsgOrError err a)
forall a b. (a -> b) -> a -> b
$ String
"Unknown MsgOrError tag: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show Int
tag

-- | Roundtrip test:
-- >>> import qualified Network.WebSockets as WS
-- >>> let msgVal = Msg "test" :: MsgOrError Text Text
-- >>> WS.fromLazyByteString (WS.toLazyByteString msgVal) == msgVal
-- True
-- >>> let errVal = UserErr "whoops" :: MsgOrError Text Text
-- >>> WS.fromLazyByteString (WS.toLazyByteString errVal) == errVal
-- True
--
-- >>> let errVal = DeserialiseFailure "whoops" :: MsgOrError Text Text
-- >>> WS.fromLazyByteString (WS.toLazyByteString errVal) == errVal
-- True
-- >>> let dataMsg = WS.Binary (WS.toLazyByteString msgVal)
-- >>> WS.fromDataMessage dataMsg == msgVal
-- True
instance (CBOR.Serialise msg, CBOR.Serialise e) => WebSocketsData (MsgOrError e msg) where
  fromLazyByteString :: ByteString -> MsgOrError e msg
fromLazyByteString ByteString
bytes =
    case ByteString -> Either DeserialiseFailure (MsgOrError e msg)
forall a. Serialise a => ByteString -> Either DeserialiseFailure a
CBOR.deserialiseOrFail ByteString
bytes of
      Left DeserialiseFailure
err -> Text -> MsgOrError e msg
forall err a. Text -> MsgOrError err a
DeserialiseFailure (String -> Text
Text.pack (DeserialiseFailure -> String
forall a. Show a => a -> String
show DeserialiseFailure
err))
      Right MsgOrError e msg
msg -> MsgOrError e msg
msg

  toLazyByteString :: MsgOrError e msg -> ByteString
toLazyByteString = MsgOrError e msg -> ByteString
forall a. Serialise a => a -> ByteString
CBOR.serialise

  fromDataMessage :: DataMessage -> MsgOrError e msg
fromDataMessage DataMessage
dm = do
    case DataMessage
dm of
      WS.Text ByteString
bytes Maybe Text
_ -> ByteString -> MsgOrError e msg
forall a. WebSocketsData a => ByteString -> a
WS.fromLazyByteString ByteString
bytes
      WS.Binary ByteString
bytes -> ByteString -> MsgOrError e msg
forall a. WebSocketsData a => ByteString -> a
WS.fromLazyByteString ByteString
bytes