module Unison.Runtime.Crypto.Rsa
  ( parseRsaPublicKey,
    parseRsaPrivateKey,
    rsaErrorToText,
  )
where

import Crypto.Number.Basic qualified as Crypto
import Crypto.PubKey.RSA qualified as RSA
import Data.ASN1.BinaryEncoding qualified as ASN1
import Data.ASN1.BitArray qualified as ASN1
import Data.ASN1.Encoding qualified as ASN1
import Data.ASN1.Error qualified as ASN1
import Data.ASN1.Types qualified as ASN1
import Data.ByteString qualified as BS
import Data.ByteString.Lazy qualified as BSL
import Unison.Util.Text (Text)
import Unison.Util.Text qualified as Util.Text

-- | Parse a RSA public key from a ByteString
--   The input bytestring is a hex-encoded string of the DER file for the public key.
--   It can be generated with those commands:
--     # generate a RSA key of a given size
--     openssl genrsa -out private_key.pem <size>
--     # output the DER format as a hex string
--     openssl rsa -in private_key.pem -outform DER -pubout | xxd -p
parseRsaPublicKey :: BS.ByteString -> Either Text RSA.PublicKey
parseRsaPublicKey :: ByteString -> Either Text PublicKey
parseRsaPublicKey ByteString
bs = case DER -> ByteString -> Either ASN1Error [ASN1]
forall a.
ASN1Decoding a =>
a -> ByteString -> Either ASN1Error [ASN1]
ASN1.decodeASN1 DER
ASN1.DER (ByteString -> ByteString
BSL.fromStrict ByteString
bs) of
  Left ASN1Error
err -> Text -> Either Text PublicKey
forall a b. a -> Either a b
Left (Text -> Either Text PublicKey) -> Text -> Either Text PublicKey
forall a b. (a -> b) -> a -> b
$ Text
"rsa: cannot decode as an ASN.1 structure. " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> ASN1Error -> Text
asn1ErrorToText ASN1Error
err
  Right [ASN1]
asn1 ->
    case [ASN1]
asn1 of
      [ ASN1.Start ASN1ConstructionType
ASN1.Sequence,
        ASN1.Start ASN1ConstructionType
ASN1.Sequence,
        ASN1.OID OID
_,
        ASN1
ASN1.Null,
        ASN1.End ASN1ConstructionType
ASN1.Sequence,
        ASN1.BitString (ASN1.BitArray Word64
_ ByteString
bits),
        ASN1.End ASN1ConstructionType
ASN1.Sequence
        ] -> case DER -> ByteString -> Either ASN1Error [ASN1]
forall a.
ASN1Decoding a =>
a -> ByteString -> Either ASN1Error [ASN1]
ASN1.decodeASN1 DER
ASN1.DER (ByteString -> ByteString
BSL.fromStrict ByteString
bits) of
          Left ASN1Error
err -> Text -> Either Text PublicKey
forall a b. a -> Either a b
Left (Text -> Either Text PublicKey) -> Text -> Either Text PublicKey
forall a b. (a -> b) -> a -> b
$ Text
"rsa: cannot decode as an ASN.1 inner structure. " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> ASN1Error -> Text
asn1ErrorToText ASN1Error
err
          Right [ASN1]
asn1 -> case [ASN1]
asn1 of
            [ASN1.Start ASN1ConstructionType
ASN1.Sequence, ASN1.IntVal Integer
n, ASN1.IntVal Integer
e, ASN1.End ASN1ConstructionType
ASN1.Sequence] ->
              PublicKey -> Either Text PublicKey
forall a b. b -> Either a b
Right
                RSA.PublicKey
                  { public_size :: Int
public_size = Integer -> Int
Crypto.numBytes Integer
n,
                    public_n :: Integer
public_n = Integer
n,
                    public_e :: Integer
public_e = Integer
e
                  }
            [ASN1]
other -> Text -> Either Text PublicKey
forall a b. a -> Either a b
Left (Text
"rsa: unexpected ASN.1 inner structure for a RSA public key" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
Util.Text.pack ([ASN1] -> String
forall a. Show a => a -> String
show [ASN1]
other))
      [ASN1]
other -> Text -> Either Text PublicKey
forall a b. a -> Either a b
Left (Text
"rsa: unexpected ASN.1 outer structure for a RSA public key" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
Util.Text.pack ([ASN1] -> String
forall a. Show a => a -> String
show [ASN1]
other))

-- | Parse a RSA private key from a ByteString
--   The input bytestring is a hex-encoded string of the DER file for the private key.
--   It can be generated with those commands:
--     # generate a RSA key of a given size
--     openssl genrsa -out private_key.pem <size>
--     # output the DER format as a hex string
--     openssl rsa -in private_key.pem -outform DER | xxd -p
parseRsaPrivateKey :: BS.ByteString -> Either Text RSA.PrivateKey
parseRsaPrivateKey :: ByteString -> Either Text PrivateKey
parseRsaPrivateKey ByteString
bs = case DER -> ByteString -> Either ASN1Error [ASN1]
forall a.
ASN1Decoding a =>
a -> ByteString -> Either ASN1Error [ASN1]
ASN1.decodeASN1 DER
ASN1.DER (ByteString -> ByteString
BSL.fromStrict ByteString
bs) of
  Left ASN1Error
err -> Text -> Either Text PrivateKey
forall a b. a -> Either a b
Left (Text -> Either Text PrivateKey) -> Text -> Either Text PrivateKey
forall a b. (a -> b) -> a -> b
$ Text
"Error decoding ASN.1: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> ASN1Error -> Text
asn1ErrorToText ASN1Error
err
  Right [ASN1]
asn1 ->
    case [ASN1]
asn1 of
      [ ASN1.Start ASN1ConstructionType
ASN1.Sequence,
        ASN1.IntVal Integer
0,
        ASN1.Start ASN1ConstructionType
ASN1.Sequence,
        ASN1.OID OID
_,
        ASN1
ASN1.Null,
        ASN1.End ASN1ConstructionType
ASN1.Sequence,
        ASN1.OctetString ByteString
bits,
        ASN1.End ASN1ConstructionType
ASN1.Sequence
        ] ->
          case DER -> ByteString -> Either ASN1Error [ASN1]
forall a.
ASN1Decoding a =>
a -> ByteString -> Either ASN1Error [ASN1]
ASN1.decodeASN1 DER
ASN1.DER (ByteString -> ByteString
BSL.fromStrict ByteString
bits) of
            Left ASN1Error
err -> Text -> Either Text PrivateKey
forall a b. a -> Either a b
Left (Text -> Either Text PrivateKey) -> Text -> Either Text PrivateKey
forall a b. (a -> b) -> a -> b
$ Text
"Error decoding inner ASN.1: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> ASN1Error -> Text
asn1ErrorToText ASN1Error
err
            Right [ASN1]
asn1 ->
              case [ASN1]
asn1 of
                [ ASN1.Start ASN1ConstructionType
ASN1.Sequence,
                  ASN1.IntVal Integer
_,
                  ASN1.IntVal Integer
n,
                  ASN1.IntVal Integer
e,
                  ASN1.IntVal Integer
d,
                  ASN1.IntVal Integer
p,
                  ASN1.IntVal Integer
q,
                  ASN1.IntVal Integer
dP,
                  ASN1.IntVal Integer
dQ,
                  ASN1.IntVal Integer
qinv,
                  ASN1.End ASN1ConstructionType
ASN1.Sequence
                  ] ->
                    PrivateKey -> Either Text PrivateKey
forall a b. b -> Either a b
Right
                      RSA.PrivateKey
                        { private_pub :: PublicKey
private_pub = RSA.PublicKey {public_size :: Int
public_size = Integer -> Int
Crypto.numBytes Integer
n, public_n :: Integer
public_n = Integer
n, public_e :: Integer
public_e = Integer
e},
                          private_d :: Integer
private_d = Integer
d,
                          private_p :: Integer
private_p = Integer
p,
                          private_q :: Integer
private_q = Integer
q,
                          private_dP :: Integer
private_dP = Integer
dP,
                          private_dQ :: Integer
private_dQ = Integer
dQ,
                          private_qinv :: Integer
private_qinv = Integer
qinv
                        }
                [ASN1]
other -> Text -> Either Text PrivateKey
forall a b. a -> Either a b
Left (Text
"rsa: unexpected inner ASN.1 structure for a RSA private key" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
Util.Text.pack ([ASN1] -> String
forall a. Show a => a -> String
show [ASN1]
other))
      [ASN1]
other -> Text -> Either Text PrivateKey
forall a b. a -> Either a b
Left (Text
"rsa: unexpected outer ASN.1 structure for a RSA private key" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
Util.Text.pack ([ASN1] -> String
forall a. Show a => a -> String
show [ASN1]
other))

-- | Display an ASN1 Error
asn1ErrorToText :: ASN1.ASN1Error -> Text
asn1ErrorToText :: ASN1Error -> Text
asn1ErrorToText = \case
  ASN1Error
ASN1.StreamUnexpectedEOC -> Text
"Unexpected EOC in the stream"
  ASN1Error
ASN1.StreamInfinitePrimitive -> Text
"Invalid primitive with infinite length in a stream"
  ASN1Error
ASN1.StreamConstructionWrongSize -> Text
"A construction goes over the size specified in the header"
  ASN1.StreamUnexpectedSituation String
s -> Text
"An unexpected situation has come up parsing an ASN1 event stream: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
Util.Text.pack String
s
  ASN1.ParsingHeaderFail String
s -> Text
"Parsing an invalid header: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
Util.Text.pack String
s
  ASN1Error
ASN1.ParsingPartial -> Text
"Parsing is not finished, the key is not complete"
  ASN1.TypeNotImplemented String
s -> Text
"Decoding of a type that is not implemented: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
Util.Text.pack String
s
  ASN1.TypeDecodingFailed String
s -> Text
"Decoding of a known type failed: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
Util.Text.pack String
s
  ASN1.TypePrimitiveInvalid String
s -> Text
"Invalid primitive type: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
Util.Text.pack String
s
  ASN1.PolicyFailed String
s1 String
s2 -> Text
"Policy failed. Policy name: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
Util.Text.pack String
s1 Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
", reason:" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
Util.Text.pack String
s2

-- | Display a RSA Error
rsaErrorToText :: RSA.Error -> Text
rsaErrorToText :: Error -> Text
rsaErrorToText = \case
  Error
RSA.MessageSizeIncorrect ->
    Text
"rsa: The message to decrypt is not of the correct size (need to be == private_size)"
  Error
RSA.MessageTooLong ->
    Text
"rsa: The message to encrypt is too long"
  Error
RSA.MessageNotRecognized ->
    Text
"rsa: The message decrypted doesn't have a PKCS15 structure (0 2 .. 0 msg)"
  Error
RSA.SignatureTooLong ->
    Text
"rsa: The message's digest is too long"
  Error
RSA.InvalidParameters ->
    Text
"rsa: Some parameters lead to breaking assumptions"