-- Copyright (C) 2014-2022  Fraser Tweedale
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
--      http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.

{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TemplateHaskell #-}

{-|

JOSE error types and helpers.

-}
module Crypto.JOSE.Error
  (
  -- * Running JOSE computations
    runJOSE
  , unwrapJOSE
  , JOSE(..)

  -- * Base error type and class
  , Error(..)
  , AsError(..)

  -- * JOSE compact serialisation errors
  , InvalidNumberOfParts(..), expectedParts, actualParts
  , CompactTextError(..)
  , CompactDecodeError(..)
  , _CompactInvalidNumberOfParts
  , _CompactInvalidText

  ) where

import Numeric.Natural

import Control.Monad.Except (MonadError(..), ExceptT, runExceptT)
import Control.Monad.Trans (MonadIO(liftIO), MonadTrans(lift))
import qualified Crypto.PubKey.RSA as RSA
import Crypto.Error (CryptoError)
import Crypto.Random (MonadRandom(..))
import Control.Lens (Getter, to)
import Control.Lens.TH (makeClassyPrisms, makePrisms)
import qualified Data.Text as T
import qualified Data.Text.Encoding.Error as T


-- | The wrong number of parts were found when decoding a
-- compact JOSE object.
--
data InvalidNumberOfParts =
  InvalidNumberOfParts Natural Natural -- ^ expected vs actual parts
  deriving (InvalidNumberOfParts -> InvalidNumberOfParts -> Bool
(InvalidNumberOfParts -> InvalidNumberOfParts -> Bool)
-> (InvalidNumberOfParts -> InvalidNumberOfParts -> Bool)
-> Eq InvalidNumberOfParts
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: InvalidNumberOfParts -> InvalidNumberOfParts -> Bool
== :: InvalidNumberOfParts -> InvalidNumberOfParts -> Bool
$c/= :: InvalidNumberOfParts -> InvalidNumberOfParts -> Bool
/= :: InvalidNumberOfParts -> InvalidNumberOfParts -> Bool
Eq)

instance Show InvalidNumberOfParts where
  show :: InvalidNumberOfParts -> String
show (InvalidNumberOfParts Natural
n Natural
m) =
    String
"Expected " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Natural -> String
forall a. Show a => a -> String
show Natural
n String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
" parts; got " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Natural -> String
forall a. Show a => a -> String
show Natural
m

-- | Get the expected or actual number of parts.
expectedParts, actualParts :: Getter InvalidNumberOfParts Natural
expectedParts :: Getter InvalidNumberOfParts Natural
expectedParts = (InvalidNumberOfParts -> Natural)
-> Optic' (->) f InvalidNumberOfParts Natural
forall (p :: * -> * -> *) (f :: * -> *) s a.
(Profunctor p, Contravariant f) =>
(s -> a) -> Optic' p f s a
to ((InvalidNumberOfParts -> Natural)
 -> Optic' (->) f InvalidNumberOfParts Natural)
-> (InvalidNumberOfParts -> Natural)
-> Optic' (->) f InvalidNumberOfParts Natural
forall a b. (a -> b) -> a -> b
$ \(InvalidNumberOfParts Natural
n Natural
_) -> Natural
n
actualParts :: Getter InvalidNumberOfParts Natural
actualParts   = (InvalidNumberOfParts -> Natural)
-> Optic' (->) f InvalidNumberOfParts Natural
forall (p :: * -> * -> *) (f :: * -> *) s a.
(Profunctor p, Contravariant f) =>
(s -> a) -> Optic' p f s a
to ((InvalidNumberOfParts -> Natural)
 -> Optic' (->) f InvalidNumberOfParts Natural)
-> (InvalidNumberOfParts -> Natural)
-> Optic' (->) f InvalidNumberOfParts Natural
forall a b. (a -> b) -> a -> b
$ \(InvalidNumberOfParts Natural
_ Natural
n) -> Natural
n


-- | Bad UTF-8 data in a compact object, at the specified index
data CompactTextError = CompactTextError
  Natural
  T.UnicodeException
  deriving (CompactTextError -> CompactTextError -> Bool
(CompactTextError -> CompactTextError -> Bool)
-> (CompactTextError -> CompactTextError -> Bool)
-> Eq CompactTextError
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: CompactTextError -> CompactTextError -> Bool
== :: CompactTextError -> CompactTextError -> Bool
$c/= :: CompactTextError -> CompactTextError -> Bool
/= :: CompactTextError -> CompactTextError -> Bool
Eq)

instance Show CompactTextError where
  show :: CompactTextError -> String
show (CompactTextError Natural
n UnicodeException
s) =
    String
"Invalid text at part " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Natural -> String
forall a. Show a => a -> String
show Natural
n String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
": " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> UnicodeException -> String
forall a. Show a => a -> String
show UnicodeException
s


-- | An error when decoding a JOSE compact object.
-- JSON decoding errors that occur during compact object processing
-- throw 'JSONDecodeError'.
--
data CompactDecodeError
  = CompactInvalidNumberOfParts InvalidNumberOfParts
  | CompactInvalidText CompactTextError
  deriving (CompactDecodeError -> CompactDecodeError -> Bool
(CompactDecodeError -> CompactDecodeError -> Bool)
-> (CompactDecodeError -> CompactDecodeError -> Bool)
-> Eq CompactDecodeError
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: CompactDecodeError -> CompactDecodeError -> Bool
== :: CompactDecodeError -> CompactDecodeError -> Bool
$c/= :: CompactDecodeError -> CompactDecodeError -> Bool
/= :: CompactDecodeError -> CompactDecodeError -> Bool
Eq)
makePrisms ''CompactDecodeError

instance Show CompactDecodeError where
  show :: CompactDecodeError -> String
show (CompactInvalidNumberOfParts InvalidNumberOfParts
e) = String
"Invalid number of parts: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> InvalidNumberOfParts -> String
forall a. Show a => a -> String
show InvalidNumberOfParts
e
  show (CompactInvalidText CompactTextError
e) = String
"Invalid text: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> CompactTextError -> String
forall a. Show a => a -> String
show CompactTextError
e



-- | All the errors that can occur.
--
data Error
  = AlgorithmNotImplemented   -- ^ A requested algorithm is not implemented
  | AlgorithmMismatch String  -- ^ A requested algorithm cannot be used
  | KeyMismatch T.Text        -- ^ Wrong type of key was given
  | KeySizeTooSmall           -- ^ Key size is too small
  | OtherPrimesNotSupported   -- ^ RSA private key with >2 primes not supported
  | RSAError RSA.Error        -- ^ RSA encryption, decryption or signing error
  | CryptoError CryptoError   -- ^ Various cryptonite library error cases
  | CompactDecodeError CompactDecodeError
  -- ^ Wrong number of parts in compact serialisation
  | JSONDecodeError String    -- ^ JSON (Aeson) decoding error
  | NoUsableKeys              -- ^ No usable keys were found in the key store
  | JWSCritUnprotected
  | JWSNoValidSignatures
  -- ^ 'AnyValidated' policy active, and no valid signature encountered
  | JWSInvalidSignature
  -- ^ 'AllValidated' policy active, and invalid signature encountered
  | JWSNoSignatures
  -- ^ 'AllValidated' policy active, and there were no signatures on object
  --   that matched the allowed algorithms
  deriving (Error -> Error -> Bool
(Error -> Error -> Bool) -> (Error -> Error -> Bool) -> Eq Error
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Error -> Error -> Bool
== :: Error -> Error -> Bool
$c/= :: Error -> Error -> Bool
/= :: Error -> Error -> Bool
Eq, Int -> Error -> ShowS
[Error] -> ShowS
Error -> String
(Int -> Error -> ShowS)
-> (Error -> String) -> ([Error] -> ShowS) -> Show Error
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Error -> ShowS
showsPrec :: Int -> Error -> ShowS
$cshow :: Error -> String
show :: Error -> String
$cshowList :: [Error] -> ShowS
showList :: [Error] -> ShowS
Show)
makeClassyPrisms ''Error


newtype JOSE e m a = JOSE (ExceptT e m a)

-- | Run the 'JOSE' computation.  Result is an @Either e a@
-- where @e@ is the error type (typically 'Error' or 'Crypto.JWT.JWTError')
runJOSE :: JOSE e m a -> m (Either e a)
runJOSE :: forall e (m :: * -> *) a. JOSE e m a -> m (Either e a)
runJOSE = ExceptT e m a -> m (Either e a)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT e m a -> m (Either e a))
-> (JOSE e m a -> ExceptT e m a) -> JOSE e m a -> m (Either e a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (\(JOSE ExceptT e m a
m) -> ExceptT e m a
m)

-- | Get the inner 'ExceptT' value of the 'JOSE' computation.
-- Typically 'runJOSE' would be preferred, unless you specifically
-- need an 'ExceptT' value.
unwrapJOSE :: JOSE e m a -> ExceptT e m a
unwrapJOSE :: forall e (m :: * -> *) a. JOSE e m a -> ExceptT e m a
unwrapJOSE (JOSE ExceptT e m a
m) = ExceptT e m a
m


instance (Functor m) => Functor (JOSE e m) where
  fmap :: forall a b. (a -> b) -> JOSE e m a -> JOSE e m b
fmap a -> b
f (JOSE ExceptT e m a
ma) = ExceptT e m b -> JOSE e m b
forall e (m :: * -> *) a. ExceptT e m a -> JOSE e m a
JOSE ((a -> b) -> ExceptT e m a -> ExceptT e m b
forall a b. (a -> b) -> ExceptT e m a -> ExceptT e m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
f ExceptT e m a
ma)

instance (Monad m) => Applicative (JOSE e m) where
  pure :: forall a. a -> JOSE e m a
pure = ExceptT e m a -> JOSE e m a
forall e (m :: * -> *) a. ExceptT e m a -> JOSE e m a
JOSE (ExceptT e m a -> JOSE e m a)
-> (a -> ExceptT e m a) -> a -> JOSE e m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> ExceptT e m a
forall a. a -> ExceptT e m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
  JOSE ExceptT e m (a -> b)
mf <*> :: forall a b. JOSE e m (a -> b) -> JOSE e m a -> JOSE e m b
<*> JOSE ExceptT e m a
ma = ExceptT e m b -> JOSE e m b
forall e (m :: * -> *) a. ExceptT e m a -> JOSE e m a
JOSE (ExceptT e m (a -> b)
mf ExceptT e m (a -> b) -> ExceptT e m a -> ExceptT e m b
forall a b. ExceptT e m (a -> b) -> ExceptT e m a -> ExceptT e m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ExceptT e m a
ma)

instance (Monad m) => Monad (JOSE e m) where
  JOSE ExceptT e m a
ma >>= :: forall a b. JOSE e m a -> (a -> JOSE e m b) -> JOSE e m b
>>= a -> JOSE e m b
f = ExceptT e m b -> JOSE e m b
forall e (m :: * -> *) a. ExceptT e m a -> JOSE e m a
JOSE (ExceptT e m a
ma ExceptT e m a -> (a -> ExceptT e m b) -> ExceptT e m b
forall a b. ExceptT e m a -> (a -> ExceptT e m b) -> ExceptT e m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= JOSE e m b -> ExceptT e m b
forall e (m :: * -> *) a. JOSE e m a -> ExceptT e m a
unwrapJOSE (JOSE e m b -> ExceptT e m b)
-> (a -> JOSE e m b) -> a -> ExceptT e m b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> JOSE e m b
f)

instance MonadTrans (JOSE e) where
  lift :: forall (m :: * -> *) a. Monad m => m a -> JOSE e m a
lift = ExceptT e m a -> JOSE e m a
forall e (m :: * -> *) a. ExceptT e m a -> JOSE e m a
JOSE (ExceptT e m a -> JOSE e m a)
-> (m a -> ExceptT e m a) -> m a -> JOSE e m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m a -> ExceptT e m a
forall (m :: * -> *) a. Monad m => m a -> ExceptT e m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift

instance (Monad m) => MonadError e (JOSE e m) where
  throwError :: forall a. e -> JOSE e m a
throwError = ExceptT e m a -> JOSE e m a
forall e (m :: * -> *) a. ExceptT e m a -> JOSE e m a
JOSE (ExceptT e m a -> JOSE e m a)
-> (e -> ExceptT e m a) -> e -> JOSE e m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. e -> ExceptT e m a
forall a. e -> ExceptT e m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError
  catchError :: forall a. JOSE e m a -> (e -> JOSE e m a) -> JOSE e m a
catchError (JOSE ExceptT e m a
m) e -> JOSE e m a
handle = ExceptT e m a -> JOSE e m a
forall e (m :: * -> *) a. ExceptT e m a -> JOSE e m a
JOSE (ExceptT e m a -> (e -> ExceptT e m a) -> ExceptT e m a
forall a. ExceptT e m a -> (e -> ExceptT e m a) -> ExceptT e m a
forall e (m :: * -> *) a.
MonadError e m =>
m a -> (e -> m a) -> m a
catchError ExceptT e m a
m (JOSE e m a -> ExceptT e m a
forall e (m :: * -> *) a. JOSE e m a -> ExceptT e m a
unwrapJOSE (JOSE e m a -> ExceptT e m a)
-> (e -> JOSE e m a) -> e -> ExceptT e m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. e -> JOSE e m a
handle))

instance (MonadIO m) => MonadIO (JOSE e m) where
  liftIO :: forall a. IO a -> JOSE e m a
liftIO = ExceptT e m a -> JOSE e m a
forall e (m :: * -> *) a. ExceptT e m a -> JOSE e m a
JOSE (ExceptT e m a -> JOSE e m a)
-> (IO a -> ExceptT e m a) -> IO a -> JOSE e m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO a -> ExceptT e m a
forall a. IO a -> ExceptT e m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO

instance (MonadRandom m) => MonadRandom (JOSE e m) where
    getRandomBytes :: forall byteArray. ByteArray byteArray => Int -> JOSE e m byteArray
getRandomBytes = m byteArray -> JOSE e m byteArray
forall (m :: * -> *) a. Monad m => m a -> JOSE e m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m byteArray -> JOSE e m byteArray)
-> (Int -> m byteArray) -> Int -> JOSE e m byteArray
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> m byteArray
forall byteArray. ByteArray byteArray => Int -> m byteArray
forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
Int -> m byteArray
getRandomBytes