-- |
-- Module      : Network.TLS.Handshake.Process
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : experimental
-- Portability : unknown
-- process handshake message received
module Network.TLS.Handshake.Process
    ( processHandshake
    , processHandshake13
    , startHandshake
    ) where

import Network.TLS.Context.Internal
import Network.TLS.Crypto
import Network.TLS.ErrT
import Network.TLS.Extension
import Network.TLS.Handshake.Key
import Network.TLS.Handshake.Random
import Network.TLS.Handshake.Signature
import Network.TLS.Handshake.State
import Network.TLS.Handshake.State13
import Network.TLS.Imports
import Network.TLS.Packet
import Network.TLS.Parameters
import Network.TLS.Sending
import Network.TLS.State
import Network.TLS.Struct
import Network.TLS.Struct13
import Network.TLS.Types (Role(..), invertRole, MasterSecret(..))
import Network.TLS.Util

import Control.Concurrent.MVar
import Control.Monad.IO.Class (liftIO)
import Control.Monad.State.Strict (gets)
import Data.X509 (CertificateChain(..), Certificate(..), getCertificate)
import Data.IORef (writeIORef)

processHandshake :: Context -> Handshake -> IO ()
processHandshake :: Context -> Handshake -> IO ()
processHandshake Context
ctx Handshake
hs = do
role <- Context -> TLSSt Role -> IO Role
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx TLSSt Role
    case Handshake
hs of
        ClientHello Version
cver ClientRandom
ran Session
_ [CipherID]
cids [CompressionID]
_ [ExtensionRaw]
ex Maybe ByteString
_ -> Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Role
role Role -> Role -> Bool
forall a. Eq a => a -> a -> Bool
== Role
ServerRole) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
            (ExtensionRaw -> IO ()) -> [ExtensionRaw] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Context -> TLSSt () -> IO ()
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx (TLSSt () -> IO ())
-> (ExtensionRaw -> TLSSt ()) -> ExtensionRaw -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ExtensionRaw -> TLSSt ()
processClientExtension) [ExtensionRaw]
            -- RFC 5746: secure renegotiation
            -- TLS_EMPTY_RENEGOTIATION_INFO_SCSV: {0x00, 0xFF}
            Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool
secureRenegotiation Bool -> Bool -> Bool
&& (CipherID
0xff CipherID -> [CipherID] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [CipherID]
cids)) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
                Context -> TLSSt () -> IO ()
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx (TLSSt () -> IO ()) -> TLSSt () -> IO ()
forall a b. (a -> b) -> a -> b
$ Bool -> TLSSt ()
setSecureRenegotiation Bool
hrr <- Context -> TLSSt Bool -> IO Bool
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx TLSSt Bool
            Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
hrr (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Context -> Version -> ClientRandom -> IO ()
startHandshake Context
ctx Version
cver ClientRandom
        Certificates CertificateChain
certs            -> Role -> CertificateChain -> IO ()
processCertificates Role
role CertificateChain
        Finished ByteString
fdata                -> Context -> ByteString -> IO ()
processClientFinished Context
ctx ByteString
_                             -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Handshake -> Bool
isHRR Handshake
hs) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Context -> HandshakeM () -> IO ()
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx HandshakeM ()
    IO ByteString -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO ByteString -> IO ()) -> IO ByteString -> IO ()
forall a b. (a -> b) -> a -> b
$ Context -> Role -> Handshake -> IO ByteString
updateHandshake Context
ctx Role
ServerRole Handshake
    case Handshake
hs of
        ClientKeyXchg ClientKeyXchgAlgorithmData
content  -> Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Role
role Role -> Role -> Bool
forall a. Eq a => a -> a -> Bool
== Role
ServerRole) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
            Context -> ClientKeyXchgAlgorithmData -> IO ()
processClientKeyXchg Context
ctx ClientKeyXchgAlgorithmData
_                      -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  where secureRenegotiation :: Bool
secureRenegotiation = Supported -> Bool
supportedSecureRenegotiation (Supported -> Bool) -> Supported -> Bool
forall a b. (a -> b) -> a -> b
$ Context -> Supported
ctxSupported Context
        -- RFC5746: secure renegotiation
        -- the renegotiation_info extension: 0xff01
        processClientExtension :: ExtensionRaw -> TLSSt ()
processClientExtension (ExtensionRaw CipherID
0xff01 ByteString
content) | Bool
secureRenegotiation = do
v <- Role -> TLSSt ByteString
getVerifiedData Role
            let bs :: ByteString
bs = SecureRenegotiation -> ByteString
forall a. Extension a => a -> ByteString
extensionEncode (ByteString -> Maybe ByteString -> SecureRenegotiation
SecureRenegotiation ByteString
v Maybe ByteString
forall a. Maybe a
            Bool -> TLSSt () -> TLSSt ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ByteString
bs ByteString -> ByteString -> Bool
`bytesEq` ByteString
content) (TLSSt () -> TLSSt ()) -> TLSSt () -> TLSSt ()
forall a b. (a -> b) -> a -> b
$ TLSError -> TLSSt ()
forall a. TLSError -> TLSSt a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (TLSError -> TLSSt ()) -> TLSError -> TLSSt ()
forall a b. (a -> b) -> a -> b
$ (String, Bool, AlertDescription) -> TLSError
Error_Protocol (String
"client verified data not matching: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ ByteString -> String
forall a. Show a => a -> String
show ByteString
v String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
":" String -> String -> String
forall a. [a] -> [a] -> [a]
++ ByteString -> String
forall a. Show a => a -> String
show ByteString
content, Bool
True, AlertDescription

            Bool -> TLSSt ()
setSecureRenegotiation Bool
        -- unknown extensions
        processClientExtension ExtensionRaw
_ = () -> TLSSt ()
forall a. a -> TLSSt a
forall (m :: * -> *) a. Monad m => a -> m a
return ()

        processCertificates :: Role -> CertificateChain -> IO ()
        processCertificates :: Role -> CertificateChain -> IO ()
processCertificates Role
ServerRole (CertificateChain []) = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        processCertificates Role
ClientRole (CertificateChain []) =
            TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$ (String, Bool, AlertDescription) -> TLSError
Error_Protocol (String
"server certificate missing", Bool
True, AlertDescription
        processCertificates Role
_ (CertificateChain (SignedExact Certificate
c:[SignedExact Certificate]
_)) =
            Context -> HandshakeM () -> IO ()
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx (HandshakeM () -> IO ()) -> HandshakeM () -> IO ()
forall a b. (a -> b) -> a -> b
$ PubKey -> HandshakeM ()
setPublicKey PubKey
          where pubkey :: PubKey
pubkey = Certificate -> PubKey
certPubKey (Certificate -> PubKey) -> Certificate -> PubKey
forall a b. (a -> b) -> a -> b
$ SignedExact Certificate -> Certificate
getCertificate SignedExact Certificate

        isHRR :: Handshake -> Bool
isHRR (ServerHello Version
TLS12 ServerRandom
srand Session
_ CipherID
_ CompressionID
_ [ExtensionRaw]
_) = ServerRandom -> Bool
isHelloRetryRequest ServerRandom
        isHRR Handshake
_                                 = Bool

processHandshake13 :: Context -> Handshake13 -> IO ()
processHandshake13 :: Context -> Handshake13 -> IO ()
processHandshake13 Context
ctx = IO ByteString -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO ByteString -> IO ())
-> (Handshake13 -> IO ByteString) -> Handshake13 -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Context -> Handshake13 -> IO ByteString
updateHandshake13 Context

-- process the client key exchange message. the protocol expects the initial
-- client version received in ClientHello, not the negotiated version.
-- in case the version mismatch, generate a random master secret
processClientKeyXchg :: Context -> ClientKeyXchgAlgorithmData -> IO ()
processClientKeyXchg :: Context -> ClientKeyXchgAlgorithmData -> IO ()
processClientKeyXchg Context
ctx (CKX_RSA ByteString
encryptedPremaster) = do
rver, Role
role, ByteString
random) <- Context
-> TLSSt (Version, Role, ByteString)
-> IO (Version, Role, ByteString)
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx (TLSSt (Version, Role, ByteString)
 -> IO (Version, Role, ByteString))
-> TLSSt (Version, Role, ByteString)
-> IO (Version, Role, ByteString)
forall a b. (a -> b) -> a -> b
$ do
        (,,) (Version -> Role -> ByteString -> (Version, Role, ByteString))
-> TLSSt Version
-> TLSSt (Role -> ByteString -> (Version, Role, ByteString))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TLSSt Version
getVersion TLSSt (Role -> ByteString -> (Version, Role, ByteString))
-> TLSSt Role -> TLSSt (ByteString -> (Version, Role, ByteString))
forall a b. TLSSt (a -> b) -> TLSSt a -> TLSSt b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> TLSSt Role
isClientContext TLSSt (ByteString -> (Version, Role, ByteString))
-> TLSSt ByteString -> TLSSt (Version, Role, ByteString)
forall a b. TLSSt (a -> b) -> TLSSt a -> TLSSt b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> TLSSt ByteString
genRandom Int
    Either KxError ByteString
ePremaster <- Context -> ByteString -> IO (Either KxError ByteString)
decryptRSA Context
ctx ByteString
masterSecret <- Context -> HandshakeM ByteString -> IO ByteString
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx (HandshakeM ByteString -> IO ByteString)
-> HandshakeM ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ do
expectedVer <- (HandshakeState -> Version) -> HandshakeM Version
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets HandshakeState -> Version
        case Either KxError ByteString
ePremaster of
            Left KxError
_          -> Version -> Role -> ByteString -> HandshakeM ByteString
forall preMaster.
ByteArrayAccess preMaster =>
Version -> Role -> preMaster -> HandshakeM ByteString
setMasterSecretFromPre Version
rver Role
role ByteString
            Right ByteString
premaster -> case ByteString -> Either TLSError (Version, ByteString)
decodePreMasterSecret ByteString
premaster of
                Left TLSError
_                   -> Version -> Role -> ByteString -> HandshakeM ByteString
forall preMaster.
ByteArrayAccess preMaster =>
Version -> Role -> preMaster -> HandshakeM ByteString
setMasterSecretFromPre Version
rver Role
role ByteString
                Right (Version
ver, ByteString
                    | Version
ver Version -> Version -> Bool
forall a. Eq a => a -> a -> Bool
/= Version
expectedVer -> Version -> Role -> ByteString -> HandshakeM ByteString
forall preMaster.
ByteArrayAccess preMaster =>
Version -> Role -> preMaster -> HandshakeM ByteString
setMasterSecretFromPre Version
rver Role
role ByteString
                    | Bool
otherwise          -> Version -> Role -> ByteString -> HandshakeM ByteString
forall preMaster.
ByteArrayAccess preMaster =>
Version -> Role -> preMaster -> HandshakeM ByteString
setMasterSecretFromPre Version
rver Role
role ByteString
    IO () -> IO ()
forall a. IO a -> IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Context -> MasterSecret -> IO ()
forall a. LogLabel a => Context -> a -> IO ()
logKey Context
ctx (ByteString -> MasterSecret
MasterSecret ByteString

processClientKeyXchg Context
ctx (CKX_DH DHPublic
clientDHValue) = do
rver <- Context -> TLSSt Version -> IO Version
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx TLSSt Version
role <- Context -> TLSSt Role -> IO Role
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx TLSSt Role

serverParams <- Context -> HandshakeM ServerDHParams -> IO ServerDHParams
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx HandshakeM ServerDHParams
    let params :: DHParams
params = ServerDHParams -> DHParams
serverDHParamsToParams ServerDHParams
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (DHParams -> Integer -> Bool
dhValid DHParams
params (Integer -> Bool) -> Integer -> Bool
forall a b. (a -> b) -> a -> b
$ DHPublic -> Integer
dhUnwrapPublic DHPublic
clientDHValue) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
        TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$ (String, Bool, AlertDescription) -> TLSError
Error_Protocol (String
"invalid client public key", Bool
True, AlertDescription

dhpriv       <- Context -> HandshakeM DHPrivate -> IO DHPrivate
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx HandshakeM DHPrivate
    let premaster :: DHKey
premaster = DHParams -> DHPrivate -> DHPublic -> DHKey
dhGetShared DHParams
params DHPrivate
dhpriv DHPublic
masterSecret <- Context -> HandshakeM ByteString -> IO ByteString
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx (HandshakeM ByteString -> IO ByteString)
-> HandshakeM ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ Version -> Role -> DHKey -> HandshakeM ByteString
forall preMaster.
ByteArrayAccess preMaster =>
Version -> Role -> preMaster -> HandshakeM ByteString
setMasterSecretFromPre Version
rver Role
role DHKey
    IO () -> IO ()
forall a. IO a -> IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Context -> MasterSecret -> IO ()
forall a. LogLabel a => Context -> a -> IO ()
logKey Context
ctx (ByteString -> MasterSecret
MasterSecret ByteString

processClientKeyXchg Context
ctx (CKX_ECDH ByteString
bytes) = do
    ServerECDHParams Group
grp GroupPublic
_ <- Context -> HandshakeM ServerECDHParams -> IO ServerECDHParams
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx HandshakeM ServerECDHParams
    case Group -> ByteString -> Either CryptoError GroupPublic
decodeGroupPublic Group
grp ByteString
bytes of
      Left CryptoError
_ -> TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$ (String, Bool, AlertDescription) -> TLSError
Error_Protocol (String
"client public key cannot be decoded", Bool
True, AlertDescription
      Right GroupPublic
clipub -> do
srvpri <- Context -> HandshakeM GroupPrivate -> IO GroupPrivate
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx HandshakeM GroupPrivate
          case GroupPublic -> GroupPrivate -> Maybe GroupKey
groupGetShared GroupPublic
clipub GroupPrivate
srvpri of
              Just GroupKey
premaster -> do
rver <- Context -> TLSSt Version -> IO Version
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx TLSSt Version
role <- Context -> TLSSt Role -> IO Role
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx TLSSt Role
masterSecret <- Context -> HandshakeM ByteString -> IO ByteString
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx (HandshakeM ByteString -> IO ByteString)
-> HandshakeM ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ Version -> Role -> GroupKey -> HandshakeM ByteString
forall preMaster.
ByteArrayAccess preMaster =>
Version -> Role -> preMaster -> HandshakeM ByteString
setMasterSecretFromPre Version
rver Role
role GroupKey
                  IO () -> IO ()
forall a. IO a -> IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Context -> MasterSecret -> IO ()
forall a. LogLabel a => Context -> a -> IO ()
logKey Context
ctx (ByteString -> MasterSecret
MasterSecret ByteString
              Maybe GroupKey
Nothing -> TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$ (String, Bool, AlertDescription) -> TLSError
Error_Protocol (String
"cannot generate a shared secret on ECDH", Bool
True, AlertDescription

processClientFinished :: Context -> FinishedData -> IO ()
processClientFinished :: Context -> ByteString -> IO ()
processClientFinished Context
ctx ByteString
fdata = do
ver) <- Context -> TLSSt (Role, Version) -> IO (Role, Version)
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx (TLSSt (Role, Version) -> IO (Role, Version))
-> TLSSt (Role, Version) -> IO (Role, Version)
forall a b. (a -> b) -> a -> b
$ (,) (Role -> Version -> (Role, Version))
-> TLSSt Role -> TLSSt (Version -> (Role, Version))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TLSSt Role
isClientContext TLSSt (Version -> (Role, Version))
-> TLSSt Version -> TLSSt (Role, Version)
forall a b. TLSSt (a -> b) -> TLSSt a -> TLSSt b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> TLSSt Version
expected <- Context -> HandshakeM ByteString -> IO ByteString
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx (HandshakeM ByteString -> IO ByteString)
-> HandshakeM ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ Version -> Role -> HandshakeM ByteString
getHandshakeDigest Version
ver (Role -> HandshakeM ByteString) -> Role -> HandshakeM ByteString
forall a b. (a -> b) -> a -> b
$ Role -> Role
invertRole Role
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (ByteString
expected ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
/= ByteString
fdata) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ String -> IO ()
forall (m :: * -> *) a. MonadIO m => String -> m a
decryptError String
"cannot verify finished"
    IORef (Maybe ByteString) -> Maybe ByteString -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef (Context -> IORef (Maybe ByteString)
ctxPeerFinished Context
ctx) (Maybe ByteString -> IO ()) -> Maybe ByteString -> IO ()
forall a b. (a -> b) -> a -> b
$ ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString

-- initialize a new Handshake context (initial handshake or renegotiations)
startHandshake :: Context -> Version -> ClientRandom -> IO ()
startHandshake :: Context -> Version -> ClientRandom -> IO ()
startHandshake Context
ctx Version
ver ClientRandom
crand =
    let hs :: Maybe HandshakeState
hs = HandshakeState -> Maybe HandshakeState
forall a. a -> Maybe a
Just (HandshakeState -> Maybe HandshakeState)
-> HandshakeState -> Maybe HandshakeState
forall a b. (a -> b) -> a -> b
$ Version -> ClientRandom -> HandshakeState
newEmptyHandshake Version
ver ClientRandom
    in IO () -> IO ()
forall a. IO a -> IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ IO (Maybe HandshakeState) -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO (Maybe HandshakeState) -> IO ())
-> IO (Maybe HandshakeState) -> IO ()
forall a b. (a -> b) -> a -> b
$ MVar (Maybe HandshakeState)
-> Maybe HandshakeState -> IO (Maybe HandshakeState)
forall a. MVar a -> a -> IO a
swapMVar (Context -> MVar (Maybe HandshakeState)
ctxHandshake Context
ctx) Maybe HandshakeState