{-# LANGUAGE RecordWildCards #-}
-- Manipulating JWT claims with addClaim etc. directly is deprecated, so we'll need to fix that eventually.
-- The new way appears to be to define custom types with JSON instances and use those to encode/decode the JWT;
-- see https://github.com/frasertweedale/hs-jose/issues/116
-- https://github.com/unisonweb/unison/issues/5153
{-# OPTIONS_GHC -Wno-deprecations #-}

-- | Hash-related types in the Share API.
module Unison.Share.API.Hash
  ( -- * Hash types
    HashJWT (..),
    hashJWTHash,
    HashJWTClaims (..),
    DecodedHashJWT (..),
    decodeHashJWT,
    decodeHashJWTClaims,
    decodedHashJWTHash,
  )
where

import Control.Lens (folding, ix, (^?))
import Crypto.JWT qualified as Jose
import Data.Aeson
import Data.Aeson qualified as Aeson
import Data.ByteArray.Encoding qualified as BE
import Data.ByteString.Lazy qualified as BL
import Data.Text qualified as Text
import Data.Text.Encoding qualified as Text
import Servant.Auth.JWT qualified as Servant.Auth
import Unison.Hash32 (Hash32)
import Unison.Hash32.Orphans.Aeson ()
import Unison.Prelude

newtype HashJWT = HashJWT {HashJWT -> Text
unHashJWT :: Text}
  deriving newtype (Int -> HashJWT -> ShowS
[HashJWT] -> ShowS
HashJWT -> String
(Int -> HashJWT -> ShowS)
-> (HashJWT -> String) -> ([HashJWT] -> ShowS) -> Show HashJWT
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> HashJWT -> ShowS
showsPrec :: Int -> HashJWT -> ShowS
$cshow :: HashJWT -> String
show :: HashJWT -> String
$cshowList :: [HashJWT] -> ShowS
showList :: [HashJWT] -> ShowS
Show, HashJWT -> HashJWT -> Bool
(HashJWT -> HashJWT -> Bool)
-> (HashJWT -> HashJWT -> Bool) -> Eq HashJWT
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: HashJWT -> HashJWT -> Bool
== :: HashJWT -> HashJWT -> Bool
$c/= :: HashJWT -> HashJWT -> Bool
/= :: HashJWT -> HashJWT -> Bool
Eq, Eq HashJWT
Eq HashJWT =>
(HashJWT -> HashJWT -> Ordering)
-> (HashJWT -> HashJWT -> Bool)
-> (HashJWT -> HashJWT -> Bool)
-> (HashJWT -> HashJWT -> Bool)
-> (HashJWT -> HashJWT -> Bool)
-> (HashJWT -> HashJWT -> HashJWT)
-> (HashJWT -> HashJWT -> HashJWT)
-> Ord HashJWT
HashJWT -> HashJWT -> Bool
HashJWT -> HashJWT -> Ordering
HashJWT -> HashJWT -> HashJWT
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: HashJWT -> HashJWT -> Ordering
compare :: HashJWT -> HashJWT -> Ordering
$c< :: HashJWT -> HashJWT -> Bool
< :: HashJWT -> HashJWT -> Bool
$c<= :: HashJWT -> HashJWT -> Bool
<= :: HashJWT -> HashJWT -> Bool
$c> :: HashJWT -> HashJWT -> Bool
> :: HashJWT -> HashJWT -> Bool
$c>= :: HashJWT -> HashJWT -> Bool
>= :: HashJWT -> HashJWT -> Bool
$cmax :: HashJWT -> HashJWT -> HashJWT
max :: HashJWT -> HashJWT -> HashJWT
$cmin :: HashJWT -> HashJWT -> HashJWT
min :: HashJWT -> HashJWT -> HashJWT
Ord, [HashJWT] -> Value
[HashJWT] -> Encoding
HashJWT -> Bool
HashJWT -> Value
HashJWT -> Encoding
(HashJWT -> Value)
-> (HashJWT -> Encoding)
-> ([HashJWT] -> Value)
-> ([HashJWT] -> Encoding)
-> (HashJWT -> Bool)
-> ToJSON HashJWT
forall a.
(a -> Value)
-> (a -> Encoding)
-> ([a] -> Value)
-> ([a] -> Encoding)
-> (a -> Bool)
-> ToJSON a
$ctoJSON :: HashJWT -> Value
toJSON :: HashJWT -> Value
$ctoEncoding :: HashJWT -> Encoding
toEncoding :: HashJWT -> Encoding
$ctoJSONList :: [HashJWT] -> Value
toJSONList :: [HashJWT] -> Value
$ctoEncodingList :: [HashJWT] -> Encoding
toEncodingList :: [HashJWT] -> Encoding
$comitField :: HashJWT -> Bool
omitField :: HashJWT -> Bool
ToJSON, Maybe HashJWT
Value -> Parser [HashJWT]
Value -> Parser HashJWT
(Value -> Parser HashJWT)
-> (Value -> Parser [HashJWT]) -> Maybe HashJWT -> FromJSON HashJWT
forall a.
(Value -> Parser a)
-> (Value -> Parser [a]) -> Maybe a -> FromJSON a
$cparseJSON :: Value -> Parser HashJWT
parseJSON :: Value -> Parser HashJWT
$cparseJSONList :: Value -> Parser [HashJWT]
parseJSONList :: Value -> Parser [HashJWT]
$comittedField :: Maybe HashJWT
omittedField :: Maybe HashJWT
FromJSON)

-- | Grab the hash out of a hash JWT.
--
-- This decodes the whole JWT, then throws away the claims; use it if you really only need the hash!
hashJWTHash :: HashJWT -> Hash32
hashJWTHash :: HashJWT -> Hash32
hashJWTHash =
  DecodedHashJWT -> Hash32
decodedHashJWTHash (DecodedHashJWT -> Hash32)
-> (HashJWT -> DecodedHashJWT) -> HashJWT -> Hash32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HashJWT -> DecodedHashJWT
decodeHashJWT

data HashJWTClaims = HashJWTClaims
  { HashJWTClaims -> Hash32
hash :: Hash32,
    HashJWTClaims -> Maybe Text
userId :: Maybe Text
  }
  deriving stock (Int -> HashJWTClaims -> ShowS
[HashJWTClaims] -> ShowS
HashJWTClaims -> String
(Int -> HashJWTClaims -> ShowS)
-> (HashJWTClaims -> String)
-> ([HashJWTClaims] -> ShowS)
-> Show HashJWTClaims
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> HashJWTClaims -> ShowS
showsPrec :: Int -> HashJWTClaims -> ShowS
$cshow :: HashJWTClaims -> String
show :: HashJWTClaims -> String
$cshowList :: [HashJWTClaims] -> ShowS
showList :: [HashJWTClaims] -> ShowS
Show, HashJWTClaims -> HashJWTClaims -> Bool
(HashJWTClaims -> HashJWTClaims -> Bool)
-> (HashJWTClaims -> HashJWTClaims -> Bool) -> Eq HashJWTClaims
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: HashJWTClaims -> HashJWTClaims -> Bool
== :: HashJWTClaims -> HashJWTClaims -> Bool
$c/= :: HashJWTClaims -> HashJWTClaims -> Bool
/= :: HashJWTClaims -> HashJWTClaims -> Bool
Eq, Eq HashJWTClaims
Eq HashJWTClaims =>
(HashJWTClaims -> HashJWTClaims -> Ordering)
-> (HashJWTClaims -> HashJWTClaims -> Bool)
-> (HashJWTClaims -> HashJWTClaims -> Bool)
-> (HashJWTClaims -> HashJWTClaims -> Bool)
-> (HashJWTClaims -> HashJWTClaims -> Bool)
-> (HashJWTClaims -> HashJWTClaims -> HashJWTClaims)
-> (HashJWTClaims -> HashJWTClaims -> HashJWTClaims)
-> Ord HashJWTClaims
HashJWTClaims -> HashJWTClaims -> Bool
HashJWTClaims -> HashJWTClaims -> Ordering
HashJWTClaims -> HashJWTClaims -> HashJWTClaims
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: HashJWTClaims -> HashJWTClaims -> Ordering
compare :: HashJWTClaims -> HashJWTClaims -> Ordering
$c< :: HashJWTClaims -> HashJWTClaims -> Bool
< :: HashJWTClaims -> HashJWTClaims -> Bool
$c<= :: HashJWTClaims -> HashJWTClaims -> Bool
<= :: HashJWTClaims -> HashJWTClaims -> Bool
$c> :: HashJWTClaims -> HashJWTClaims -> Bool
> :: HashJWTClaims -> HashJWTClaims -> Bool
$c>= :: HashJWTClaims -> HashJWTClaims -> Bool
>= :: HashJWTClaims -> HashJWTClaims -> Bool
$cmax :: HashJWTClaims -> HashJWTClaims -> HashJWTClaims
max :: HashJWTClaims -> HashJWTClaims -> HashJWTClaims
$cmin :: HashJWTClaims -> HashJWTClaims -> HashJWTClaims
min :: HashJWTClaims -> HashJWTClaims -> HashJWTClaims
Ord)

-- | Adding a type tag to the jwt prevents users from using jwts we issue for other things
-- in this spot. All of our jwts should have a type parameter of some kind.
hashJWTType :: String
hashJWTType :: String
hashJWTType = String
"hj"

instance Servant.Auth.ToJWT HashJWTClaims where
  encodeJWT :: HashJWTClaims -> ClaimsSet
encodeJWT (HashJWTClaims Hash32
h Maybe Text
u) =
    ClaimsSet
Jose.emptyClaimsSet
      ClaimsSet -> (ClaimsSet -> ClaimsSet) -> ClaimsSet
forall a b. a -> (a -> b) -> b
& Text -> Value -> ClaimsSet -> ClaimsSet
Jose.addClaim Text
"h" (Hash32 -> Value
forall a. ToJSON a => a -> Value
toJSON Hash32
h)
      ClaimsSet -> (ClaimsSet -> ClaimsSet) -> ClaimsSet
forall a b. a -> (a -> b) -> b
& Text -> Value -> ClaimsSet -> ClaimsSet
Jose.addClaim Text
"u" (Maybe Text -> Value
forall a. ToJSON a => a -> Value
toJSON Maybe Text
u)
      ClaimsSet -> (ClaimsSet -> ClaimsSet) -> ClaimsSet
forall a b. a -> (a -> b) -> b
& Text -> Value -> ClaimsSet -> ClaimsSet
Jose.addClaim Text
"t" (String -> Value
forall a. ToJSON a => a -> Value
toJSON String
hashJWTType)

instance Servant.Auth.FromJWT HashJWTClaims where
  decodeJWT :: ClaimsSet -> Either Text HashJWTClaims
decodeJWT ClaimsSet
claims = Either Text HashJWTClaims
-> (HashJWTClaims -> Either Text HashJWTClaims)
-> Maybe HashJWTClaims
-> Either Text HashJWTClaims
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (Text -> Either Text HashJWTClaims
forall a b. a -> Either a b
Left Text
"Invalid HashJWTClaims") HashJWTClaims -> Either Text HashJWTClaims
forall a. a -> Either Text a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe HashJWTClaims -> Either Text HashJWTClaims)
-> Maybe HashJWTClaims -> Either Text HashJWTClaims
forall a b. (a -> b) -> a -> b
$ do
    hash <- ClaimsSet
claims ClaimsSet
-> Getting (First Hash32) ClaimsSet Hash32 -> Maybe Hash32
forall s a. s -> Getting (First a) s a -> Maybe a
^? (Map Text Value -> Const (First Hash32) (Map Text Value))
-> ClaimsSet -> Const (First Hash32) ClaimsSet
Lens' ClaimsSet (Map Text Value)
Jose.unregisteredClaims ((Map Text Value -> Const (First Hash32) (Map Text Value))
 -> ClaimsSet -> Const (First Hash32) ClaimsSet)
-> ((Hash32 -> Const (First Hash32) Hash32)
    -> Map Text Value -> Const (First Hash32) (Map Text Value))
-> Getting (First Hash32) ClaimsSet Hash32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Index (Map Text Value)
-> Traversal' (Map Text Value) (IxValue (Map Text Value))
forall m. Ixed m => Index m -> Traversal' m (IxValue m)
ix Text
Index (Map Text Value)
"h" ((Value -> Const (First Hash32) Value)
 -> Map Text Value -> Const (First Hash32) (Map Text Value))
-> ((Hash32 -> Const (First Hash32) Hash32)
    -> Value -> Const (First Hash32) Value)
-> (Hash32 -> Const (First Hash32) Hash32)
-> Map Text Value
-> Const (First Hash32) (Map Text Value)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Value -> Result Hash32) -> Fold Value Hash32
forall (f :: * -> *) s a. Foldable f => (s -> f a) -> Fold s a
folding Value -> Result Hash32
forall a. FromJSON a => Value -> Result a
fromJSON
    userId <- claims ^? Jose.unregisteredClaims . ix "u" . folding fromJSON
    case claims ^? Jose.unregisteredClaims . ix "t" . folding fromJSON of
      Just String
t | String
t String -> String -> Bool
forall a. Eq a => a -> a -> Bool
== String
hashJWTType -> () -> Maybe ()
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
      Maybe String
_ -> Maybe ()
forall a. Maybe a
forall (f :: * -> *) a. Alternative f => f a
empty
    pure HashJWTClaims {..}

instance ToJSON HashJWTClaims where
  toJSON :: HashJWTClaims -> Value
toJSON (HashJWTClaims Hash32
hash Maybe Text
userId) =
    [Pair] -> Value
object
      [ Key
"h" Key -> Hash32 -> Pair
forall v. ToJSON v => Key -> v -> Pair
forall e kv v. (KeyValue e kv, ToJSON v) => Key -> v -> kv
.= Hash32
hash,
        Key
"u" Key -> Maybe Text -> Pair
forall v. ToJSON v => Key -> v -> Pair
forall e kv v. (KeyValue e kv, ToJSON v) => Key -> v -> kv
.= Maybe Text
userId,
        Key
"t" Key -> String -> Pair
forall v. ToJSON v => Key -> v -> Pair
forall e kv v. (KeyValue e kv, ToJSON v) => Key -> v -> kv
.= String
hashJWTType
      ]

instance FromJSON HashJWTClaims where
  parseJSON :: Value -> Parser HashJWTClaims
parseJSON = String
-> (Object -> Parser HashJWTClaims)
-> Value
-> Parser HashJWTClaims
forall a. String -> (Object -> Parser a) -> Value -> Parser a
Aeson.withObject String
"HashJWTClaims" \Object
obj -> do
    hash <- Object
obj Object -> Key -> Parser Hash32
forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"h"
    userId <- obj .: "u"
    pure HashJWTClaims {..}

-- | A decoded hash JWT that retains the original encoded JWT.
data DecodedHashJWT = DecodedHashJWT
  { DecodedHashJWT -> HashJWTClaims
claims :: HashJWTClaims,
    DecodedHashJWT -> HashJWT
hashJWT :: HashJWT
  }
  deriving (DecodedHashJWT -> DecodedHashJWT -> Bool
(DecodedHashJWT -> DecodedHashJWT -> Bool)
-> (DecodedHashJWT -> DecodedHashJWT -> Bool) -> Eq DecodedHashJWT
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: DecodedHashJWT -> DecodedHashJWT -> Bool
== :: DecodedHashJWT -> DecodedHashJWT -> Bool
$c/= :: DecodedHashJWT -> DecodedHashJWT -> Bool
/= :: DecodedHashJWT -> DecodedHashJWT -> Bool
Eq, Eq DecodedHashJWT
Eq DecodedHashJWT =>
(DecodedHashJWT -> DecodedHashJWT -> Ordering)
-> (DecodedHashJWT -> DecodedHashJWT -> Bool)
-> (DecodedHashJWT -> DecodedHashJWT -> Bool)
-> (DecodedHashJWT -> DecodedHashJWT -> Bool)
-> (DecodedHashJWT -> DecodedHashJWT -> Bool)
-> (DecodedHashJWT -> DecodedHashJWT -> DecodedHashJWT)
-> (DecodedHashJWT -> DecodedHashJWT -> DecodedHashJWT)
-> Ord DecodedHashJWT
DecodedHashJWT -> DecodedHashJWT -> Bool
DecodedHashJWT -> DecodedHashJWT -> Ordering
DecodedHashJWT -> DecodedHashJWT -> DecodedHashJWT
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: DecodedHashJWT -> DecodedHashJWT -> Ordering
compare :: DecodedHashJWT -> DecodedHashJWT -> Ordering
$c< :: DecodedHashJWT -> DecodedHashJWT -> Bool
< :: DecodedHashJWT -> DecodedHashJWT -> Bool
$c<= :: DecodedHashJWT -> DecodedHashJWT -> Bool
<= :: DecodedHashJWT -> DecodedHashJWT -> Bool
$c> :: DecodedHashJWT -> DecodedHashJWT -> Bool
> :: DecodedHashJWT -> DecodedHashJWT -> Bool
$c>= :: DecodedHashJWT -> DecodedHashJWT -> Bool
>= :: DecodedHashJWT -> DecodedHashJWT -> Bool
$cmax :: DecodedHashJWT -> DecodedHashJWT -> DecodedHashJWT
max :: DecodedHashJWT -> DecodedHashJWT -> DecodedHashJWT
$cmin :: DecodedHashJWT -> DecodedHashJWT -> DecodedHashJWT
min :: DecodedHashJWT -> DecodedHashJWT -> DecodedHashJWT
Ord, Int -> DecodedHashJWT -> ShowS
[DecodedHashJWT] -> ShowS
DecodedHashJWT -> String
(Int -> DecodedHashJWT -> ShowS)
-> (DecodedHashJWT -> String)
-> ([DecodedHashJWT] -> ShowS)
-> Show DecodedHashJWT
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> DecodedHashJWT -> ShowS
showsPrec :: Int -> DecodedHashJWT -> ShowS
$cshow :: DecodedHashJWT -> String
show :: DecodedHashJWT -> String
$cshowList :: [DecodedHashJWT] -> ShowS
showList :: [DecodedHashJWT] -> ShowS
Show)

-- | Decode a hash JWT.
decodeHashJWT :: HashJWT -> DecodedHashJWT
decodeHashJWT :: HashJWT -> DecodedHashJWT
decodeHashJWT HashJWT
hashJWT =
  DecodedHashJWT
    { claims :: HashJWTClaims
claims = HasCallStack => HashJWT -> HashJWTClaims
HashJWT -> HashJWTClaims
decodeHashJWTClaims HashJWT
hashJWT,
      HashJWT
hashJWT :: HashJWT
hashJWT :: HashJWT
hashJWT
    }

-- | ATTENTION: THIS DOES NOT VERIFY THE JWT
-- Decode the claims out of a hash JWT,
decodeHashJWTClaims :: (HasCallStack) => HashJWT -> HashJWTClaims
decodeHashJWTClaims :: HasCallStack => HashJWT -> HashJWTClaims
decodeHashJWTClaims (HashJWT Text
text) =
  HasCallStack => Text -> Text -> [Text]
Text -> Text -> [Text]
Text.splitOn Text
"." Text
text
    [Text] -> ([Text] -> Text) -> Text
forall a b. a -> (a -> b) -> b
& \case
      [Text
_, Text
body, Text
_] -> Text
body
      [Text]
_ -> String -> Text
forall a. HasCallStack => String -> a
error (String -> Text) -> String -> Text
forall a b. (a -> b) -> a -> b
$ String
"decodeHashJWTClaims: Encountered invalid JWT: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Text -> String
forall a. Show a => a -> String
show Text
text
    Text -> (Text -> ByteString) -> ByteString
forall a b. a -> (a -> b) -> b
& Text -> ByteString
Text.encodeUtf8
    ByteString
-> (ByteString -> Either String ByteString)
-> Either String ByteString
forall a b. a -> (a -> b) -> b
& Base -> ByteString -> Either String ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
Base -> bin -> Either String bout
BE.convertFromBase Base
BE.Base64URLUnpadded
    Either String ByteString
-> (Either String ByteString -> ByteString) -> ByteString
forall a b. a -> (a -> b) -> b
& ByteString -> Either String ByteString -> ByteString
forall b a. b -> Either a b -> b
fromRight (String -> ByteString
forall a. HasCallStack => String -> a
error (String -> ByteString) -> String -> ByteString
forall a b. (a -> b) -> a -> b
$ String
"decodeHashJWTClaims: Encountered invalid JWT, bad base64 in body: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Text -> String
forall a. Show a => a -> String
show Text
text)
    ByteString -> (ByteString -> LazyByteString) -> LazyByteString
forall a b. a -> (a -> b) -> b
& ByteString -> LazyByteString
BL.fromStrict
    LazyByteString
-> (LazyByteString -> Maybe HashJWTClaims) -> Maybe HashJWTClaims
forall a b. a -> (a -> b) -> b
& forall a. FromJSON a => LazyByteString -> Maybe a
Aeson.decode @HashJWTClaims
    Maybe HashJWTClaims
-> (Maybe HashJWTClaims -> HashJWTClaims) -> HashJWTClaims
forall a b. a -> (a -> b) -> b
& HashJWTClaims -> Maybe HashJWTClaims -> HashJWTClaims
forall a. a -> Maybe a -> a
fromMaybe (String -> HashJWTClaims
forall a. HasCallStack => String -> a
error (String -> HashJWTClaims) -> String -> HashJWTClaims
forall a b. (a -> b) -> a -> b
$ String
"decodeHashJWTClaims: Encountered invalid JWT, failed to decode claims: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Text -> String
forall a. Show a => a -> String
show Text
text)

-- | Grab the hash out of a decoded hash JWT.
decodedHashJWTHash :: DecodedHashJWT -> Hash32
decodedHashJWTHash :: DecodedHashJWT -> Hash32
decodedHashJWTHash DecodedHashJWT {claims :: DecodedHashJWT -> HashJWTClaims
claims = HashJWTClaims {Hash32
hash :: HashJWTClaims -> Hash32
hash :: Hash32
hash}} =
  Hash32
hash