module Unison.Util.Cache
  ( Cache,
    cache,
    nullCache,
    semispaceCache,
    lookup,
    insert,
    apply,
    applyDefined,
  )
where

import Control.Monad (when)
import Control.Monad.IO.Class (liftIO)
import Data.Foldable (for_)
import Data.Functor (($>))
import Data.Map qualified as Map
import UnliftIO (MonadIO, atomically, modifyTVar', newTVarIO, readTVar, readTVarIO, writeTVar)
import Prelude hiding (lookup)

data Cache k v = Cache
  { forall k v. Cache k v -> k -> IO (Maybe v)
lookup_ :: k -> IO (Maybe v),
    forall k v. Cache k v -> k -> v -> IO ()
insert_ :: k -> v -> IO ()
  }

lookup :: (MonadIO m) => Cache k v -> k -> m (Maybe v)
lookup :: forall (m :: * -> *) k v.
MonadIO m =>
Cache k v -> k -> m (Maybe v)
lookup Cache k v
c k
k = IO (Maybe v) -> m (Maybe v)
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (Cache k v -> k -> IO (Maybe v)
forall k v. Cache k v -> k -> IO (Maybe v)
lookup_ Cache k v
c k
k)

insert :: (MonadIO m) => Cache k v -> k -> v -> m ()
insert :: forall (m :: * -> *) k v. MonadIO m => Cache k v -> k -> v -> m ()
insert Cache k v
c k
k v
v = IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (Cache k v -> k -> v -> IO ()
forall k v. Cache k v -> k -> v -> IO ()
insert_ Cache k v
c k
k v
v)

-- Create a cache of unbounded size.
cache :: (MonadIO m, Ord k) => m (Cache k v)
cache :: forall (m :: * -> *) k v. (MonadIO m, Ord k) => m (Cache k v)
cache = do
  TVar (Map k v)
t <- Map k v -> m (TVar (Map k v))
forall (m :: * -> *) a. MonadIO m => a -> m (TVar a)
newTVarIO Map k v
forall k a. Map k a
Map.empty
  let lookup :: k -> IO (Maybe v)
lookup k
k = k -> Map k v -> Maybe v
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup k
k (Map k v -> Maybe v) -> IO (Map k v) -> IO (Maybe v)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TVar (Map k v) -> IO (Map k v)
forall (m :: * -> *) a. MonadIO m => TVar a -> m a
readTVarIO TVar (Map k v)
t
      insert :: k -> v -> IO ()
insert k
k v
v = do
        Map k v
m <- TVar (Map k v) -> IO (Map k v)
forall (m :: * -> *) a. MonadIO m => TVar a -> m a
readTVarIO TVar (Map k v)
t
        case k -> Map k v -> Maybe v
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup k
k Map k v
m of
          Maybe v
Nothing -> STM () -> IO ()
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TVar (Map k v) -> (Map k v -> Map k v) -> STM ()
forall a. TVar a -> (a -> a) -> STM ()
modifyTVar' TVar (Map k v)
t (k -> v -> Map k v -> Map k v
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert k
k v
v)
          Maybe v
_ -> () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

  Cache k v -> m (Cache k v)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Cache k v -> m (Cache k v)) -> Cache k v -> m (Cache k v)
forall a b. (a -> b) -> a -> b
$ (k -> IO (Maybe v)) -> (k -> v -> IO ()) -> Cache k v
forall k v. (k -> IO (Maybe v)) -> (k -> v -> IO ()) -> Cache k v
Cache k -> IO (Maybe v)
lookup k -> v -> IO ()
insert

nullCache :: Cache k v
nullCache :: forall k v. Cache k v
nullCache = (k -> IO (Maybe v)) -> (k -> v -> IO ()) -> Cache k v
forall k v. (k -> IO (Maybe v)) -> (k -> v -> IO ()) -> Cache k v
Cache (IO (Maybe v) -> k -> IO (Maybe v)
forall a b. a -> b -> a
const (Maybe v -> IO (Maybe v)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe v
forall a. Maybe a
Nothing)) (\k
_ v
_ -> () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ())

-- Create a cache of bounded size. Once the cache
-- reaches a size of `maxSize`, older unused entries
-- are evicted from the cache. Unlike LRU caching,
-- where cache hits require updating LRU info,
-- cache hits here are read-only and contention free.
semispaceCache :: (MonadIO m, Ord k) => Word -> m (Cache k v)
semispaceCache :: forall (m :: * -> *) k v.
(MonadIO m, Ord k) =>
Word -> m (Cache k v)
semispaceCache Word
0 = Cache k v -> m (Cache k v)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Cache k v
forall k v. Cache k v
nullCache
semispaceCache Word
maxSize = do
  -- Analogous to semispace GC, keep 2 maps: gen0 and gen1
  -- `insert k v` is done in gen0
  --   if full, gen1 = gen0; gen0 = Map.empty
  -- `lookup k` is done in gen0; then gen1
  --   if found in gen0, return immediately
  --   if found in gen1, `insert k v`, then return
  -- Thus, older keys not recently looked up are forgotten
  TVar (Map k v)
gen0 <- Map k v -> m (TVar (Map k v))
forall (m :: * -> *) a. MonadIO m => a -> m (TVar a)
newTVarIO Map k v
forall k a. Map k a
Map.empty
  TVar (Map k v)
gen1 <- Map k v -> m (TVar (Map k v))
forall (m :: * -> *) a. MonadIO m => a -> m (TVar a)
newTVarIO Map k v
forall k a. Map k a
Map.empty
  let lookup :: k -> IO (Maybe v)
lookup k
k =
        TVar (Map k v) -> IO (Map k v)
forall (m :: * -> *) a. MonadIO m => TVar a -> m a
readTVarIO TVar (Map k v)
gen0 IO (Map k v) -> (Map k v -> IO (Maybe v)) -> IO (Maybe v)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Map k v
m0 ->
          case k -> Map k v -> Maybe v
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup k
k Map k v
m0 of
            Maybe v
Nothing ->
              TVar (Map k v) -> IO (Map k v)
forall (m :: * -> *) a. MonadIO m => TVar a -> m a
readTVarIO TVar (Map k v)
gen1 IO (Map k v) -> (Map k v -> IO (Maybe v)) -> IO (Maybe v)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Map k v
m1 ->
                case k -> Map k v -> Maybe v
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup k
k Map k v
m1 of
                  Maybe v
Nothing -> Maybe v -> IO (Maybe v)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe v
forall a. Maybe a
Nothing
                  Just v
v -> k -> v -> IO ()
insert k
k v
v IO () -> Maybe v -> IO (Maybe v)
forall (f :: * -> *) a b. Functor f => f a -> b -> f b
$> v -> Maybe v
forall a. a -> Maybe a
Just v
v
            Maybe v
just -> Maybe v -> IO (Maybe v)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe v
just
      insert :: k -> v -> IO ()
insert k
k v
v = STM () -> IO ()
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
        TVar (Map k v) -> (Map k v -> Map k v) -> STM ()
forall a. TVar a -> (a -> a) -> STM ()
modifyTVar' TVar (Map k v)
gen0 (k -> v -> Map k v -> Map k v
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert k
k v
v)
        Map k v
m0 <- TVar (Map k v) -> STM (Map k v)
forall a. TVar a -> STM a
readTVar TVar (Map k v)
gen0
        Bool -> STM () -> STM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int -> Word
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Map k v -> Int
forall k a. Map k a -> Int
Map.size Map k v
m0) Word -> Word -> Bool
forall a. Ord a => a -> a -> Bool
>= Word
maxSize) (STM () -> STM ()) -> STM () -> STM ()
forall a b. (a -> b) -> a -> b
$ do
          TVar (Map k v) -> Map k v -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar (Map k v)
gen1 Map k v
m0
          TVar (Map k v) -> Map k v -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar (Map k v)
gen0 Map k v
forall k a. Map k a
Map.empty
  Cache k v -> m (Cache k v)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Cache k v -> m (Cache k v)) -> Cache k v -> m (Cache k v)
forall a b. (a -> b) -> a -> b
$ (k -> IO (Maybe v)) -> (k -> v -> IO ()) -> Cache k v
forall k v. (k -> IO (Maybe v)) -> (k -> v -> IO ()) -> Cache k v
Cache k -> IO (Maybe v)
lookup k -> v -> IO ()
insert

-- Cached function application: if a key `k` is not in the cache,
-- calls `f` and inserts `f k` results in the cache.
apply :: (MonadIO m) => Cache k v -> (k -> m v) -> k -> m v
apply :: forall (m :: * -> *) k v.
MonadIO m =>
Cache k v -> (k -> m v) -> k -> m v
apply Cache k v
c k -> m v
f k
k =
  Cache k v -> k -> m (Maybe v)
forall (m :: * -> *) k v.
MonadIO m =>
Cache k v -> k -> m (Maybe v)
lookup Cache k v
c k
k m (Maybe v) -> (Maybe v -> m v) -> m v
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Just v
v -> v -> m v
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure v
v
    Maybe v
Nothing -> do
      v
v <- k -> m v
f k
k
      Cache k v -> k -> v -> m ()
forall (m :: * -> *) k v. MonadIO m => Cache k v -> k -> v -> m ()
insert Cache k v
c k
k v
v
      pure v
v

-- Cached function application which only caches values for
-- which `f k` is non-empty. For instance, if `g` is `Maybe`,
-- and `f x` returns `Nothing`, this won't be cached.
--
-- Useful when we think that missing results for `f` may be
-- later filled in so we don't want to cache missing results.
applyDefined ::
  (MonadIO m, Applicative g, Traversable g) =>
  Cache k v ->
  (k -> m (g v)) ->
  k ->
  m (g v)
applyDefined :: forall (m :: * -> *) (g :: * -> *) k v.
(MonadIO m, Applicative g, Traversable g) =>
Cache k v -> (k -> m (g v)) -> k -> m (g v)
applyDefined Cache k v
c k -> m (g v)
f k
k =
  Cache k v -> k -> m (Maybe v)
forall (m :: * -> *) k v.
MonadIO m =>
Cache k v -> k -> m (Maybe v)
lookup Cache k v
c k
k m (Maybe v) -> (Maybe v -> m (g v)) -> m (g v)
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Just v
v -> g v -> m (g v)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (v -> g v
forall a. a -> g a
forall (f :: * -> *) a. Applicative f => a -> f a
pure v
v)
    Maybe v
Nothing -> do
      g v
v <- k -> m (g v)
f k
k
      -- only populate the cache if f returns a non-empty result
      g v -> (v -> m ()) -> m ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ g v
v ((v -> m ()) -> m ()) -> (v -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \v
v -> Cache k v -> k -> v -> m ()
forall (m :: * -> *) k v. MonadIO m => Cache k v -> k -> v -> m ()
insert Cache k v
c k
k v
v
      pure g v
v