-- Module for working with "PersonalKeys" in Unison Auth.
--
-- A Personal Key is just an Ed25519 EdDSA key pair.
-- We use the private key to make assertions on behalf of the user,
-- such as signing comments.
--
-- Then we can register the public key with our share user account.
-- to link the key to the user.

module Unison.Auth.PersonalKey
  ( PersonalPrivateKey,
    encodePrivateKey,
    PersonalPublicKey,
    publicKey,
    generatePersonalKey,
    personalKeyThumbprint,
    signWithPersonalKey,
    verifyWithPersonalKey,
    PersonalKeySignature (..),
  )
where

import Control.Monad.Error.Class
import Control.Monad.Trans.Except
import Crypto.JOSE qualified as JOSE
import Crypto.JOSE.JWA.JWK qualified as JWA
import Crypto.JOSE.JWK (JWK, KeyMaterialGenParam (OKPGenParam), OKPCrv (Ed25519), genJWK)
import Crypto.JOSE.JWK qualified as JWK
import Crypto.JOSE.JWS qualified as JWS
import Crypto.Random
import Data.Aeson (ToJSON)
import Data.Aeson qualified as Aeson
import Data.Aeson.Types (Value)
import Data.ByteArray qualified as ByteArray
import Data.ByteString qualified as BS
import Data.ByteString.Base64.URL qualified as Base64URL
import Data.Text.Encoding qualified as Text
import Unison.KeyThumbprint (KeyThumbprint (..))
import Unison.Prelude

-- | A JWK representing a personal key
newtype PersonalPrivateKey = PersonalPrivateKey {PersonalPrivateKey -> JWK
_personalPrivateKeyJWK :: JWK}
  deriving stock (PersonalPrivateKey -> PersonalPrivateKey -> Bool
(PersonalPrivateKey -> PersonalPrivateKey -> Bool)
-> (PersonalPrivateKey -> PersonalPrivateKey -> Bool)
-> Eq PersonalPrivateKey
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: PersonalPrivateKey -> PersonalPrivateKey -> Bool
== :: PersonalPrivateKey -> PersonalPrivateKey -> Bool
$c/= :: PersonalPrivateKey -> PersonalPrivateKey -> Bool
/= :: PersonalPrivateKey -> PersonalPrivateKey -> Bool
Eq)
  deriving newtype (Value -> Parser [PersonalPrivateKey]
Value -> Parser PersonalPrivateKey
(Value -> Parser PersonalPrivateKey)
-> (Value -> Parser [PersonalPrivateKey])
-> FromJSON PersonalPrivateKey
forall a.
(Value -> Parser a) -> (Value -> Parser [a]) -> FromJSON a
$cparseJSON :: Value -> Parser PersonalPrivateKey
parseJSON :: Value -> Parser PersonalPrivateKey
$cparseJSONList :: Value -> Parser [PersonalPrivateKey]
parseJSONList :: Value -> Parser [PersonalPrivateKey]
Aeson.FromJSON)

personalKeyThumbprint :: PersonalPrivateKey -> KeyThumbprint
personalKeyThumbprint :: PersonalPrivateKey -> KeyThumbprint
personalKeyThumbprint (PersonalPrivateKey JWK
jwk) = JWK -> KeyThumbprint
jwkThumbprint JWK
jwk

jwkThumbprint :: JWK.JWK -> KeyThumbprint
jwkThumbprint :: JWK -> KeyThumbprint
jwkThumbprint JWK
jwk =
  JWK
jwk JWK -> Getting (Digest SHA256) JWK (Digest SHA256) -> Digest SHA256
forall s a. s -> Getting a s a -> a
^. forall a. HashAlgorithm a => Getter JWK (Digest a)
JWK.thumbprint @JWK.SHA256
    Digest SHA256 -> (Digest SHA256 -> ByteString) -> ByteString
forall a b. a -> (a -> b) -> b
& Digest SHA256 -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
ByteArray.convert
    ByteString -> (ByteString -> ByteString) -> ByteString
forall a b. a -> (a -> b) -> b
& ByteString -> ByteString
Base64URL.encodeUnpadded
    ByteString -> (ByteString -> Text) -> Text
forall a b. a -> (a -> b) -> b
& ByteString -> Text
Text.decodeUtf8
    Text -> (Text -> KeyThumbprint) -> KeyThumbprint
forall a b. a -> (a -> b) -> b
& Text -> KeyThumbprint
KeyThumbprint

-- | Encode the private JWK.
--
-- I left off a ToJSON instance because I want to be explicit about when
-- we're encoding the private key.
encodePrivateKey :: PersonalPrivateKey -> Value
encodePrivateKey :: PersonalPrivateKey -> Value
encodePrivateKey (PersonalPrivateKey JWK
jwk) = JWK -> Value
forall a. ToJSON a => a -> Value
Aeson.toJSON JWK
jwk

publicKey :: PersonalPrivateKey -> PersonalPublicKey
publicKey :: PersonalPrivateKey -> PersonalPublicKey
publicKey (PersonalPrivateKey JWK
jwk) = case (JWK
jwk JWK -> Getting (Maybe JWK) JWK (Maybe JWK) -> Maybe JWK
forall s a. s -> Getting a s a -> a
^. Getting (Maybe JWK) JWK (Maybe JWK)
forall k. AsPublicKey k => Getter k (Maybe k)
Getter JWK (Maybe JWK)
JWK.asPublicKey) of
  Just JWK
public -> JWK -> PersonalPublicKey
PersonalPublicKey JWK
public
  Maybe JWK
Nothing -> [Char] -> PersonalPublicKey
forall a. HasCallStack => [Char] -> a
error [Char]
"publicKey: Failed to extract public key from private key. This should never happen."

newtype PersonalPublicKey = PersonalPublicKey {PersonalPublicKey -> JWK
_personalPublicKeyJWK :: JWK}
  deriving newtype ([PersonalPublicKey] -> Value
[PersonalPublicKey] -> Encoding
PersonalPublicKey -> Value
PersonalPublicKey -> Encoding
(PersonalPublicKey -> Value)
-> (PersonalPublicKey -> Encoding)
-> ([PersonalPublicKey] -> Value)
-> ([PersonalPublicKey] -> Encoding)
-> ToJSON PersonalPublicKey
forall a.
(a -> Value)
-> (a -> Encoding)
-> ([a] -> Value)
-> ([a] -> Encoding)
-> ToJSON a
$ctoJSON :: PersonalPublicKey -> Value
toJSON :: PersonalPublicKey -> Value
$ctoEncoding :: PersonalPublicKey -> Encoding
toEncoding :: PersonalPublicKey -> Encoding
$ctoJSONList :: [PersonalPublicKey] -> Value
toJSONList :: [PersonalPublicKey] -> Value
$ctoEncodingList :: [PersonalPublicKey] -> Encoding
toEncodingList :: [PersonalPublicKey] -> Encoding
ToJSON)

-- Generate a single Ed25519 JWK
generatePersonalKey :: (MonadIO m) => m PersonalPrivateKey
generatePersonalKey :: forall (m :: * -> *). MonadIO m => m PersonalPrivateKey
generatePersonalKey = IO PersonalPrivateKey -> m PersonalPrivateKey
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO PersonalPrivateKey -> m PersonalPrivateKey)
-> IO PersonalPrivateKey -> m PersonalPrivateKey
forall a b. (a -> b) -> a -> b
$ do
  forall (m :: * -> *). MonadRandom m => KeyMaterialGenParam -> m JWK
genJWK @IO (OKPCrv -> KeyMaterialGenParam
OKPGenParam OKPCrv
Ed25519)
    IO JWK -> (JWK -> JWK) -> IO JWK
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (Maybe KeyUse -> Identity (Maybe KeyUse)) -> JWK -> Identity JWK
Lens' JWK (Maybe KeyUse)
JWK.jwkUse ((Maybe KeyUse -> Identity (Maybe KeyUse)) -> JWK -> Identity JWK)
-> Maybe KeyUse -> JWK -> JWK
forall s t a b. ASetter s t a b -> b -> s -> t
.~ KeyUse -> Maybe KeyUse
forall a. a -> Maybe a
Just KeyUse
JWK.Sig
    IO JWK -> (JWK -> JWK) -> IO JWK
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (Maybe JWKAlg -> Identity (Maybe JWKAlg)) -> JWK -> Identity JWK
Lens' JWK (Maybe JWKAlg)
JWK.jwkAlg ((Maybe JWKAlg -> Identity (Maybe JWKAlg)) -> JWK -> Identity JWK)
-> Maybe JWKAlg -> JWK -> JWK
forall s t a b. ASetter s t a b -> b -> s -> t
.~ JWKAlg -> Maybe JWKAlg
forall a. a -> Maybe a
Just (Alg -> JWKAlg
JWK.JWSAlg Alg
JWS.EdDSA)
    IO JWK -> (JWK -> JWK) -> IO JWK
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (\JWK
j -> JWK
j JWK -> (JWK -> JWK) -> JWK
forall a b. a -> (a -> b) -> b
& (Maybe Text -> Identity (Maybe Text)) -> JWK -> Identity JWK
Lens' JWK (Maybe Text)
JWK.jwkKid ((Maybe Text -> Identity (Maybe Text)) -> JWK -> Identity JWK)
-> Maybe Text -> JWK -> JWK
forall s t a b. ASetter s t a b -> b -> s -> t
.~ Text -> Maybe Text
forall a. a -> Maybe a
Just (KeyThumbprint -> Text
thumbprintToText (KeyThumbprint -> Text) -> KeyThumbprint -> Text
forall a b. (a -> b) -> a -> b
$ JWK -> KeyThumbprint
jwkThumbprint JWK
j))
    IO JWK -> (JWK -> PersonalPrivateKey) -> IO PersonalPrivateKey
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> JWK -> PersonalPrivateKey
PersonalPrivateKey

newtype PersonalKeySignature = PersonalKeySignature {PersonalKeySignature -> ByteString
unPersonalKeySignature :: ByteString}
  deriving (Int -> PersonalKeySignature -> ShowS
[PersonalKeySignature] -> ShowS
PersonalKeySignature -> [Char]
(Int -> PersonalKeySignature -> ShowS)
-> (PersonalKeySignature -> [Char])
-> ([PersonalKeySignature] -> ShowS)
-> Show PersonalKeySignature
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> PersonalKeySignature -> ShowS
showsPrec :: Int -> PersonalKeySignature -> ShowS
$cshow :: PersonalKeySignature -> [Char]
show :: PersonalKeySignature -> [Char]
$cshowList :: [PersonalKeySignature] -> ShowS
showList :: [PersonalKeySignature] -> ShowS
Show, PersonalKeySignature -> PersonalKeySignature -> Bool
(PersonalKeySignature -> PersonalKeySignature -> Bool)
-> (PersonalKeySignature -> PersonalKeySignature -> Bool)
-> Eq PersonalKeySignature
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: PersonalKeySignature -> PersonalKeySignature -> Bool
== :: PersonalKeySignature -> PersonalKeySignature -> Bool
$c/= :: PersonalKeySignature -> PersonalKeySignature -> Bool
/= :: PersonalKeySignature -> PersonalKeySignature -> Bool
Eq)

-- | For some reason `sign` and `verify` require a single monad which implements both MonadRandom and MonadError,
-- but ExceptT doesn't implement MonadRandom :|
newtype SignM a = SignM {forall a. SignM a -> ExceptT Error IO a
_unSignM :: ExceptT JOSE.Error IO a}
  deriving newtype ((forall a b. (a -> b) -> SignM a -> SignM b)
-> (forall a b. a -> SignM b -> SignM a) -> Functor SignM
forall a b. a -> SignM b -> SignM a
forall a b. (a -> b) -> SignM a -> SignM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall a b. (a -> b) -> SignM a -> SignM b
fmap :: forall a b. (a -> b) -> SignM a -> SignM b
$c<$ :: forall a b. a -> SignM b -> SignM a
<$ :: forall a b. a -> SignM b -> SignM a
Functor, Functor SignM
Functor SignM =>
(forall a. a -> SignM a)
-> (forall a b. SignM (a -> b) -> SignM a -> SignM b)
-> (forall a b c. (a -> b -> c) -> SignM a -> SignM b -> SignM c)
-> (forall a b. SignM a -> SignM b -> SignM b)
-> (forall a b. SignM a -> SignM b -> SignM a)
-> Applicative SignM
forall a. a -> SignM a
forall a b. SignM a -> SignM b -> SignM a
forall a b. SignM a -> SignM b -> SignM b
forall a b. SignM (a -> b) -> SignM a -> SignM b
forall a b c. (a -> b -> c) -> SignM a -> SignM b -> SignM c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall a. a -> SignM a
pure :: forall a. a -> SignM a
$c<*> :: forall a b. SignM (a -> b) -> SignM a -> SignM b
<*> :: forall a b. SignM (a -> b) -> SignM a -> SignM b
$cliftA2 :: forall a b c. (a -> b -> c) -> SignM a -> SignM b -> SignM c
liftA2 :: forall a b c. (a -> b -> c) -> SignM a -> SignM b -> SignM c
$c*> :: forall a b. SignM a -> SignM b -> SignM b
*> :: forall a b. SignM a -> SignM b -> SignM b
$c<* :: forall a b. SignM a -> SignM b -> SignM a
<* :: forall a b. SignM a -> SignM b -> SignM a
Applicative, Applicative SignM
Applicative SignM =>
(forall a b. SignM a -> (a -> SignM b) -> SignM b)
-> (forall a b. SignM a -> SignM b -> SignM b)
-> (forall a. a -> SignM a)
-> Monad SignM
forall a. a -> SignM a
forall a b. SignM a -> SignM b -> SignM b
forall a b. SignM a -> (a -> SignM b) -> SignM b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall a b. SignM a -> (a -> SignM b) -> SignM b
>>= :: forall a b. SignM a -> (a -> SignM b) -> SignM b
$c>> :: forall a b. SignM a -> SignM b -> SignM b
>> :: forall a b. SignM a -> SignM b -> SignM b
$creturn :: forall a. a -> SignM a
return :: forall a. a -> SignM a
Monad, Monad SignM
Monad SignM => (forall a. IO a -> SignM a) -> MonadIO SignM
forall a. IO a -> SignM a
forall (m :: * -> *).
Monad m =>
(forall a. IO a -> m a) -> MonadIO m
$cliftIO :: forall a. IO a -> SignM a
liftIO :: forall a. IO a -> SignM a
MonadIO, MonadError JOSE.Error)

runSignM :: (MonadIO m) => SignM a -> m (Either JOSE.Error a)
runSignM :: forall (m :: * -> *) a. MonadIO m => SignM a -> m (Either Error a)
runSignM (SignM ExceptT Error IO a
e) = IO (Either Error a) -> m (Either Error a)
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Either Error a) -> m (Either Error a))
-> IO (Either Error a) -> m (Either Error a)
forall a b. (a -> b) -> a -> b
$ ExceptT Error IO a -> IO (Either Error a)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT ExceptT Error IO a
e

instance MonadRandom SignM where
  getRandomBytes :: forall byteArray. ByteArray byteArray => Int -> SignM byteArray
getRandomBytes Int
n = ExceptT Error IO byteArray -> SignM byteArray
forall a. ExceptT Error IO a -> SignM a
SignM (ExceptT Error IO byteArray -> SignM byteArray)
-> (IO byteArray -> ExceptT Error IO byteArray)
-> IO byteArray
-> SignM byteArray
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO byteArray -> ExceptT Error IO byteArray
forall a. IO a -> ExceptT Error IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO byteArray -> SignM byteArray)
-> IO byteArray -> SignM byteArray
forall a b. (a -> b) -> a -> b
$ Int -> IO byteArray
forall byteArray. ByteArray byteArray => Int -> IO byteArray
forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
Int -> m byteArray
getRandomBytes Int
n

-- | Sign arbitrary bytes using a personal private key
--
-- >>> key <- generatePersonalKey
-- >>> let msg = "Hello, world!"
-- >>> signature <- fromRight (error "failed to sign") <$> signWithPersonalKey key msg
-- >>> verifyWithPersonalKey (publicKey key) msg signature
-- True
signWithPersonalKey :: (MonadIO m) => PersonalPrivateKey -> BS.ByteString -> m (Either JOSE.Error PersonalKeySignature)
signWithPersonalKey :: forall (m :: * -> *).
MonadIO m =>
PersonalPrivateKey
-> ByteString -> m (Either Error PersonalKeySignature)
signWithPersonalKey (PersonalPrivateKey JWK
jwk) ByteString
bytes = SignM PersonalKeySignature -> m (Either Error PersonalKeySignature)
forall (m :: * -> *) a. MonadIO m => SignM a -> m (Either Error a)
runSignM (SignM PersonalKeySignature
 -> m (Either Error PersonalKeySignature))
-> SignM PersonalKeySignature
-> m (Either Error PersonalKeySignature)
forall a b. (a -> b) -> a -> b
$ do
  ByteString -> PersonalKeySignature
PersonalKeySignature (ByteString -> PersonalKeySignature)
-> SignM ByteString -> SignM PersonalKeySignature
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall (m :: * -> *) e.
(MonadRandom m, MonadError e m, AsError e) =>
Alg -> KeyMaterial -> ByteString -> m ByteString
JWA.sign @SignM Alg
JWS.EdDSA (JWK
jwk JWK -> Getting KeyMaterial JWK KeyMaterial -> KeyMaterial
forall s a. s -> Getting a s a -> a
^. Getting KeyMaterial JWK KeyMaterial
Lens' JWK KeyMaterial
JWS.jwkMaterial) ByteString
bytes)

-- | Verify a signature made with a personal private key
verifyWithPersonalKey :: (MonadIO m) => PersonalPublicKey -> BS.ByteString -> PersonalKeySignature -> m Bool
verifyWithPersonalKey :: forall (m :: * -> *).
MonadIO m =>
PersonalPublicKey -> ByteString -> PersonalKeySignature -> m Bool
verifyWithPersonalKey (PersonalPublicKey JWK
jwk) ByteString
bytes (PersonalKeySignature ByteString
signature) =
  (forall e (m :: * -> *).
(MonadError e m, AsError e) =>
Alg -> KeyMaterial -> ByteString -> ByteString -> m Bool
JWA.verify @JOSE.Error @SignM Alg
JWS.EdDSA (JWK
jwk JWK -> Getting KeyMaterial JWK KeyMaterial -> KeyMaterial
forall s a. s -> Getting a s a -> a
^. Getting KeyMaterial JWK KeyMaterial
Lens' JWK KeyMaterial
JWS.jwkMaterial) ByteString
bytes ByteString
signature)
    SignM Bool
-> (SignM Bool -> m (Either Error Bool)) -> m (Either Error Bool)
forall a b. a -> (a -> b) -> b
& SignM Bool -> m (Either Error Bool)
forall (m :: * -> *) a. MonadIO m => SignM a -> m (Either Error a)
runSignM
    m (Either Error Bool) -> (Either Error Bool -> Bool) -> m Bool
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> Bool -> Either Error Bool -> Bool
forall b a. b -> Either a b -> b
fromRight Bool
False