module Unison.Util.CycleTable where

import Data.HashTable.IO (BasicHashTable)
import Data.HashTable.IO qualified as HT
import Data.Hashable (Hashable)
import Data.Mutable qualified as M

-- A hash table along with a unique number which gets incremented on
-- each insert. This is used as an implementation detail by `CyclicEq`,
-- `CyclicOrd`, etc to be able to compare, hash, or serialize cyclic structures.

data CycleTable k v = CycleTable
  { forall k v. CycleTable k v -> BasicHashTable k v
table :: BasicHashTable k v,
    forall k v. CycleTable k v -> IOPRef Int
sizeRef :: M.IOPRef Int
  }

new :: Int -> IO (CycleTable k v)
new :: forall k v. Int -> IO (CycleTable k v)
new Int
size = do
  HashTable RealWorld k v
t <- Int -> IO (IOHashTable HashTable k v)
forall (h :: * -> * -> * -> *) k v.
HashTable h =>
Int -> IO (IOHashTable h k v)
HT.newSized Int
size
  PRef RealWorld Int
r <- RefElement (PRef RealWorld Int) -> IO (PRef RealWorld Int)
forall c (m :: * -> *).
(MutableRef c, PrimMonad m, PrimState m ~ MCState c) =>
RefElement c -> m c
forall (m :: * -> *).
(PrimMonad m, PrimState m ~ MCState (PRef RealWorld Int)) =>
RefElement (PRef RealWorld Int) -> m (PRef RealWorld Int)
M.newRef RefElement (PRef RealWorld Int)
0
  pure (IOHashTable HashTable k v -> IOPRef Int -> CycleTable k v
forall k v. BasicHashTable k v -> IOPRef Int -> CycleTable k v
CycleTable HashTable RealWorld k v
IOHashTable HashTable k v
t PRef RealWorld Int
IOPRef Int
r)

lookup :: (Hashable k, Eq k) => k -> CycleTable k v -> IO (Maybe v)
lookup :: forall k v.
(Hashable k, Eq k) =>
k -> CycleTable k v -> IO (Maybe v)
lookup k
k CycleTable k v
t = IOHashTable HashTable k v -> k -> IO (Maybe v)
forall (h :: * -> * -> * -> *) k v.
(HashTable h, Eq k, Hashable k) =>
IOHashTable h k v -> k -> IO (Maybe v)
HT.lookup (CycleTable k v -> IOHashTable HashTable k v
forall k v. CycleTable k v -> BasicHashTable k v
table CycleTable k v
t) k
k

insert :: (Hashable k, Eq k) => k -> v -> CycleTable k v -> IO ()
insert :: forall k v. (Hashable k, Eq k) => k -> v -> CycleTable k v -> IO ()
insert k
k v
v CycleTable k v
t = do
  IOHashTable HashTable k v -> k -> v -> IO ()
forall (h :: * -> * -> * -> *) k v.
(HashTable h, Eq k, Hashable k) =>
IOHashTable h k v -> k -> v -> IO ()
HT.insert (CycleTable k v -> IOHashTable HashTable k v
forall k v. CycleTable k v -> BasicHashTable k v
table CycleTable k v
t) k
k v
v
  PRef RealWorld Int
-> (RefElement (PRef RealWorld Int)
    -> RefElement (PRef RealWorld Int))
-> IO ()
forall c (m :: * -> *).
(MutableRef c, PrimMonad m, PrimState m ~ MCState c) =>
c -> (RefElement c -> RefElement c) -> m ()
forall (m :: * -> *).
(PrimMonad m, PrimState m ~ MCState (PRef RealWorld Int)) =>
PRef RealWorld Int
-> (RefElement (PRef RealWorld Int)
    -> RefElement (PRef RealWorld Int))
-> m ()
M.modifyRef (CycleTable k v -> IOPRef Int
forall k v. CycleTable k v -> IOPRef Int
sizeRef CycleTable k v
t) (Int
RefElement (PRef RealWorld Int)
1 RefElement (PRef RealWorld Int)
-> RefElement (PRef RealWorld Int)
-> RefElement (PRef RealWorld Int)
forall a. Num a => a -> a -> a
+)

size :: CycleTable k v -> IO Int
size :: forall k v. CycleTable k v -> IO Int
size CycleTable k v
h = PRef RealWorld Int -> IO (RefElement (PRef RealWorld Int))
forall c (m :: * -> *).
(MutableRef c, PrimMonad m, PrimState m ~ MCState c) =>
c -> m (RefElement c)
forall (m :: * -> *).
(PrimMonad m, PrimState m ~ MCState (PRef RealWorld Int)) =>
PRef RealWorld Int -> m (RefElement (PRef RealWorld Int))
M.readRef (CycleTable k v -> IOPRef Int
forall k v. CycleTable k v -> IOPRef Int
sizeRef CycleTable k v
h)

insertEnd :: (Hashable k, Eq k) => k -> CycleTable k Int -> IO ()
insertEnd :: forall k. (Hashable k, Eq k) => k -> CycleTable k Int -> IO ()
insertEnd k
k CycleTable k Int
t = do
  Int
n <- CycleTable k Int -> IO Int
forall k v. CycleTable k v -> IO Int
size CycleTable k Int
t
  k -> Int -> CycleTable k Int -> IO ()
forall k v. (Hashable k, Eq k) => k -> v -> CycleTable k v -> IO ()
insert k
k Int
n CycleTable k Int
t