{-# LANGUAGE OverloadedStrings #-}

-- |
-- Module      : Network.TLS.Handshake.State13
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : experimental
-- Portability : unknown
--
module Network.TLS.Handshake.State13
       ( CryptLevel ( CryptEarlySecret
                    , CryptHandshakeSecret
                    , CryptApplicationSecret
                    )
       , TrafficSecret
       , getTxState
       , getRxState
       , setTxState
       , setRxState
       , clearTxState
       , clearRxState
       , setHelloParameters13
       , transcriptHash
       , wrapAsMessageHash13
       , PendingAction(..)
       , setPendingActions
       , popPendingAction
       ) where

import Control.Concurrent.MVar
import Control.Monad.State
import qualified Data.ByteString as B
import Data.IORef
import Network.TLS.Cipher
import Network.TLS.Compression
import Network.TLS.Context.Internal
import Network.TLS.Crypto
import Network.TLS.Handshake.State
import Network.TLS.KeySchedule (hkdfExpandLabel)
import Network.TLS.Record.State
import Network.TLS.Struct
import Network.TLS.Imports
import Network.TLS.Types
import Network.TLS.Util

getTxState :: Context -> IO (Hash, Cipher, CryptLevel, ByteString)
getTxState :: Context -> IO (Hash, Cipher, CryptLevel, ByteString)
getTxState Context
ctx = Context
-> (Context -> MVar RecordState)
-> IO (Hash, Cipher, CryptLevel, ByteString)
getXState Context
ctx Context -> MVar RecordState
ctxTxState

getRxState :: Context -> IO (Hash, Cipher, CryptLevel, ByteString)
getRxState :: Context -> IO (Hash, Cipher, CryptLevel, ByteString)
getRxState Context
ctx = Context
-> (Context -> MVar RecordState)
-> IO (Hash, Cipher, CryptLevel, ByteString)
getXState Context
ctx Context -> MVar RecordState
ctxRxState

getXState :: Context
          -> (Context -> MVar RecordState)
          -> IO (Hash, Cipher, CryptLevel, ByteString)
getXState :: Context
-> (Context -> MVar RecordState)
-> IO (Hash, Cipher, CryptLevel, ByteString)
getXState Context
ctx Context -> MVar RecordState
func = do
    RecordState
tx <- MVar RecordState -> IO RecordState
forall a. MVar a -> IO a
readMVar (Context -> MVar RecordState
func Context
ctx)
    let Just Cipher
usedCipher = RecordState -> Maybe Cipher
stCipher RecordState
tx
        usedHash :: Hash
usedHash = Cipher -> Hash
cipherHash Cipher
usedCipher
        level :: CryptLevel
level = RecordState -> CryptLevel
stCryptLevel RecordState
tx
        secret :: ByteString
secret = CryptState -> ByteString
cstMacSecret (CryptState -> ByteString) -> CryptState -> ByteString
forall a b. (a -> b) -> a -> b
$ RecordState -> CryptState
stCryptState RecordState
tx
    (Hash, Cipher, CryptLevel, ByteString)
-> IO (Hash, Cipher, CryptLevel, ByteString)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Hash
usedHash, Cipher
usedCipher, CryptLevel
level, ByteString
secret)

class TrafficSecret ty where
    fromTrafficSecret :: ty -> (CryptLevel, ByteString)

instance HasCryptLevel a => TrafficSecret (AnyTrafficSecret a) where
    fromTrafficSecret :: AnyTrafficSecret a -> (CryptLevel, ByteString)
fromTrafficSecret prx :: AnyTrafficSecret a
prx@(AnyTrafficSecret ByteString
s) = (AnyTrafficSecret a -> CryptLevel
forall a (proxy :: * -> *).
HasCryptLevel a =>
proxy a -> CryptLevel
forall (proxy :: * -> *). proxy a -> CryptLevel
getCryptLevel AnyTrafficSecret a
prx, ByteString
s)

instance HasCryptLevel a => TrafficSecret (ClientTrafficSecret a) where
    fromTrafficSecret :: ClientTrafficSecret a -> (CryptLevel, ByteString)
fromTrafficSecret prx :: ClientTrafficSecret a
prx@(ClientTrafficSecret ByteString
s) = (ClientTrafficSecret a -> CryptLevel
forall a (proxy :: * -> *).
HasCryptLevel a =>
proxy a -> CryptLevel
forall (proxy :: * -> *). proxy a -> CryptLevel
getCryptLevel ClientTrafficSecret a
prx, ByteString
s)

instance HasCryptLevel a => TrafficSecret (ServerTrafficSecret a) where
    fromTrafficSecret :: ServerTrafficSecret a -> (CryptLevel, ByteString)
fromTrafficSecret prx :: ServerTrafficSecret a
prx@(ServerTrafficSecret ByteString
s) = (ServerTrafficSecret a -> CryptLevel
forall a (proxy :: * -> *).
HasCryptLevel a =>
proxy a -> CryptLevel
forall (proxy :: * -> *). proxy a -> CryptLevel
getCryptLevel ServerTrafficSecret a
prx, ByteString
s)

setTxState :: TrafficSecret ty => Context -> Hash -> Cipher -> ty -> IO ()
setTxState :: forall ty.
TrafficSecret ty =>
Context -> Hash -> Cipher -> ty -> IO ()
setTxState = (Context -> MVar RecordState)
-> BulkDirection -> Context -> Hash -> Cipher -> ty -> IO ()
forall ty.
TrafficSecret ty =>
(Context -> MVar RecordState)
-> BulkDirection -> Context -> Hash -> Cipher -> ty -> IO ()
setXState Context -> MVar RecordState
ctxTxState BulkDirection
BulkEncrypt

setRxState :: TrafficSecret ty => Context -> Hash -> Cipher -> ty -> IO ()
setRxState :: forall ty.
TrafficSecret ty =>
Context -> Hash -> Cipher -> ty -> IO ()
setRxState = (Context -> MVar RecordState)
-> BulkDirection -> Context -> Hash -> Cipher -> ty -> IO ()
forall ty.
TrafficSecret ty =>
(Context -> MVar RecordState)
-> BulkDirection -> Context -> Hash -> Cipher -> ty -> IO ()
setXState Context -> MVar RecordState
ctxRxState BulkDirection
BulkDecrypt

setXState :: TrafficSecret ty
          => (Context -> MVar RecordState) -> BulkDirection
          -> Context -> Hash -> Cipher -> ty
          -> IO ()
setXState :: forall ty.
TrafficSecret ty =>
(Context -> MVar RecordState)
-> BulkDirection -> Context -> Hash -> Cipher -> ty -> IO ()
setXState Context -> MVar RecordState
func BulkDirection
encOrDec Context
ctx Hash
h Cipher
cipher ty
ts =
    let (CryptLevel
lvl, ByteString
secret) = ty -> (CryptLevel, ByteString)
forall ty. TrafficSecret ty => ty -> (CryptLevel, ByteString)
fromTrafficSecret ty
ts
     in (Context -> MVar RecordState)
-> BulkDirection
-> Context
-> Hash
-> Cipher
-> CryptLevel
-> ByteString
-> IO ()
setXState' Context -> MVar RecordState
func BulkDirection
encOrDec Context
ctx Hash
h Cipher
cipher CryptLevel
lvl ByteString
secret

setXState' :: (Context -> MVar RecordState) -> BulkDirection
          -> Context -> Hash -> Cipher -> CryptLevel -> ByteString
          -> IO ()
setXState' :: (Context -> MVar RecordState)
-> BulkDirection
-> Context
-> Hash
-> Cipher
-> CryptLevel
-> ByteString
-> IO ()
setXState' Context -> MVar RecordState
func BulkDirection
encOrDec Context
ctx Hash
h Cipher
cipher CryptLevel
lvl ByteString
secret =
    MVar RecordState -> (RecordState -> IO RecordState) -> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ (Context -> MVar RecordState
func Context
ctx) (\RecordState
_ -> RecordState -> IO RecordState
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return RecordState
rt)
  where
    bulk :: Bulk
bulk    = Cipher -> Bulk
cipherBulk Cipher
cipher
    keySize :: Int
keySize = Bulk -> Int
bulkKeySize Bulk
bulk
    ivSize :: Int
ivSize  = Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
8 (Bulk -> Int
bulkIVSize Bulk
bulk Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Bulk -> Int
bulkExplicitIV Bulk
bulk)
    key :: ByteString
key = Hash -> ByteString -> ByteString -> ByteString -> Int -> ByteString
hkdfExpandLabel Hash
h ByteString
secret ByteString
"key" ByteString
"" Int
keySize
    iv :: ByteString
iv  = Hash -> ByteString -> ByteString -> ByteString -> Int -> ByteString
hkdfExpandLabel Hash
h ByteString
secret ByteString
"iv"  ByteString
"" Int
ivSize
    cst :: CryptState
cst = CryptState {
        cstKey :: BulkState
cstKey       = Bulk -> BulkDirection -> ByteString -> BulkState
bulkInit Bulk
bulk BulkDirection
encOrDec ByteString
key
      , cstIV :: ByteString
cstIV        = ByteString
iv
      , cstMacSecret :: ByteString
cstMacSecret = ByteString
secret
      }
    rt :: RecordState
rt = RecordState {
        stCryptState :: CryptState
stCryptState  = CryptState
cst
      , stMacState :: MacState
stMacState    = MacState { msSequence :: Word64
msSequence = Word64
0 }
      , stCryptLevel :: CryptLevel
stCryptLevel  = CryptLevel
lvl
      , stCipher :: Maybe Cipher
stCipher      = Cipher -> Maybe Cipher
forall a. a -> Maybe a
Just Cipher
cipher
      , stCompression :: Compression
stCompression = Compression
nullCompression
      }

clearTxState :: Context -> IO ()
clearTxState :: Context -> IO ()
clearTxState = (Context -> MVar RecordState) -> Context -> IO ()
clearXState Context -> MVar RecordState
ctxTxState

clearRxState :: Context -> IO ()
clearRxState :: Context -> IO ()
clearRxState = (Context -> MVar RecordState) -> Context -> IO ()
clearXState Context -> MVar RecordState
ctxRxState

clearXState :: (Context -> MVar RecordState) -> Context -> IO ()
clearXState :: (Context -> MVar RecordState) -> Context -> IO ()
clearXState Context -> MVar RecordState
func Context
ctx =
    MVar RecordState -> (RecordState -> IO RecordState) -> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ (Context -> MVar RecordState
func Context
ctx) (\RecordState
rt -> RecordState -> IO RecordState
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return RecordState
rt { stCipher = Nothing })

setHelloParameters13 :: Cipher -> HandshakeM (Either TLSError ())
setHelloParameters13 :: Cipher -> HandshakeM (Either TLSError ())
setHelloParameters13 Cipher
cipher = do
    HandshakeState
hst <- HandshakeM HandshakeState
forall s (m :: * -> *). MonadState s m => m s
get
    case HandshakeState -> Maybe Cipher
hstPendingCipher HandshakeState
hst of
        Maybe Cipher
Nothing -> do
            HandshakeState -> HandshakeM ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put HandshakeState
hst {
                  hstPendingCipher      = Just cipher
                , hstPendingCompression = nullCompression
                , hstHandshakeDigest    = updateDigest $ hstHandshakeDigest hst
                }
            Either TLSError () -> HandshakeM (Either TLSError ())
forall a. a -> HandshakeM a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError () -> HandshakeM (Either TLSError ()))
-> Either TLSError () -> HandshakeM (Either TLSError ())
forall a b. (a -> b) -> a -> b
$ () -> Either TLSError ()
forall a b. b -> Either a b
Right ()
        Just Cipher
oldcipher
            | Cipher
cipher Cipher -> Cipher -> Bool
forall a. Eq a => a -> a -> Bool
== Cipher
oldcipher -> Either TLSError () -> HandshakeM (Either TLSError ())
forall a. a -> HandshakeM a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError () -> HandshakeM (Either TLSError ()))
-> Either TLSError () -> HandshakeM (Either TLSError ())
forall a b. (a -> b) -> a -> b
$ () -> Either TLSError ()
forall a b. b -> Either a b
Right ()
            | Bool
otherwise -> Either TLSError () -> HandshakeM (Either TLSError ())
forall a. a -> HandshakeM a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError () -> HandshakeM (Either TLSError ()))
-> Either TLSError () -> HandshakeM (Either TLSError ())
forall a b. (a -> b) -> a -> b
$ TLSError -> Either TLSError ()
forall a b. a -> Either a b
Left (TLSError -> Either TLSError ()) -> TLSError -> Either TLSError ()
forall a b. (a -> b) -> a -> b
$ (String, Bool, AlertDescription) -> TLSError
Error_Protocol (String
"TLS 1.3 cipher changed after hello retry", Bool
True, AlertDescription
IllegalParameter)
  where
    hashAlg :: Hash
hashAlg = Cipher -> Hash
cipherHash Cipher
cipher
    updateDigest :: HandshakeDigest -> HandshakeDigest
updateDigest (HandshakeMessages [ByteString]
bytes)  = HashCtx -> HandshakeDigest
HandshakeDigestContext (HashCtx -> HandshakeDigest) -> HashCtx -> HandshakeDigest
forall a b. (a -> b) -> a -> b
$ (HashCtx -> ByteString -> HashCtx)
-> HashCtx -> [ByteString] -> HashCtx
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl HashCtx -> ByteString -> HashCtx
hashUpdate (Hash -> HashCtx
hashInit Hash
hashAlg) ([ByteString] -> HashCtx) -> [ByteString] -> HashCtx
forall a b. (a -> b) -> a -> b
$ [ByteString] -> [ByteString]
forall a. [a] -> [a]
reverse [ByteString]
bytes
    updateDigest (HandshakeDigestContext HashCtx
_) = String -> HandshakeDigest
forall a. HasCallStack => String -> a
error String
"cannot initialize digest with another digest"

-- When a HelloRetryRequest is sent or received, the existing transcript must be
-- wrapped in a "message_hash" construct.  See RFC 8446 section 4.4.1.  This
-- applies to key-schedule computations as well as the ones for PSK binders.
wrapAsMessageHash13 :: HandshakeM ()
wrapAsMessageHash13 :: HandshakeM ()
wrapAsMessageHash13 = do
    Cipher
cipher <- HandshakeM Cipher
getPendingCipher
    Hash -> (ByteString -> ByteString) -> HandshakeM ()
foldHandshakeDigest (Cipher -> Hash
cipherHash Cipher
cipher) ByteString -> ByteString
foldFunc
  where
    foldFunc :: ByteString -> ByteString
foldFunc ByteString
dig = [ByteString] -> ByteString
B.concat [ ByteString
"\254\0\0"
                            , Word8 -> ByteString
B.singleton (Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Word8) -> Int -> Word8
forall a b. (a -> b) -> a -> b
$ ByteString -> Int
B.length ByteString
dig)
                            , ByteString
dig
                            ]

transcriptHash :: MonadIO m => Context -> m ByteString
transcriptHash :: forall (m :: * -> *). MonadIO m => Context -> m ByteString
transcriptHash Context
ctx = do
    HandshakeState
hst <- String -> Maybe HandshakeState -> HandshakeState
forall a. String -> Maybe a -> a
fromJust String
"HState" (Maybe HandshakeState -> HandshakeState)
-> m (Maybe HandshakeState) -> m HandshakeState
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Context -> m (Maybe HandshakeState)
forall (m :: * -> *).
MonadIO m =>
Context -> m (Maybe HandshakeState)
getHState Context
ctx
    case HandshakeState -> HandshakeDigest
hstHandshakeDigest HandshakeState
hst of
      HandshakeDigestContext HashCtx
hashCtx -> ByteString -> m ByteString
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> m ByteString) -> ByteString -> m ByteString
forall a b. (a -> b) -> a -> b
$ HashCtx -> ByteString
hashFinal HashCtx
hashCtx
      HandshakeMessages      [ByteString]
_       -> String -> m ByteString
forall a. HasCallStack => String -> a
error String
"un-initialized handshake digest"

setPendingActions :: Context -> [PendingAction] -> IO ()
setPendingActions :: Context -> [PendingAction] -> IO ()
setPendingActions Context
ctx = IORef [PendingAction] -> [PendingAction] -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef (Context -> IORef [PendingAction]
ctxPendingActions Context
ctx)

popPendingAction :: Context -> IO (Maybe PendingAction)
popPendingAction :: Context -> IO (Maybe PendingAction)
popPendingAction Context
ctx = do
    let ref :: IORef [PendingAction]
ref = Context -> IORef [PendingAction]
ctxPendingActions Context
ctx
    [PendingAction]
actions <- IORef [PendingAction] -> IO [PendingAction]
forall a. IORef a -> IO a
readIORef IORef [PendingAction]
ref
    case [PendingAction]
actions of
        PendingAction
bs:[PendingAction]
bss -> IORef [PendingAction] -> [PendingAction] -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef [PendingAction]
ref [PendingAction]
bss IO () -> IO (Maybe PendingAction) -> IO (Maybe PendingAction)
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Maybe PendingAction -> IO (Maybe PendingAction)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (PendingAction -> Maybe PendingAction
forall a. a -> Maybe a
Just PendingAction
bs)
        []     -> Maybe PendingAction -> IO (Maybe PendingAction)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe PendingAction
forall a. Maybe a
Nothing