{-# LANGUAGE CPP #-}
-- |
-- Module      : Network.TLS.Backend
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : experimental
-- Portability : unknown
--
-- A Backend represents a unified way to do IO on different
-- types without burdening our calling API with multiple
-- ways to initialize a new context.
--
-- Typically, a backend provides:
-- * a way to read data
-- * a way to write data
-- * a way to close the stream
-- * a way to flush the stream
--
module Network.TLS.Backend
    ( HasBackend(..)
    , Backend(..)
    ) where

import Network.TLS.Imports
import qualified Data.ByteString as B
import System.IO (Handle, hSetBuffering, BufferMode(..), hFlush, hClose)

#ifdef INCLUDE_NETWORK
import qualified Network.Socket as Network (Socket, close)
import qualified Network.Socket.ByteString as Network
#endif

#ifdef INCLUDE_HANS
import qualified Data.ByteString.Lazy as L
import qualified Hans.NetworkStack as Hans
#endif

-- | Connection IO backend
data Backend = Backend
    { Backend -> IO ()
backendFlush :: IO ()                -- ^ Flush the connection sending buffer, if any.
    , Backend -> IO ()
backendClose :: IO ()                -- ^ Close the connection.
    , Backend -> ByteString -> IO ()
backendSend  :: ByteString -> IO ()  -- ^ Send a bytestring through the connection.
    , Backend -> Int -> IO ByteString
backendRecv  :: Int -> IO ByteString -- ^ Receive specified number of bytes from the connection.
    }

class HasBackend a where
    initializeBackend :: a -> IO ()
    getBackend :: a -> Backend

instance HasBackend Backend where
    initializeBackend :: Backend -> IO ()
initializeBackend Backend
_ = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    getBackend :: Backend -> Backend
getBackend = Backend -> Backend
forall a. a -> a
id

#if defined(__GLASGOW_HASKELL__) && WINDOWS
-- Socket recv and accept calls on Windows platform cannot be interrupted when compiled with -threaded.
-- See https://ghc.haskell.org/trac/ghc/ticket/5797 for details.
-- The following enables simple workaround
#define SOCKET_ACCEPT_RECV_WORKAROUND
#endif

safeRecv :: Network.Socket -> Int -> IO ByteString
#ifndef SOCKET_ACCEPT_RECV_WORKAROUND
safeRecv :: Socket -> Int -> IO ByteString
safeRecv = Socket -> Int -> IO ByteString
Network.recv
#else
safeRecv s buf = do
    var <- newEmptyMVar
    forkIO $ Network.recv s buf `E.catch` (\(_::IOException) -> return S8.empty) >>= putMVar var
    takeMVar var
#endif

#ifdef INCLUDE_NETWORK
instance HasBackend Network.Socket where
    initializeBackend :: Socket -> IO ()
initializeBackend Socket
_ = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    getBackend :: Socket -> Backend
getBackend Socket
sock = IO ()
-> IO ()
-> (ByteString -> IO ())
-> (Int -> IO ByteString)
-> Backend
Backend (() -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()) (Socket -> IO ()
Network.close Socket
sock) (Socket -> ByteString -> IO ()
Network.sendAll Socket
sock) Int -> IO ByteString
recvAll
      where recvAll :: Int -> IO ByteString
recvAll Int
n = [ByteString] -> ByteString
B.concat ([ByteString] -> ByteString) -> IO [ByteString] -> IO ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> IO [ByteString]
loop Int
n
              where loop :: Int -> IO [ByteString]
loop Int
0    = [ByteString] -> IO [ByteString]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return []
                    loop Int
left = do
                        ByteString
r <- Socket -> Int -> IO ByteString
safeRecv Socket
sock Int
left
                        if ByteString -> Bool
B.null ByteString
r
                            then [ByteString] -> IO [ByteString]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return []
                            else (ByteString
rByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
:) ([ByteString] -> [ByteString])
-> IO [ByteString] -> IO [ByteString]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> IO [ByteString]
loop (Int
left Int -> Int -> Int
forall a. Num a => a -> a -> a
- ByteString -> Int
B.length ByteString
r)
#endif

#ifdef INCLUDE_HANS
instance HasBackend Hans.Socket where
    initializeBackend _ = return ()
    getBackend sock = Backend (return ()) (Hans.close sock) sendAll recvAll
      where sendAll x = do
              amt <- fromIntegral <$> Hans.sendBytes sock (L.fromStrict x)
              if (amt == 0) || (amt == B.length x)
                 then return ()
                 else sendAll (B.drop amt x)
            recvAll n = loop (fromIntegral n) L.empty
            loop    0 acc = return (L.toStrict acc)
            loop left acc = do
                r <- Hans.recvBytes sock left
                if L.null r
                   then loop 0 acc
                   else loop (left - L.length r) (acc `L.append` r)
#endif

instance HasBackend Handle where
    initializeBackend :: Handle -> IO ()
initializeBackend Handle
handle = Handle -> BufferMode -> IO ()
hSetBuffering Handle
handle BufferMode
NoBuffering
    getBackend :: Handle -> Backend
getBackend Handle
handle = IO ()
-> IO ()
-> (ByteString -> IO ())
-> (Int -> IO ByteString)
-> Backend
Backend (Handle -> IO ()
hFlush Handle
handle) (Handle -> IO ()
hClose Handle
handle) (Handle -> ByteString -> IO ()
B.hPut Handle
handle) (Handle -> Int -> IO ByteString
B.hGet Handle
handle)