module Unison.Runtime.Crypto.P256
  ( derivePublicKey,
    signSha256,
    verifySha256,
  )
where

import Crypto.Hash qualified as Hash
import Crypto.Number.Serialize (i2ospOf_, os2ip)
import Crypto.PubKey.ECC.ECDSA qualified as ECDSA
import Crypto.PubKey.ECC.Prim qualified as ECC
import Crypto.PubKey.ECC.Types qualified as ECC
import Data.ByteString qualified as BS
import Data.Word (Word8)
import Unison.Util.Text (Text)

curve :: ECC.Curve
curve :: Curve
curve = CurveName -> Curve
ECC.getCurveByName CurveName
ECC.SEC_p256r1

curveOrder :: Integer
curveOrder :: Integer
curveOrder = CurveCommon -> Integer
ECC.ecc_n (Curve -> CurveCommon
ECC.common_curve Curve
curve)

coordinateBytes :: Int
coordinateBytes :: Int
coordinateBytes = Int
32

uncompressedPrefix :: Word8
uncompressedPrefix :: Word8
uncompressedPrefix = Word8
0x04

derivePublicKey :: BS.ByteString -> Either Text BS.ByteString
derivePublicKey :: ByteString -> Either Text ByteString
derivePublicKey ByteString
privateKeyBytes = do
  PrivateKey
privateKey <- ByteString -> Either Text PrivateKey
parsePrivateKey ByteString
privateKeyBytes
  pure (PublicKey -> ByteString
encodePublicKey (PrivateKey -> PublicKey
toPublicKey PrivateKey
privateKey))

signSha256 :: BS.ByteString -> BS.ByteString -> Either Text BS.ByteString
signSha256 :: ByteString -> ByteString -> Either Text ByteString
signSha256 ByteString
privateKeyBytes ByteString
message = do
  PrivateKey
privateKey <- ByteString -> Either Text PrivateKey
parsePrivateKey ByteString
privateKeyBytes
  let digest :: Digest SHA256
digest = ByteString -> Digest SHA256
forall ba a.
(ByteArrayAccess ba, HashAlgorithm a) =>
ba -> Digest a
Hash.hash ByteString
message :: Hash.Digest Hash.SHA256
      signature :: Signature
signature =
        SHA256
-> PrivateKey
-> Digest SHA256
-> (Integer -> Maybe Signature)
-> Signature
forall hashDRG hashDigest a.
(HashAlgorithm hashDRG, HashAlgorithm hashDigest) =>
hashDRG
-> PrivateKey -> Digest hashDigest -> (Integer -> Maybe a) -> a
ECDSA.deterministicNonce SHA256
Hash.SHA256 PrivateKey
privateKey Digest SHA256
digest ((Integer -> Maybe Signature) -> Signature)
-> (Integer -> Maybe Signature) -> Signature
forall a b. (a -> b) -> a -> b
$
          \Integer
nonce -> Integer -> PrivateKey -> Digest SHA256 -> Maybe Signature
forall hash.
HashAlgorithm hash =>
Integer -> PrivateKey -> Digest hash -> Maybe Signature
ECDSA.signDigestWith Integer
nonce PrivateKey
privateKey Digest SHA256
digest
  ByteString -> Either Text ByteString
forall a. a -> Either Text a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Signature -> ByteString
encodeSignature Signature
signature)

verifySha256 :: BS.ByteString -> BS.ByteString -> BS.ByteString -> Either Text Bool
verifySha256 :: ByteString -> ByteString -> ByteString -> Either Text Bool
verifySha256 ByteString
publicKeyBytes ByteString
message ByteString
signatureBytes = do
  PublicKey
publicKey <- ByteString -> Either Text PublicKey
parsePublicKey ByteString
publicKeyBytes
  Signature
signature <- ByteString -> Either Text Signature
parseSignature ByteString
signatureBytes
  pure (SHA256 -> PublicKey -> Signature -> ByteString -> Bool
forall msg hash.
(ByteArrayAccess msg, HashAlgorithm hash) =>
hash -> PublicKey -> Signature -> msg -> Bool
ECDSA.verify SHA256
Hash.SHA256 PublicKey
publicKey Signature
signature ByteString
message)

parsePrivateKey :: BS.ByteString -> Either Text ECDSA.PrivateKey
parsePrivateKey :: ByteString -> Either Text PrivateKey
parsePrivateKey ByteString
bytes
  | ByteString -> Int
BS.length ByteString
bytes Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
coordinateBytes =
      Text -> Either Text PrivateKey
forall a b. a -> Either a b
Left Text
"p256: private key must be 32 bytes"
  | Integer
d Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= Integer
0 =
      Text -> Either Text PrivateKey
forall a b. a -> Either a b
Left Text
"p256: private key scalar must be non-zero"
  | Integer
d Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
>= Integer
curveOrder =
      Text -> Either Text PrivateKey
forall a b. a -> Either a b
Left Text
"p256: private key scalar is out of range"
  | Bool
otherwise =
      PrivateKey -> Either Text PrivateKey
forall a b. b -> Either a b
Right (Curve -> Integer -> PrivateKey
ECDSA.PrivateKey Curve
curve Integer
d)
  where
    d :: Integer
d = ByteString -> Integer
forall ba. ByteArrayAccess ba => ba -> Integer
os2ip ByteString
bytes

parsePublicKey :: BS.ByteString -> Either Text ECDSA.PublicKey
parsePublicKey :: ByteString -> Either Text PublicKey
parsePublicKey ByteString
bytes
  | ByteString -> Int
BS.length ByteString
bytes Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ (Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
coordinateBytes) =
      Text -> Either Text PublicKey
forall a b. a -> Either a b
Left Text
"p256: public key must be 65 bytes in uncompressed SEC1 form"
  | HasCallStack => ByteString -> Word8
ByteString -> Word8
BS.head ByteString
bytes Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
/= Word8
uncompressedPrefix =
      Text -> Either Text PublicKey
forall a b. a -> Either a b
Left Text
"p256: public key must start with 0x04 (uncompressed SEC1 form)"
  | Point
point Point -> Point -> Bool
forall a. Eq a => a -> a -> Bool
== Point
ECC.PointO =
      Text -> Either Text PublicKey
forall a b. a -> Either a b
Left Text
"p256: public key point at infinity is invalid"
  | Bool -> Bool
not (Curve -> Point -> Bool
ECC.isPointValid Curve
curve Point
point) =
      Text -> Either Text PublicKey
forall a b. a -> Either a b
Left Text
"p256: public key point is not on the curve"
  | Curve -> Integer -> Point -> Point
ECC.pointMul Curve
curve Integer
curveOrder Point
point Point -> Point -> Bool
forall a. Eq a => a -> a -> Bool
/= Point
ECC.PointO =
      Text -> Either Text PublicKey
forall a b. a -> Either a b
Left Text
"p256: public key point is not in the prime-order subgroup"
  | Bool
otherwise =
      PublicKey -> Either Text PublicKey
forall a b. b -> Either a b
Right (Curve -> Point -> PublicKey
ECDSA.PublicKey Curve
curve Point
point)
  where
    (ByteString
xBytes, ByteString
yBytes) = Int -> ByteString -> (ByteString, ByteString)
BS.splitAt Int
coordinateBytes (HasCallStack => ByteString -> ByteString
ByteString -> ByteString
BS.tail ByteString
bytes)
    point :: Point
point = Integer -> Integer -> Point
ECC.Point (ByteString -> Integer
forall ba. ByteArrayAccess ba => ba -> Integer
os2ip ByteString
xBytes) (ByteString -> Integer
forall ba. ByteArrayAccess ba => ba -> Integer
os2ip ByteString
yBytes)

parseSignature :: BS.ByteString -> Either Text ECDSA.Signature
parseSignature :: ByteString -> Either Text Signature
parseSignature ByteString
bytes
  | ByteString -> Int
BS.length ByteString
bytes Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
coordinateBytes =
      Text -> Either Text Signature
forall a b. a -> Either a b
Left Text
"p256: signature must be 64 bytes"
  | Integer
r Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= Integer
0 Bool -> Bool -> Bool
|| Integer
r Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
>= Integer
curveOrder =
      Text -> Either Text Signature
forall a b. a -> Either a b
Left Text
"p256: signature r component is out of range"
  | Integer
s Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= Integer
0 Bool -> Bool -> Bool
|| Integer
s Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
>= Integer
curveOrder =
      Text -> Either Text Signature
forall a b. a -> Either a b
Left Text
"p256: signature s component is out of range"
  | Bool
otherwise =
      Signature -> Either Text Signature
forall a b. b -> Either a b
Right (Integer -> Integer -> Signature
ECDSA.Signature Integer
r Integer
s)
  where
    (ByteString
rBytes, ByteString
sBytes) = Int -> ByteString -> (ByteString, ByteString)
BS.splitAt Int
coordinateBytes ByteString
bytes
    r :: Integer
r = ByteString -> Integer
forall ba. ByteArrayAccess ba => ba -> Integer
os2ip ByteString
rBytes
    s :: Integer
s = ByteString -> Integer
forall ba. ByteArrayAccess ba => ba -> Integer
os2ip ByteString
sBytes

toPublicKey :: ECDSA.PrivateKey -> ECDSA.PublicKey
toPublicKey :: PrivateKey -> PublicKey
toPublicKey PrivateKey
privateKey =
  Curve -> Point -> PublicKey
ECDSA.PublicKey Curve
curve (Curve -> Integer -> Point
ECC.pointBaseMul Curve
curve (PrivateKey -> Integer
ECDSA.private_d PrivateKey
privateKey))

encodePublicKey :: ECDSA.PublicKey -> BS.ByteString
encodePublicKey :: PublicKey -> ByteString
encodePublicKey PublicKey
publicKey =
  case PublicKey -> Point
ECDSA.public_q PublicKey
publicKey of
    ECC.Point Integer
x Integer
y ->
      [ByteString] -> ByteString
BS.concat
        [ Word8 -> ByteString
BS.singleton Word8
uncompressedPrefix,
          Int -> Integer -> ByteString
forall ba. ByteArray ba => Int -> Integer -> ba
i2ospOf_ Int
coordinateBytes Integer
x,
          Int -> Integer -> ByteString
forall ba. ByteArray ba => Int -> Integer -> ba
i2ospOf_ Int
coordinateBytes Integer
y
        ]
    Point
ECC.PointO ->
      String -> ByteString
forall a. HasCallStack => String -> a
error String
"p256: attempted to encode point at infinity"

encodeSignature :: ECDSA.Signature -> BS.ByteString
encodeSignature :: Signature -> ByteString
encodeSignature Signature
signature =
  [ByteString] -> ByteString
BS.concat
    [ Int -> Integer -> ByteString
forall ba. ByteArray ba => Int -> Integer -> ba
i2ospOf_ Int
coordinateBytes (Signature -> Integer
ECDSA.sign_r Signature
signature),
      Int -> Integer -> ByteString
forall ba. ByteArray ba => Int -> Integer -> ba
i2ospOf_ Int
coordinateBytes (Signature -> Integer
ECDSA.sign_s Signature
signature)
    ]