module Unison.Codebase.Editor.HandleInput.AuthLogin
  ( authLogin,
    ensureAuthenticatedWithCodeserver,
  )
where

import Control.Concurrent.MVar
import Control.Monad.Reader
import Crypto.Hash qualified as Crypto
import Crypto.Random (getRandomBytes)
import Data.Aeson qualified as Aeson
import Data.ByteArray.Encoding qualified as BE
import Data.ByteString.Char8 qualified as BSC
import Data.Text qualified as Text
import Data.Text.Encoding qualified as Text
import Data.Time.Clock (getCurrentTime)
import Network.HTTP.Client (urlEncodedBody)
import Network.HTTP.Client qualified as HTTP
import Network.HTTP.Client.TLS qualified as HTTP
import Network.HTTP.Types
import Network.URI (URI (..), parseURI)
import Network.Wai
import Network.Wai qualified as Wai
import Network.Wai.Handler.Warp qualified as Warp
import U.Codebase.Sqlite.Queries qualified as Q
import Unison.Auth.CredentialManager (getCodeserverCredentials, saveCredentials)
import Unison.Auth.Discovery (discoveryURIForCodeserver, fetchDiscoveryDoc)
import Unison.Auth.Types
  ( Code,
    CodeserverCredentials (..),
    CredentialFailure (..),
    DiscoveryDoc (..),
    OAuthState,
    PKCEChallenge,
    PKCEVerifier,
    Tokens (..),
    UserInfo,
    codeserverCredentials,
  )
import Unison.Auth.UserInfo (getUserInfo)
import Unison.Cli.Monad (Cli)
import Unison.Cli.Monad qualified as Cli
import Unison.Codebase.Editor.Output qualified as Output
import Unison.Debug qualified as Debug
import Unison.Prelude
import Unison.Share.Types
import UnliftIO qualified
import Web.Browser qualified as Web

ucmOAuthClientID :: ByteString
ucmOAuthClientID :: Method
ucmOAuthClientID = Method
"ucm"

-- | Checks if the user has valid auth for the given codeserver,
-- and runs through an authentication flow if not.
ensureAuthenticatedWithCodeserver :: CodeserverURI -> Cli UserInfo
ensureAuthenticatedWithCodeserver :: CodeserverURI -> Cli UserInfo
ensureAuthenticatedWithCodeserver CodeserverURI
codeserverURI = do
  Cli.Env {credentialManager} <- Cli Env
forall r (m :: * -> *). MonadReader r m => m r
ask
  either (const $ authLogin codeserverURI) (\CodeserverCredentials {UserInfo
userInfo :: UserInfo
userInfo :: CodeserverCredentials -> UserInfo
userInfo} -> UserInfo -> Cli UserInfo
forall a. a -> Cli a
forall (f :: * -> *) a. Applicative f => a -> f a
pure UserInfo
userInfo)
    <=< liftIO
    . getCodeserverCredentials credentialManager
    $ codeserverIdFromCodeserverURI codeserverURI

-- | Direct the user through an authentication flow with the given server and store the credentials in the provided
-- credential manager.
authLogin :: CodeserverURI -> Cli UserInfo
authLogin :: CodeserverURI -> Cli UserInfo
authLogin CodeserverURI
host = do
  Cli.Env {credentialManager} <- Cli Env
forall r (m :: * -> *). MonadReader r m => m r
ask
  httpClient <- liftIO HTTP.getGlobalManager
  let discoveryURI = CodeserverURI -> URI
discoveryURIForCodeserver CodeserverURI
host
  doc@(DiscoveryDoc {authorizationEndpoint, tokenEndpoint}) <- bailOnFailure (fetchDiscoveryDoc discoveryURI)
  Debug.debugM Debug.Auth "Discovery Doc" doc
  authResultVar <- liftIO (newEmptyMVar @(Either CredentialFailure Tokens))
  -- The redirect_uri depends on the port, so we need to spin up the server first, but
  -- we can't spin up the server without the code-handler which depends on the redirect_uri.
  -- So, annoyingly we just embed an MVar which will be filled as soon as the server boots up,
  -- and it all works out fine.
  redirectURIVar <- liftIO newEmptyMVar
  (verifier, challenge, state) <- generateParams
  let codeHandler :: (Code -> Maybe URI -> (Response -> IO ResponseReceived) -> IO ResponseReceived)
      codeHandler Text
code Maybe URI
mayNextURI Response -> IO ResponseReceived
respond = do
        redirectURI <- MVar String -> IO String
forall a. MVar a -> IO a
readMVar MVar String
redirectURIVar
        result <- exchangeCode httpClient tokenEndpoint code verifier redirectURI
        respReceived <- case result of
          Left CredentialFailure
err -> do
            DebugFlag -> String -> CredentialFailure -> IO ()
forall a (m :: * -> *).
(Show a, Monad m) =>
DebugFlag -> String -> a -> m ()
Debug.debugM DebugFlag
Debug.Auth String
"Auth Error" CredentialFailure
err
            Response -> IO ResponseReceived
respond (Response -> IO ResponseReceived)
-> Response -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$ Status -> ResponseHeaders -> ByteString -> Response
Wai.responseLBS Status
internalServerError500 [] ByteString
"Something went wrong, please try again."
          Right Tokens
_ ->
            case Maybe URI
mayNextURI of
              Maybe URI
Nothing -> Response -> IO ResponseReceived
respond (Response -> IO ResponseReceived)
-> Response -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$ Status -> ResponseHeaders -> ByteString -> Response
Wai.responseLBS Status
found302 [] ByteString
"Authorization successful. You may close this page and return to UCM."
              Just URI
nextURI ->
                Response -> IO ResponseReceived
respond (Response -> IO ResponseReceived)
-> Response -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$
                  Status -> ResponseHeaders -> ByteString -> Response
Wai.responseLBS
                    Status
found302
                    [(HeaderName
"LOCATION", String -> Method
BSC.pack (String -> Method) -> String -> Method
forall a b. (a -> b) -> a -> b
$ forall a. Show a => a -> String
show @URI URI
nextURI)]
                    ByteString
"Authorization successful. You may close this page and return to UCM."
        -- Wait until we've responded to the browser before putting the result,
        -- otherwise the server will shut down prematurely.
        putMVar authResultVar result
        pure respReceived
  fetchTime <- liftIO getCurrentTime
  tokens@(Tokens {accessToken}) <-
    Cli.with (Warp.withApplication (pure $ authTransferServer codeHandler)) \Port
port -> do
      let redirectURI :: String
redirectURI = String
"http://localhost:" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Port -> String
forall a. Show a => a -> String
show Port
port String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"/redirect"
      IO () -> Cli ()
forall a. IO a -> Cli a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (MVar String -> String -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar String
redirectURIVar String
redirectURI)
      let authorizationKickoff :: URI
authorizationKickoff = URI -> String -> Method -> Method -> URI
authURI URI
authorizationEndpoint String
redirectURI Method
state Method
challenge
      Output -> Cli ()
Cli.respond (Output -> Cli ()) -> (URI -> Output) -> URI -> Cli ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. URI -> Output
Output.InitiateAuthFlow (URI -> Cli ()) -> URI -> Cli ()
forall a b. (a -> b) -> a -> b
$ URI
authorizationKickoff
      IO (Either CredentialFailure Tokens) -> Cli Tokens
forall {a}. IO (Either CredentialFailure a) -> Cli a
bailOnFailure (IO (Either CredentialFailure Tokens) -> Cli Tokens)
-> (IO (Either CredentialFailure Tokens)
    -> IO (Either CredentialFailure Tokens))
-> IO (Either CredentialFailure Tokens)
-> Cli Tokens
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO (Either CredentialFailure Tokens)
-> IO (Either CredentialFailure Tokens)
forall a. IO a -> IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Either CredentialFailure Tokens) -> Cli Tokens)
-> IO (Either CredentialFailure Tokens) -> Cli Tokens
forall a b. (a -> b) -> a -> b
$ IO Bool
-> (Async Bool -> IO (Either CredentialFailure Tokens))
-> IO (Either CredentialFailure Tokens)
forall (m :: * -> *) a b.
MonadUnliftIO m =>
m a -> (Async a -> m b) -> m b
UnliftIO.withAsync (String -> IO Bool
Web.openBrowser (URI -> String
forall a. Show a => a -> String
show URI
authorizationKickoff)) \Async Bool
_ -> MVar (Either CredentialFailure Tokens)
-> IO (Either CredentialFailure Tokens)
forall a. MVar a -> IO a
readMVar MVar (Either CredentialFailure Tokens)
authResultVar
  userInfo <- bailOnFailure (getUserInfo doc accessToken)
  let codeserverId = CodeserverURI -> CodeserverId
codeserverIdFromCodeserverURI CodeserverURI
host
  let creds = URI -> Tokens -> UTCTime -> UserInfo -> CodeserverCredentials
codeserverCredentials URI
discoveryURI Tokens
tokens UTCTime
fetchTime UserInfo
userInfo
  -- Before saving new credentials we clear the temp entity caches,
  -- this is to handle the case that the user logged into a new user and that they have
  -- some hashJWTs for a different user around which won't work against the new user
  -- credentials.
  --
  -- It also means that if the server changes signing-keys the user will simply get
  -- "unauthenticated", call `auth.login`, and that will clear out any hashjwts signed with
  -- the old key.
  Cli.runTransaction Q.clearTempEntityTables
  liftIO (saveCredentials credentialManager codeserverId creds)
  Cli.respond Output.Success
  pure userInfo
  where
    bailOnFailure :: IO (Either CredentialFailure a) -> Cli a
bailOnFailure IO (Either CredentialFailure a)
action = IO (Either CredentialFailure a)
-> (CredentialFailure -> Cli a) -> Cli a
forall e a. IO (Either e a) -> (e -> Cli a) -> Cli a
Cli.ioE IO (Either CredentialFailure a)
action \CredentialFailure
err -> do
      Output -> Cli a
forall a. Output -> Cli a
Cli.returnEarly (CredentialFailure -> Output
Output.CredentialFailureMsg CredentialFailure
err)

-- | A server in the format expected for a Wai Application
-- This is a temporary server which is spun up only until we get a code back from the
-- auth server.
authTransferServer :: (Code -> Maybe URI -> (Response -> IO ResponseReceived) -> IO ResponseReceived) -> Request -> (Response -> IO ResponseReceived) -> IO ResponseReceived
authTransferServer :: (Text
 -> Maybe URI
 -> (Response -> IO ResponseReceived)
 -> IO ResponseReceived)
-> Application
authTransferServer Text
-> Maybe URI
-> (Response -> IO ResponseReceived)
-> IO ResponseReceived
callback Request
req Response -> IO ResponseReceived
respond =
  case (Request -> Method
requestMethod Request
req, Request -> [Text]
pathInfo Request
req, Request -> (Maybe Text, Maybe URI)
getQueryParams Request
req) of
    (Method
"GET", [Text
"redirect"], (Just Text
code, Maybe URI
maybeNextURI)) -> do
      Text
-> Maybe URI
-> (Response -> IO ResponseReceived)
-> IO ResponseReceived
callback Text
code Maybe URI
maybeNextURI Response -> IO ResponseReceived
respond
    (Method, [Text], (Maybe Text, Maybe URI))
_ -> Response -> IO ResponseReceived
respond (Status -> ResponseHeaders -> ByteString -> Response
responseLBS Status
status404 [] ByteString
"Not Found")
  where
    getQueryParams :: Request -> (Maybe Text, Maybe URI)
getQueryParams Request
req = do
      let code :: Maybe Method
code = Maybe (Maybe Method) -> Maybe Method
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (Maybe (Maybe Method) -> Maybe Method)
-> Maybe (Maybe Method) -> Maybe Method
forall a b. (a -> b) -> a -> b
$ Method -> Query -> Maybe (Maybe Method)
forall a b. Eq a => a -> [(a, b)] -> Maybe b
Prelude.lookup Method
"code" (Request -> Query
queryString Request
req)
          nextURI :: Maybe URI
nextURI = do
            nextBS <- Maybe (Maybe Method) -> Maybe Method
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (Maybe (Maybe Method) -> Maybe Method)
-> Maybe (Maybe Method) -> Maybe Method
forall a b. (a -> b) -> a -> b
$ Method -> Query -> Maybe (Maybe Method)
forall a b. Eq a => a -> [(a, b)] -> Maybe b
Prelude.lookup Method
"next" (Request -> Query
queryString Request
req)
            parseURI (BSC.unpack nextBS)
       in (Method -> Text
Text.decodeUtf8 (Method -> Text) -> Maybe Method -> Maybe Text
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe Method
code, Maybe URI
nextURI)

-- | Construct an authorization URL from the parameters required.
authURI :: URI -> String -> OAuthState -> PKCEChallenge -> URI
authURI :: URI -> String -> Method -> Method -> URI
authURI URI
authEndpoint String
redirectURI Method
state Method
challenge =
  URI
authEndpoint
    URI -> (URI -> URI) -> URI
forall a b. a -> (a -> b) -> b
& Method -> Method -> URI -> URI
addQueryParam Method
"state" Method
state
    URI -> (URI -> URI) -> URI
forall a b. a -> (a -> b) -> b
& Method -> Method -> URI -> URI
addQueryParam Method
"redirect_uri" (String -> Method
BSC.pack String
redirectURI)
    URI -> (URI -> URI) -> URI
forall a b. a -> (a -> b) -> b
& Method -> Method -> URI -> URI
addQueryParam Method
"response_type" Method
"code"
    URI -> (URI -> URI) -> URI
forall a b. a -> (a -> b) -> b
& Method -> Method -> URI -> URI
addQueryParam Method
"scope" Method
"openid cloud sync"
    URI -> (URI -> URI) -> URI
forall a b. a -> (a -> b) -> b
& Method -> Method -> URI -> URI
addQueryParam Method
"client_id" Method
ucmOAuthClientID
    URI -> (URI -> URI) -> URI
forall a b. a -> (a -> b) -> b
& Method -> Method -> URI -> URI
addQueryParam Method
"code_challenge" Method
challenge
    URI -> (URI -> URI) -> URI
forall a b. a -> (a -> b) -> b
& Method -> Method -> URI -> URI
addQueryParam Method
"code_challenge_method" Method
"S256"

addQueryParam :: ByteString -> ByteString -> URI -> URI
addQueryParam :: Method -> Method -> URI -> URI
addQueryParam Method
key Method
val URI
uri =
  let existingQuery :: Query
existingQuery = Method -> Query
parseQuery (Method -> Query) -> Method -> Query
forall a b. (a -> b) -> a -> b
$ String -> Method
BSC.pack (URI -> String
uriQuery URI
uri)
      newParam :: (Method, Maybe Method)
newParam = (Method
key, Method -> Maybe Method
forall a. a -> Maybe a
Just Method
val)
   in URI
uri {uriQuery = BSC.unpack $ renderQuery True (existingQuery <> [newParam])}

generateParams :: (MonadIO m) => m (PKCEVerifier, PKCEChallenge, OAuthState)
generateParams :: forall (m :: * -> *). MonadIO m => m (Method, Method, Method)
generateParams = IO (Method, Method, Method) -> m (Method, Method, Method)
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Method, Method, Method) -> m (Method, Method, Method))
-> IO (Method, Method, Method) -> m (Method, Method, Method)
forall a b. (a -> b) -> a -> b
$ do
  verifier <- forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
Base -> bin -> bout
BE.convertToBase @ByteString Base
BE.Base64URLUnpadded (Method -> Method) -> IO Method -> IO Method
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Port -> IO Method
forall byteArray. ByteArray byteArray => Port -> IO byteArray
forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
Port -> m byteArray
getRandomBytes Port
50
  let digest = SHA256 -> Method -> Digest SHA256
forall ba alg.
(ByteArrayAccess ba, HashAlgorithm alg) =>
alg -> ba -> Digest alg
Crypto.hashWith SHA256
Crypto.SHA256 Method
verifier
  let challenge = Base -> Digest SHA256 -> Method
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
Base -> bin -> bout
BE.convertToBase Base
BE.Base64URLUnpadded Digest SHA256
digest
  state <- BE.convertToBase @ByteString BE.Base64URLUnpadded <$> getRandomBytes 12
  pure (verifier, challenge, state)

-- | Exchange an authorization code for tokens.
exchangeCode ::
  (MonadIO m) =>
  HTTP.Manager ->
  URI ->
  Code ->
  PKCEVerifier ->
  String ->
  m (Either CredentialFailure Tokens)
exchangeCode :: forall (m :: * -> *).
MonadIO m =>
Manager
-> URI
-> Text
-> Method
-> String
-> m (Either CredentialFailure Tokens)
exchangeCode Manager
httpClient URI
tokenEndpoint Text
code Method
verifier String
redirectURI = IO (Either CredentialFailure Tokens)
-> m (Either CredentialFailure Tokens)
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Either CredentialFailure Tokens)
 -> m (Either CredentialFailure Tokens))
-> IO (Either CredentialFailure Tokens)
-> m (Either CredentialFailure Tokens)
forall a b. (a -> b) -> a -> b
$ do
  req <- URI -> IO Request
forall (m :: * -> *). MonadThrow m => URI -> m Request
HTTP.requestFromURI URI
tokenEndpoint
  let addFormData =
        [(Method, Method)] -> Request -> Request
urlEncodedBody
          [ (Method
"code", Text -> Method
Text.encodeUtf8 Text
code),
            (Method
"code_verifier", Method
verifier),
            (Method
"grant_type", Method
"authorization_code"),
            (Method
"redirect_uri", String -> Method
BSC.pack String
redirectURI),
            (Method
"client_id", Method
ucmOAuthClientID)
          ]
  let fullReq = Request -> Request
addFormData (Request -> Request) -> Request -> Request
forall a b. (a -> b) -> a -> b
$ Request
req {HTTP.method = "POST", HTTP.requestHeaders = [("Accept", "application/json")]}
  resp <- HTTP.httpLbs fullReq httpClient
  case HTTP.responseStatus resp of
    Status
status
      | Status
status Status -> Status -> Bool
forall a. Ord a => a -> a -> Bool
< Status
status300 -> do
          let respBytes :: ByteString
respBytes = Response ByteString -> ByteString
forall body. Response body -> body
HTTP.responseBody Response ByteString
resp
          Either CredentialFailure Tokens
-> IO (Either CredentialFailure Tokens)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either CredentialFailure Tokens
 -> IO (Either CredentialFailure Tokens))
-> Either CredentialFailure Tokens
-> IO (Either CredentialFailure Tokens)
forall a b. (a -> b) -> a -> b
$ case forall a. FromJSON a => ByteString -> Either String a
Aeson.eitherDecode @Tokens ByteString
respBytes of
            Left String
err -> CredentialFailure -> Either CredentialFailure Tokens
forall a b. a -> Either a b
Left (URI -> Text -> CredentialFailure
InvalidTokenResponse URI
tokenEndpoint (String -> Text
Text.pack String
err))
            Right Tokens
a -> Tokens -> Either CredentialFailure Tokens
forall a b. b -> Either a b
Right Tokens
a
      | Bool
otherwise -> Either CredentialFailure Tokens
-> IO (Either CredentialFailure Tokens)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either CredentialFailure Tokens
 -> IO (Either CredentialFailure Tokens))
-> Either CredentialFailure Tokens
-> IO (Either CredentialFailure Tokens)
forall a b. (a -> b) -> a -> b
$ CredentialFailure -> Either CredentialFailure Tokens
forall a b. a -> Either a b
Left (URI -> Text -> CredentialFailure
InvalidTokenResponse URI
tokenEndpoint (Text -> CredentialFailure) -> Text -> CredentialFailure
forall a b. (a -> b) -> a -> b
$ Text
"Received " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Status -> Text
forall a. Show a => a -> Text
tShow Status
status Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" response from token endpoint")