module Unison.Auth.Tokens where
import Control.Monad.Except
import Data.Aeson qualified as Aeson
import Data.ByteString.Char8 qualified as BSC
import Data.Text qualified as Text
import Data.Time.Clock (getCurrentTime)
import Network.HTTP.Client qualified as HTTP
import Network.HTTP.Client.TLS qualified as HTTP
import Network.HTTP.Types qualified as Network
import Unison.Auth.CredentialManager
import Unison.Auth.Discovery (fetchDiscoveryDoc)
import Unison.Auth.Types
import Unison.Auth.UserInfo (getUserInfo)
import Unison.Prelude
import Unison.Share.Types (CodeserverId)
import UnliftIO qualified
type TokenProvider = CodeserverId -> IO (Either CredentialFailure AccessToken)
newTokenProvider :: CredentialManager -> TokenProvider
newTokenProvider :: CredentialManager -> TokenProvider
newTokenProvider CredentialManager
manager CodeserverId
host = forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
m a -> m (Either e a)
UnliftIO.try @_ @CredentialFailure (IO AccessToken -> IO (Either CredentialFailure AccessToken))
-> IO AccessToken -> IO (Either CredentialFailure AccessToken)
forall a b. (a -> b) -> a -> b
$ do
creds :: CodeserverCredentials
creds@CodeserverCredentials {Tokens
tokens :: Tokens
$sel:tokens:CodeserverCredentials :: CodeserverCredentials -> Tokens
tokens, URI
discoveryURI :: URI
$sel:discoveryURI:CodeserverCredentials :: CodeserverCredentials -> URI
discoveryURI} <- IO (Either CredentialFailure CodeserverCredentials)
-> IO CodeserverCredentials
forall e (m :: * -> *) a.
(MonadIO m, Exception e) =>
m (Either e a) -> m a
throwEitherM (IO (Either CredentialFailure CodeserverCredentials)
-> IO CodeserverCredentials)
-> IO (Either CredentialFailure CodeserverCredentials)
-> IO CodeserverCredentials
forall a b. (a -> b) -> a -> b
$ CredentialManager
-> CodeserverId
-> IO (Either CredentialFailure CodeserverCredentials)
forall (m :: * -> *).
MonadIO m =>
CredentialManager
-> CodeserverId
-> m (Either CredentialFailure CodeserverCredentials)
getCredentials CredentialManager
manager CodeserverId
host
let Tokens {$sel:accessToken:Tokens :: Tokens -> AccessToken
accessToken = AccessToken
currentAccessToken} = Tokens
tokens
Bool
expired <- CodeserverCredentials -> IO Bool
forall (m :: * -> *). MonadIO m => CodeserverCredentials -> m Bool
isExpired CodeserverCredentials
creds
if Bool
expired
then do
DiscoveryDoc
discoveryDoc <- IO (Either CredentialFailure DiscoveryDoc) -> IO DiscoveryDoc
forall e (m :: * -> *) a.
(MonadIO m, Exception e) =>
m (Either e a) -> m a
throwEitherM (IO (Either CredentialFailure DiscoveryDoc) -> IO DiscoveryDoc)
-> IO (Either CredentialFailure DiscoveryDoc) -> IO DiscoveryDoc
forall a b. (a -> b) -> a -> b
$ URI -> IO (Either CredentialFailure DiscoveryDoc)
forall (m :: * -> *).
MonadIO m =>
URI -> m (Either CredentialFailure DiscoveryDoc)
fetchDiscoveryDoc URI
discoveryURI
UTCTime
fetchTime <- IO UTCTime
getCurrentTime
newTokens :: Tokens
newTokens@(Tokens {$sel:accessToken:Tokens :: Tokens -> AccessToken
accessToken = AccessToken
newAccessToken}) <- IO (Either CredentialFailure Tokens) -> IO Tokens
forall e (m :: * -> *) a.
(MonadIO m, Exception e) =>
m (Either e a) -> m a
throwEitherM (IO (Either CredentialFailure Tokens) -> IO Tokens)
-> IO (Either CredentialFailure Tokens) -> IO Tokens
forall a b. (a -> b) -> a -> b
$ DiscoveryDoc -> Tokens -> IO (Either CredentialFailure Tokens)
forall (m :: * -> *).
MonadIO m =>
DiscoveryDoc -> Tokens -> m (Either CredentialFailure Tokens)
performTokenRefresh DiscoveryDoc
discoveryDoc Tokens
tokens
UserInfo
userInfo <- IO (Either CredentialFailure UserInfo) -> IO UserInfo
forall e (m :: * -> *) a.
(MonadIO m, Exception e) =>
m (Either e a) -> m a
throwEitherM (IO (Either CredentialFailure UserInfo) -> IO UserInfo)
-> IO (Either CredentialFailure UserInfo) -> IO UserInfo
forall a b. (a -> b) -> a -> b
$ DiscoveryDoc
-> AccessToken -> IO (Either CredentialFailure UserInfo)
forall (m :: * -> *).
MonadIO m =>
DiscoveryDoc
-> AccessToken -> m (Either CredentialFailure UserInfo)
getUserInfo DiscoveryDoc
discoveryDoc AccessToken
newAccessToken
CredentialManager -> CodeserverId -> CodeserverCredentials -> IO ()
forall (m :: * -> *).
MonadUnliftIO m =>
CredentialManager -> CodeserverId -> CodeserverCredentials -> m ()
saveCredentials CredentialManager
manager CodeserverId
host (URI -> Tokens -> UTCTime -> UserInfo -> CodeserverCredentials
codeserverCredentials URI
discoveryURI Tokens
newTokens UTCTime
fetchTime UserInfo
userInfo)
pure $ AccessToken
newAccessToken
else AccessToken -> IO AccessToken
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure AccessToken
currentAccessToken
performTokenRefresh :: (MonadIO m) => DiscoveryDoc -> Tokens -> m (Either CredentialFailure Tokens)
performTokenRefresh :: forall (m :: * -> *).
MonadIO m =>
DiscoveryDoc -> Tokens -> m (Either CredentialFailure Tokens)
performTokenRefresh DiscoveryDoc {URI
tokenEndpoint :: URI
$sel:tokenEndpoint:DiscoveryDoc :: DiscoveryDoc -> URI
tokenEndpoint} (Tokens {$sel:refreshToken:Tokens :: Tokens -> Maybe AccessToken
refreshToken = Maybe AccessToken
currentRefreshToken}) = ExceptT CredentialFailure m Tokens
-> m (Either CredentialFailure Tokens)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT CredentialFailure m Tokens
-> m (Either CredentialFailure Tokens))
-> ExceptT CredentialFailure m Tokens
-> m (Either CredentialFailure Tokens)
forall a b. (a -> b) -> a -> b
$
case Maybe AccessToken
currentRefreshToken of
Maybe AccessToken
Nothing ->
CredentialFailure -> ExceptT CredentialFailure m Tokens
forall a. CredentialFailure -> ExceptT CredentialFailure m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (CredentialFailure -> ExceptT CredentialFailure m Tokens)
-> CredentialFailure -> ExceptT CredentialFailure m Tokens
forall a b. (a -> b) -> a -> b
$ (AccessToken -> CredentialFailure
RefreshFailure (AccessToken -> CredentialFailure)
-> (String -> AccessToken) -> String -> CredentialFailure
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> AccessToken
Text.pack (String -> CredentialFailure) -> String -> CredentialFailure
forall a b. (a -> b) -> a -> b
$ String
"Unable to refresh authentication, please run auth.login and try again.")
Just AccessToken
rt -> do
Request
req <- IO Request -> ExceptT CredentialFailure m Request
forall a. IO a -> ExceptT CredentialFailure m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Request -> ExceptT CredentialFailure m Request)
-> IO Request -> ExceptT CredentialFailure m Request
forall a b. (a -> b) -> a -> b
$ URI -> IO Request
forall (m :: * -> *). MonadThrow m => URI -> m Request
HTTP.requestFromURI URI
tokenEndpoint
let addFormData :: Request -> Request
addFormData =
[(ByteString, ByteString)] -> Request -> Request
HTTP.urlEncodedBody
[ (ByteString
"grant_type", ByteString
"refresh_token"),
(ByteString
"refresh_token", String -> ByteString
BSC.pack (String -> ByteString)
-> (AccessToken -> String) -> AccessToken -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AccessToken -> String
Text.unpack (AccessToken -> ByteString) -> AccessToken -> ByteString
forall a b. (a -> b) -> a -> b
$ AccessToken
rt)
]
let fullReq :: Request
fullReq = Request -> Request
addFormData (Request -> Request) -> Request -> Request
forall a b. (a -> b) -> a -> b
$ Request
req {HTTP.method = "POST", HTTP.requestHeaders = [("Accept", "application/json")]}
Manager
unauthenticatedHttpClient <- IO Manager -> ExceptT CredentialFailure m Manager
forall a. IO a -> ExceptT CredentialFailure m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Manager -> ExceptT CredentialFailure m Manager)
-> IO Manager -> ExceptT CredentialFailure m Manager
forall a b. (a -> b) -> a -> b
$ IO Manager
HTTP.getGlobalManager
Response ByteString
resp <- IO (Response ByteString)
-> ExceptT CredentialFailure m (Response ByteString)
forall a. IO a -> ExceptT CredentialFailure m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Response ByteString)
-> ExceptT CredentialFailure m (Response ByteString))
-> IO (Response ByteString)
-> ExceptT CredentialFailure m (Response ByteString)
forall a b. (a -> b) -> a -> b
$ Request -> Manager -> IO (Response ByteString)
HTTP.httpLbs Request
fullReq Manager
unauthenticatedHttpClient
Tokens
newTokens <- case Response ByteString -> Status
forall body. Response body -> Status
HTTP.responseStatus Response ByteString
resp of
Status
status
| Status
status Status -> Status -> Bool
forall a. Ord a => a -> a -> Bool
< Status
Network.status300 -> do
let respBytes :: ByteString
respBytes = Response ByteString -> ByteString
forall body. Response body -> body
HTTP.responseBody Response ByteString
resp
case forall a. FromJSON a => ByteString -> Either String a
Aeson.eitherDecode @Tokens ByteString
respBytes of
Left String
err -> CredentialFailure -> ExceptT CredentialFailure m Tokens
forall a. CredentialFailure -> ExceptT CredentialFailure m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (URI -> AccessToken -> CredentialFailure
InvalidTokenResponse URI
tokenEndpoint (String -> AccessToken
Text.pack String
err))
Right Tokens
a -> Tokens -> ExceptT CredentialFailure m Tokens
forall a. a -> ExceptT CredentialFailure m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Tokens
a
| Bool
otherwise -> CredentialFailure -> ExceptT CredentialFailure m Tokens
forall a. CredentialFailure -> ExceptT CredentialFailure m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (CredentialFailure -> ExceptT CredentialFailure m Tokens)
-> CredentialFailure -> ExceptT CredentialFailure m Tokens
forall a b. (a -> b) -> a -> b
$ (URI -> AccessToken -> CredentialFailure
InvalidTokenResponse URI
tokenEndpoint (AccessToken -> CredentialFailure)
-> AccessToken -> CredentialFailure
forall a b. (a -> b) -> a -> b
$ AccessToken
"Received " AccessToken -> AccessToken -> AccessToken
forall a. Semigroup a => a -> a -> a
<> Status -> AccessToken
forall a. Show a => a -> AccessToken
tShow Status
status AccessToken -> AccessToken -> AccessToken
forall a. Semigroup a => a -> a -> a
<> AccessToken
" response from token endpoint")
pure $ Tokens
newTokens {refreshToken = refreshToken newTokens <|> currentRefreshToken}