{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
-- |
-- Module      : Network.TLS.Handshake.State
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : experimental
-- Portability : unknown
--
module Network.TLS.Handshake.State
    ( HandshakeState(..)
    , HandshakeDigest(..)
    , HandshakeMode13(..)
    , RTT0Status(..)
    , CertReqCBdata
    , HandshakeM
    , newEmptyHandshake
    , runHandshake
    -- * key accessors
    , setPublicKey
    , setPublicPrivateKeys
    , getLocalPublicPrivateKeys
    , getRemotePublicKey
    , setServerDHParams
    , getServerDHParams
    , setServerECDHParams
    , getServerECDHParams
    , setDHPrivate
    , getDHPrivate
    , setGroupPrivate
    , getGroupPrivate
    -- * cert accessors
    , setClientCertSent
    , getClientCertSent
    , setCertReqSent
    , getCertReqSent
    , setClientCertChain
    , getClientCertChain
    , setCertReqToken
    , getCertReqToken
    , setCertReqCBdata
    , getCertReqCBdata
    , setCertReqSigAlgsCert
    , getCertReqSigAlgsCert
    -- * digest accessors
    , addHandshakeMessage
    , updateHandshakeDigest
    , getHandshakeMessages
    , getHandshakeMessagesRev
    , getHandshakeDigest
    , foldHandshakeDigest
    -- * master secret
    , setMasterSecret
    , setMasterSecretFromPre
    -- * misc accessor
    , getPendingCipher
    , setServerHelloParameters
    , setExtendedMasterSec
    , getExtendedMasterSec
    , setNegotiatedGroup
    , getNegotiatedGroup
    , setTLS13HandshakeMode
    , getTLS13HandshakeMode
    , setTLS13RTT0Status
    , getTLS13RTT0Status
    , setTLS13EarlySecret
    , getTLS13EarlySecret
    , setTLS13ResumptionSecret
    , getTLS13ResumptionSecret
    , setCCS13Sent
    , getCCS13Sent
    ) where

import Network.TLS.Util
import Network.TLS.Struct
import Network.TLS.Record.State
import Network.TLS.Packet
import Network.TLS.Crypto
import Network.TLS.Cipher
import Network.TLS.Compression
import Network.TLS.Types
import Network.TLS.Imports
import Control.Monad.State.Strict
import Data.X509 (CertificateChain)
import Data.ByteArray (ByteArrayAccess)

data HandshakeKeyState = HandshakeKeyState
    { HandshakeKeyState -> Maybe PubKey
hksRemotePublicKey :: !(Maybe PubKey)
    , HandshakeKeyState -> Maybe (PubKey, PrivKey)
hksLocalPublicPrivateKeys :: !(Maybe (PubKey, PrivKey))
    } deriving (Int -> HandshakeKeyState -> ShowS
[HandshakeKeyState] -> ShowS
HandshakeKeyState -> String
(Int -> HandshakeKeyState -> ShowS)
-> (HandshakeKeyState -> String)
-> ([HandshakeKeyState] -> ShowS)
-> Show HandshakeKeyState
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> HandshakeKeyState -> ShowS
showsPrec :: Int -> HandshakeKeyState -> ShowS
$cshow :: HandshakeKeyState -> String
show :: HandshakeKeyState -> String
$cshowList :: [HandshakeKeyState] -> ShowS
showList :: [HandshakeKeyState] -> ShowS
Show)

data HandshakeDigest = HandshakeMessages [ByteString]
                     | HandshakeDigestContext HashCtx
                     deriving (Int -> HandshakeDigest -> ShowS
[HandshakeDigest] -> ShowS
HandshakeDigest -> String
(Int -> HandshakeDigest -> ShowS)
-> (HandshakeDigest -> String)
-> ([HandshakeDigest] -> ShowS)
-> Show HandshakeDigest
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> HandshakeDigest -> ShowS
showsPrec :: Int -> HandshakeDigest -> ShowS
$cshow :: HandshakeDigest -> String
show :: HandshakeDigest -> String
$cshowList :: [HandshakeDigest] -> ShowS
showList :: [HandshakeDigest] -> ShowS
Show)

data HandshakeState = HandshakeState
    { HandshakeState -> Version
hstClientVersion       :: !Version
    , HandshakeState -> ClientRandom
hstClientRandom        :: !ClientRandom
    , HandshakeState -> Maybe ServerRandom
hstServerRandom        :: !(Maybe ServerRandom)
    , HandshakeState -> Maybe ByteString
hstMasterSecret        :: !(Maybe ByteString)
    , HandshakeState -> HandshakeKeyState
hstKeyState            :: !HandshakeKeyState
    , HandshakeState -> Maybe ServerDHParams
hstServerDHParams      :: !(Maybe ServerDHParams)
    , HandshakeState -> Maybe DHPrivate
hstDHPrivate           :: !(Maybe DHPrivate)
    , HandshakeState -> Maybe ServerECDHParams
hstServerECDHParams    :: !(Maybe ServerECDHParams)
    , HandshakeState -> Maybe GroupPrivate
hstGroupPrivate        :: !(Maybe GroupPrivate)
    , HandshakeState -> HandshakeDigest
hstHandshakeDigest     :: !HandshakeDigest
    , HandshakeState -> [ByteString]
hstHandshakeMessages   :: [ByteString]
    , HandshakeState -> Maybe ByteString
hstCertReqToken        :: !(Maybe ByteString)
        -- ^ Set to Just-value when a TLS13 certificate request is received
    , HandshakeState -> Maybe CertReqCBdata
hstCertReqCBdata       :: !(Maybe CertReqCBdata)
        -- ^ Set to Just-value when a certificate request is received
    , HandshakeState -> Maybe [HashAndSignatureAlgorithm]
hstCertReqSigAlgsCert  :: !(Maybe [HashAndSignatureAlgorithm])
        -- ^ In TLS 1.3, these are separate from the certificate
        -- issuer signature algorithm hints in the callback data.
        -- In TLS 1.2 the same list is overloaded for both purposes.
        -- Not present in TLS 1.1 and earlier
    , HandshakeState -> Bool
hstClientCertSent      :: !Bool
        -- ^ Set to true when a client certificate chain was sent
    , HandshakeState -> Bool
hstCertReqSent         :: !Bool
        -- ^ Set to true when a certificate request was sent.  This applies
        -- only to requests sent during handshake (not post-handshake).
    , HandshakeState -> Maybe CertificateChain
hstClientCertChain     :: !(Maybe CertificateChain)
    , HandshakeState -> Maybe RecordState
hstPendingTxState      :: Maybe RecordState
    , HandshakeState -> Maybe RecordState
hstPendingRxState      :: Maybe RecordState
    , HandshakeState -> Maybe Cipher
hstPendingCipher       :: Maybe Cipher
    , HandshakeState -> Compression
hstPendingCompression  :: Compression
    , HandshakeState -> Bool
hstExtendedMasterSec   :: Bool
    , HandshakeState -> Maybe Group
hstNegotiatedGroup     :: Maybe Group
    , HandshakeState -> HandshakeMode13
hstTLS13HandshakeMode  :: HandshakeMode13
    , HandshakeState -> RTT0Status
hstTLS13RTT0Status     :: !RTT0Status
    , HandshakeState -> Maybe (BaseSecret EarlySecret)
hstTLS13EarlySecret    :: Maybe (BaseSecret EarlySecret)
    , HandshakeState -> Maybe (BaseSecret ResumptionSecret)
hstTLS13ResumptionSecret :: Maybe (BaseSecret ResumptionSecret)
    , HandshakeState -> Bool
hstCCS13Sent           :: !Bool
    } deriving (Int -> HandshakeState -> ShowS
[HandshakeState] -> ShowS
HandshakeState -> String
(Int -> HandshakeState -> ShowS)
-> (HandshakeState -> String)
-> ([HandshakeState] -> ShowS)
-> Show HandshakeState
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> HandshakeState -> ShowS
showsPrec :: Int -> HandshakeState -> ShowS
$cshow :: HandshakeState -> String
show :: HandshakeState -> String
$cshowList :: [HandshakeState] -> ShowS
showList :: [HandshakeState] -> ShowS
Show)

{- | When we receive a CertificateRequest from a server, a just-in-time
   callback is issued to the application to obtain a suitable certificate.
   Somewhat unfortunately, the callback parameters don't abstract away the
   details of the TLS 1.2 Certificate Request message, which combines the
   legacy @certificate_types@ and new @supported_signature_algorithms@
   parameters is a rather subtle way.

   TLS 1.2 also (again unfortunately, in the opinion of the author of this
   comment) overloads the signature algorithms parameter to constrain not only
   the algorithms used in TLS, but also the algorithms used by issuing CAs in
   the X.509 chain.  Best practice is to NOT treat such that restriction as a
   MUST, but rather take it as merely a preference, when a choice exists.  If
   the best chain available does not match the provided signature algorithm
   list, go ahead and use it anyway, it will probably work, and the server may
   not even care about the issuer CAs at all, it may be doing DANE or have
   explicit mappings for the client's public key, ...

   The TLS 1.3 @CertificateRequest@ message, drops @certificate_types@ and no
   longer overloads @supported_signature_algorithms@ to cover X.509.  It also
   includes a new opaque context token that the client must echo back, which
   makes certain client authentication replay attacks more difficult.  We will
   store that context separately, it does not need to be presented in the user
   callback.  The certificate signature algorithms preferred by the peer are
   now in the separate @signature_algorithms_cert@ extension, but we cannot
   report these to the application callback without an API change.  The good
   news is that filtering the X.509 signature types is generally unnecessary,
   unwise and difficult.  So we just ignore this extension.

   As a result, the information we provide to the callback is no longer a
   verbatim copy of the certificate request payload.  In the case of TLS 1.3
   The 'CertificateType' list is synthetically generated from the server's
   @signature_algorithms@ extension, and the @signature_algorithms_certs@
   extension is ignored.

   Since the original TLS 1.2 'CertificateType' has no provision for the newer
   certificate types that have appeared in TLS 1.3 we're adding some synthetic
   values that have no equivalent values in the TLS 1.2 'CertificateType' as
   defined in the IANA
   <https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml#tls-parameters-2
   TLS ClientCertificateType Identifiers> registry.  These values are inferred
   from the TLS 1.3 @signature_algorithms@ extension, and will allow clients to
   present Ed25519 and Ed448 certificates when these become supported.
-}
type CertReqCBdata =
     ( [CertificateType]
     , Maybe [HashAndSignatureAlgorithm]
     , [DistinguishedName] )

newtype HandshakeM a = HandshakeM { forall a. HandshakeM a -> State HandshakeState a
runHandshakeM :: State HandshakeState a }
    deriving ((forall a b. (a -> b) -> HandshakeM a -> HandshakeM b)
-> (forall a b. a -> HandshakeM b -> HandshakeM a)
-> Functor HandshakeM
forall a b. a -> HandshakeM b -> HandshakeM a
forall a b. (a -> b) -> HandshakeM a -> HandshakeM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall a b. (a -> b) -> HandshakeM a -> HandshakeM b
fmap :: forall a b. (a -> b) -> HandshakeM a -> HandshakeM b
$c<$ :: forall a b. a -> HandshakeM b -> HandshakeM a
<$ :: forall a b. a -> HandshakeM b -> HandshakeM a
Functor, Functor HandshakeM
Functor HandshakeM =>
(forall a. a -> HandshakeM a)
-> (forall a b.
    HandshakeM (a -> b) -> HandshakeM a -> HandshakeM b)
-> (forall a b c.
    (a -> b -> c) -> HandshakeM a -> HandshakeM b -> HandshakeM c)
-> (forall a b. HandshakeM a -> HandshakeM b -> HandshakeM b)
-> (forall a b. HandshakeM a -> HandshakeM b -> HandshakeM a)
-> Applicative HandshakeM
forall a. a -> HandshakeM a
forall a b. HandshakeM a -> HandshakeM b -> HandshakeM a
forall a b. HandshakeM a -> HandshakeM b -> HandshakeM b
forall a b. HandshakeM (a -> b) -> HandshakeM a -> HandshakeM b
forall a b c.
(a -> b -> c) -> HandshakeM a -> HandshakeM b -> HandshakeM c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall a. a -> HandshakeM a
pure :: forall a. a -> HandshakeM a
$c<*> :: forall a b. HandshakeM (a -> b) -> HandshakeM a -> HandshakeM b
<*> :: forall a b. HandshakeM (a -> b) -> HandshakeM a -> HandshakeM b
$cliftA2 :: forall a b c.
(a -> b -> c) -> HandshakeM a -> HandshakeM b -> HandshakeM c
liftA2 :: forall a b c.
(a -> b -> c) -> HandshakeM a -> HandshakeM b -> HandshakeM c
$c*> :: forall a b. HandshakeM a -> HandshakeM b -> HandshakeM b
*> :: forall a b. HandshakeM a -> HandshakeM b -> HandshakeM b
$c<* :: forall a b. HandshakeM a -> HandshakeM b -> HandshakeM a
<* :: forall a b. HandshakeM a -> HandshakeM b -> HandshakeM a
Applicative, Applicative HandshakeM
Applicative HandshakeM =>
(forall a b. HandshakeM a -> (a -> HandshakeM b) -> HandshakeM b)
-> (forall a b. HandshakeM a -> HandshakeM b -> HandshakeM b)
-> (forall a. a -> HandshakeM a)
-> Monad HandshakeM
forall a. a -> HandshakeM a
forall a b. HandshakeM a -> HandshakeM b -> HandshakeM b
forall a b. HandshakeM a -> (a -> HandshakeM b) -> HandshakeM b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall a b. HandshakeM a -> (a -> HandshakeM b) -> HandshakeM b
>>= :: forall a b. HandshakeM a -> (a -> HandshakeM b) -> HandshakeM b
$c>> :: forall a b. HandshakeM a -> HandshakeM b -> HandshakeM b
>> :: forall a b. HandshakeM a -> HandshakeM b -> HandshakeM b
$creturn :: forall a. a -> HandshakeM a
return :: forall a. a -> HandshakeM a
Monad)

instance MonadState HandshakeState HandshakeM where
    put :: HandshakeState -> HandshakeM ()
put HandshakeState
x = State HandshakeState () -> HandshakeM ()
forall a. State HandshakeState a -> HandshakeM a
HandshakeM (HandshakeState -> State HandshakeState ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put HandshakeState
x)
    get :: HandshakeM HandshakeState
get   = State HandshakeState HandshakeState -> HandshakeM HandshakeState
forall a. State HandshakeState a -> HandshakeM a
HandshakeM State HandshakeState HandshakeState
forall s (m :: * -> *). MonadState s m => m s
get
    state :: forall a. (HandshakeState -> (a, HandshakeState)) -> HandshakeM a
state HandshakeState -> (a, HandshakeState)
f = State HandshakeState a -> HandshakeM a
forall a. State HandshakeState a -> HandshakeM a
HandshakeM ((HandshakeState -> (a, HandshakeState)) -> State HandshakeState a
forall a.
(HandshakeState -> (a, HandshakeState))
-> StateT HandshakeState Identity a
forall s (m :: * -> *) a. MonadState s m => (s -> (a, s)) -> m a
state HandshakeState -> (a, HandshakeState)
f)

-- create a new empty handshake state
newEmptyHandshake :: Version -> ClientRandom -> HandshakeState
newEmptyHandshake :: Version -> ClientRandom -> HandshakeState
newEmptyHandshake Version
ver ClientRandom
crand = HandshakeState
    { hstClientVersion :: Version
hstClientVersion       = Version
ver
    , hstClientRandom :: ClientRandom
hstClientRandom        = ClientRandom
crand
    , hstServerRandom :: Maybe ServerRandom
hstServerRandom        = Maybe ServerRandom
forall a. Maybe a
Nothing
    , hstMasterSecret :: Maybe ByteString
hstMasterSecret        = Maybe ByteString
forall a. Maybe a
Nothing
    , hstKeyState :: HandshakeKeyState
hstKeyState            = Maybe PubKey -> Maybe (PubKey, PrivKey) -> HandshakeKeyState
HandshakeKeyState Maybe PubKey
forall a. Maybe a
Nothing Maybe (PubKey, PrivKey)
forall a. Maybe a
Nothing
    , hstServerDHParams :: Maybe ServerDHParams
hstServerDHParams      = Maybe ServerDHParams
forall a. Maybe a
Nothing
    , hstDHPrivate :: Maybe DHPrivate
hstDHPrivate           = Maybe DHPrivate
forall a. Maybe a
Nothing
    , hstServerECDHParams :: Maybe ServerECDHParams
hstServerECDHParams    = Maybe ServerECDHParams
forall a. Maybe a
Nothing
    , hstGroupPrivate :: Maybe GroupPrivate
hstGroupPrivate        = Maybe GroupPrivate
forall a. Maybe a
Nothing
    , hstHandshakeDigest :: HandshakeDigest
hstHandshakeDigest     = [ByteString] -> HandshakeDigest
HandshakeMessages []
    , hstHandshakeMessages :: [ByteString]
hstHandshakeMessages   = []
    , hstCertReqToken :: Maybe ByteString
hstCertReqToken        = Maybe ByteString
forall a. Maybe a
Nothing
    , hstCertReqCBdata :: Maybe CertReqCBdata
hstCertReqCBdata       = Maybe CertReqCBdata
forall a. Maybe a
Nothing
    , hstCertReqSigAlgsCert :: Maybe [HashAndSignatureAlgorithm]
hstCertReqSigAlgsCert  = Maybe [HashAndSignatureAlgorithm]
forall a. Maybe a
Nothing
    , hstClientCertSent :: Bool
hstClientCertSent      = Bool
False
    , hstCertReqSent :: Bool
hstCertReqSent         = Bool
False
    , hstClientCertChain :: Maybe CertificateChain
hstClientCertChain     = Maybe CertificateChain
forall a. Maybe a
Nothing
    , hstPendingTxState :: Maybe RecordState
hstPendingTxState      = Maybe RecordState
forall a. Maybe a
Nothing
    , hstPendingRxState :: Maybe RecordState
hstPendingRxState      = Maybe RecordState
forall a. Maybe a
Nothing
    , hstPendingCipher :: Maybe Cipher
hstPendingCipher       = Maybe Cipher
forall a. Maybe a
Nothing
    , hstPendingCompression :: Compression
hstPendingCompression  = Compression
nullCompression
    , hstExtendedMasterSec :: Bool
hstExtendedMasterSec   = Bool
False
    , hstNegotiatedGroup :: Maybe Group
hstNegotiatedGroup     = Maybe Group
forall a. Maybe a
Nothing
    , hstTLS13HandshakeMode :: HandshakeMode13
hstTLS13HandshakeMode  = HandshakeMode13
FullHandshake
    , hstTLS13RTT0Status :: RTT0Status
hstTLS13RTT0Status     = RTT0Status
RTT0None
    , hstTLS13EarlySecret :: Maybe (BaseSecret EarlySecret)
hstTLS13EarlySecret    = Maybe (BaseSecret EarlySecret)
forall a. Maybe a
Nothing
    , hstTLS13ResumptionSecret :: Maybe (BaseSecret ResumptionSecret)
hstTLS13ResumptionSecret = Maybe (BaseSecret ResumptionSecret)
forall a. Maybe a
Nothing
    , hstCCS13Sent :: Bool
hstCCS13Sent           = Bool
False
    }

runHandshake :: HandshakeState -> HandshakeM a -> (a, HandshakeState)
runHandshake :: forall a. HandshakeState -> HandshakeM a -> (a, HandshakeState)
runHandshake HandshakeState
hst HandshakeM a
f = State HandshakeState a -> HandshakeState -> (a, HandshakeState)
forall s a. State s a -> s -> (a, s)
runState (HandshakeM a -> State HandshakeState a
forall a. HandshakeM a -> State HandshakeState a
runHandshakeM HandshakeM a
f) HandshakeState
hst

setPublicKey :: PubKey -> HandshakeM ()
setPublicKey :: PubKey -> HandshakeM ()
setPublicKey PubKey
pk = (HandshakeState -> HandshakeState) -> HandshakeM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\HandshakeState
hst -> HandshakeState
hst { hstKeyState = setPK (hstKeyState hst) })
  where setPK :: HandshakeKeyState -> HandshakeKeyState
setPK HandshakeKeyState
hks = HandshakeKeyState
hks { hksRemotePublicKey = Just pk }

setPublicPrivateKeys :: (PubKey, PrivKey) -> HandshakeM ()
setPublicPrivateKeys :: (PubKey, PrivKey) -> HandshakeM ()
setPublicPrivateKeys (PubKey, PrivKey)
keys = (HandshakeState -> HandshakeState) -> HandshakeM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\HandshakeState
hst -> HandshakeState
hst { hstKeyState = setKeys (hstKeyState hst) })
  where setKeys :: HandshakeKeyState -> HandshakeKeyState
setKeys HandshakeKeyState
hks = HandshakeKeyState
hks { hksLocalPublicPrivateKeys = Just keys }

getRemotePublicKey :: HandshakeM PubKey
getRemotePublicKey :: HandshakeM PubKey
getRemotePublicKey = String -> Maybe PubKey -> PubKey
forall a. String -> Maybe a -> a
fromJust String
"remote public key" (Maybe PubKey -> PubKey)
-> HandshakeM (Maybe PubKey) -> HandshakeM PubKey
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (HandshakeState -> Maybe PubKey) -> HandshakeM (Maybe PubKey)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (HandshakeKeyState -> Maybe PubKey
hksRemotePublicKey (HandshakeKeyState -> Maybe PubKey)
-> (HandshakeState -> HandshakeKeyState)
-> HandshakeState
-> Maybe PubKey
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HandshakeState -> HandshakeKeyState
hstKeyState)

getLocalPublicPrivateKeys :: HandshakeM (PubKey, PrivKey)
getLocalPublicPrivateKeys :: HandshakeM (PubKey, PrivKey)
getLocalPublicPrivateKeys = String -> Maybe (PubKey, PrivKey) -> (PubKey, PrivKey)
forall a. String -> Maybe a -> a
fromJust String
"local public/private key" (Maybe (PubKey, PrivKey) -> (PubKey, PrivKey))
-> HandshakeM (Maybe (PubKey, PrivKey))
-> HandshakeM (PubKey, PrivKey)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (HandshakeState -> Maybe (PubKey, PrivKey))
-> HandshakeM (Maybe (PubKey, PrivKey))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (HandshakeKeyState -> Maybe (PubKey, PrivKey)
hksLocalPublicPrivateKeys (HandshakeKeyState -> Maybe (PubKey, PrivKey))
-> (HandshakeState -> HandshakeKeyState)
-> HandshakeState
-> Maybe (PubKey, PrivKey)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HandshakeState -> HandshakeKeyState
hstKeyState)

setServerDHParams :: ServerDHParams -> HandshakeM ()
setServerDHParams :: ServerDHParams -> HandshakeM ()
setServerDHParams ServerDHParams
shp = (HandshakeState -> HandshakeState) -> HandshakeM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\HandshakeState
hst -> HandshakeState
hst { hstServerDHParams = Just shp })

getServerDHParams :: HandshakeM ServerDHParams
getServerDHParams :: HandshakeM ServerDHParams
getServerDHParams = String -> Maybe ServerDHParams -> ServerDHParams
forall a. String -> Maybe a -> a
fromJust String
"server DH params" (Maybe ServerDHParams -> ServerDHParams)
-> HandshakeM (Maybe ServerDHParams) -> HandshakeM ServerDHParams
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (HandshakeState -> Maybe ServerDHParams)
-> HandshakeM (Maybe ServerDHParams)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets HandshakeState -> Maybe ServerDHParams
hstServerDHParams

setServerECDHParams :: ServerECDHParams -> HandshakeM ()
setServerECDHParams :: ServerECDHParams -> HandshakeM ()
setServerECDHParams ServerECDHParams
shp = (HandshakeState -> HandshakeState) -> HandshakeM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\HandshakeState
hst -> HandshakeState
hst { hstServerECDHParams = Just shp })

getServerECDHParams :: HandshakeM ServerECDHParams
getServerECDHParams :: HandshakeM ServerECDHParams
getServerECDHParams = String -> Maybe ServerECDHParams -> ServerECDHParams
forall a. String -> Maybe a -> a
fromJust String
"server ECDH params" (Maybe ServerECDHParams -> ServerECDHParams)
-> HandshakeM (Maybe ServerECDHParams)
-> HandshakeM ServerECDHParams
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (HandshakeState -> Maybe ServerECDHParams)
-> HandshakeM (Maybe ServerECDHParams)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets HandshakeState -> Maybe ServerECDHParams
hstServerECDHParams

setDHPrivate :: DHPrivate -> HandshakeM ()
setDHPrivate :: DHPrivate -> HandshakeM ()
setDHPrivate DHPrivate
shp = (HandshakeState -> HandshakeState) -> HandshakeM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\HandshakeState
hst -> HandshakeState
hst { hstDHPrivate = Just shp })

getDHPrivate :: HandshakeM DHPrivate
getDHPrivate :: HandshakeM DHPrivate
getDHPrivate = String -> Maybe DHPrivate -> DHPrivate
forall a. String -> Maybe a -> a
fromJust String
"server DH private" (Maybe DHPrivate -> DHPrivate)
-> HandshakeM (Maybe DHPrivate) -> HandshakeM DHPrivate
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (HandshakeState -> Maybe DHPrivate) -> HandshakeM (Maybe DHPrivate)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets HandshakeState -> Maybe DHPrivate
hstDHPrivate

getGroupPrivate :: HandshakeM GroupPrivate
getGroupPrivate :: HandshakeM GroupPrivate
getGroupPrivate = String -> Maybe GroupPrivate -> GroupPrivate
forall a. String -> Maybe a -> a
fromJust String
"server ECDH private" (Maybe GroupPrivate -> GroupPrivate)
-> HandshakeM (Maybe GroupPrivate) -> HandshakeM GroupPrivate
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (HandshakeState -> Maybe GroupPrivate)
-> HandshakeM (Maybe GroupPrivate)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets HandshakeState -> Maybe GroupPrivate
hstGroupPrivate

setGroupPrivate :: GroupPrivate -> HandshakeM ()
setGroupPrivate :: GroupPrivate -> HandshakeM ()
setGroupPrivate GroupPrivate
shp = (HandshakeState -> HandshakeState) -> HandshakeM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\HandshakeState
hst -> HandshakeState
hst { hstGroupPrivate = Just shp })

setExtendedMasterSec :: Bool -> HandshakeM ()
setExtendedMasterSec :: Bool -> HandshakeM ()
setExtendedMasterSec Bool
b = (HandshakeState -> HandshakeState) -> HandshakeM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\HandshakeState
hst -> HandshakeState
hst { hstExtendedMasterSec = b })

getExtendedMasterSec :: HandshakeM Bool
getExtendedMasterSec :: HandshakeM Bool
getExtendedMasterSec = (HandshakeState -> Bool) -> HandshakeM Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets HandshakeState -> Bool
hstExtendedMasterSec

setNegotiatedGroup :: Group -> HandshakeM ()
setNegotiatedGroup :: Group -> HandshakeM ()
setNegotiatedGroup Group
g = (HandshakeState -> HandshakeState) -> HandshakeM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\HandshakeState
hst -> HandshakeState
hst { hstNegotiatedGroup = Just g })

getNegotiatedGroup :: HandshakeM (Maybe Group)
getNegotiatedGroup :: HandshakeM (Maybe Group)
getNegotiatedGroup = (HandshakeState -> Maybe Group) -> HandshakeM (Maybe Group)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets HandshakeState -> Maybe Group
hstNegotiatedGroup

-- | Type to show which handshake mode is used in TLS 1.3.
data HandshakeMode13 =
      -- | Full handshake is used.
      FullHandshake
      -- | Full handshake is used with hello retry request.
    | HelloRetryRequest
      -- | Server authentication is skipped.
    | PreSharedKey
      -- | Server authentication is skipped and early data is sent.
    | RTT0
    deriving (Int -> HandshakeMode13 -> ShowS
[HandshakeMode13] -> ShowS
HandshakeMode13 -> String
(Int -> HandshakeMode13 -> ShowS)
-> (HandshakeMode13 -> String)
-> ([HandshakeMode13] -> ShowS)
-> Show HandshakeMode13
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> HandshakeMode13 -> ShowS
showsPrec :: Int -> HandshakeMode13 -> ShowS
$cshow :: HandshakeMode13 -> String
show :: HandshakeMode13 -> String
$cshowList :: [HandshakeMode13] -> ShowS
showList :: [HandshakeMode13] -> ShowS
Show,HandshakeMode13 -> HandshakeMode13 -> Bool
(HandshakeMode13 -> HandshakeMode13 -> Bool)
-> (HandshakeMode13 -> HandshakeMode13 -> Bool)
-> Eq HandshakeMode13
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: HandshakeMode13 -> HandshakeMode13 -> Bool
== :: HandshakeMode13 -> HandshakeMode13 -> Bool
$c/= :: HandshakeMode13 -> HandshakeMode13 -> Bool
/= :: HandshakeMode13 -> HandshakeMode13 -> Bool
Eq)

setTLS13HandshakeMode :: HandshakeMode13 -> HandshakeM ()
setTLS13HandshakeMode :: HandshakeMode13 -> HandshakeM ()
setTLS13HandshakeMode HandshakeMode13
s = (HandshakeState -> HandshakeState) -> HandshakeM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\HandshakeState
hst -> HandshakeState
hst { hstTLS13HandshakeMode = s })

getTLS13HandshakeMode :: HandshakeM HandshakeMode13
getTLS13HandshakeMode :: HandshakeM HandshakeMode13
getTLS13HandshakeMode = (HandshakeState -> HandshakeMode13) -> HandshakeM HandshakeMode13
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets HandshakeState -> HandshakeMode13
hstTLS13HandshakeMode

data RTT0Status = RTT0None
                | RTT0Sent
                | RTT0Accepted
                | RTT0Rejected
                deriving (Int -> RTT0Status -> ShowS
[RTT0Status] -> ShowS
RTT0Status -> String
(Int -> RTT0Status -> ShowS)
-> (RTT0Status -> String)
-> ([RTT0Status] -> ShowS)
-> Show RTT0Status
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> RTT0Status -> ShowS
showsPrec :: Int -> RTT0Status -> ShowS
$cshow :: RTT0Status -> String
show :: RTT0Status -> String
$cshowList :: [RTT0Status] -> ShowS
showList :: [RTT0Status] -> ShowS
Show,RTT0Status -> RTT0Status -> Bool
(RTT0Status -> RTT0Status -> Bool)
-> (RTT0Status -> RTT0Status -> Bool) -> Eq RTT0Status
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: RTT0Status -> RTT0Status -> Bool
== :: RTT0Status -> RTT0Status -> Bool
$c/= :: RTT0Status -> RTT0Status -> Bool
/= :: RTT0Status -> RTT0Status -> Bool
Eq)

setTLS13RTT0Status :: RTT0Status -> HandshakeM ()
setTLS13RTT0Status :: RTT0Status -> HandshakeM ()
setTLS13RTT0Status RTT0Status
s = (HandshakeState -> HandshakeState) -> HandshakeM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\HandshakeState
hst -> HandshakeState
hst { hstTLS13RTT0Status = s })

getTLS13RTT0Status :: HandshakeM RTT0Status
getTLS13RTT0Status :: HandshakeM RTT0Status
getTLS13RTT0Status = (HandshakeState -> RTT0Status) -> HandshakeM RTT0Status
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets HandshakeState -> RTT0Status
hstTLS13RTT0Status

setTLS13EarlySecret :: BaseSecret EarlySecret -> HandshakeM ()
setTLS13EarlySecret :: BaseSecret EarlySecret -> HandshakeM ()
setTLS13EarlySecret BaseSecret EarlySecret
secret = (HandshakeState -> HandshakeState) -> HandshakeM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\HandshakeState
hst -> HandshakeState
hst { hstTLS13EarlySecret = Just secret })

getTLS13EarlySecret :: HandshakeM (Maybe (BaseSecret EarlySecret))
getTLS13EarlySecret :: HandshakeM (Maybe (BaseSecret EarlySecret))
getTLS13EarlySecret = (HandshakeState -> Maybe (BaseSecret EarlySecret))
-> HandshakeM (Maybe (BaseSecret EarlySecret))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets HandshakeState -> Maybe (BaseSecret EarlySecret)
hstTLS13EarlySecret

setTLS13ResumptionSecret :: BaseSecret ResumptionSecret -> HandshakeM ()
setTLS13ResumptionSecret :: BaseSecret ResumptionSecret -> HandshakeM ()
setTLS13ResumptionSecret BaseSecret ResumptionSecret
secret = (HandshakeState -> HandshakeState) -> HandshakeM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\HandshakeState
hst -> HandshakeState
hst { hstTLS13ResumptionSecret = Just secret })

getTLS13ResumptionSecret :: HandshakeM (Maybe (BaseSecret ResumptionSecret))
getTLS13ResumptionSecret :: HandshakeM (Maybe (BaseSecret ResumptionSecret))
getTLS13ResumptionSecret = (HandshakeState -> Maybe (BaseSecret ResumptionSecret))
-> HandshakeM (Maybe (BaseSecret ResumptionSecret))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets HandshakeState -> Maybe (BaseSecret ResumptionSecret)
hstTLS13ResumptionSecret

setCCS13Sent :: Bool -> HandshakeM ()
setCCS13Sent :: Bool -> HandshakeM ()
setCCS13Sent Bool
sent = (HandshakeState -> HandshakeState) -> HandshakeM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\HandshakeState
hst -> HandshakeState
hst { hstCCS13Sent = sent })

getCCS13Sent :: HandshakeM Bool
getCCS13Sent :: HandshakeM Bool
getCCS13Sent = (HandshakeState -> Bool) -> HandshakeM Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets HandshakeState -> Bool
hstCCS13Sent

setCertReqSent :: Bool -> HandshakeM ()
setCertReqSent :: Bool -> HandshakeM ()
setCertReqSent Bool
b = (HandshakeState -> HandshakeState) -> HandshakeM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\HandshakeState
hst -> HandshakeState
hst { hstCertReqSent = b })

getCertReqSent :: HandshakeM Bool
getCertReqSent :: HandshakeM Bool
getCertReqSent = (HandshakeState -> Bool) -> HandshakeM Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets HandshakeState -> Bool
hstCertReqSent

setClientCertSent :: Bool -> HandshakeM ()
setClientCertSent :: Bool -> HandshakeM ()
setClientCertSent Bool
b = (HandshakeState -> HandshakeState) -> HandshakeM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\HandshakeState
hst -> HandshakeState
hst { hstClientCertSent = b })

getClientCertSent :: HandshakeM Bool
getClientCertSent :: HandshakeM Bool
getClientCertSent = (HandshakeState -> Bool) -> HandshakeM Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets HandshakeState -> Bool
hstClientCertSent

setClientCertChain :: CertificateChain -> HandshakeM ()
setClientCertChain :: CertificateChain -> HandshakeM ()
setClientCertChain CertificateChain
b = (HandshakeState -> HandshakeState) -> HandshakeM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\HandshakeState
hst -> HandshakeState
hst { hstClientCertChain = Just b })

getClientCertChain :: HandshakeM (Maybe CertificateChain)
getClientCertChain :: HandshakeM (Maybe CertificateChain)
getClientCertChain = (HandshakeState -> Maybe CertificateChain)
-> HandshakeM (Maybe CertificateChain)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets HandshakeState -> Maybe CertificateChain
hstClientCertChain

--
setCertReqToken :: Maybe ByteString -> HandshakeM ()
setCertReqToken :: Maybe ByteString -> HandshakeM ()
setCertReqToken Maybe ByteString
token = (HandshakeState -> HandshakeState) -> HandshakeM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((HandshakeState -> HandshakeState) -> HandshakeM ())
-> (HandshakeState -> HandshakeState) -> HandshakeM ()
forall a b. (a -> b) -> a -> b
$ \HandshakeState
hst -> HandshakeState
hst { hstCertReqToken = token }

getCertReqToken :: HandshakeM (Maybe ByteString)
getCertReqToken :: HandshakeM (Maybe ByteString)
getCertReqToken = (HandshakeState -> Maybe ByteString)
-> HandshakeM (Maybe ByteString)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets HandshakeState -> Maybe ByteString
hstCertReqToken

--
setCertReqCBdata :: Maybe CertReqCBdata -> HandshakeM ()
setCertReqCBdata :: Maybe CertReqCBdata -> HandshakeM ()
setCertReqCBdata Maybe CertReqCBdata
d = (HandshakeState -> HandshakeState) -> HandshakeM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\HandshakeState
hst -> HandshakeState
hst { hstCertReqCBdata = d })

getCertReqCBdata :: HandshakeM (Maybe CertReqCBdata)
getCertReqCBdata :: HandshakeM (Maybe CertReqCBdata)
getCertReqCBdata = (HandshakeState -> Maybe CertReqCBdata)
-> HandshakeM (Maybe CertReqCBdata)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets HandshakeState -> Maybe CertReqCBdata
hstCertReqCBdata

-- Dead code, until we find some use for the extension
setCertReqSigAlgsCert :: Maybe [HashAndSignatureAlgorithm] -> HandshakeM ()
setCertReqSigAlgsCert :: Maybe [HashAndSignatureAlgorithm] -> HandshakeM ()
setCertReqSigAlgsCert Maybe [HashAndSignatureAlgorithm]
as = (HandshakeState -> HandshakeState) -> HandshakeM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((HandshakeState -> HandshakeState) -> HandshakeM ())
-> (HandshakeState -> HandshakeState) -> HandshakeM ()
forall a b. (a -> b) -> a -> b
$ \HandshakeState
hst -> HandshakeState
hst { hstCertReqSigAlgsCert = as }

getCertReqSigAlgsCert :: HandshakeM (Maybe [HashAndSignatureAlgorithm])
getCertReqSigAlgsCert :: HandshakeM (Maybe [HashAndSignatureAlgorithm])
getCertReqSigAlgsCert = (HandshakeState -> Maybe [HashAndSignatureAlgorithm])
-> HandshakeM (Maybe [HashAndSignatureAlgorithm])
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets HandshakeState -> Maybe [HashAndSignatureAlgorithm]
hstCertReqSigAlgsCert

--
getPendingCipher :: HandshakeM Cipher
getPendingCipher :: HandshakeM Cipher
getPendingCipher = String -> Maybe Cipher -> Cipher
forall a. String -> Maybe a -> a
fromJust String
"pending cipher" (Maybe Cipher -> Cipher)
-> HandshakeM (Maybe Cipher) -> HandshakeM Cipher
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (HandshakeState -> Maybe Cipher) -> HandshakeM (Maybe Cipher)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets HandshakeState -> Maybe Cipher
hstPendingCipher

addHandshakeMessage :: ByteString -> HandshakeM ()
addHandshakeMessage :: ByteString -> HandshakeM ()
addHandshakeMessage ByteString
content = (HandshakeState -> HandshakeState) -> HandshakeM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((HandshakeState -> HandshakeState) -> HandshakeM ())
-> (HandshakeState -> HandshakeState) -> HandshakeM ()
forall a b. (a -> b) -> a -> b
$ \HandshakeState
hs -> HandshakeState
hs { hstHandshakeMessages = content : hstHandshakeMessages hs}

getHandshakeMessages :: HandshakeM [ByteString]
getHandshakeMessages :: HandshakeM [ByteString]
getHandshakeMessages = (HandshakeState -> [ByteString]) -> HandshakeM [ByteString]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ([ByteString] -> [ByteString]
forall a. [a] -> [a]
reverse ([ByteString] -> [ByteString])
-> (HandshakeState -> [ByteString])
-> HandshakeState
-> [ByteString]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HandshakeState -> [ByteString]
hstHandshakeMessages)

getHandshakeMessagesRev :: HandshakeM [ByteString]
getHandshakeMessagesRev :: HandshakeM [ByteString]
getHandshakeMessagesRev = (HandshakeState -> [ByteString]) -> HandshakeM [ByteString]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets HandshakeState -> [ByteString]
hstHandshakeMessages

updateHandshakeDigest :: ByteString -> HandshakeM ()
updateHandshakeDigest :: ByteString -> HandshakeM ()
updateHandshakeDigest ByteString
content = (HandshakeState -> HandshakeState) -> HandshakeM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((HandshakeState -> HandshakeState) -> HandshakeM ())
-> (HandshakeState -> HandshakeState) -> HandshakeM ()
forall a b. (a -> b) -> a -> b
$ \HandshakeState
hs -> HandshakeState
hs
    { hstHandshakeDigest = case hstHandshakeDigest hs of
        HandshakeMessages [ByteString]
bytes        -> [ByteString] -> HandshakeDigest
HandshakeMessages (ByteString
contentByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
:[ByteString]
bytes)
        HandshakeDigestContext HashCtx
hashCtx -> HashCtx -> HandshakeDigest
HandshakeDigestContext (HashCtx -> HandshakeDigest) -> HashCtx -> HandshakeDigest
forall a b. (a -> b) -> a -> b
$ HashCtx -> ByteString -> HashCtx
hashUpdate HashCtx
hashCtx ByteString
content }

-- | Compress the whole transcript with the specified function.  Function @f@
-- takes the handshake digest as input and returns an encoded handshake message
-- to replace the transcript with.
foldHandshakeDigest :: Hash -> (ByteString -> ByteString) -> HandshakeM ()
foldHandshakeDigest :: Hash -> (ByteString -> ByteString) -> HandshakeM ()
foldHandshakeDigest Hash
hashAlg ByteString -> ByteString
f = (HandshakeState -> HandshakeState) -> HandshakeM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((HandshakeState -> HandshakeState) -> HandshakeM ())
-> (HandshakeState -> HandshakeState) -> HandshakeM ()
forall a b. (a -> b) -> a -> b
$ \HandshakeState
hs ->
    case HandshakeState -> HandshakeDigest
hstHandshakeDigest HandshakeState
hs of
        HandshakeMessages [ByteString]
bytes ->
            let hashCtx :: HashCtx
hashCtx  = (HashCtx -> ByteString -> HashCtx)
-> HashCtx -> [ByteString] -> HashCtx
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl HashCtx -> ByteString -> HashCtx
hashUpdate (Hash -> HashCtx
hashInit Hash
hashAlg) ([ByteString] -> HashCtx) -> [ByteString] -> HashCtx
forall a b. (a -> b) -> a -> b
$ [ByteString] -> [ByteString]
forall a. [a] -> [a]
reverse [ByteString]
bytes
                !folded :: ByteString
folded  = ByteString -> ByteString
f (HashCtx -> ByteString
hashFinal HashCtx
hashCtx)
             in HandshakeState
hs { hstHandshakeDigest   = HandshakeMessages [folded]
                   , hstHandshakeMessages = [folded]
                   }
        HandshakeDigestContext HashCtx
hashCtx ->
            let !folded :: ByteString
folded  = ByteString -> ByteString
f (HashCtx -> ByteString
hashFinal HashCtx
hashCtx)
                hashCtx' :: HashCtx
hashCtx' = HashCtx -> ByteString -> HashCtx
hashUpdate (Hash -> HashCtx
hashInit Hash
hashAlg) ByteString
folded
             in HandshakeState
hs { hstHandshakeDigest   = HandshakeDigestContext hashCtx'
                   , hstHandshakeMessages = [folded]
                   }

getSessionHash :: HandshakeM ByteString
getSessionHash :: HandshakeM ByteString
getSessionHash = (HandshakeState -> ByteString) -> HandshakeM ByteString
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((HandshakeState -> ByteString) -> HandshakeM ByteString)
-> (HandshakeState -> ByteString) -> HandshakeM ByteString
forall a b. (a -> b) -> a -> b
$ \HandshakeState
hst ->
    case HandshakeState -> HandshakeDigest
hstHandshakeDigest HandshakeState
hst of
        HandshakeDigestContext HashCtx
hashCtx -> HashCtx -> ByteString
hashFinal HashCtx
hashCtx
        HandshakeMessages [ByteString]
_ -> String -> ByteString
forall a. HasCallStack => String -> a
error String
"un-initialized session hash"

getHandshakeDigest :: Version -> Role -> HandshakeM ByteString
getHandshakeDigest :: Version -> Role -> HandshakeM ByteString
getHandshakeDigest Version
ver Role
role = (HandshakeState -> ByteString) -> HandshakeM ByteString
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets HandshakeState -> ByteString
gen
  where gen :: HandshakeState -> ByteString
gen HandshakeState
hst = case HandshakeState -> HandshakeDigest
hstHandshakeDigest HandshakeState
hst of
                      HandshakeDigestContext HashCtx
hashCtx ->
                         let msecret :: ByteString
msecret = String -> Maybe ByteString -> ByteString
forall a. String -> Maybe a -> a
fromJust String
"master secret" (Maybe ByteString -> ByteString) -> Maybe ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ HandshakeState -> Maybe ByteString
hstMasterSecret HandshakeState
hst
                             cipher :: Cipher
cipher  = String -> Maybe Cipher -> Cipher
forall a. String -> Maybe a -> a
fromJust String
"cipher" (Maybe Cipher -> Cipher) -> Maybe Cipher -> Cipher
forall a b. (a -> b) -> a -> b
$ HandshakeState -> Maybe Cipher
hstPendingCipher HandshakeState
hst
                          in Version -> Cipher -> ByteString -> HashCtx -> ByteString
generateFinish Version
ver Cipher
cipher ByteString
msecret HashCtx
hashCtx
                      HandshakeMessages [ByteString]
_        ->
                         String -> ByteString
forall a. HasCallStack => String -> a
error String
"un-initialized handshake digest"
        generateFinish :: Version -> Cipher -> ByteString -> HashCtx -> ByteString
generateFinish | Role
role Role -> Role -> Bool
forall a. Eq a => a -> a -> Bool
== Role
ClientRole = Version -> Cipher -> ByteString -> HashCtx -> ByteString
generateClientFinished
                       | Bool
otherwise          = Version -> Cipher -> ByteString -> HashCtx -> ByteString
generateServerFinished

-- | Generate the master secret from the pre master secret.
setMasterSecretFromPre :: ByteArrayAccess preMaster
                       => Version   -- ^ chosen transmission version
                       -> Role      -- ^ the role (Client or Server) of the generating side
                       -> preMaster -- ^ the pre master secret
                       -> HandshakeM ByteString
setMasterSecretFromPre :: forall preMaster.
ByteArrayAccess preMaster =>
Version -> Role -> preMaster -> HandshakeM ByteString
setMasterSecretFromPre Version
ver Role
role preMaster
premasterSecret = do
    Bool
ems <- HandshakeM Bool
getExtendedMasterSec
    ByteString
secret <- if Bool
ems then HandshakeM HandshakeState
forall s (m :: * -> *). MonadState s m => m s
get HandshakeM HandshakeState
-> (HandshakeState -> HandshakeM ByteString)
-> HandshakeM ByteString
forall a b. HandshakeM a -> (a -> HandshakeM b) -> HandshakeM b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= HandshakeState -> HandshakeM ByteString
genExtendedSecret else HandshakeState -> ByteString
genSecret (HandshakeState -> ByteString)
-> HandshakeM HandshakeState -> HandshakeM ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> HandshakeM HandshakeState
forall s (m :: * -> *). MonadState s m => m s
get
    Version -> Role -> ByteString -> HandshakeM ()
setMasterSecret Version
ver Role
role ByteString
secret
    ByteString -> HandshakeM ByteString
forall a. a -> HandshakeM a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
secret
  where genSecret :: HandshakeState -> ByteString
genSecret HandshakeState
hst =
            Version
-> Cipher
-> preMaster
-> ClientRandom
-> ServerRandom
-> ByteString
forall preMaster.
ByteArrayAccess preMaster =>
Version
-> Cipher
-> preMaster
-> ClientRandom
-> ServerRandom
-> ByteString
generateMasterSecret Version
ver (String -> Maybe Cipher -> Cipher
forall a. String -> Maybe a -> a
fromJust String
"cipher" (Maybe Cipher -> Cipher) -> Maybe Cipher -> Cipher
forall a b. (a -> b) -> a -> b
$ HandshakeState -> Maybe Cipher
hstPendingCipher HandshakeState
hst)
                                 preMaster
premasterSecret
                                 (HandshakeState -> ClientRandom
hstClientRandom HandshakeState
hst)
                                 (String -> Maybe ServerRandom -> ServerRandom
forall a. String -> Maybe a -> a
fromJust String
"server random" (Maybe ServerRandom -> ServerRandom)
-> Maybe ServerRandom -> ServerRandom
forall a b. (a -> b) -> a -> b
$ HandshakeState -> Maybe ServerRandom
hstServerRandom HandshakeState
hst)
        genExtendedSecret :: HandshakeState -> HandshakeM ByteString
genExtendedSecret HandshakeState
hst =
            Version -> Cipher -> preMaster -> ByteString -> ByteString
forall preMaster.
ByteArrayAccess preMaster =>
Version -> Cipher -> preMaster -> ByteString -> ByteString
generateExtendedMasterSec Version
ver (String -> Maybe Cipher -> Cipher
forall a. String -> Maybe a -> a
fromJust String
"cipher" (Maybe Cipher -> Cipher) -> Maybe Cipher -> Cipher
forall a b. (a -> b) -> a -> b
$ HandshakeState -> Maybe Cipher
hstPendingCipher HandshakeState
hst)
                                      preMaster
premasterSecret
                (ByteString -> ByteString)
-> HandshakeM ByteString -> HandshakeM ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> HandshakeM ByteString
getSessionHash

-- | Set master secret and as a side effect generate the key block
-- with all the right parameters, and setup the pending tx/rx state.
setMasterSecret :: Version -> Role -> ByteString -> HandshakeM ()
setMasterSecret :: Version -> Role -> ByteString -> HandshakeM ()
setMasterSecret Version
ver Role
role ByteString
masterSecret = (HandshakeState -> HandshakeState) -> HandshakeM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((HandshakeState -> HandshakeState) -> HandshakeM ())
-> (HandshakeState -> HandshakeState) -> HandshakeM ()
forall a b. (a -> b) -> a -> b
$ \HandshakeState
hst ->
    let (RecordState
pendingTx, RecordState
pendingRx) = HandshakeState
-> ByteString -> Version -> Role -> (RecordState, RecordState)
computeKeyBlock HandshakeState
hst ByteString
masterSecret Version
ver Role
role
     in HandshakeState
hst { hstMasterSecret   = Just masterSecret
            , hstPendingTxState = Just pendingTx
            , hstPendingRxState = Just pendingRx }

computeKeyBlock :: HandshakeState -> ByteString -> Version -> Role -> (RecordState, RecordState)
computeKeyBlock :: HandshakeState
-> ByteString -> Version -> Role -> (RecordState, RecordState)
computeKeyBlock HandshakeState
hst ByteString
masterSecret Version
ver Role
cc = (RecordState
pendingTx, RecordState
pendingRx)
  where cipher :: Cipher
cipher       = String -> Maybe Cipher -> Cipher
forall a. String -> Maybe a -> a
fromJust String
"cipher" (Maybe Cipher -> Cipher) -> Maybe Cipher -> Cipher
forall a b. (a -> b) -> a -> b
$ HandshakeState -> Maybe Cipher
hstPendingCipher HandshakeState
hst
        keyblockSize :: Int
keyblockSize = Cipher -> Int
cipherKeyBlockSize Cipher
cipher

        bulk :: Bulk
bulk         = Cipher -> Bulk
cipherBulk Cipher
cipher
        digestSize :: Int
digestSize   = if BulkFunctions -> Bool
hasMAC (Bulk -> BulkFunctions
bulkF Bulk
bulk) then Hash -> Int
hashDigestSize (Cipher -> Hash
cipherHash Cipher
cipher)
                                              else Int
0
        keySize :: Int
keySize      = Bulk -> Int
bulkKeySize Bulk
bulk
        ivSize :: Int
ivSize       = Bulk -> Int
bulkIVSize Bulk
bulk
        kb :: ByteString
kb           = Version
-> Cipher
-> ClientRandom
-> ServerRandom
-> ByteString
-> Int
-> ByteString
generateKeyBlock Version
ver Cipher
cipher (HandshakeState -> ClientRandom
hstClientRandom HandshakeState
hst)
                                        (String -> Maybe ServerRandom -> ServerRandom
forall a. String -> Maybe a -> a
fromJust String
"server random" (Maybe ServerRandom -> ServerRandom)
-> Maybe ServerRandom -> ServerRandom
forall a b. (a -> b) -> a -> b
$ HandshakeState -> Maybe ServerRandom
hstServerRandom HandshakeState
hst)
                                        ByteString
masterSecret Int
keyblockSize

        (ByteString
cMACSecret, ByteString
sMACSecret, ByteString
cWriteKey, ByteString
sWriteKey, ByteString
cWriteIV, ByteString
sWriteIV) =
                    String
-> Maybe
     (ByteString, ByteString, ByteString, ByteString, ByteString,
      ByteString)
-> (ByteString, ByteString, ByteString, ByteString, ByteString,
    ByteString)
forall a. String -> Maybe a -> a
fromJust String
"p6" (Maybe
   (ByteString, ByteString, ByteString, ByteString, ByteString,
    ByteString)
 -> (ByteString, ByteString, ByteString, ByteString, ByteString,
     ByteString))
-> Maybe
     (ByteString, ByteString, ByteString, ByteString, ByteString,
      ByteString)
-> (ByteString, ByteString, ByteString, ByteString, ByteString,
    ByteString)
forall a b. (a -> b) -> a -> b
$ ByteString
-> (Int, Int, Int, Int, Int, Int)
-> Maybe
     (ByteString, ByteString, ByteString, ByteString, ByteString,
      ByteString)
partition6 ByteString
kb (Int
digestSize, Int
digestSize, Int
keySize, Int
keySize, Int
ivSize, Int
ivSize)

        cstClient :: CryptState
cstClient = CryptState { cstKey :: BulkState
cstKey        = Bulk -> BulkDirection -> ByteString -> BulkState
bulkInit Bulk
bulk (BulkDirection
BulkEncrypt BulkDirection -> BulkDirection -> BulkDirection
forall {p}. p -> p -> p
`orOnServer` BulkDirection
BulkDecrypt) ByteString
cWriteKey
                               , cstIV :: ByteString
cstIV         = ByteString
cWriteIV
                               , cstMacSecret :: ByteString
cstMacSecret  = ByteString
cMACSecret }
        cstServer :: CryptState
cstServer = CryptState { cstKey :: BulkState
cstKey        = Bulk -> BulkDirection -> ByteString -> BulkState
bulkInit Bulk
bulk (BulkDirection
BulkDecrypt BulkDirection -> BulkDirection -> BulkDirection
forall {p}. p -> p -> p
`orOnServer` BulkDirection
BulkEncrypt) ByteString
sWriteKey
                               , cstIV :: ByteString
cstIV         = ByteString
sWriteIV
                               , cstMacSecret :: ByteString
cstMacSecret  = ByteString
sMACSecret }
        msClient :: MacState
msClient = MacState { msSequence :: Word64
msSequence = Word64
0 }
        msServer :: MacState
msServer = MacState { msSequence :: Word64
msSequence = Word64
0 }

        pendingTx :: RecordState
pendingTx = RecordState
                  { stCryptState :: CryptState
stCryptState  = if Role
cc Role -> Role -> Bool
forall a. Eq a => a -> a -> Bool
== Role
ClientRole then CryptState
cstClient else CryptState
cstServer
                  , stMacState :: MacState
stMacState    = if Role
cc Role -> Role -> Bool
forall a. Eq a => a -> a -> Bool
== Role
ClientRole then MacState
msClient else MacState
msServer
                  , stCryptLevel :: CryptLevel
stCryptLevel  = CryptLevel
CryptMasterSecret
                  , stCipher :: Maybe Cipher
stCipher      = Cipher -> Maybe Cipher
forall a. a -> Maybe a
Just Cipher
cipher
                  , stCompression :: Compression
stCompression = HandshakeState -> Compression
hstPendingCompression HandshakeState
hst
                  }
        pendingRx :: RecordState
pendingRx = RecordState
                  { stCryptState :: CryptState
stCryptState  = if Role
cc Role -> Role -> Bool
forall a. Eq a => a -> a -> Bool
== Role
ClientRole then CryptState
cstServer else CryptState
cstClient
                  , stMacState :: MacState
stMacState    = if Role
cc Role -> Role -> Bool
forall a. Eq a => a -> a -> Bool
== Role
ClientRole then MacState
msServer else MacState
msClient
                  , stCryptLevel :: CryptLevel
stCryptLevel  = CryptLevel
CryptMasterSecret
                  , stCipher :: Maybe Cipher
stCipher      = Cipher -> Maybe Cipher
forall a. a -> Maybe a
Just Cipher
cipher
                  , stCompression :: Compression
stCompression = HandshakeState -> Compression
hstPendingCompression HandshakeState
hst
                  }

        orOnServer :: p -> p -> p
orOnServer p
f p
g = if Role
cc Role -> Role -> Bool
forall a. Eq a => a -> a -> Bool
== Role
ClientRole then p
f else p
g


setServerHelloParameters :: Version      -- ^ chosen version
                         -> ServerRandom
                         -> Cipher
                         -> Compression
                         -> HandshakeM ()
setServerHelloParameters :: Version -> ServerRandom -> Cipher -> Compression -> HandshakeM ()
setServerHelloParameters Version
ver ServerRandom
sran Cipher
cipher Compression
compression = do
    (HandshakeState -> HandshakeState) -> HandshakeM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((HandshakeState -> HandshakeState) -> HandshakeM ())
-> (HandshakeState -> HandshakeState) -> HandshakeM ()
forall a b. (a -> b) -> a -> b
$ \HandshakeState
hst -> HandshakeState
hst
                { hstServerRandom       = Just sran
                , hstPendingCipher      = Just cipher
                , hstPendingCompression = compression
                , hstHandshakeDigest    = updateDigest $ hstHandshakeDigest hst
                }
  where hashAlg :: Hash
hashAlg = Version -> Cipher -> Hash
getHash Version
ver Cipher
cipher
        updateDigest :: HandshakeDigest -> HandshakeDigest
updateDigest (HandshakeMessages [ByteString]
bytes)  = HashCtx -> HandshakeDigest
HandshakeDigestContext (HashCtx -> HandshakeDigest) -> HashCtx -> HandshakeDigest
forall a b. (a -> b) -> a -> b
$ (HashCtx -> ByteString -> HashCtx)
-> HashCtx -> [ByteString] -> HashCtx
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl HashCtx -> ByteString -> HashCtx
hashUpdate (Hash -> HashCtx
hashInit Hash
hashAlg) ([ByteString] -> HashCtx) -> [ByteString] -> HashCtx
forall a b. (a -> b) -> a -> b
$ [ByteString] -> [ByteString]
forall a. [a] -> [a]
reverse [ByteString]
bytes
        updateDigest (HandshakeDigestContext HashCtx
_) = String -> HandshakeDigest
forall a. HasCallStack => String -> a
error String
"cannot initialize digest with another digest"

-- The TLS12 Hash is cipher specific, and some TLS12 algorithms use SHA384
-- instead of the default SHA256.
getHash :: Version -> Cipher -> Hash
getHash :: Version -> Cipher -> Hash
getHash Version
ver Cipher
ciph
    | Version
ver Version -> Version -> Bool
forall a. Ord a => a -> a -> Bool
< Version
TLS12                              = Hash
SHA1_MD5
    | Bool -> (Version -> Bool) -> Maybe Version -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
True (Version -> Version -> Bool
forall a. Ord a => a -> a -> Bool
< Version
TLS12) (Cipher -> Maybe Version
cipherMinVer Cipher
ciph) = Hash
SHA256
    | Bool
otherwise                                = Cipher -> Hash
cipherHash Cipher
ciph