module Unison.Auth.Tokens where

import Control.Monad.Except
import Data.Aeson qualified as Aeson
import Data.ByteString.Char8 qualified as BSC
import Data.Text qualified as Text
import Data.Time.Clock (getCurrentTime)
import Network.HTTP.Client qualified as HTTP
import Network.HTTP.Client.TLS qualified as HTTP
import Network.HTTP.Types qualified as Network
import System.Environment (lookupEnv)
import Unison.Auth.CredentialManager
import Unison.Auth.CredentialManager qualified as CredMan
import Unison.Auth.Discovery (fetchDiscoveryDoc)
import Unison.Auth.Types
import Unison.Auth.UserInfo (getUserInfo)
import Unison.Prelude
import Unison.Share.Types (CodeserverId)
import UnliftIO qualified

-- | Given a 'CodeserverId', provide a valid 'AccessToken' for the associated host.
-- The TokenProvider may automatically refresh access tokens if we have a refresh token.
type TokenProvider = CodeserverId -> IO (Either CredentialFailure AccessToken)

-- | If provided, this access token will be used on all
-- requests which use the Authenticated HTTP Client; i.e. all codeserver interactions.
--
-- It's useful in scripted contexts or when running transcripts against a codeserver.
accessTokenEnvVarKey :: String
accessTokenEnvVarKey :: String
accessTokenEnvVarKey = String
"UNISON_SHARE_ACCESS_TOKEN"

-- | Creates a 'TokenProvider' using the given 'CredentialManager'
newTokenProvider :: CredentialManager -> TokenProvider
newTokenProvider :: CredentialManager -> TokenProvider
newTokenProvider CredentialManager
manager CodeserverId
host = forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
m a -> m (Either e a)
UnliftIO.try @_ @CredentialFailure (IO AccessToken -> IO (Either CredentialFailure AccessToken))
-> IO AccessToken -> IO (Either CredentialFailure AccessToken)
forall a b. (a -> b) -> a -> b
$ do
  mayShareAccessToken <- (String -> AccessToken) -> Maybe String -> Maybe AccessToken
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap String -> AccessToken
Text.pack (Maybe String -> Maybe AccessToken)
-> IO (Maybe String) -> IO (Maybe AccessToken)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> IO (Maybe String)
lookupEnv String
accessTokenEnvVarKey
  case mayShareAccessToken of
    Just AccessToken
accessToken -> do
      -- If the access token is provided via environment variable, we don't need to refresh it.
      AccessToken -> IO AccessToken
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure AccessToken
accessToken
    Maybe AccessToken
Nothing -> do
      creds@CodeserverCredentials {tokens, discoveryURI} <- IO (Either CredentialFailure CodeserverCredentials)
-> IO CodeserverCredentials
forall e (m :: * -> *) a.
(MonadIO m, Exception e) =>
m (Either e a) -> m a
throwEitherM (IO (Either CredentialFailure CodeserverCredentials)
 -> IO CodeserverCredentials)
-> IO (Either CredentialFailure CodeserverCredentials)
-> IO CodeserverCredentials
forall a b. (a -> b) -> a -> b
$ CredentialManager
-> CodeserverId
-> IO (Either CredentialFailure CodeserverCredentials)
CredMan.getCodeserverCredentials CredentialManager
manager CodeserverId
host
      let Tokens {accessToken = currentAccessToken} = tokens
      expired <- isExpired creds
      if expired
        then do
          discoveryDoc <- throwEitherM $ fetchDiscoveryDoc discoveryURI
          fetchTime <- getCurrentTime
          newTokens@(Tokens {accessToken = newAccessToken}) <- throwEitherM $ performTokenRefresh discoveryDoc tokens
          userInfo <- throwEitherM $ getUserInfo discoveryDoc newAccessToken
          saveCredentials manager host (codeserverCredentials discoveryURI newTokens fetchTime userInfo)
          pure newAccessToken
        else pure currentAccessToken

-- | Don't yet support automatically refreshing tokens.
--
-- Specification: https://datatracker.ietf.org/doc/html/rfc6749#section-6
performTokenRefresh :: (MonadIO m) => DiscoveryDoc -> Tokens -> m (Either CredentialFailure Tokens)
performTokenRefresh :: forall (m :: * -> *).
MonadIO m =>
DiscoveryDoc -> Tokens -> m (Either CredentialFailure Tokens)
performTokenRefresh DiscoveryDoc {URI
tokenEndpoint :: URI
tokenEndpoint :: DiscoveryDoc -> URI
tokenEndpoint} (Tokens {refreshToken :: Tokens -> Maybe AccessToken
refreshToken = Maybe AccessToken
currentRefreshToken}) = ExceptT CredentialFailure m Tokens
-> m (Either CredentialFailure Tokens)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT CredentialFailure m Tokens
 -> m (Either CredentialFailure Tokens))
-> ExceptT CredentialFailure m Tokens
-> m (Either CredentialFailure Tokens)
forall a b. (a -> b) -> a -> b
$
  case Maybe AccessToken
currentRefreshToken of
    Maybe AccessToken
Nothing ->
      CredentialFailure -> ExceptT CredentialFailure m Tokens
forall a. CredentialFailure -> ExceptT CredentialFailure m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (CredentialFailure -> ExceptT CredentialFailure m Tokens)
-> CredentialFailure -> ExceptT CredentialFailure m Tokens
forall a b. (a -> b) -> a -> b
$ (AccessToken -> CredentialFailure
RefreshFailure (AccessToken -> CredentialFailure)
-> (String -> AccessToken) -> String -> CredentialFailure
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> AccessToken
Text.pack (String -> CredentialFailure) -> String -> CredentialFailure
forall a b. (a -> b) -> a -> b
$ String
"Unable to refresh authentication, please run auth.login and try again.")
    Just AccessToken
rt -> do
      req <- IO Request -> ExceptT CredentialFailure m Request
forall a. IO a -> ExceptT CredentialFailure m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Request -> ExceptT CredentialFailure m Request)
-> IO Request -> ExceptT CredentialFailure m Request
forall a b. (a -> b) -> a -> b
$ URI -> IO Request
forall (m :: * -> *). MonadThrow m => URI -> m Request
HTTP.requestFromURI URI
tokenEndpoint
      let addFormData =
            [(ByteString, ByteString)] -> Request -> Request
HTTP.urlEncodedBody
              [ (ByteString
"grant_type", ByteString
"refresh_token"),
                (ByteString
"refresh_token", String -> ByteString
BSC.pack (String -> ByteString)
-> (AccessToken -> String) -> AccessToken -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AccessToken -> String
Text.unpack (AccessToken -> ByteString) -> AccessToken -> ByteString
forall a b. (a -> b) -> a -> b
$ AccessToken
rt)
              ]
      let fullReq = Request -> Request
addFormData (Request -> Request) -> Request -> Request
forall a b. (a -> b) -> a -> b
$ Request
req {HTTP.method = "POST", HTTP.requestHeaders = [("Accept", "application/json")]}
      unauthenticatedHttpClient <- liftIO $ HTTP.getGlobalManager
      resp <- liftIO $ HTTP.httpLbs fullReq unauthenticatedHttpClient
      newTokens <- case HTTP.responseStatus resp of
        Status
status
          | Status
status Status -> Status -> Bool
forall a. Ord a => a -> a -> Bool
< Status
Network.status300 -> do
              let respBytes :: ByteString
respBytes = Response ByteString -> ByteString
forall body. Response body -> body
HTTP.responseBody Response ByteString
resp
              case forall a. FromJSON a => ByteString -> Either String a
Aeson.eitherDecode @Tokens ByteString
respBytes of
                Left String
err -> CredentialFailure -> ExceptT CredentialFailure m Tokens
forall a. CredentialFailure -> ExceptT CredentialFailure m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (URI -> AccessToken -> CredentialFailure
InvalidTokenResponse URI
tokenEndpoint (String -> AccessToken
Text.pack String
err))
                Right Tokens
a -> Tokens -> ExceptT CredentialFailure m Tokens
forall a. a -> ExceptT CredentialFailure m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Tokens
a
          | Bool
otherwise -> CredentialFailure -> ExceptT CredentialFailure m Tokens
forall a. CredentialFailure -> ExceptT CredentialFailure m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (CredentialFailure -> ExceptT CredentialFailure m Tokens)
-> CredentialFailure -> ExceptT CredentialFailure m Tokens
forall a b. (a -> b) -> a -> b
$ (URI -> AccessToken -> CredentialFailure
InvalidTokenResponse URI
tokenEndpoint (AccessToken -> CredentialFailure)
-> AccessToken -> CredentialFailure
forall a b. (a -> b) -> a -> b
$ AccessToken
"Received " AccessToken -> AccessToken -> AccessToken
forall a. Semigroup a => a -> a -> a
<> Status -> AccessToken
forall a. Show a => a -> AccessToken
tShow Status
status AccessToken -> AccessToken -> AccessToken
forall a. Semigroup a => a -> a -> a
<> AccessToken
" response from token endpoint")
      -- According to the spec, servers may or may not update the refresh token itself.
      -- If updated we need to replace it, if not updated we keep the existing one.
      pure $ newTokens {refreshToken = refreshToken newTokens <|> currentRefreshToken}