{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE Strict #-}
{-# LANGUAGE StrictData #-}

module Unison.Util.CyclicEq where

import Data.Sequence qualified as S
import Data.Vector (Vector)
import Data.Vector qualified as V
import Unison.Prelude
import Unison.Util.CycleTable qualified as CT

{-
 Typeclass used for comparing potentially cyclic types for equality.
 Cyclic types may refer to themselves indirectly, so something is needed to
 prevent an infinite loop in these cases. The basic idea: when a subexpression
 is first examined, its "id" (represented as some `Int`) may be added to the
 mutable hash table along with its position. The next time that same id is
 encountered, it will be compared based on this position.
 -}
class CyclicEq a where
  -- Map from `Ref` ID to position in the stream
  -- If a ref is encountered again, we use its mapped ID
  cyclicEq :: CT.CycleTable Int Int -> CT.CycleTable Int Int -> a -> a -> IO Bool

bothEq' ::
  (Eq a, CyclicEq b) =>
  CT.CycleTable Int Int ->
  CT.CycleTable Int Int ->
  a ->
  a ->
  b ->
  b ->
  IO Bool
bothEq' :: forall a b.
(Eq a, CyclicEq b) =>
CycleTable Int Int
-> CycleTable Int Int -> a -> a -> b -> b -> IO Bool
bothEq' CycleTable Int Int
h1 CycleTable Int Int
h2 a
a1 a
a2 b
b1 b
b2 =
  if a
a1 a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
a2
    then CycleTable Int Int -> CycleTable Int Int -> b -> b -> IO Bool
forall a.
CyclicEq a =>
CycleTable Int Int -> CycleTable Int Int -> a -> a -> IO Bool
cyclicEq CycleTable Int Int
h1 CycleTable Int Int
h2 b
b1 b
b2
    else Bool -> IO Bool
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False

bothEq ::
  (CyclicEq a, CyclicEq b) =>
  CT.CycleTable Int Int ->
  CT.CycleTable Int Int ->
  a ->
  a ->
  b ->
  b ->
  IO Bool
bothEq :: forall a b.
(CyclicEq a, CyclicEq b) =>
CycleTable Int Int
-> CycleTable Int Int -> a -> a -> b -> b -> IO Bool
bothEq CycleTable Int Int
h1 CycleTable Int Int
h2 a
a1 a
a2 b
b1 b
b2 =
  CycleTable Int Int -> CycleTable Int Int -> a -> a -> IO Bool
forall a.
CyclicEq a =>
CycleTable Int Int -> CycleTable Int Int -> a -> a -> IO Bool
cyclicEq CycleTable Int Int
h1 CycleTable Int Int
h2 a
a1 a
a2 IO Bool -> (Bool -> IO Bool) -> IO Bool
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Bool
b ->
    if Bool
b
      then CycleTable Int Int -> CycleTable Int Int -> b -> b -> IO Bool
forall a.
CyclicEq a =>
CycleTable Int Int -> CycleTable Int Int -> a -> a -> IO Bool
cyclicEq CycleTable Int Int
h1 CycleTable Int Int
h2 b
b1 b
b2
      else Bool -> IO Bool
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False

instance (CyclicEq a) => CyclicEq [a] where
  cyclicEq :: CycleTable Int Int -> CycleTable Int Int -> [a] -> [a] -> IO Bool
cyclicEq CycleTable Int Int
h1 CycleTable Int Int
h2 (a
x : [a]
xs) (a
y : [a]
ys) = CycleTable Int Int
-> CycleTable Int Int -> a -> a -> [a] -> [a] -> IO Bool
forall a b.
(CyclicEq a, CyclicEq b) =>
CycleTable Int Int
-> CycleTable Int Int -> a -> a -> b -> b -> IO Bool
bothEq CycleTable Int Int
h1 CycleTable Int Int
h2 a
x a
y [a]
xs [a]
ys
  cyclicEq CycleTable Int Int
_ CycleTable Int Int
_ [] [] = Bool -> IO Bool
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True
  cyclicEq CycleTable Int Int
_ CycleTable Int Int
_ [a]
_ [a]
_ = Bool -> IO Bool
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False

instance (CyclicEq a) => CyclicEq (S.Seq a) where
  cyclicEq :: CycleTable Int Int
-> CycleTable Int Int -> Seq a -> Seq a -> IO Bool
cyclicEq CycleTable Int Int
h1 CycleTable Int Int
h2 Seq a
xs Seq a
ys =
    if Seq a -> Int
forall a. Seq a -> Int
S.length Seq a
xs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Seq a -> Int
forall a. Seq a -> Int
S.length Seq a
ys
      then CycleTable Int Int -> CycleTable Int Int -> [a] -> [a] -> IO Bool
forall a.
CyclicEq a =>
CycleTable Int Int -> CycleTable Int Int -> a -> a -> IO Bool
cyclicEq CycleTable Int Int
h1 CycleTable Int Int
h2 (Seq a -> [a]
forall a. Seq a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList Seq a
xs) (Seq a -> [a]
forall a. Seq a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList Seq a
ys)
      else Bool -> IO Bool
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False

instance (CyclicEq a) => CyclicEq (Vector a) where
  cyclicEq :: CycleTable Int Int
-> CycleTable Int Int -> Vector a -> Vector a -> IO Bool
cyclicEq CycleTable Int Int
h1 CycleTable Int Int
h2 Vector a
xs Vector a
ys =
    if Vector a -> Int
forall a. Vector a -> Int
V.length Vector a
xs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Vector a -> Int
forall a. Vector a -> Int
V.length Vector a
ys
      then Bool -> IO Bool
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
      else Int
-> CycleTable Int Int
-> CycleTable Int Int
-> Vector a
-> Vector a
-> IO Bool
forall {a}.
CyclicEq a =>
Int
-> CycleTable Int Int
-> CycleTable Int Int
-> Vector a
-> Vector a
-> IO Bool
go Int
0 CycleTable Int Int
h1 CycleTable Int Int
h2 Vector a
xs Vector a
ys
    where
      go :: Int
-> CycleTable Int Int
-> CycleTable Int Int
-> Vector a
-> Vector a
-> IO Bool
go !Int
i !CycleTable Int Int
h1 !CycleTable Int Int
h2 !Vector a
xs !Vector a
ys =
        if Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Vector a -> Int
forall a. Vector a -> Int
V.length Vector a
xs
          then Bool -> IO Bool
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True
          else do
            Bool
b <- CycleTable Int Int -> CycleTable Int Int -> a -> a -> IO Bool
forall a.
CyclicEq a =>
CycleTable Int Int -> CycleTable Int Int -> a -> a -> IO Bool
cyclicEq CycleTable Int Int
h1 CycleTable Int Int
h2 (Vector a
xs Vector a -> Int -> a
forall a. Vector a -> Int -> a
V.! Int
i) (Vector a
ys Vector a -> Int -> a
forall a. Vector a -> Int -> a
V.! Int
i)
            if Bool
b
              then Int
-> CycleTable Int Int
-> CycleTable Int Int
-> Vector a
-> Vector a
-> IO Bool
go (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) CycleTable Int Int
h1 CycleTable Int Int
h2 Vector a
xs Vector a
ys
              else Bool -> IO Bool
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False