{-# LANGUAGE DeriveAnyClass #-}
module Unison.Auth.CredentialManager
( saveCredentials,
CredentialManager,
newCredentialManager,
getCredentials,
isExpired,
)
where
import Control.Monad.Trans.Except
import Data.Time.Clock (addUTCTime, diffUTCTime, getCurrentTime)
import Unison.Auth.CredentialFile
import Unison.Auth.Types
import Unison.Prelude
import Unison.Share.Types (CodeserverId)
import UnliftIO qualified
newtype CredentialManager = CredentialManager (UnliftIO.MVar Credentials)
saveCredentials :: (UnliftIO.MonadUnliftIO m) => CredentialManager -> CodeserverId -> CodeserverCredentials -> m ()
saveCredentials :: forall (m :: * -> *).
MonadUnliftIO m =>
CredentialManager -> CodeserverId -> CodeserverCredentials -> m ()
saveCredentials CredentialManager
credManager CodeserverId
aud CodeserverCredentials
creds = do
m Credentials -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m Credentials -> m ())
-> ((Credentials -> Credentials) -> m Credentials)
-> (Credentials -> Credentials)
-> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CredentialManager -> (Credentials -> Credentials) -> m Credentials
forall (m :: * -> *).
MonadUnliftIO m =>
CredentialManager -> (Credentials -> Credentials) -> m Credentials
modifyCredentials CredentialManager
credManager ((Credentials -> Credentials) -> m ())
-> (Credentials -> Credentials) -> m ()
forall a b. (a -> b) -> a -> b
$ CodeserverId -> CodeserverCredentials -> Credentials -> Credentials
setCodeserverCredentials CodeserverId
aud CodeserverCredentials
creds
modifyCredentials :: (UnliftIO.MonadUnliftIO m) => CredentialManager -> (Credentials -> Credentials) -> m Credentials
modifyCredentials :: forall (m :: * -> *).
MonadUnliftIO m =>
CredentialManager -> (Credentials -> Credentials) -> m Credentials
modifyCredentials (CredentialManager MVar Credentials
credsVar) Credentials -> Credentials
f = do
MVar Credentials
-> (Credentials -> m (Credentials, Credentials)) -> m Credentials
forall (m :: * -> *) a b.
MonadUnliftIO m =>
MVar a -> (a -> m (a, b)) -> m b
UnliftIO.modifyMVar MVar Credentials
credsVar ((Credentials -> m (Credentials, Credentials)) -> m Credentials)
-> (Credentials -> m (Credentials, Credentials)) -> m Credentials
forall a b. (a -> b) -> a -> b
$ \Credentials
_ -> do
Credentials
newCreds <- (Credentials -> Credentials) -> m Credentials
forall (m :: * -> *).
MonadIO m =>
(Credentials -> Credentials) -> m Credentials
atomicallyModifyCredentialsFile Credentials -> Credentials
f
pure (Credentials
newCreds, Credentials
newCreds)
getCredentials :: (MonadIO m) => CredentialManager -> CodeserverId -> m (Either CredentialFailure CodeserverCredentials)
getCredentials :: forall (m :: * -> *).
MonadIO m =>
CredentialManager
-> CodeserverId
-> m (Either CredentialFailure CodeserverCredentials)
getCredentials (CredentialManager MVar Credentials
credsVar) CodeserverId
aud = ExceptT CredentialFailure m CodeserverCredentials
-> m (Either CredentialFailure CodeserverCredentials)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT do
Credentials
creds <- m Credentials -> ExceptT CredentialFailure m Credentials
forall (m :: * -> *) a.
Monad m =>
m a -> ExceptT CredentialFailure m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (MVar Credentials -> m Credentials
forall (m :: * -> *) a. MonadIO m => MVar a -> m a
UnliftIO.readMVar MVar Credentials
credsVar)
CodeserverCredentials
codeserverCreds <- Either CredentialFailure CodeserverCredentials
-> ExceptT CredentialFailure m CodeserverCredentials
forall (m :: * -> *) e a. Monad m => Either e a -> ExceptT e m a
except (CodeserverId
-> Credentials -> Either CredentialFailure CodeserverCredentials
getCodeserverCredentials CodeserverId
aud Credentials
creds)
m Bool -> ExceptT CredentialFailure m Bool
forall (m :: * -> *) a.
Monad m =>
m a -> ExceptT CredentialFailure m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (CodeserverCredentials -> m Bool
forall (m :: * -> *). MonadIO m => CodeserverCredentials -> m Bool
isExpired CodeserverCredentials
codeserverCreds) ExceptT CredentialFailure m Bool
-> (Bool -> ExceptT CredentialFailure m CodeserverCredentials)
-> ExceptT CredentialFailure m CodeserverCredentials
forall a b.
ExceptT CredentialFailure m a
-> (a -> ExceptT CredentialFailure m b)
-> ExceptT CredentialFailure m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Bool
True -> CredentialFailure
-> ExceptT CredentialFailure m CodeserverCredentials
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE (CodeserverId -> CredentialFailure
ReauthRequired CodeserverId
aud)
Bool
False -> CodeserverCredentials
-> ExceptT CredentialFailure m CodeserverCredentials
forall a. a -> ExceptT CredentialFailure m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure CodeserverCredentials
codeserverCreds
newCredentialManager :: (MonadIO m) => m CredentialManager
newCredentialManager :: forall (m :: * -> *). MonadIO m => m CredentialManager
newCredentialManager = do
Credentials
credentials <- (Credentials -> Credentials) -> m Credentials
forall (m :: * -> *).
MonadIO m =>
(Credentials -> Credentials) -> m Credentials
atomicallyModifyCredentialsFile Credentials -> Credentials
forall a. a -> a
id
MVar Credentials
credentialsVar <- Credentials -> m (MVar Credentials)
forall (m :: * -> *) a. MonadIO m => a -> m (MVar a)
UnliftIO.newMVar Credentials
credentials
pure (MVar Credentials -> CredentialManager
CredentialManager MVar Credentials
credentialsVar)
isExpired :: (MonadIO m) => CodeserverCredentials -> m Bool
isExpired :: forall (m :: * -> *). MonadIO m => CodeserverCredentials -> m Bool
isExpired CodeserverCredentials {UTCTime
fetchTime :: UTCTime
$sel:fetchTime:CodeserverCredentials :: CodeserverCredentials -> UTCTime
fetchTime, $sel:tokens:CodeserverCredentials :: CodeserverCredentials -> Tokens
tokens = Tokens {NominalDiffTime
expiresIn :: NominalDiffTime
$sel:expiresIn:Tokens :: Tokens -> NominalDiffTime
expiresIn}} = IO Bool -> m Bool
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO do
UTCTime
now <- IO UTCTime
getCurrentTime
let expTime :: UTCTime
expTime = NominalDiffTime -> UTCTime -> UTCTime
addUTCTime NominalDiffTime
expiresIn UTCTime
fetchTime
let remainingTime :: NominalDiffTime
remainingTime = UTCTime -> UTCTime -> NominalDiffTime
diffUTCTime UTCTime
expTime UTCTime
now
let threshold :: NominalDiffTime
threshold = NominalDiffTime
expiresIn NominalDiffTime -> NominalDiffTime -> NominalDiffTime
forall a. Num a => a -> a -> a
* NominalDiffTime
0.1
pure (NominalDiffTime
threshold NominalDiffTime -> NominalDiffTime -> Bool
forall a. Ord a => a -> a -> Bool
>= NominalDiffTime
remainingTime)