module Unison.Auth.Tokens where

import Control.Monad.Except
import Data.Aeson qualified as Aeson
import Data.ByteString.Char8 qualified as BSC
import Data.Text qualified as Text
import Data.Time.Clock (getCurrentTime)
import Network.HTTP.Client qualified as HTTP
import Network.HTTP.Client.TLS qualified as HTTP
import Network.HTTP.Types qualified as Network
import Unison.Auth.CredentialManager
import Unison.Auth.Discovery (fetchDiscoveryDoc)
import Unison.Auth.Types
import Unison.Auth.UserInfo (getUserInfo)
import Unison.Prelude
import Unison.Share.Types (CodeserverId)
import UnliftIO qualified

-- | Given a 'CodeserverId', provide a valid 'AccessToken' for the associated host.
-- The TokenProvider may automatically refresh access tokens if we have a refresh token.
type TokenProvider = CodeserverId -> IO (Either CredentialFailure AccessToken)

-- | Creates a 'TokenProvider' using the given 'CredentialManager'
newTokenProvider :: CredentialManager -> TokenProvider
newTokenProvider :: CredentialManager -> TokenProvider
newTokenProvider CredentialManager
manager CodeserverId
host = forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
m a -> m (Either e a)
UnliftIO.try @_ @CredentialFailure (IO AccessToken -> IO (Either CredentialFailure AccessToken))
-> IO AccessToken -> IO (Either CredentialFailure AccessToken)
forall a b. (a -> b) -> a -> b
$ do
  creds :: CodeserverCredentials
creds@CodeserverCredentials {Tokens
tokens :: Tokens
$sel:tokens:CodeserverCredentials :: CodeserverCredentials -> Tokens
tokens, URI
discoveryURI :: URI
$sel:discoveryURI:CodeserverCredentials :: CodeserverCredentials -> URI
discoveryURI} <- IO (Either CredentialFailure CodeserverCredentials)
-> IO CodeserverCredentials
forall e (m :: * -> *) a.
(MonadIO m, Exception e) =>
m (Either e a) -> m a
throwEitherM (IO (Either CredentialFailure CodeserverCredentials)
 -> IO CodeserverCredentials)
-> IO (Either CredentialFailure CodeserverCredentials)
-> IO CodeserverCredentials
forall a b. (a -> b) -> a -> b
$ CredentialManager
-> CodeserverId
-> IO (Either CredentialFailure CodeserverCredentials)
forall (m :: * -> *).
MonadIO m =>
CredentialManager
-> CodeserverId
-> m (Either CredentialFailure CodeserverCredentials)
getCredentials CredentialManager
manager CodeserverId
host
  let Tokens {$sel:accessToken:Tokens :: Tokens -> AccessToken
accessToken = AccessToken
currentAccessToken} = Tokens
tokens
  Bool
expired <- CodeserverCredentials -> IO Bool
forall (m :: * -> *). MonadIO m => CodeserverCredentials -> m Bool
isExpired CodeserverCredentials
creds
  if Bool
expired
    then do
      DiscoveryDoc
discoveryDoc <- IO (Either CredentialFailure DiscoveryDoc) -> IO DiscoveryDoc
forall e (m :: * -> *) a.
(MonadIO m, Exception e) =>
m (Either e a) -> m a
throwEitherM (IO (Either CredentialFailure DiscoveryDoc) -> IO DiscoveryDoc)
-> IO (Either CredentialFailure DiscoveryDoc) -> IO DiscoveryDoc
forall a b. (a -> b) -> a -> b
$ URI -> IO (Either CredentialFailure DiscoveryDoc)
forall (m :: * -> *).
MonadIO m =>
URI -> m (Either CredentialFailure DiscoveryDoc)
fetchDiscoveryDoc URI
discoveryURI
      UTCTime
fetchTime <- IO UTCTime
getCurrentTime
      newTokens :: Tokens
newTokens@(Tokens {$sel:accessToken:Tokens :: Tokens -> AccessToken
accessToken = AccessToken
newAccessToken}) <- IO (Either CredentialFailure Tokens) -> IO Tokens
forall e (m :: * -> *) a.
(MonadIO m, Exception e) =>
m (Either e a) -> m a
throwEitherM (IO (Either CredentialFailure Tokens) -> IO Tokens)
-> IO (Either CredentialFailure Tokens) -> IO Tokens
forall a b. (a -> b) -> a -> b
$ DiscoveryDoc -> Tokens -> IO (Either CredentialFailure Tokens)
forall (m :: * -> *).
MonadIO m =>
DiscoveryDoc -> Tokens -> m (Either CredentialFailure Tokens)
performTokenRefresh DiscoveryDoc
discoveryDoc Tokens
tokens
      UserInfo
userInfo <- IO (Either CredentialFailure UserInfo) -> IO UserInfo
forall e (m :: * -> *) a.
(MonadIO m, Exception e) =>
m (Either e a) -> m a
throwEitherM (IO (Either CredentialFailure UserInfo) -> IO UserInfo)
-> IO (Either CredentialFailure UserInfo) -> IO UserInfo
forall a b. (a -> b) -> a -> b
$ DiscoveryDoc
-> AccessToken -> IO (Either CredentialFailure UserInfo)
forall (m :: * -> *).
MonadIO m =>
DiscoveryDoc
-> AccessToken -> m (Either CredentialFailure UserInfo)
getUserInfo DiscoveryDoc
discoveryDoc AccessToken
newAccessToken
      CredentialManager -> CodeserverId -> CodeserverCredentials -> IO ()
forall (m :: * -> *).
MonadUnliftIO m =>
CredentialManager -> CodeserverId -> CodeserverCredentials -> m ()
saveCredentials CredentialManager
manager CodeserverId
host (URI -> Tokens -> UTCTime -> UserInfo -> CodeserverCredentials
codeserverCredentials URI
discoveryURI Tokens
newTokens UTCTime
fetchTime UserInfo
userInfo)
      pure $ AccessToken
newAccessToken
    else AccessToken -> IO AccessToken
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure AccessToken
currentAccessToken

-- | Don't yet support automatically refreshing tokens.
--
-- Specification: https://datatracker.ietf.org/doc/html/rfc6749#section-6
performTokenRefresh :: (MonadIO m) => DiscoveryDoc -> Tokens -> m (Either CredentialFailure Tokens)
performTokenRefresh :: forall (m :: * -> *).
MonadIO m =>
DiscoveryDoc -> Tokens -> m (Either CredentialFailure Tokens)
performTokenRefresh DiscoveryDoc {URI
tokenEndpoint :: URI
$sel:tokenEndpoint:DiscoveryDoc :: DiscoveryDoc -> URI
tokenEndpoint} (Tokens {$sel:refreshToken:Tokens :: Tokens -> Maybe AccessToken
refreshToken = Maybe AccessToken
currentRefreshToken}) = ExceptT CredentialFailure m Tokens
-> m (Either CredentialFailure Tokens)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT CredentialFailure m Tokens
 -> m (Either CredentialFailure Tokens))
-> ExceptT CredentialFailure m Tokens
-> m (Either CredentialFailure Tokens)
forall a b. (a -> b) -> a -> b
$
  case Maybe AccessToken
currentRefreshToken of
    Maybe AccessToken
Nothing ->
      CredentialFailure -> ExceptT CredentialFailure m Tokens
forall a. CredentialFailure -> ExceptT CredentialFailure m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (CredentialFailure -> ExceptT CredentialFailure m Tokens)
-> CredentialFailure -> ExceptT CredentialFailure m Tokens
forall a b. (a -> b) -> a -> b
$ (AccessToken -> CredentialFailure
RefreshFailure (AccessToken -> CredentialFailure)
-> (String -> AccessToken) -> String -> CredentialFailure
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> AccessToken
Text.pack (String -> CredentialFailure) -> String -> CredentialFailure
forall a b. (a -> b) -> a -> b
$ String
"Unable to refresh authentication, please run auth.login and try again.")
    Just AccessToken
rt -> do
      Request
req <- IO Request -> ExceptT CredentialFailure m Request
forall a. IO a -> ExceptT CredentialFailure m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Request -> ExceptT CredentialFailure m Request)
-> IO Request -> ExceptT CredentialFailure m Request
forall a b. (a -> b) -> a -> b
$ URI -> IO Request
forall (m :: * -> *). MonadThrow m => URI -> m Request
HTTP.requestFromURI URI
tokenEndpoint
      let addFormData :: Request -> Request
addFormData =
            [(ByteString, ByteString)] -> Request -> Request
HTTP.urlEncodedBody
              [ (ByteString
"grant_type", ByteString
"refresh_token"),
                (ByteString
"refresh_token", String -> ByteString
BSC.pack (String -> ByteString)
-> (AccessToken -> String) -> AccessToken -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AccessToken -> String
Text.unpack (AccessToken -> ByteString) -> AccessToken -> ByteString
forall a b. (a -> b) -> a -> b
$ AccessToken
rt)
              ]
      let fullReq :: Request
fullReq = Request -> Request
addFormData (Request -> Request) -> Request -> Request
forall a b. (a -> b) -> a -> b
$ Request
req {HTTP.method = "POST", HTTP.requestHeaders = [("Accept", "application/json")]}
      Manager
unauthenticatedHttpClient <- IO Manager -> ExceptT CredentialFailure m Manager
forall a. IO a -> ExceptT CredentialFailure m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Manager -> ExceptT CredentialFailure m Manager)
-> IO Manager -> ExceptT CredentialFailure m Manager
forall a b. (a -> b) -> a -> b
$ IO Manager
HTTP.getGlobalManager
      Response ByteString
resp <- IO (Response ByteString)
-> ExceptT CredentialFailure m (Response ByteString)
forall a. IO a -> ExceptT CredentialFailure m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Response ByteString)
 -> ExceptT CredentialFailure m (Response ByteString))
-> IO (Response ByteString)
-> ExceptT CredentialFailure m (Response ByteString)
forall a b. (a -> b) -> a -> b
$ Request -> Manager -> IO (Response ByteString)
HTTP.httpLbs Request
fullReq Manager
unauthenticatedHttpClient
      Tokens
newTokens <- case Response ByteString -> Status
forall body. Response body -> Status
HTTP.responseStatus Response ByteString
resp of
        Status
status
          | Status
status Status -> Status -> Bool
forall a. Ord a => a -> a -> Bool
< Status
Network.status300 -> do
              let respBytes :: ByteString
respBytes = Response ByteString -> ByteString
forall body. Response body -> body
HTTP.responseBody Response ByteString
resp
              case forall a. FromJSON a => ByteString -> Either String a
Aeson.eitherDecode @Tokens ByteString
respBytes of
                Left String
err -> CredentialFailure -> ExceptT CredentialFailure m Tokens
forall a. CredentialFailure -> ExceptT CredentialFailure m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (URI -> AccessToken -> CredentialFailure
InvalidTokenResponse URI
tokenEndpoint (String -> AccessToken
Text.pack String
err))
                Right Tokens
a -> Tokens -> ExceptT CredentialFailure m Tokens
forall a. a -> ExceptT CredentialFailure m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Tokens
a
          | Bool
otherwise -> CredentialFailure -> ExceptT CredentialFailure m Tokens
forall a. CredentialFailure -> ExceptT CredentialFailure m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (CredentialFailure -> ExceptT CredentialFailure m Tokens)
-> CredentialFailure -> ExceptT CredentialFailure m Tokens
forall a b. (a -> b) -> a -> b
$ (URI -> AccessToken -> CredentialFailure
InvalidTokenResponse URI
tokenEndpoint (AccessToken -> CredentialFailure)
-> AccessToken -> CredentialFailure
forall a b. (a -> b) -> a -> b
$ AccessToken
"Received " AccessToken -> AccessToken -> AccessToken
forall a. Semigroup a => a -> a -> a
<> Status -> AccessToken
forall a. Show a => a -> AccessToken
tShow Status
status AccessToken -> AccessToken -> AccessToken
forall a. Semigroup a => a -> a -> a
<> AccessToken
" response from token endpoint")
      -- According to the spec, servers may or may not update the refresh token itself.
      -- If updated we need to replace it, if not updated we keep the existing one.
      pure $ Tokens
newTokens {refreshToken = refreshToken newTokens <|> currentRefreshToken}