{-# LANGUAGE DeriveAnyClass #-}

module Unison.Auth.CredentialManager
  ( saveCredentials,
    CredentialManager,
    newCredentialManager,
    getCredentials,
    isExpired,
  )
where

import Control.Monad.Trans.Except
import Data.Time.Clock (addUTCTime, diffUTCTime, getCurrentTime)
import Unison.Auth.CredentialFile
import Unison.Auth.Types
import Unison.Prelude
import Unison.Share.Types (CodeserverId)
import UnliftIO qualified

-- | A 'CredentialManager' knows how to load, save, and cache credentials.
-- It's thread-safe and safe for use across multiple UCM clients.
-- Note: Currently the in-memory cache is _not_ updated if a different UCM updates
-- the credentials file, however this shouldn't pose any problems, since auth will still
-- be refreshed if we encounter any auth failures on requests.
newtype CredentialManager = CredentialManager (UnliftIO.MVar Credentials)

-- | Saves credentials to the active profile.
saveCredentials :: (UnliftIO.MonadUnliftIO m) => CredentialManager -> CodeserverId -> CodeserverCredentials -> m ()
saveCredentials :: forall (m :: * -> *).
MonadUnliftIO m =>
CredentialManager -> CodeserverId -> CodeserverCredentials -> m ()
saveCredentials CredentialManager
credManager CodeserverId
aud CodeserverCredentials
creds = do
  m Credentials -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m Credentials -> m ())
-> ((Credentials -> Credentials) -> m Credentials)
-> (Credentials -> Credentials)
-> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CredentialManager -> (Credentials -> Credentials) -> m Credentials
forall (m :: * -> *).
MonadUnliftIO m =>
CredentialManager -> (Credentials -> Credentials) -> m Credentials
modifyCredentials CredentialManager
credManager ((Credentials -> Credentials) -> m ())
-> (Credentials -> Credentials) -> m ()
forall a b. (a -> b) -> a -> b
$ CodeserverId -> CodeserverCredentials -> Credentials -> Credentials
setCodeserverCredentials CodeserverId
aud CodeserverCredentials
creds

-- | Atomically update the credential storage file, and update the in-memory cache.
modifyCredentials :: (UnliftIO.MonadUnliftIO m) => CredentialManager -> (Credentials -> Credentials) -> m Credentials
modifyCredentials :: forall (m :: * -> *).
MonadUnliftIO m =>
CredentialManager -> (Credentials -> Credentials) -> m Credentials
modifyCredentials (CredentialManager MVar Credentials
credsVar) Credentials -> Credentials
f = do
  MVar Credentials
-> (Credentials -> m (Credentials, Credentials)) -> m Credentials
forall (m :: * -> *) a b.
MonadUnliftIO m =>
MVar a -> (a -> m (a, b)) -> m b
UnliftIO.modifyMVar MVar Credentials
credsVar ((Credentials -> m (Credentials, Credentials)) -> m Credentials)
-> (Credentials -> m (Credentials, Credentials)) -> m Credentials
forall a b. (a -> b) -> a -> b
$ \Credentials
_ -> do
    Credentials
newCreds <- (Credentials -> Credentials) -> m Credentials
forall (m :: * -> *).
MonadIO m =>
(Credentials -> Credentials) -> m Credentials
atomicallyModifyCredentialsFile Credentials -> Credentials
f
    pure (Credentials
newCreds, Credentials
newCreds)

getCredentials :: (MonadIO m) => CredentialManager -> CodeserverId -> m (Either CredentialFailure CodeserverCredentials)
getCredentials :: forall (m :: * -> *).
MonadIO m =>
CredentialManager
-> CodeserverId
-> m (Either CredentialFailure CodeserverCredentials)
getCredentials (CredentialManager MVar Credentials
credsVar) CodeserverId
aud = ExceptT CredentialFailure m CodeserverCredentials
-> m (Either CredentialFailure CodeserverCredentials)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT do
  Credentials
creds <- m Credentials -> ExceptT CredentialFailure m Credentials
forall (m :: * -> *) a.
Monad m =>
m a -> ExceptT CredentialFailure m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (MVar Credentials -> m Credentials
forall (m :: * -> *) a. MonadIO m => MVar a -> m a
UnliftIO.readMVar MVar Credentials
credsVar)
  CodeserverCredentials
codeserverCreds <- Either CredentialFailure CodeserverCredentials
-> ExceptT CredentialFailure m CodeserverCredentials
forall (m :: * -> *) e a. Monad m => Either e a -> ExceptT e m a
except (CodeserverId
-> Credentials -> Either CredentialFailure CodeserverCredentials
getCodeserverCredentials CodeserverId
aud Credentials
creds)
  m Bool -> ExceptT CredentialFailure m Bool
forall (m :: * -> *) a.
Monad m =>
m a -> ExceptT CredentialFailure m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (CodeserverCredentials -> m Bool
forall (m :: * -> *). MonadIO m => CodeserverCredentials -> m Bool
isExpired CodeserverCredentials
codeserverCreds) ExceptT CredentialFailure m Bool
-> (Bool -> ExceptT CredentialFailure m CodeserverCredentials)
-> ExceptT CredentialFailure m CodeserverCredentials
forall a b.
ExceptT CredentialFailure m a
-> (a -> ExceptT CredentialFailure m b)
-> ExceptT CredentialFailure m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Bool
True -> CredentialFailure
-> ExceptT CredentialFailure m CodeserverCredentials
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE (CodeserverId -> CredentialFailure
ReauthRequired CodeserverId
aud)
    Bool
False -> CodeserverCredentials
-> ExceptT CredentialFailure m CodeserverCredentials
forall a. a -> ExceptT CredentialFailure m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure CodeserverCredentials
codeserverCreds

newCredentialManager :: (MonadIO m) => m CredentialManager
newCredentialManager :: forall (m :: * -> *). MonadIO m => m CredentialManager
newCredentialManager = do
  Credentials
credentials <- (Credentials -> Credentials) -> m Credentials
forall (m :: * -> *).
MonadIO m =>
(Credentials -> Credentials) -> m Credentials
atomicallyModifyCredentialsFile Credentials -> Credentials
forall a. a -> a
id
  MVar Credentials
credentialsVar <- Credentials -> m (MVar Credentials)
forall (m :: * -> *) a. MonadIO m => a -> m (MVar a)
UnliftIO.newMVar Credentials
credentials
  pure (MVar Credentials -> CredentialManager
CredentialManager MVar Credentials
credentialsVar)

-- | Checks whether CodeserverCredentials are expired.
isExpired :: (MonadIO m) => CodeserverCredentials -> m Bool
isExpired :: forall (m :: * -> *). MonadIO m => CodeserverCredentials -> m Bool
isExpired CodeserverCredentials {UTCTime
fetchTime :: UTCTime
$sel:fetchTime:CodeserverCredentials :: CodeserverCredentials -> UTCTime
fetchTime, $sel:tokens:CodeserverCredentials :: CodeserverCredentials -> Tokens
tokens = Tokens {NominalDiffTime
expiresIn :: NominalDiffTime
$sel:expiresIn:Tokens :: Tokens -> NominalDiffTime
expiresIn}} = IO Bool -> m Bool
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO do
  UTCTime
now <- IO UTCTime
getCurrentTime
  let expTime :: UTCTime
expTime = NominalDiffTime -> UTCTime -> UTCTime
addUTCTime NominalDiffTime
expiresIn UTCTime
fetchTime
  let remainingTime :: NominalDiffTime
remainingTime = UTCTime -> UTCTime -> NominalDiffTime
diffUTCTime UTCTime
expTime UTCTime
now
  let threshold :: NominalDiffTime
threshold = NominalDiffTime
expiresIn NominalDiffTime -> NominalDiffTime -> NominalDiffTime
forall a. Num a => a -> a -> a
* NominalDiffTime
0.1
  pure (NominalDiffTime
threshold NominalDiffTime -> NominalDiffTime -> Bool
forall a. Ord a => a -> a -> Bool
>= NominalDiffTime
remainingTime)