module Unison.Auth.CredentialManager
  ( saveCredentials,
    CredentialManager,
    globalCredentialManager,
    newCredentialManager,
    getCodeserverCredentials,
    getOrCreatePersonalKey,
    isExpired,
  )
where

import Control.Concurrent.MVar (MVar, modifyMVar, newMVar)
import Control.Monad.Catch (MonadMask)
import Control.Monad.Trans.Except
import Data.Map qualified as Map
import Data.Time.Clock (addUTCTime, diffUTCTime, getCurrentTime)
import System.IO.Unsafe (unsafePerformIO)
import Unison.Auth.CredentialFile qualified as CF
import Unison.Auth.PersonalKey (PersonalPrivateKey, generatePersonalKey)
import Unison.Auth.Types hiding (getCodeserverCredentials)
import Unison.Auth.Types qualified as Auth
import Unison.Prelude
import Unison.Share.Types (CodeserverId)
import UnliftIO qualified

-- | A 'CredentialManager' knows how to load, save, and cache credentials.
--
-- It's thread-safe and safe for use across multiple UCM clients.
--
-- __Note__: Currently the in-memory cache is _not_ updated if a different UCM updates the credentials file, however
--           this shouldn't pose any problems, since auth will still be refreshed if we encounter any auth failures on
--           requests.
data CredentialManager = CredentialManager
  { CredentialManager -> MVar (Maybe Credentials)
credsVar :: MVar (Maybe Credentials),
    CredentialManager -> FilePath
file :: FilePath
  }

-- | A global CredentialManager instance/singleton.
globalCredentialManager :: CredentialManager
globalCredentialManager :: CredentialManager
globalCredentialManager = IO CredentialManager -> CredentialManager
forall a. IO a -> a
unsafePerformIO (IO CredentialManager -> CredentialManager)
-> IO CredentialManager -> CredentialManager
forall a b. (a -> b) -> a -> b
$ Maybe FilePath -> IO CredentialManager
newCredentialManager Maybe FilePath
forall a. Maybe a
Nothing
{-# NOINLINE globalCredentialManager #-}

-- | Fetches the user's personal key from the active profile, if it exists.
-- Otherwise it creates a new personal key, saves it to the active profile, and returns it.
getOrCreatePersonalKey :: CredentialManager -> IO PersonalPrivateKey
getOrCreatePersonalKey :: CredentialManager -> IO PersonalPrivateKey
getOrCreatePersonalKey CredentialManager
credMan = do
  CredentialManager
-> (Credentials -> IO (Credentials, PersonalPrivateKey))
-> IO PersonalPrivateKey
forall (m :: * -> *) r.
(MonadMask m, MonadUnliftIO m) =>
CredentialManager -> (Credentials -> m (Credentials, r)) -> m r
modifyCredentials CredentialManager
credMan \creds :: Credentials
creds@(Credentials {ProfileName
activeProfile :: ProfileName
activeProfile :: Credentials -> ProfileName
activeProfile, Map ProfileName PersonalPrivateKey
personalKeys :: Map ProfileName PersonalPrivateKey
personalKeys :: Credentials -> Map ProfileName PersonalPrivateKey
personalKeys}) ->
    case ProfileName
-> Map ProfileName PersonalPrivateKey -> Maybe PersonalPrivateKey
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup ProfileName
activeProfile Map ProfileName PersonalPrivateKey
personalKeys of
      Just PersonalPrivateKey
pk -> (Credentials, PersonalPrivateKey)
-> IO (Credentials, PersonalPrivateKey)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Credentials
creds, PersonalPrivateKey
pk)
      Maybe PersonalPrivateKey
Nothing -> do
        PersonalPrivateKey
pk <- IO PersonalPrivateKey
forall (m :: * -> *). MonadIO m => m PersonalPrivateKey
generatePersonalKey
        pure (Credentials
creds {personalKeys = Map.insert activeProfile pk personalKeys}, PersonalPrivateKey
pk)

-- | Saves credentials to the active profile.
saveCredentials :: CredentialManager -> CodeserverId -> CodeserverCredentials -> IO ()
saveCredentials :: CredentialManager -> CodeserverId -> CodeserverCredentials -> IO ()
saveCredentials CredentialManager
credManager CodeserverId
aud CodeserverCredentials
creds = do
  CredentialManager -> (Credentials -> IO (Credentials, ())) -> IO ()
forall (m :: * -> *) r.
(MonadMask m, MonadUnliftIO m) =>
CredentialManager -> (Credentials -> m (Credentials, r)) -> m r
modifyCredentials CredentialManager
credManager ((Credentials -> IO (Credentials, ())) -> IO ())
-> (Credentials -> IO (Credentials, ())) -> IO ()
forall a b. (a -> b) -> a -> b
$ (Credentials, ()) -> IO (Credentials, ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Credentials, ()) -> IO (Credentials, ()))
-> (Credentials -> (Credentials, ()))
-> Credentials
-> IO (Credentials, ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (,()) (Credentials -> (Credentials, ()))
-> (Credentials -> Credentials) -> Credentials -> (Credentials, ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CodeserverId -> CodeserverCredentials -> Credentials -> Credentials
setCodeserverCredentials CodeserverId
aud CodeserverCredentials
creds

-- | Atomically update the credential storage file, and update the in-memory cache.
modifyCredentials :: (MonadMask m, UnliftIO.MonadUnliftIO m) => CredentialManager -> (Credentials -> m (Credentials, r)) -> m r
modifyCredentials :: forall (m :: * -> *) r.
(MonadMask m, MonadUnliftIO m) =>
CredentialManager -> (Credentials -> m (Credentials, r)) -> m r
modifyCredentials (CredentialManager {MVar (Maybe Credentials)
credsVar :: CredentialManager -> MVar (Maybe Credentials)
credsVar :: MVar (Maybe Credentials)
credsVar, FilePath
file :: CredentialManager -> FilePath
file :: FilePath
file}) Credentials -> m (Credentials, r)
f =
  MVar (Maybe Credentials)
-> (Maybe Credentials -> m (Maybe Credentials, r)) -> m r
forall (m :: * -> *) a b.
MonadUnliftIO m =>
MVar a -> (a -> m (a, b)) -> m b
UnliftIO.modifyMVar MVar (Maybe Credentials)
credsVar ((Maybe Credentials -> m (Maybe Credentials, r)) -> m r)
-> (m (Maybe Credentials, r)
    -> Maybe Credentials -> m (Maybe Credentials, r))
-> m (Maybe Credentials, r)
-> m r
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m (Maybe Credentials, r)
-> Maybe Credentials -> m (Maybe Credentials, r)
forall a b. a -> b -> a
const (m (Maybe Credentials, r) -> m r)
-> m (Maybe Credentials, r) -> m r
forall a b. (a -> b) -> a -> b
$
    (Credentials -> Maybe Credentials)
-> (Credentials, r) -> (Maybe Credentials, r)
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first Credentials -> Maybe Credentials
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Credentials, r) -> (Maybe Credentials, r))
-> m (Credentials, r) -> m (Maybe Credentials, r)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Credentials -> m (Credentials, (Credentials, r)))
-> FilePath -> m (Credentials, r)
forall (m :: * -> *) r.
(MonadMask m, MonadIO m) =>
(Credentials -> m (Credentials, r)) -> FilePath -> m r
CF.atomicallyModifyCredentialsFile (((Credentials, r) -> (Credentials, (Credentials, r)))
-> m (Credentials, r) -> m (Credentials, (Credentials, r))
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\(Credentials
creds', r
r') -> (Credentials
creds', (Credentials
creds', r
r'))) (m (Credentials, r) -> m (Credentials, (Credentials, r)))
-> (Credentials -> m (Credentials, r))
-> Credentials
-> m (Credentials, (Credentials, r))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Credentials -> m (Credentials, r)
f) FilePath
file

readCredentials :: CredentialManager -> IO Credentials
readCredentials :: CredentialManager -> IO Credentials
readCredentials (CredentialManager {MVar (Maybe Credentials)
credsVar :: CredentialManager -> MVar (Maybe Credentials)
credsVar :: MVar (Maybe Credentials)
credsVar, FilePath
file :: CredentialManager -> FilePath
file :: FilePath
file}) =
  MVar (Maybe Credentials)
-> (Maybe Credentials -> IO (Maybe Credentials, Credentials))
-> IO Credentials
forall a b. MVar a -> (a -> IO (a, b)) -> IO b
modifyMVar MVar (Maybe Credentials)
credsVar ((Maybe Credentials -> IO (Maybe Credentials, Credentials))
 -> IO Credentials)
-> (Maybe Credentials -> IO (Maybe Credentials, Credentials))
-> IO Credentials
forall a b. (a -> b) -> a -> b
$ \case
    Just Credentials
creds -> (Maybe Credentials, Credentials)
-> IO (Maybe Credentials, Credentials)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Credentials -> Maybe Credentials
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Credentials
creds, Credentials
creds)
    Maybe Credentials
Nothing -> do
      Credentials
creds <- (Credentials -> IO (Credentials, Credentials))
-> FilePath -> IO Credentials
forall (m :: * -> *) r.
(MonadMask m, MonadIO m) =>
(Credentials -> m (Credentials, r)) -> FilePath -> m r
CF.atomicallyModifyCredentialsFile (\Credentials
c -> (Credentials, Credentials) -> IO (Credentials, Credentials)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Credentials
c, Credentials
c)) FilePath
file
      pure (Credentials -> Maybe Credentials
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Credentials
creds, Credentials
creds)

getCodeserverCredentials :: CredentialManager -> CodeserverId -> IO (Either CredentialFailure CodeserverCredentials)
getCodeserverCredentials :: CredentialManager
-> CodeserverId
-> IO (Either CredentialFailure CodeserverCredentials)
getCodeserverCredentials CredentialManager
credMan CodeserverId
aud = ExceptT CredentialFailure IO CodeserverCredentials
-> IO (Either CredentialFailure CodeserverCredentials)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT do
  Credentials
creds <- IO Credentials -> ExceptT CredentialFailure IO Credentials
forall (m :: * -> *) a.
Monad m =>
m a -> ExceptT CredentialFailure m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (IO Credentials -> ExceptT CredentialFailure IO Credentials)
-> IO Credentials -> ExceptT CredentialFailure IO Credentials
forall a b. (a -> b) -> a -> b
$ CredentialManager -> IO Credentials
readCredentials CredentialManager
credMan
  CodeserverCredentials
codeserverCreds <- Either CredentialFailure CodeserverCredentials
-> ExceptT CredentialFailure IO CodeserverCredentials
forall (m :: * -> *) e a. Monad m => Either e a -> ExceptT e m a
except (CodeserverId
-> Credentials -> Either CredentialFailure CodeserverCredentials
Auth.getCodeserverCredentials CodeserverId
aud Credentials
creds)
  IO Bool -> ExceptT CredentialFailure IO Bool
forall (m :: * -> *) a.
Monad m =>
m a -> ExceptT CredentialFailure m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (CodeserverCredentials -> IO Bool
isExpired CodeserverCredentials
codeserverCreds) ExceptT CredentialFailure IO Bool
-> (Bool -> ExceptT CredentialFailure IO CodeserverCredentials)
-> ExceptT CredentialFailure IO CodeserverCredentials
forall a b.
ExceptT CredentialFailure IO a
-> (a -> ExceptT CredentialFailure IO b)
-> ExceptT CredentialFailure IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Bool
True -> CredentialFailure
-> ExceptT CredentialFailure IO CodeserverCredentials
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE (CodeserverId -> CredentialFailure
ReauthRequired CodeserverId
aud)
    Bool
False -> CodeserverCredentials
-> ExceptT CredentialFailure IO CodeserverCredentials
forall a. a -> ExceptT CredentialFailure IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure CodeserverCredentials
codeserverCreds

newCredentialManager :: Maybe FilePath -> IO CredentialManager
newCredentialManager :: Maybe FilePath -> IO CredentialManager
newCredentialManager Maybe FilePath
mfile = do
  FilePath
file <- IO FilePath
-> (FilePath -> IO FilePath) -> Maybe FilePath -> IO FilePath
forall b a. b -> (a -> b) -> Maybe a -> b
maybe IO FilePath
CF.getCredentialJSONFilePath FilePath -> IO FilePath
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe FilePath
mfile
  Credentials
credentials <- (Credentials -> IO (Credentials, Credentials))
-> FilePath -> IO Credentials
forall (m :: * -> *) r.
(MonadMask m, MonadIO m) =>
(Credentials -> m (Credentials, r)) -> FilePath -> m r
CF.atomicallyModifyCredentialsFile (\Credentials
c -> (Credentials, Credentials) -> IO (Credentials, Credentials)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Credentials
c, Credentials
c)) FilePath
file
  MVar (Maybe Credentials)
credsVar <- Maybe Credentials -> IO (MVar (Maybe Credentials))
forall a. a -> IO (MVar a)
newMVar (Maybe Credentials -> IO (MVar (Maybe Credentials)))
-> Maybe Credentials -> IO (MVar (Maybe Credentials))
forall a b. (a -> b) -> a -> b
$ Credentials -> Maybe Credentials
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Credentials
credentials
  pure CredentialManager {MVar (Maybe Credentials)
credsVar :: MVar (Maybe Credentials)
credsVar :: MVar (Maybe Credentials)
credsVar, FilePath
file :: FilePath
file :: FilePath
file}

-- | Checks whether CodeserverCredentials are expired.
isExpired :: CodeserverCredentials -> IO Bool
isExpired :: CodeserverCredentials -> IO Bool
isExpired CodeserverCredentials {UTCTime
fetchTime :: UTCTime
fetchTime :: CodeserverCredentials -> UTCTime
fetchTime, tokens :: CodeserverCredentials -> Tokens
tokens = Tokens {NominalDiffTime
expiresIn :: NominalDiffTime
expiresIn :: Tokens -> NominalDiffTime
expiresIn}} = do
  UTCTime
now <- IO UTCTime
getCurrentTime
  let expTime :: UTCTime
expTime = NominalDiffTime -> UTCTime -> UTCTime
addUTCTime NominalDiffTime
expiresIn UTCTime
fetchTime
  let remainingTime :: NominalDiffTime
remainingTime = UTCTime -> UTCTime -> NominalDiffTime
diffUTCTime UTCTime
expTime UTCTime
now
  let threshold :: NominalDiffTime
threshold = NominalDiffTime
expiresIn NominalDiffTime -> NominalDiffTime -> NominalDiffTime
forall a. Num a => a -> a -> a
* NominalDiffTime
0.1
  pure (NominalDiffTime
threshold NominalDiffTime -> NominalDiffTime -> Bool
forall a. Ord a => a -> a -> Bool
>= NominalDiffTime
remainingTime)