module Unison.Auth.Tokens where
import Control.Monad.Except
import Data.Aeson qualified as Aeson
import Data.ByteString.Char8 qualified as BSC
import Data.Text qualified as Text
import Data.Time.Clock (getCurrentTime)
import Network.HTTP.Client qualified as HTTP
import Network.HTTP.Client.TLS qualified as HTTP
import Network.HTTP.Types qualified as Network
import System.Environment (lookupEnv)
import Unison.Auth.CredentialManager
import Unison.Auth.CredentialManager qualified as CredMan
import Unison.Auth.Discovery (fetchDiscoveryDoc)
import Unison.Auth.Types
import Unison.Auth.UserInfo (getUserInfo)
import Unison.Prelude
import Unison.Share.Types (CodeserverId)
import UnliftIO qualified
type TokenProvider = CodeserverId -> IO (Either CredentialFailure AccessToken)
accessTokenEnvVarKey :: String
accessTokenEnvVarKey :: String
accessTokenEnvVarKey = String
"UNISON_SHARE_ACCESS_TOKEN"
newTokenProvider :: CredentialManager -> TokenProvider
newTokenProvider :: CredentialManager -> TokenProvider
newTokenProvider CredentialManager
manager CodeserverId
host = forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
m a -> m (Either e a)
UnliftIO.try @_ @CredentialFailure (IO AccessToken -> IO (Either CredentialFailure AccessToken))
-> IO AccessToken -> IO (Either CredentialFailure AccessToken)
forall a b. (a -> b) -> a -> b
$ do
mayShareAccessToken <- (String -> AccessToken) -> Maybe String -> Maybe AccessToken
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap String -> AccessToken
Text.pack (Maybe String -> Maybe AccessToken)
-> IO (Maybe String) -> IO (Maybe AccessToken)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> IO (Maybe String)
lookupEnv String
accessTokenEnvVarKey
case mayShareAccessToken of
Just AccessToken
accessToken -> do
AccessToken -> IO AccessToken
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure AccessToken
accessToken
Maybe AccessToken
Nothing -> do
creds@CodeserverCredentials {tokens, discoveryURI} <- IO (Either CredentialFailure CodeserverCredentials)
-> IO CodeserverCredentials
forall e (m :: * -> *) a.
(MonadIO m, Exception e) =>
m (Either e a) -> m a
throwEitherM (IO (Either CredentialFailure CodeserverCredentials)
-> IO CodeserverCredentials)
-> IO (Either CredentialFailure CodeserverCredentials)
-> IO CodeserverCredentials
forall a b. (a -> b) -> a -> b
$ CredentialManager
-> CodeserverId
-> IO (Either CredentialFailure CodeserverCredentials)
CredMan.getCodeserverCredentials CredentialManager
manager CodeserverId
host
let Tokens {accessToken = currentAccessToken} = tokens
expired <- isExpired creds
if expired
then do
discoveryDoc <- throwEitherM $ fetchDiscoveryDoc discoveryURI
fetchTime <- getCurrentTime
newTokens@(Tokens {accessToken = newAccessToken}) <- throwEitherM $ performTokenRefresh discoveryDoc tokens
userInfo <- throwEitherM $ getUserInfo discoveryDoc newAccessToken
saveCredentials manager host (codeserverCredentials discoveryURI newTokens fetchTime userInfo)
pure newAccessToken
else pure currentAccessToken
performTokenRefresh :: (MonadIO m) => DiscoveryDoc -> Tokens -> m (Either CredentialFailure Tokens)
performTokenRefresh :: forall (m :: * -> *).
MonadIO m =>
DiscoveryDoc -> Tokens -> m (Either CredentialFailure Tokens)
performTokenRefresh DiscoveryDoc {URI
tokenEndpoint :: URI
tokenEndpoint :: DiscoveryDoc -> URI
tokenEndpoint} (Tokens {refreshToken :: Tokens -> Maybe AccessToken
refreshToken = Maybe AccessToken
currentRefreshToken}) = ExceptT CredentialFailure m Tokens
-> m (Either CredentialFailure Tokens)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT CredentialFailure m Tokens
-> m (Either CredentialFailure Tokens))
-> ExceptT CredentialFailure m Tokens
-> m (Either CredentialFailure Tokens)
forall a b. (a -> b) -> a -> b
$
case Maybe AccessToken
currentRefreshToken of
Maybe AccessToken
Nothing ->
CredentialFailure -> ExceptT CredentialFailure m Tokens
forall a. CredentialFailure -> ExceptT CredentialFailure m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (CredentialFailure -> ExceptT CredentialFailure m Tokens)
-> CredentialFailure -> ExceptT CredentialFailure m Tokens
forall a b. (a -> b) -> a -> b
$ (AccessToken -> CredentialFailure
RefreshFailure (AccessToken -> CredentialFailure)
-> (String -> AccessToken) -> String -> CredentialFailure
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> AccessToken
Text.pack (String -> CredentialFailure) -> String -> CredentialFailure
forall a b. (a -> b) -> a -> b
$ String
"Unable to refresh authentication, please run auth.login and try again.")
Just AccessToken
rt -> do
req <- IO Request -> ExceptT CredentialFailure m Request
forall a. IO a -> ExceptT CredentialFailure m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Request -> ExceptT CredentialFailure m Request)
-> IO Request -> ExceptT CredentialFailure m Request
forall a b. (a -> b) -> a -> b
$ URI -> IO Request
forall (m :: * -> *). MonadThrow m => URI -> m Request
HTTP.requestFromURI URI
tokenEndpoint
let addFormData =
[(ByteString, ByteString)] -> Request -> Request
HTTP.urlEncodedBody
[ (ByteString
"grant_type", ByteString
"refresh_token"),
(ByteString
"refresh_token", String -> ByteString
BSC.pack (String -> ByteString)
-> (AccessToken -> String) -> AccessToken -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AccessToken -> String
Text.unpack (AccessToken -> ByteString) -> AccessToken -> ByteString
forall a b. (a -> b) -> a -> b
$ AccessToken
rt)
]
let fullReq = Request -> Request
addFormData (Request -> Request) -> Request -> Request
forall a b. (a -> b) -> a -> b
$ Request
req {HTTP.method = "POST", HTTP.requestHeaders = [("Accept", "application/json")]}
unauthenticatedHttpClient <- liftIO $ HTTP.getGlobalManager
resp <- liftIO $ HTTP.httpLbs fullReq unauthenticatedHttpClient
newTokens <- case HTTP.responseStatus resp of
Status
status
| Status
status Status -> Status -> Bool
forall a. Ord a => a -> a -> Bool
< Status
Network.status300 -> do
let respBytes :: ByteString
respBytes = Response ByteString -> ByteString
forall body. Response body -> body
HTTP.responseBody Response ByteString
resp
case forall a. FromJSON a => ByteString -> Either String a
Aeson.eitherDecode @Tokens ByteString
respBytes of
Left String
err -> CredentialFailure -> ExceptT CredentialFailure m Tokens
forall a. CredentialFailure -> ExceptT CredentialFailure m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (URI -> AccessToken -> CredentialFailure
InvalidTokenResponse URI
tokenEndpoint (String -> AccessToken
Text.pack String
err))
Right Tokens
a -> Tokens -> ExceptT CredentialFailure m Tokens
forall a. a -> ExceptT CredentialFailure m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Tokens
a
| Bool
otherwise -> CredentialFailure -> ExceptT CredentialFailure m Tokens
forall a. CredentialFailure -> ExceptT CredentialFailure m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (CredentialFailure -> ExceptT CredentialFailure m Tokens)
-> CredentialFailure -> ExceptT CredentialFailure m Tokens
forall a b. (a -> b) -> a -> b
$ (URI -> AccessToken -> CredentialFailure
InvalidTokenResponse URI
tokenEndpoint (AccessToken -> CredentialFailure)
-> AccessToken -> CredentialFailure
forall a b. (a -> b) -> a -> b
$ AccessToken
"Received " AccessToken -> AccessToken -> AccessToken
forall a. Semigroup a => a -> a -> a
<> Status -> AccessToken
forall a. Show a => a -> AccessToken
tShow Status
status AccessToken -> AccessToken -> AccessToken
forall a. Semigroup a => a -> a -> a
<> AccessToken
" response from token endpoint")
pure $ newTokens {refreshToken = refreshToken newTokens <|> currentRefreshToken}