-- Based on: http://semantic-domain.blogspot.com/2015/03/abstract-binding-trees.html
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}

module Unison.Hashing.V2.ABT (Unison.ABT.Term, hash, hashComponents) where

import Data.List hiding (cycle, find)
import Data.List qualified as List (sort)
import Data.Map qualified as Map
import Data.Set qualified as Set
import Unison.ABT
import Unison.Hash (Hash)
import Unison.Hashing.V2.Tokenizable (Hashable1, hash1)
import Unison.Hashing.V2.Tokenizable qualified as Hashable
import Unison.Prelude
import Prelude hiding (abs, cycle)

-- Hash a strongly connected component and sort its definitions into a canonical order.
hashComponent ::
  forall a f v.
  (Functor f, Hashable1 f, Foldable f, Eq v, Show v, Ord v) =>
  Map.Map v (Term f v a) ->
  (Hash, [(v, Term f v a)])
hashComponent :: forall a (f :: * -> *) v.
(Functor f, Hashable1 f, Foldable f, Eq v, Show v, Ord v) =>
Map v (Term f v a) -> (Hash, [(v, Term f v a)])
hashComponent Map v (Term f v a)
byName =
  let ts :: [(v, Term f v a)]
ts = Map v (Term f v a) -> [(v, Term f v a)]
forall k a. Map k a -> [(k, a)]
Map.toList Map v (Term f v a)
byName
      -- First, compute a canonical hash ordering of the component, as well as an environment in which we can hash
      -- individual names.
      ([Hash]
hashes, [Either [v] v]
env) = [Either [v] v] -> [(v, Term f v a)] -> ([Hash], [Either [v] v])
forall a (f :: * -> *) v.
(Eq v, Functor f, Hashable1 f, Show v) =>
[Either [v] v] -> [(v, Term f v a)] -> ([Hash], [Either [v] v])
doHashCycle [] [(v, Term f v a)]
ts
      -- Construct a list of tokens that is shared by all members of the component. They are disambiguated only by their
      -- name that gets tumbled into the hash.
      commonTokens :: [Hashable.Token]
      commonTokens :: [Token]
commonTokens = Word8 -> Token
Hashable.Tag Word8
1 Token -> [Token] -> [Token]
forall a. a -> [a] -> [a]
: (Hash -> Token) -> [Hash] -> [Token]
forall a b. (a -> b) -> [a] -> [b]
map Hash -> Token
Hashable.Hashed [Hash]
hashes
      -- Use a helper function that hashes a single term given its name, now that we have an environment in which we can
      -- look the name up, as well as the common tokens.
      hashName :: v -> Hash
      hashName :: v -> Hash
hashName v
v = [Token] -> Hash
Hashable.accumulate ([Token]
commonTokens [Token] -> [Token] -> [Token]
forall a. [a] -> [a] -> [a]
++ [Hash -> Token
Hashable.Hashed ([Either [v] v] -> Term f v () -> Hash
forall (f :: * -> *) v a.
(Functor f, Hashable1 f, Eq v, Show v) =>
[Either [v] v] -> Term f v a -> Hash
hash' [Either [v] v]
env (v -> Term f v ()
forall v (f :: * -> *). v -> Term f v ()
var v
v :: Term f v ()))])
      ([Hash]
hashes', [(v, Term f v a)]
permutedTerms) =
        [(v, Term f v a)]
ts
          -- Pair each term with its hash
          [(v, Term f v a)]
-> ([(v, Term f v a)] -> [(Hash, (v, Term f v a))])
-> [(Hash, (v, Term f v a))]
forall a b. a -> (a -> b) -> b
& ((v, Term f v a) -> (Hash, (v, Term f v a)))
-> [(v, Term f v a)] -> [(Hash, (v, Term f v a))]
forall a b. (a -> b) -> [a] -> [b]
map (\(v, Term f v a)
t -> (v -> Hash
hashName ((v, Term f v a) -> v
forall a b. (a, b) -> a
fst (v, Term f v a)
t), (v, Term f v a)
t))
          -- Sort again to get the final canonical ordering
          [(Hash, (v, Term f v a))]
-> ([(Hash, (v, Term f v a))] -> [(Hash, (v, Term f v a))])
-> [(Hash, (v, Term f v a))]
forall a b. a -> (a -> b) -> b
& ((Hash, (v, Term f v a)) -> Hash)
-> [(Hash, (v, Term f v a))] -> [(Hash, (v, Term f v a))]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn (Hash, (v, Term f v a)) -> Hash
forall a b. (a, b) -> a
fst
          [(Hash, (v, Term f v a))]
-> ([(Hash, (v, Term f v a))] -> ([Hash], [(v, Term f v a)]))
-> ([Hash], [(v, Term f v a)])
forall a b. a -> (a -> b) -> b
& [(Hash, (v, Term f v a))] -> ([Hash], [(v, Term f v a)])
forall a b. [(a, b)] -> ([a], [b])
unzip
      overallHash :: Hash
overallHash = [Token] -> Hash
Hashable.accumulate ((Hash -> Token) -> [Hash] -> [Token]
forall a b. (a -> b) -> [a] -> [b]
map Hash -> Token
Hashable.Hashed [Hash]
hashes')
   in (Hash
overallHash, [(v, Term f v a)]
permutedTerms)

-- Group the definitions into strongly connected components and hash
-- each component. Substitute the hash of each component into subsequent
-- components (using the `termFromHash` function). Requires that the
-- overall component has no free variables.
hashComponents ::
  (Functor f, Hashable1 f, Foldable f, Eq v, Show v, Var v) =>
  (Hash -> Word64 -> Term f v ()) ->
  Map.Map v (Term f v a) ->
  [(Hash, [(v, Term f v a)])]
hashComponents :: forall (f :: * -> *) v a.
(Functor f, Hashable1 f, Foldable f, Eq v, Show v, Var v) =>
(Hash -> Word64 -> Term f v ())
-> Map v (Term f v a) -> [(Hash, [(v, Term f v a)])]
hashComponents Hash -> Word64 -> Term f v ()
termFromHash Map v (Term f v a)
termsByName =
  let bound :: Set v
bound = [v] -> Set v
forall a. Ord a => [a] -> Set a
Set.fromList (Map v (Term f v a) -> [v]
forall k a. Map k a -> [k]
Map.keys Map v (Term f v a)
termsByName)
      escapedVars :: Set v
escapedVars = [Set v] -> Set v
forall (f :: * -> *) a. (Foldable f, Ord a) => f (Set a) -> Set a
Set.unions (Term f v a -> Set v
forall (f :: * -> *) v a. Term f v a -> Set v
freeVars (Term f v a -> Set v) -> [Term f v a] -> [Set v]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Map v (Term f v a) -> [Term f v a]
forall k a. Map k a -> [a]
Map.elems Map v (Term f v a)
termsByName) Set v -> Set v -> Set v
forall a. Ord a => Set a -> Set a -> Set a
`Set.difference` Set v
bound
      sccs :: [[(v, Term f v a)]]
sccs = [(v, Term f v a)] -> [[(v, Term f v a)]]
forall v (f :: * -> *) a.
Var v =>
[(v, Term f v a)] -> [[(v, Term f v a)]]
components (Map v (Term f v a) -> [(v, Term f v a)]
forall k a. Map k a -> [(k, a)]
Map.toList Map v (Term f v a)
termsByName)
      go :: Map v (Term f v ())
-> [[(v, Term f v a)]] -> [(Hash, [(v, Term f v a)])]
go Map v (Term f v ())
_ [] = []
      go Map v (Term f v ())
prevHashes ([(v, Term f v a)]
component : [[(v, Term f v a)]]
rest) =
        let sub :: Term f v a -> Term f v a
sub = [(v, Term f v ())] -> Term f v a -> Term f v a
forall (f :: * -> *) v b a.
(Foldable f, Functor f, Var v) =>
[(v, Term f v b)] -> Term f v a -> Term f v a
substsInheritAnnotation (Map v (Term f v ()) -> [(v, Term f v ())]
forall k a. Map k a -> [(k, a)]
Map.toList Map v (Term f v ())
prevHashes)
            (Hash
h, [(v, Term f v a)]
sortedComponent) = Map v (Term f v a) -> (Hash, [(v, Term f v a)])
forall a (f :: * -> *) v.
(Functor f, Hashable1 f, Foldable f, Eq v, Show v, Ord v) =>
Map v (Term f v a) -> (Hash, [(v, Term f v a)])
hashComponent (Map v (Term f v a) -> (Hash, [(v, Term f v a)]))
-> Map v (Term f v a) -> (Hash, [(v, Term f v a)])
forall a b. (a -> b) -> a -> b
$ [(v, Term f v a)] -> Map v (Term f v a)
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList [(v
v, Term f v a -> Term f v a
forall {a}. Term f v a -> Term f v a
sub Term f v a
t) | (v
v, Term f v a
t) <- [(v, Term f v a)]
component]
            curHashes :: Map v (Term f v ())
curHashes = [(v, Term f v ())] -> Map v (Term f v ())
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList [(v
v, Hash -> Word64 -> Term f v ()
termFromHash Hash
h Word64
i) | ((v
v, Term f v a
_), Word64
i) <- [(v, Term f v a)]
sortedComponent [(v, Term f v a)] -> [Word64] -> [((v, Term f v a), Word64)]
forall a b. [a] -> [b] -> [(a, b)]
`zip` [Word64
0 ..]]
            newHashes :: Map v (Term f v ())
newHashes = Map v (Term f v ())
prevHashes Map v (Term f v ()) -> Map v (Term f v ()) -> Map v (Term f v ())
forall k a. Ord k => Map k a -> Map k a -> Map k a
`Map.union` Map v (Term f v ())
curHashes
            newHashesL :: [(v, Term f v ())]
newHashesL = Map v (Term f v ()) -> [(v, Term f v ())]
forall k a. Map k a -> [(k, a)]
Map.toList Map v (Term f v ())
newHashes
            sortedComponent' :: [(v, Term f v a)]
sortedComponent' = [(v
v, [(v, Term f v ())] -> Term f v a -> Term f v a
forall (f :: * -> *) v b a.
(Foldable f, Functor f, Var v) =>
[(v, Term f v b)] -> Term f v a -> Term f v a
substsInheritAnnotation [(v, Term f v ())]
newHashesL Term f v a
t) | (v
v, Term f v a
t) <- [(v, Term f v a)]
sortedComponent]
         in (Hash
h, [(v, Term f v a)]
sortedComponent') (Hash, [(v, Term f v a)])
-> [(Hash, [(v, Term f v a)])] -> [(Hash, [(v, Term f v a)])]
forall a. a -> [a] -> [a]
: Map v (Term f v ())
-> [[(v, Term f v a)]] -> [(Hash, [(v, Term f v a)])]
go Map v (Term f v ())
newHashes [[(v, Term f v a)]]
rest
   in if Set v -> Bool
forall a. Set a -> Bool
Set.null Set v
escapedVars
        then Map v (Term f v ())
-> [[(v, Term f v a)]] -> [(Hash, [(v, Term f v a)])]
forall {a}.
Map v (Term f v ())
-> [[(v, Term f v a)]] -> [(Hash, [(v, Term f v a)])]
go Map v (Term f v ())
forall k a. Map k a
Map.empty [[(v, Term f v a)]]
sccs
        else
          [Char] -> [(Hash, [(v, Term f v a)])]
forall a. HasCallStack => [Char] -> a
error ([Char] -> [(Hash, [(v, Term f v a)])])
-> [Char] -> [(Hash, [(v, Term f v a)])]
forall a b. (a -> b) -> a -> b
$
            [Char]
"can't hashComponents if bindings have free variables:\n  "
              [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [[Char]] -> [Char]
forall a. Show a => a -> [Char]
show ((v -> [Char]) -> [v] -> [[Char]]
forall a b. (a -> b) -> [a] -> [b]
map v -> [Char]
forall a. Show a => a -> [Char]
show (Set v -> [v]
forall a. Set a -> [a]
Set.toList Set v
escapedVars))
              [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"\n  "
              [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [[Char]] -> [Char]
forall a. Show a => a -> [Char]
show ((v -> [Char]) -> [v] -> [[Char]]
forall a b. (a -> b) -> [a] -> [b]
map v -> [Char]
forall a. Show a => a -> [Char]
show (Map v (Term f v a) -> [v]
forall k a. Map k a -> [k]
Map.keys Map v (Term f v a)
termsByName))

-- | We ignore annotations in the `Term`, as these should never affect the
-- meaning of the term.
hash ::
  forall f v a.
  (Functor f, Hashable1 f, Eq v, Show v) =>
  Term f v a ->
  Hash
hash :: forall (f :: * -> *) v a.
(Functor f, Hashable1 f, Eq v, Show v) =>
Term f v a -> Hash
hash = [Either [v] v] -> Term f v a -> Hash
forall (f :: * -> *) v a.
(Functor f, Hashable1 f, Eq v, Show v) =>
[Either [v] v] -> Term f v a -> Hash
hash' []

hash' ::
  forall f v a.
  (Functor f, Hashable1 f, Eq v, Show v) =>
  [Either [v] v] ->
  Term f v a ->
  Hash
hash' :: forall (f :: * -> *) v a.
(Functor f, Hashable1 f, Eq v, Show v) =>
[Either [v] v] -> Term f v a -> Hash
hash' [Either [v] v]
env = \case
  Var' v
v -> Hash -> (Int -> Hash) -> Maybe Int -> Hash
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Hash
forall {a}. a
die Int -> Hash
hashInt Maybe Int
ind
    where
      lookup :: Either (t v) v -> Bool
lookup (Left t v
cycle) = v
v v -> t v -> Bool
forall a. Eq a => a -> t a -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` t v
cycle
      lookup (Right v
v') = v
v v -> v -> Bool
forall a. Eq a => a -> a -> Bool
== v
v'
      ind :: Maybe Int
ind = (Either [v] v -> Bool) -> [Either [v] v] -> Maybe Int
forall a. (a -> Bool) -> [a] -> Maybe Int
findIndex Either [v] v -> Bool
forall {t :: * -> *}. Foldable t => Either (t v) v -> Bool
lookup [Either [v] v]
env
      hashInt :: Int -> Hash
      hashInt :: Int -> Hash
hashInt Int
i = [Token] -> Hash
Hashable.accumulate [Word64 -> Token
Hashable.Nat (Word64 -> Token) -> Word64 -> Token
forall a b. (a -> b) -> a -> b
$ Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
i]
      die :: a
die =
        [Char] -> a
forall a. HasCallStack => [Char] -> a
error ([Char] -> a) -> [Char] -> a
forall a b. (a -> b) -> a -> b
$
          [Char]
"unknown var in environment: "
            [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ v -> [Char]
forall a. Show a => a -> [Char]
show v
v
            [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
" environment = "
            [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Either [v] v] -> [Char]
forall a. Show a => a -> [Char]
show [Either [v] v]
env
  Cycle' [v]
vs f (Term f v a)
t -> ([Term f v a] -> ([Hash], Term f v a -> Hash))
-> (Term f v a -> Hash) -> f (Term f v a) -> Hash
forall a.
([a] -> ([Hash], a -> Hash)) -> (a -> Hash) -> f a -> Hash
forall (f :: * -> *) a.
Hashable1 f =>
([a] -> ([Hash], a -> Hash)) -> (a -> Hash) -> f a -> Hash
hash1 ([v]
-> [Either [v] v] -> [Term f v a] -> ([Hash], Term f v a -> Hash)
hashCycle [v]
vs [Either [v] v]
env) Term f v a -> Hash
forall a. HasCallStack => a
undefined f (Term f v a)
t
  Abs'' v
v Term f v a
t -> [Either [v] v] -> Term f v a -> Hash
forall (f :: * -> *) v a.
(Functor f, Hashable1 f, Eq v, Show v) =>
[Either [v] v] -> Term f v a -> Hash
hash' (v -> Either [v] v
forall a b. b -> Either a b
Right v
v Either [v] v -> [Either [v] v] -> [Either [v] v]
forall a. a -> [a] -> [a]
: [Either [v] v]
env) Term f v a
t
  Tm' f (Term f v a)
t -> ([Term f v a] -> ([Hash], Term f v a -> Hash))
-> (Term f v a -> Hash) -> f (Term f v a) -> Hash
forall a.
([a] -> ([Hash], a -> Hash)) -> (a -> Hash) -> f a -> Hash
forall (f :: * -> *) a.
Hashable1 f =>
([a] -> ([Hash], a -> Hash)) -> (a -> Hash) -> f a -> Hash
hash1 (\[Term f v a]
ts -> ([Hash] -> [Hash]
forall a. Ord a => [a] -> [a]
List.sort ((Term f v a -> Hash) -> [Term f v a] -> [Hash]
forall a b. (a -> b) -> [a] -> [b]
map ([Either [v] v] -> Term f v a -> Hash
forall (f :: * -> *) v a.
(Functor f, Hashable1 f, Eq v, Show v) =>
[Either [v] v] -> Term f v a -> Hash
hash' [Either [v] v]
env) [Term f v a]
ts), [Either [v] v] -> Term f v a -> Hash
forall (f :: * -> *) v a.
(Functor f, Hashable1 f, Eq v, Show v) =>
[Either [v] v] -> Term f v a -> Hash
hash' [Either [v] v]
env)) ([Either [v] v] -> Term f v a -> Hash
forall (f :: * -> *) v a.
(Functor f, Hashable1 f, Eq v, Show v) =>
[Either [v] v] -> Term f v a -> Hash
hash' [Either [v] v]
env) f (Term f v a)
t
  where
    hashCycle :: [v] -> [Either [v] v] -> [Term f v a] -> ([Hash], Term f v a -> Hash)
    hashCycle :: [v]
-> [Either [v] v] -> [Term f v a] -> ([Hash], Term f v a -> Hash)
hashCycle [v]
cycle [Either [v] v]
env [Term f v a]
ts =
      let ([Hash]
ts', [Either [v] v]
env') = [Either [v] v] -> [(v, Term f v a)] -> ([Hash], [Either [v] v])
forall a (f :: * -> *) v.
(Eq v, Functor f, Hashable1 f, Show v) =>
[Either [v] v] -> [(v, Term f v a)] -> ([Hash], [Either [v] v])
doHashCycle [Either [v] v]
env ([v] -> [Term f v a] -> [(v, Term f v a)]
forall a b. [a] -> [b] -> [(a, b)]
zip [v]
cycle [Term f v a]
ts)
       in ([Hash]
ts', [Either [v] v] -> Term f v a -> Hash
forall (f :: * -> *) v a.
(Functor f, Hashable1 f, Eq v, Show v) =>
[Either [v] v] -> Term f v a -> Hash
hash' [Either [v] v]
env')

-- | @doHashCycle env terms@ hashes cycle @terms@ in environment @env@, and returns the canonical ordering of the hashes
-- of those terms, as well as an updated environment with each of the terms' bindings in the canonical ordering.
doHashCycle ::
  forall a f v.
  (Eq v, Functor f, Hashable1 f, Show v) =>
  [Either [v] v] ->
  [(v, Term f v a)] ->
  ([Hash], [Either [v] v])
doHashCycle :: forall a (f :: * -> *) v.
(Eq v, Functor f, Hashable1 f, Show v) =>
[Either [v] v] -> [(v, Term f v a)] -> ([Hash], [Either [v] v])
doHashCycle [Either [v] v]
env [(v, Term f v a)]
namedTerms =
  ((Term f v a -> Hash) -> [Term f v a] -> [Hash]
forall a b. (a -> b) -> [a] -> [b]
map ([Either [v] v] -> Term f v a -> Hash
forall (f :: * -> *) v a.
(Functor f, Hashable1 f, Eq v, Show v) =>
[Either [v] v] -> Term f v a -> Hash
hash' [Either [v] v]
newEnv) [Term f v a]
permutedTerms, [Either [v] v]
newEnv)
  where
    names :: [v]
names = ((v, Term f v a) -> v) -> [(v, Term f v a)] -> [v]
forall a b. (a -> b) -> [a] -> [b]
map (v, Term f v a) -> v
forall a b. (a, b) -> a
fst [(v, Term f v a)]
namedTerms
    -- The environment in which we compute the canonical permutation of terms
    permutationEnv :: [Either [v] v]
permutationEnv = [v] -> Either [v] v
forall a b. a -> Either a b
Left [v]
names Either [v] v -> [Either [v] v] -> [Either [v] v]
forall a. a -> [a] -> [a]
: [Either [v] v]
env
    ([v]
permutedNames, [Term f v a]
permutedTerms) =
      [(v, Term f v a)]
namedTerms
        [(v, Term f v a)]
-> ([(v, Term f v a)] -> [(v, Term f v a)]) -> [(v, Term f v a)]
forall a b. a -> (a -> b) -> b
& ((v, Term f v a) -> Hash) -> [(v, Term f v a)] -> [(v, Term f v a)]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn ([Either [v] v] -> Term f v a -> Hash
forall (f :: * -> *) v a.
(Functor f, Hashable1 f, Eq v, Show v) =>
[Either [v] v] -> Term f v a -> Hash
hash' [Either [v] v]
permutationEnv (Term f v a -> Hash)
-> ((v, Term f v a) -> Term f v a) -> (v, Term f v a) -> Hash
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (v, Term f v a) -> Term f v a
forall a b. (a, b) -> b
snd)
        [(v, Term f v a)]
-> ([(v, Term f v a)] -> ([v], [Term f v a]))
-> ([v], [Term f v a])
forall a b. a -> (a -> b) -> b
& [(v, Term f v a)] -> ([v], [Term f v a])
forall a b. [(a, b)] -> ([a], [b])
unzip
    -- The new environment, which includes the names of all of the terms in the cycle, now that we have computed their
    -- canonical ordering
    newEnv :: [Either [v] v]
newEnv = (v -> Either [v] v) -> [v] -> [Either [v] v]
forall a b. (a -> b) -> [a] -> [b]
map v -> Either [v] v
forall a b. b -> Either a b
Right [v]
permutedNames [Either [v] v] -> [Either [v] v] -> [Either [v] v]
forall a. [a] -> [a] -> [a]
++ [Either [v] v]
env