{-# LANGUAGE DeriveAnyClass #-}

module Unison.Auth.CredentialManager
  ( saveCredentials,
    CredentialManager,
    globalCredentialManager,
    getCodeserverCredentials,
    getOrCreatePersonalKey,
    isExpired,
  )
where

import Control.Monad.Trans.Except
import Data.Map qualified as Map
import Data.Time.Clock (addUTCTime, diffUTCTime, getCurrentTime)
import System.IO.Unsafe (unsafePerformIO)
import Unison.Auth.CredentialFile qualified as CF
import Unison.Auth.PersonalKey (PersonalPrivateKey, generatePersonalKey)
import Unison.Auth.Types hiding (getCodeserverCredentials)
import Unison.Auth.Types qualified as Auth
import Unison.Prelude
import Unison.Share.Types (CodeserverId)
import UnliftIO qualified

-- | A 'CredentialManager' knows how to load, save, and cache credentials.
-- It's thread-safe and safe for use across multiple UCM clients.
-- Note: Currently the in-memory cache is _not_ updated if a different UCM updates
-- the credentials file, however this shouldn't pose any problems, since auth will still
-- be refreshed if we encounter any auth failures on requests.
newtype CredentialManager = CredentialManager (UnliftIO.MVar (Maybe Credentials {- Credentials may or may not be initialized -}))

-- | A global CredentialManager instance/singleton.
globalCredentialManager :: CredentialManager
globalCredentialManager :: CredentialManager
globalCredentialManager = IO CredentialManager -> CredentialManager
forall a. IO a -> a
unsafePerformIO do
  MVar (Maybe Credentials) -> CredentialManager
CredentialManager (MVar (Maybe Credentials) -> CredentialManager)
-> IO (MVar (Maybe Credentials)) -> IO CredentialManager
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe Credentials -> IO (MVar (Maybe Credentials))
forall (m :: * -> *) a. MonadIO m => a -> m (MVar a)
UnliftIO.newMVar Maybe Credentials
forall a. Maybe a
Nothing
{-# NOINLINE globalCredentialManager #-}

-- | Fetches the user's personal key from the active profile, if it exists.
-- Otherwise it creates a new personal key, saves it to the active profile, and returns it.
getOrCreatePersonalKey :: (MonadUnliftIO m) => CredentialManager -> m PersonalPrivateKey
getOrCreatePersonalKey :: forall (m :: * -> *).
MonadUnliftIO m =>
CredentialManager -> m PersonalPrivateKey
getOrCreatePersonalKey CredentialManager
credMan = do
  CredentialManager
-> (Credentials -> m (Credentials, PersonalPrivateKey))
-> m PersonalPrivateKey
forall (m :: * -> *) r.
MonadUnliftIO m =>
CredentialManager -> (Credentials -> m (Credentials, r)) -> m r
modifyCredentials CredentialManager
credMan \creds :: Credentials
creds@(Credentials {ProfileName
activeProfile :: ProfileName
$sel:activeProfile:Credentials :: Credentials -> ProfileName
activeProfile, Map ProfileName PersonalPrivateKey
personalKeys :: Map ProfileName PersonalPrivateKey
$sel:personalKeys:Credentials :: Credentials -> Map ProfileName PersonalPrivateKey
personalKeys}) -> do
    case ProfileName
-> Map ProfileName PersonalPrivateKey -> Maybe PersonalPrivateKey
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup ProfileName
activeProfile Map ProfileName PersonalPrivateKey
personalKeys of
      Just PersonalPrivateKey
pk -> (Credentials, PersonalPrivateKey)
-> m (Credentials, PersonalPrivateKey)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Credentials
creds, PersonalPrivateKey
pk)
      Maybe PersonalPrivateKey
Nothing -> do
        PersonalPrivateKey
pk <- m PersonalPrivateKey
forall (m :: * -> *). MonadIO m => m PersonalPrivateKey
generatePersonalKey
        pure (Credentials
creds {personalKeys = Map.insert activeProfile pk personalKeys}, PersonalPrivateKey
pk)

-- | Saves credentials to the active profile.
saveCredentials :: (UnliftIO.MonadUnliftIO m) => CredentialManager -> CodeserverId -> CodeserverCredentials -> m ()
saveCredentials :: forall (m :: * -> *).
MonadUnliftIO m =>
CredentialManager -> CodeserverId -> CodeserverCredentials -> m ()
saveCredentials CredentialManager
credManager CodeserverId
aud CodeserverCredentials
creds = do
  m () -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m () -> m ())
-> ((Credentials -> m (Credentials, ())) -> m ())
-> (Credentials -> m (Credentials, ()))
-> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CredentialManager -> (Credentials -> m (Credentials, ())) -> m ()
forall (m :: * -> *) r.
MonadUnliftIO m =>
CredentialManager -> (Credentials -> m (Credentials, r)) -> m r
modifyCredentials CredentialManager
credManager ((Credentials -> m (Credentials, ())) -> m ())
-> (Credentials -> m (Credentials, ())) -> m ()
forall a b. (a -> b) -> a -> b
$ \Credentials
cf -> (Credentials, ()) -> m (Credentials, ())
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (CodeserverId -> CodeserverCredentials -> Credentials -> Credentials
setCodeserverCredentials CodeserverId
aud CodeserverCredentials
creds Credentials
cf, ())

-- | Atomically update the credential storage file, and update the in-memory cache.
modifyCredentials :: (UnliftIO.MonadUnliftIO m) => CredentialManager -> (Credentials -> m (Credentials, r)) -> m r
modifyCredentials :: forall (m :: * -> *) r.
MonadUnliftIO m =>
CredentialManager -> (Credentials -> m (Credentials, r)) -> m r
modifyCredentials (CredentialManager MVar (Maybe Credentials)
credsVar) Credentials -> m (Credentials, r)
f = do
  MVar (Maybe Credentials)
-> (Maybe Credentials -> m (Maybe Credentials, r)) -> m r
forall (m :: * -> *) a b.
MonadUnliftIO m =>
MVar a -> (a -> m (a, b)) -> m b
UnliftIO.modifyMVar MVar (Maybe Credentials)
credsVar ((Maybe Credentials -> m (Maybe Credentials, r)) -> m r)
-> (Maybe Credentials -> m (Maybe Credentials, r)) -> m r
forall a b. (a -> b) -> a -> b
$ \Maybe Credentials
_ -> do
    (Credentials
creds, r
r) <- (Credentials -> m (Credentials, (Credentials, r)))
-> m (Credentials, r)
forall (m :: * -> *) r.
MonadUnliftIO m =>
(Credentials -> m (Credentials, r)) -> m r
CF.atomicallyModifyCredentialsFile (Credentials -> m (Credentials, r)
f (Credentials -> m (Credentials, r))
-> ((Credentials, r) -> m (Credentials, (Credentials, r)))
-> Credentials
-> m (Credentials, (Credentials, r))
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> \(Credentials
creds', r
r') -> (Credentials, (Credentials, r))
-> m (Credentials, (Credentials, r))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Credentials
creds', (Credentials
creds', r
r')))
    (Maybe Credentials, r) -> m (Maybe Credentials, r)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Credentials -> Maybe Credentials
forall a. a -> Maybe a
Just Credentials
creds, r
r)

readCredentials :: (UnliftIO.MonadUnliftIO m) => CredentialManager -> m Credentials
readCredentials :: forall (m :: * -> *).
MonadUnliftIO m =>
CredentialManager -> m Credentials
readCredentials (CredentialManager MVar (Maybe Credentials)
credsVar) = do
  MVar (Maybe Credentials)
-> (Maybe Credentials -> m (Maybe Credentials, Credentials))
-> m Credentials
forall (m :: * -> *) a b.
MonadUnliftIO m =>
MVar a -> (a -> m (a, b)) -> m b
UnliftIO.modifyMVar MVar (Maybe Credentials)
credsVar ((Maybe Credentials -> m (Maybe Credentials, Credentials))
 -> m Credentials)
-> (Maybe Credentials -> m (Maybe Credentials, Credentials))
-> m Credentials
forall a b. (a -> b) -> a -> b
$ \Maybe Credentials
mayCreds -> case Maybe Credentials
mayCreds of
    Just Credentials
creds -> (Maybe Credentials, Credentials)
-> m (Maybe Credentials, Credentials)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Credentials
mayCreds, Credentials
creds)
    Maybe Credentials
Nothing -> do
      Credentials
creds <- (Credentials -> m (Credentials, Credentials)) -> m Credentials
forall (m :: * -> *) r.
MonadUnliftIO m =>
(Credentials -> m (Credentials, r)) -> m r
CF.atomicallyModifyCredentialsFile \Credentials
c -> (Credentials, Credentials) -> m (Credentials, Credentials)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Credentials
c, Credentials
c)
      pure (Credentials -> Maybe Credentials
forall a. a -> Maybe a
Just Credentials
creds, Credentials
creds)

getCodeserverCredentials :: (MonadIO m) => CredentialManager -> CodeserverId -> m (Either CredentialFailure CodeserverCredentials)
getCodeserverCredentials :: forall (m :: * -> *).
MonadIO m =>
CredentialManager
-> CodeserverId
-> m (Either CredentialFailure CodeserverCredentials)
getCodeserverCredentials CredentialManager
credMan CodeserverId
aud = ExceptT CredentialFailure m CodeserverCredentials
-> m (Either CredentialFailure CodeserverCredentials)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT do
  Credentials
creds <- IO Credentials -> ExceptT CredentialFailure m Credentials
forall a. IO a -> ExceptT CredentialFailure m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Credentials -> ExceptT CredentialFailure m Credentials)
-> IO Credentials -> ExceptT CredentialFailure m Credentials
forall a b. (a -> b) -> a -> b
$ CredentialManager -> IO Credentials
forall (m :: * -> *).
MonadUnliftIO m =>
CredentialManager -> m Credentials
readCredentials CredentialManager
credMan
  CodeserverCredentials
codeserverCreds <- Either CredentialFailure CodeserverCredentials
-> ExceptT CredentialFailure m CodeserverCredentials
forall (m :: * -> *) e a. Monad m => Either e a -> ExceptT e m a
except (CodeserverId
-> Credentials -> Either CredentialFailure CodeserverCredentials
Auth.getCodeserverCredentials CodeserverId
aud Credentials
creds)
  m Bool -> ExceptT CredentialFailure m Bool
forall (m :: * -> *) a.
Monad m =>
m a -> ExceptT CredentialFailure m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (CodeserverCredentials -> m Bool
forall (m :: * -> *). MonadIO m => CodeserverCredentials -> m Bool
isExpired CodeserverCredentials
codeserverCreds) ExceptT CredentialFailure m Bool
-> (Bool -> ExceptT CredentialFailure m CodeserverCredentials)
-> ExceptT CredentialFailure m CodeserverCredentials
forall a b.
ExceptT CredentialFailure m a
-> (a -> ExceptT CredentialFailure m b)
-> ExceptT CredentialFailure m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Bool
True -> CredentialFailure
-> ExceptT CredentialFailure m CodeserverCredentials
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE (CodeserverId -> CredentialFailure
ReauthRequired CodeserverId
aud)
    Bool
False -> CodeserverCredentials
-> ExceptT CredentialFailure m CodeserverCredentials
forall a. a -> ExceptT CredentialFailure m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure CodeserverCredentials
codeserverCreds

-- | Checks whether CodeserverCredentials are expired.
isExpired :: (MonadIO m) => CodeserverCredentials -> m Bool
isExpired :: forall (m :: * -> *). MonadIO m => CodeserverCredentials -> m Bool
isExpired CodeserverCredentials {UTCTime
fetchTime :: UTCTime
$sel:fetchTime:CodeserverCredentials :: CodeserverCredentials -> UTCTime
fetchTime, $sel:tokens:CodeserverCredentials :: CodeserverCredentials -> Tokens
tokens = Tokens {NominalDiffTime
expiresIn :: NominalDiffTime
$sel:expiresIn:Tokens :: Tokens -> NominalDiffTime
expiresIn}} = IO Bool -> m Bool
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO do
  UTCTime
now <- IO UTCTime
getCurrentTime
  let expTime :: UTCTime
expTime = NominalDiffTime -> UTCTime -> UTCTime
addUTCTime NominalDiffTime
expiresIn UTCTime
fetchTime
  let remainingTime :: NominalDiffTime
remainingTime = UTCTime -> UTCTime -> NominalDiffTime
diffUTCTime UTCTime
expTime UTCTime
now
  let threshold :: NominalDiffTime
threshold = NominalDiffTime
expiresIn NominalDiffTime -> NominalDiffTime -> NominalDiffTime
forall a. Num a => a -> a -> a
* NominalDiffTime
0.1
  pure (NominalDiffTime
threshold NominalDiffTime -> NominalDiffTime -> Bool
forall a. Ord a => a -> a -> Bool
>= NominalDiffTime
remainingTime)