-- |
-- Module      : Network.TLS.Receiving
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : experimental
-- Portability : unknown
--
-- the Receiving module contains calls related to unmarshalling packets according
-- to the TLS state
--
{-# LANGUAGE FlexibleContexts #-}

module Network.TLS.Receiving
    ( processPacket
    , processPacket13
    ) where

import Network.TLS.Cipher
import Network.TLS.Context.Internal
import Network.TLS.ErrT
import Network.TLS.Handshake.State
import Network.TLS.Imports
import Network.TLS.Packet
import Network.TLS.Packet13
import Network.TLS.Record
import Network.TLS.State
import Network.TLS.Struct
import Network.TLS.Struct13
import Network.TLS.Util
import Network.TLS.Wire

import Control.Concurrent.MVar
import Control.Monad.State.Strict

processPacket :: Context -> Record Plaintext -> IO (Either TLSError Packet)

processPacket :: Context -> Record Plaintext -> IO (Either TLSError Packet)
processPacket Context
_ (Record ProtocolType
ProtocolType_AppData Version
_ Fragment Plaintext
fragment) = Either TLSError Packet -> IO (Either TLSError Packet)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError Packet -> IO (Either TLSError Packet))
-> Either TLSError Packet -> IO (Either TLSError Packet)
forall a b. (a -> b) -> a -> b
$ Packet -> Either TLSError Packet
forall a b. b -> Either a b
Right (Packet -> Either TLSError Packet)
-> Packet -> Either TLSError Packet
forall a b. (a -> b) -> a -> b
$ ByteString -> Packet
AppData (ByteString -> Packet) -> ByteString -> Packet
forall a b. (a -> b) -> a -> b
$ Fragment Plaintext -> ByteString
forall a. Fragment a -> ByteString
fragmentGetBytes Fragment Plaintext
fragment

processPacket Context
_ (Record ProtocolType
ProtocolType_Alert Version
_ Fragment Plaintext
fragment) = Either TLSError Packet -> IO (Either TLSError Packet)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ([(AlertLevel, AlertDescription)] -> Packet
Alert ([(AlertLevel, AlertDescription)] -> Packet)
-> Either TLSError [(AlertLevel, AlertDescription)]
-> Either TLSError Packet
forall a b l. (a -> b) -> Either l a -> Either l b
`fmapEither` ByteString -> Either TLSError [(AlertLevel, AlertDescription)]
decodeAlerts (Fragment Plaintext -> ByteString
forall a. Fragment a -> ByteString
fragmentGetBytes Fragment Plaintext
fragment))

processPacket Context
ctx (Record ProtocolType
ProtocolType_ChangeCipherSpec Version
_ Fragment Plaintext
fragment) =
    case ByteString -> Either TLSError ()
decodeChangeCipherSpec (ByteString -> Either TLSError ())
-> ByteString -> Either TLSError ()
forall a b. (a -> b) -> a -> b
$ Fragment Plaintext -> ByteString
forall a. Fragment a -> ByteString
fragmentGetBytes Fragment Plaintext
fragment of
        Left TLSError
err -> Either TLSError Packet -> IO (Either TLSError Packet)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError Packet -> IO (Either TLSError Packet))
-> Either TLSError Packet -> IO (Either TLSError Packet)
forall a b. (a -> b) -> a -> b
$ TLSError -> Either TLSError Packet
forall a b. a -> Either a b
Left TLSError
err
        Right ()
_  -> do Context -> IO ()
switchRxEncryption Context
ctx
                       Either TLSError Packet -> IO (Either TLSError Packet)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError Packet -> IO (Either TLSError Packet))
-> Either TLSError Packet -> IO (Either TLSError Packet)
forall a b. (a -> b) -> a -> b
$ Packet -> Either TLSError Packet
forall a b. b -> Either a b
Right Packet
ChangeCipherSpec

processPacket Context
ctx (Record ProtocolType
ProtocolType_Handshake Version
ver Fragment Plaintext
fragment) = do
    Maybe CipherKeyExchangeType
keyxchg <- Context -> IO (Maybe HandshakeState)
forall (m :: * -> *).
MonadIO m =>
Context -> m (Maybe HandshakeState)
getHState Context
ctx IO (Maybe HandshakeState)
-> (Maybe HandshakeState -> IO (Maybe CipherKeyExchangeType))
-> IO (Maybe CipherKeyExchangeType)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Maybe HandshakeState
hs -> Maybe CipherKeyExchangeType -> IO (Maybe CipherKeyExchangeType)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe HandshakeState
hs Maybe HandshakeState
-> (HandshakeState -> Maybe Cipher) -> Maybe Cipher
forall a b. Maybe a -> (a -> Maybe b) -> Maybe b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= HandshakeState -> Maybe Cipher
hstPendingCipher Maybe Cipher
-> (Cipher -> Maybe CipherKeyExchangeType)
-> Maybe CipherKeyExchangeType
forall a b. Maybe a -> (a -> Maybe b) -> Maybe b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= CipherKeyExchangeType -> Maybe CipherKeyExchangeType
forall a. a -> Maybe a
Just (CipherKeyExchangeType -> Maybe CipherKeyExchangeType)
-> (Cipher -> CipherKeyExchangeType)
-> Cipher
-> Maybe CipherKeyExchangeType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Cipher -> CipherKeyExchangeType
cipherKeyExchange)
    Context -> TLSSt Packet -> IO (Either TLSError Packet)
forall a. Context -> TLSSt a -> IO (Either TLSError a)
usingState Context
ctx (TLSSt Packet -> IO (Either TLSError Packet))
-> TLSSt Packet -> IO (Either TLSError Packet)
forall a b. (a -> b) -> a -> b
$ do
        let currentParams :: CurrentParams
currentParams = CurrentParams
                            { cParamsVersion :: Version
cParamsVersion     = Version
ver
                            , cParamsKeyXchgType :: Maybe CipherKeyExchangeType
cParamsKeyXchgType = Maybe CipherKeyExchangeType
keyxchg
                            }
        -- get back the optional continuation, and parse as many handshake record as possible.
        Maybe (GetContinuation (HandshakeType, ByteString))
mCont <- (TLSState -> Maybe (GetContinuation (HandshakeType, ByteString)))
-> TLSSt (Maybe (GetContinuation (HandshakeType, ByteString)))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets TLSState -> Maybe (GetContinuation (HandshakeType, ByteString))
stHandshakeRecordCont
        (TLSState -> TLSState) -> TLSSt ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\TLSState
st -> TLSState
st { stHandshakeRecordCont = Nothing })
        [Handshake]
hss   <- CurrentParams
-> Maybe (GetContinuation (HandshakeType, ByteString))
-> ByteString
-> TLSSt [Handshake]
forall {m :: * -> *}.
(MonadError TLSError m, MonadState TLSState m) =>
CurrentParams
-> Maybe (GetContinuation (HandshakeType, ByteString))
-> ByteString
-> m [Handshake]
parseMany CurrentParams
currentParams Maybe (GetContinuation (HandshakeType, ByteString))
mCont (Fragment Plaintext -> ByteString
forall a. Fragment a -> ByteString
fragmentGetBytes Fragment Plaintext
fragment)
        Packet -> TLSSt Packet
forall a. a -> TLSSt a
forall (m :: * -> *) a. Monad m => a -> m a
return (Packet -> TLSSt Packet) -> Packet -> TLSSt Packet
forall a b. (a -> b) -> a -> b
$ [Handshake] -> Packet
Handshake [Handshake]
hss
  where parseMany :: CurrentParams
-> Maybe (GetContinuation (HandshakeType, ByteString))
-> ByteString
-> m [Handshake]
parseMany CurrentParams
currentParams Maybe (GetContinuation (HandshakeType, ByteString))
mCont ByteString
bs =
            case GetContinuation (HandshakeType, ByteString)
-> Maybe (GetContinuation (HandshakeType, ByteString))
-> GetContinuation (HandshakeType, ByteString)
forall a. a -> Maybe a -> a
fromMaybe GetContinuation (HandshakeType, ByteString)
decodeHandshakeRecord Maybe (GetContinuation (HandshakeType, ByteString))
mCont ByteString
bs of
                GotError TLSError
err                -> TLSError -> m [Handshake]
forall a. TLSError -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError TLSError
err
                GotPartial GetContinuation (HandshakeType, ByteString)
cont             -> (TLSState -> TLSState) -> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\TLSState
st -> TLSState
st { stHandshakeRecordCont = Just cont }) m () -> m [Handshake] -> m [Handshake]
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> [Handshake] -> m [Handshake]
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return []
                GotSuccess (HandshakeType
ty,ByteString
content)     ->
                    (TLSError -> m [Handshake])
-> (Handshake -> m [Handshake])
-> Either TLSError Handshake
-> m [Handshake]
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either TLSError -> m [Handshake]
forall a. TLSError -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError ([Handshake] -> m [Handshake]
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ([Handshake] -> m [Handshake])
-> (Handshake -> [Handshake]) -> Handshake -> m [Handshake]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Handshake -> [Handshake] -> [Handshake]
forall a. a -> [a] -> [a]
:[])) (Either TLSError Handshake -> m [Handshake])
-> Either TLSError Handshake -> m [Handshake]
forall a b. (a -> b) -> a -> b
$ CurrentParams
-> HandshakeType -> ByteString -> Either TLSError Handshake
decodeHandshake CurrentParams
currentParams HandshakeType
ty ByteString
content
                GotSuccessRemaining (HandshakeType
ty,ByteString
content) ByteString
left ->
                    case CurrentParams
-> HandshakeType -> ByteString -> Either TLSError Handshake
decodeHandshake CurrentParams
currentParams HandshakeType
ty ByteString
content of
                        Left TLSError
err -> TLSError -> m [Handshake]
forall a. TLSError -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError TLSError
err
                        Right Handshake
hh -> (Handshake
hhHandshake -> [Handshake] -> [Handshake]
forall a. a -> [a] -> [a]
:) ([Handshake] -> [Handshake]) -> m [Handshake] -> m [Handshake]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CurrentParams
-> Maybe (GetContinuation (HandshakeType, ByteString))
-> ByteString
-> m [Handshake]
parseMany CurrentParams
currentParams Maybe (GetContinuation (HandshakeType, ByteString))
forall a. Maybe a
Nothing ByteString
left

processPacket Context
_ (Record ProtocolType
ProtocolType_DeprecatedHandshake Version
_ Fragment Plaintext
fragment) =
    case ByteString -> Either TLSError Handshake
decodeDeprecatedHandshake (ByteString -> Either TLSError Handshake)
-> ByteString -> Either TLSError Handshake
forall a b. (a -> b) -> a -> b
$ Fragment Plaintext -> ByteString
forall a. Fragment a -> ByteString
fragmentGetBytes Fragment Plaintext
fragment of
        Left TLSError
err -> Either TLSError Packet -> IO (Either TLSError Packet)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError Packet -> IO (Either TLSError Packet))
-> Either TLSError Packet -> IO (Either TLSError Packet)
forall a b. (a -> b) -> a -> b
$ TLSError -> Either TLSError Packet
forall a b. a -> Either a b
Left TLSError
err
        Right Handshake
hs -> Either TLSError Packet -> IO (Either TLSError Packet)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError Packet -> IO (Either TLSError Packet))
-> Either TLSError Packet -> IO (Either TLSError Packet)
forall a b. (a -> b) -> a -> b
$ Packet -> Either TLSError Packet
forall a b. b -> Either a b
Right (Packet -> Either TLSError Packet)
-> Packet -> Either TLSError Packet
forall a b. (a -> b) -> a -> b
$ [Handshake] -> Packet
Handshake [Handshake
hs]

switchRxEncryption :: Context -> IO ()
switchRxEncryption :: Context -> IO ()
switchRxEncryption Context
ctx =
    Context -> HandshakeM (Maybe RecordState) -> IO (Maybe RecordState)
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx ((HandshakeState -> Maybe RecordState)
-> HandshakeM (Maybe RecordState)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets HandshakeState -> Maybe RecordState
hstPendingRxState) IO (Maybe RecordState) -> (Maybe RecordState -> IO ()) -> IO ()
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Maybe RecordState
rx ->
    IO () -> IO ()
forall a. IO a -> IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ MVar RecordState -> (RecordState -> IO RecordState) -> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ (Context -> MVar RecordState
ctxRxState Context
ctx) (\RecordState
_ -> RecordState -> IO RecordState
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (RecordState -> IO RecordState) -> RecordState -> IO RecordState
forall a b. (a -> b) -> a -> b
$ String -> Maybe RecordState -> RecordState
forall a. String -> Maybe a -> a
fromJust String
"rx-state" Maybe RecordState
rx)

----------------------------------------------------------------

processPacket13 :: Context -> Record Plaintext -> IO (Either TLSError Packet13)
processPacket13 :: Context -> Record Plaintext -> IO (Either TLSError Packet13)
processPacket13 Context
_ (Record ProtocolType
ProtocolType_ChangeCipherSpec Version
_ Fragment Plaintext
_) = Either TLSError Packet13 -> IO (Either TLSError Packet13)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError Packet13 -> IO (Either TLSError Packet13))
-> Either TLSError Packet13 -> IO (Either TLSError Packet13)
forall a b. (a -> b) -> a -> b
$ Packet13 -> Either TLSError Packet13
forall a b. b -> Either a b
Right Packet13
ChangeCipherSpec13
processPacket13 Context
_ (Record ProtocolType
ProtocolType_AppData Version
_ Fragment Plaintext
fragment) = Either TLSError Packet13 -> IO (Either TLSError Packet13)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError Packet13 -> IO (Either TLSError Packet13))
-> Either TLSError Packet13 -> IO (Either TLSError Packet13)
forall a b. (a -> b) -> a -> b
$ Packet13 -> Either TLSError Packet13
forall a b. b -> Either a b
Right (Packet13 -> Either TLSError Packet13)
-> Packet13 -> Either TLSError Packet13
forall a b. (a -> b) -> a -> b
$ ByteString -> Packet13
AppData13 (ByteString -> Packet13) -> ByteString -> Packet13
forall a b. (a -> b) -> a -> b
$ Fragment Plaintext -> ByteString
forall a. Fragment a -> ByteString
fragmentGetBytes Fragment Plaintext
fragment
processPacket13 Context
_ (Record ProtocolType
ProtocolType_Alert Version
_ Fragment Plaintext
fragment) = Either TLSError Packet13 -> IO (Either TLSError Packet13)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ([(AlertLevel, AlertDescription)] -> Packet13
Alert13 ([(AlertLevel, AlertDescription)] -> Packet13)
-> Either TLSError [(AlertLevel, AlertDescription)]
-> Either TLSError Packet13
forall a b l. (a -> b) -> Either l a -> Either l b
`fmapEither` ByteString -> Either TLSError [(AlertLevel, AlertDescription)]
decodeAlerts (Fragment Plaintext -> ByteString
forall a. Fragment a -> ByteString
fragmentGetBytes Fragment Plaintext
fragment))
processPacket13 Context
ctx (Record ProtocolType
ProtocolType_Handshake Version
_ Fragment Plaintext
fragment) = Context -> TLSSt Packet13 -> IO (Either TLSError Packet13)
forall a. Context -> TLSSt a -> IO (Either TLSError a)
usingState Context
ctx (TLSSt Packet13 -> IO (Either TLSError Packet13))
-> TLSSt Packet13 -> IO (Either TLSError Packet13)
forall a b. (a -> b) -> a -> b
$ do
    Maybe (GetContinuation (HandshakeType13, ByteString))
mCont <- (TLSState -> Maybe (GetContinuation (HandshakeType13, ByteString)))
-> TLSSt (Maybe (GetContinuation (HandshakeType13, ByteString)))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets TLSState -> Maybe (GetContinuation (HandshakeType13, ByteString))
stHandshakeRecordCont13
    (TLSState -> TLSState) -> TLSSt ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\TLSState
st -> TLSState
st { stHandshakeRecordCont13 = Nothing })
    [Handshake13]
hss <- Maybe (GetContinuation (HandshakeType13, ByteString))
-> ByteString -> TLSSt [Handshake13]
forall {m :: * -> *}.
(MonadError TLSError m, MonadState TLSState m) =>
Maybe (GetContinuation (HandshakeType13, ByteString))
-> ByteString -> m [Handshake13]
parseMany Maybe (GetContinuation (HandshakeType13, ByteString))
mCont (Fragment Plaintext -> ByteString
forall a. Fragment a -> ByteString
fragmentGetBytes Fragment Plaintext
fragment)
    Packet13 -> TLSSt Packet13
forall a. a -> TLSSt a
forall (m :: * -> *) a. Monad m => a -> m a
return (Packet13 -> TLSSt Packet13) -> Packet13 -> TLSSt Packet13
forall a b. (a -> b) -> a -> b
$ [Handshake13] -> Packet13
Handshake13 [Handshake13]
hss
  where parseMany :: Maybe (GetContinuation (HandshakeType13, ByteString))
-> ByteString -> m [Handshake13]
parseMany Maybe (GetContinuation (HandshakeType13, ByteString))
mCont ByteString
bs =
            case GetContinuation (HandshakeType13, ByteString)
-> Maybe (GetContinuation (HandshakeType13, ByteString))
-> GetContinuation (HandshakeType13, ByteString)
forall a. a -> Maybe a -> a
fromMaybe GetContinuation (HandshakeType13, ByteString)
decodeHandshakeRecord13 Maybe (GetContinuation (HandshakeType13, ByteString))
mCont ByteString
bs of
                GotError TLSError
err                -> TLSError -> m [Handshake13]
forall a. TLSError -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError TLSError
err
                GotPartial GetContinuation (HandshakeType13, ByteString)
cont             -> (TLSState -> TLSState) -> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\TLSState
st -> TLSState
st { stHandshakeRecordCont13 = Just cont }) m () -> m [Handshake13] -> m [Handshake13]
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> [Handshake13] -> m [Handshake13]
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return []
                GotSuccess (HandshakeType13
ty,ByteString
content)     ->
                    (TLSError -> m [Handshake13])
-> (Handshake13 -> m [Handshake13])
-> Either TLSError Handshake13
-> m [Handshake13]
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either TLSError -> m [Handshake13]
forall a. TLSError -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError ([Handshake13] -> m [Handshake13]
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ([Handshake13] -> m [Handshake13])
-> (Handshake13 -> [Handshake13]) -> Handshake13 -> m [Handshake13]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Handshake13 -> [Handshake13] -> [Handshake13]
forall a. a -> [a] -> [a]
:[])) (Either TLSError Handshake13 -> m [Handshake13])
-> Either TLSError Handshake13 -> m [Handshake13]
forall a b. (a -> b) -> a -> b
$ HandshakeType13 -> ByteString -> Either TLSError Handshake13
decodeHandshake13 HandshakeType13
ty ByteString
content
                GotSuccessRemaining (HandshakeType13
ty,ByteString
content) ByteString
left ->
                    case HandshakeType13 -> ByteString -> Either TLSError Handshake13
decodeHandshake13 HandshakeType13
ty ByteString
content of
                        Left TLSError
err -> TLSError -> m [Handshake13]
forall a. TLSError -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError TLSError
err
                        Right Handshake13
hh -> (Handshake13
hhHandshake13 -> [Handshake13] -> [Handshake13]
forall a. a -> [a] -> [a]
:) ([Handshake13] -> [Handshake13])
-> m [Handshake13] -> m [Handshake13]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe (GetContinuation (HandshakeType13, ByteString))
-> ByteString -> m [Handshake13]
parseMany Maybe (GetContinuation (HandshakeType13, ByteString))
forall a. Maybe a
Nothing ByteString
left
processPacket13 Context
_ (Record ProtocolType
ProtocolType_DeprecatedHandshake Version
_ Fragment Plaintext
_) =
    Either TLSError Packet13 -> IO (Either TLSError Packet13)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (TLSError -> Either TLSError Packet13
forall a b. a -> Either a b
Left (TLSError -> Either TLSError Packet13)
-> TLSError -> Either TLSError Packet13
forall a b. (a -> b) -> a -> b
$ String -> TLSError
Error_Packet String
"deprecated handshake packet 1.3")