module Unison.Typechecker.Components (minimize, minimize') where

import Control.Arrow ((&&&))
import Data.Function (on)
import Data.List (groupBy, sortBy)
import Data.List.NonEmpty (NonEmpty)
import Data.List.NonEmpty qualified as Nel
import Data.Map qualified as Map
import Data.Set qualified as Set
import Unison.ABT qualified as ABT
import Unison.Prelude
import Unison.Term (Term')
import Unison.Term qualified as Term
import Unison.Var (Var)
import Unison.Var qualified as Var

unordered :: (Var v) => [(v, Term' vt v a)] -> [[(v, Term' vt v a)]]
unordered :: forall v vt a.
Var v =>
[(v, Term' vt v a)] -> [[(v, Term' vt v a)]]
unordered = [(v, Term (F vt a a) v a)] -> [[(v, Term (F vt a a) v a)]]
forall v (f :: * -> *) a.
Var v =>
[(v, Term f v a)] -> [[(v, Term f v a)]]
ABT.components

ordered :: (Var v) => [(v, Term' vt v a)] -> [[(v, Term' vt v a)]]
ordered :: forall v vt a.
Var v =>
[(v, Term' vt v a)] -> [[(v, Term' vt v a)]]
ordered = [(v, Term (F vt a a) v a)] -> [[(v, Term (F vt a a) v a)]]
forall v (f :: * -> *) a.
Var v =>
[(v, Term f v a)] -> [[(v, Term f v a)]]
ABT.orderedComponents

-- | Algorithm for minimizing cycles of a `let rec`. This can
-- improve generalization during typechecking and may also be more
-- efficient for execution.
--
-- For instance:
--
-- minimize (let rec id x = x; g = id 42; y = id "hi" in g)
-- ==>
-- Just (let id x = x; g = id 42; y = id "hi" in g)
--
-- Gets rid of the let rec and replaces it with an ordinary `let`, such
-- that `id` is suitably generalized.
--
-- Fails on the left if there are duplicate definitions.
minimize ::
  (Var v) =>
  Term' vt v a ->
  Either (NonEmpty (v, [a])) (Maybe (Term' vt v a))
minimize :: forall v vt a.
Var v =>
Term' vt v a -> Either (NonEmpty (v, [a])) (Maybe (Term' vt v a))
minimize (Term.LetRecNamedAnnotatedTop' Bool
isTop a
blockAnn [((a, v), Term' vt v a)]
bs Term' vt v a
e) =
  let bindings :: [(v, Term' vt v a)]
bindings = ((a, v) -> v) -> ((a, v), Term' vt v a) -> (v, Term' vt v a)
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (a, v) -> v
forall a b. (a, b) -> b
snd (((a, v), Term' vt v a) -> (v, Term' vt v a))
-> [((a, v), Term' vt v a)] -> [(v, Term' vt v a)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [((a, v), Term' vt v a)]
bs
      group :: [(v, Term f v b)] -> [(v, [b])]
group =
        ([(v, Term f v b)] -> (v, [b]))
-> [[(v, Term f v b)]] -> [(v, [b])]
forall a b. (a -> b) -> [a] -> [b]
map ((v, Term f v b) -> v
forall a b. (a, b) -> a
fst ((v, Term f v b) -> v)
-> ([(v, Term f v b)] -> (v, Term f v b)) -> [(v, Term f v b)] -> v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(v, Term f v b)] -> (v, Term f v b)
forall a. HasCallStack => [a] -> a
head ([(v, Term f v b)] -> v)
-> ([(v, Term f v b)] -> [b]) -> [(v, Term f v b)] -> (v, [b])
forall b c c'. (b -> c) -> (b -> c') -> b -> (c, c')
forall (a :: * -> * -> *) b c c'.
Arrow a =>
a b c -> a b c' -> a b (c, c')
&&& ((v, Term f v b) -> b) -> [(v, Term f v b)] -> [b]
forall a b. (a -> b) -> [a] -> [b]
map (Term f v b -> b
forall (f :: * -> *) v a. Term f v a -> a
ABT.annotation (Term f v b -> b)
-> ((v, Term f v b) -> Term f v b) -> (v, Term f v b) -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (v, Term f v b) -> Term f v b
forall a b. (a, b) -> b
snd))
          ([[(v, Term f v b)]] -> [(v, [b])])
-> ([(v, Term f v b)] -> [[(v, Term f v b)]])
-> [(v, Term f v b)]
-> [(v, [b])]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((v, Term f v b) -> (v, Term f v b) -> Bool)
-> [(v, Term f v b)] -> [[(v, Term f v b)]]
forall a. (a -> a -> Bool) -> [a] -> [[a]]
groupBy (v -> v -> Bool
forall a. Eq a => a -> a -> Bool
(==) (v -> v -> Bool)
-> ((v, Term f v b) -> v)
-> (v, Term f v b)
-> (v, Term f v b)
-> Bool
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (v, Term f v b) -> v
forall a b. (a, b) -> a
fst)
          ([(v, Term f v b)] -> [[(v, Term f v b)]])
-> ([(v, Term f v b)] -> [(v, Term f v b)])
-> [(v, Term f v b)]
-> [[(v, Term f v b)]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((v, Term f v b) -> (v, Term f v b) -> Ordering)
-> [(v, Term f v b)] -> [(v, Term f v b)]
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy
            (v -> v -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (v -> v -> Ordering)
-> ((v, Term f v b) -> v)
-> (v, Term f v b)
-> (v, Term f v b)
-> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (v, Term f v b) -> v
forall a b. (a, b) -> a
fst)
      grouped :: [(v, [a])]
grouped = [(v, Term' vt v a)] -> [(v, [a])]
forall {f :: * -> *} {v} {b}. [(v, Term f v b)] -> [(v, [b])]
group [(v, Term' vt v a)]
bindings
      dupes :: [(v, [a])]
dupes = ((v, [a]) -> Bool) -> [(v, [a])] -> [(v, [a])]
forall a. (a -> Bool) -> [a] -> [a]
filter (v, [a]) -> Bool
forall {v} {t :: * -> *} {a}.
(Var v, Foldable t) =>
(v, t a) -> Bool
ok [(v, [a])]
grouped
        where
          ok :: (v, t a) -> Bool
ok (v
v, t a
as)
            | v -> Text
forall v. Var v => v -> Text
Var.name v
v Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== Text
"_" = Bool
False
            | Bool
otherwise = t a -> Int
forall a. t a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length t a
as Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1
   in if Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [(v, [a])] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(v, [a])]
dupes
        then NonEmpty (v, [a])
-> Either (NonEmpty (v, [a])) (Maybe (Term' vt v a))
forall a b. a -> Either a b
Left (NonEmpty (v, [a])
 -> Either (NonEmpty (v, [a])) (Maybe (Term' vt v a)))
-> NonEmpty (v, [a])
-> Either (NonEmpty (v, [a])) (Maybe (Term' vt v a))
forall a b. (a -> b) -> a -> b
$ [(v, [a])] -> NonEmpty (v, [a])
forall a. HasCallStack => [a] -> NonEmpty a
Nel.fromList [(v, [a])]
dupes
        else
          let cs0 :: [[(v, Term' vt v a)]]
cs0 = if Bool
isTop then [(v, Term' vt v a)] -> [[(v, Term' vt v a)]]
forall v vt a.
Var v =>
[(v, Term' vt v a)] -> [[(v, Term' vt v a)]]
unordered [(v, Term' vt v a)]
bindings else [(v, Term' vt v a)] -> [[(v, Term' vt v a)]]
forall v vt a.
Var v =>
[(v, Term' vt v a)] -> [[(v, Term' vt v a)]]
ordered [(v, Term' vt v a)]
bindings
              -- within a cycle, we put the lambdas first, so
              -- unguarded definitions can refer to these lambdas, example:
              --
              --   foo x = blah + 1 + x
              --   blah = foo 10
              --
              -- Here `foo` and `blah` are part of a cycle, but putting `foo`
              -- first at least lets the program run (though it has an infinite
              -- loop).
              cs :: [[(v, Term' vt v a)]]
cs = ((v, Term' vt v a) -> Bool)
-> [(v, Term' vt v a)] -> [(v, Term' vt v a)]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn (\(v
_, Term' vt v a
e) -> Term' vt v a -> Int
forall vt at ap v a. Term2 vt at ap v a -> Int
Term.arity Term' vt v a
e Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0) ([(v, Term' vt v a)] -> [(v, Term' vt v a)])
-> [[(v, Term' vt v a)]] -> [[(v, Term' vt v a)]]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [[(v, Term' vt v a)]]
cs0
              varAnnotations :: Map v a
varAnnotations = [(v, a)] -> Map v a
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList ((\((a
a, v
v), Term' vt v a
_) -> (v
v, a
a)) (((a, v), Term' vt v a) -> (v, a))
-> [((a, v), Term' vt v a)] -> [(v, a)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [((a, v), Term' vt v a)]
bs)
              msg :: v -> a
msg v
v = [Char] -> a
forall a. HasCallStack => [Char] -> a
error ([Char] -> a) -> [Char] -> a
forall a b. (a -> b) -> a -> b
$ [Char]
"Components.minimize " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> (v, [v]) -> [Char]
forall a. Show a => a -> [Char]
show (v
v, Map v a -> [v]
forall k a. Map k a -> [k]
Map.keys Map v a
varAnnotations)
              annotationFor :: v -> a
annotationFor v
v = a -> Maybe a -> a
forall a. a -> Maybe a -> a
fromMaybe (v -> a
msg v
v) (Maybe a -> a) -> Maybe a -> a
forall a b. (a -> b) -> a -> b
$ v -> Map v a -> Maybe a
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup v
v Map v a
varAnnotations
              annotatedVar :: v -> (a, v)
annotatedVar v
v = (v -> a
annotationFor v
v, v
v)
              -- When introducing a nested let/let rec, we use the annotation
              -- of the variable that starts off that let/let rec
              mklet :: [(v, Term' vt v a)] -> Term' vt v a -> Term' vt v a
mklet [(v
hdv, Term' vt v a
hdb)] Term' vt v a
e
                | v -> Set v -> Bool
forall a. Ord a => a -> Set a -> Bool
Set.member v
hdv (Term' vt v a -> Set v
forall (f :: * -> *) v a. Term f v a -> Set v
ABT.freeVars Term' vt v a
hdb) =
                    Bool
-> a -> [((a, v), Term' vt v a)] -> Term' vt v a -> Term' vt v a
forall v vt a.
Ord v =>
Bool
-> a -> [((a, v), Term' vt v a)] -> Term' vt v a -> Term' vt v a
Term.letRec
                      Bool
isTop
                      a
blockAnn
                      [(v -> (a, v)
annotatedVar v
hdv, Term' vt v a
hdb)]
                      Term' vt v a
e
                | Bool
otherwise = Bool -> a -> a -> (v, Term' vt v a) -> Term' vt v a -> Term' vt v a
forall v a vt at ap.
Ord v =>
Bool
-> a
-> a
-> (v, Term2 vt at ap v a)
-> Term2 vt at ap v a
-> Term2 vt at ap v a
Term.singleLet Bool
isTop a
blockAnn (v -> a
annotationFor v
hdv) (v
hdv, Term' vt v a
hdb) Term' vt v a
e
              mklet cycle :: [(v, Term' vt v a)]
cycle@((v
_, Term' vt v a
_) : [(v, Term' vt v a)]
_) Term' vt v a
e =
                Bool
-> a -> [((a, v), Term' vt v a)] -> Term' vt v a -> Term' vt v a
forall v vt a.
Ord v =>
Bool
-> a -> [((a, v), Term' vt v a)] -> Term' vt v a -> Term' vt v a
Term.letRec
                  Bool
isTop
                  a
blockAnn
                  ((v -> (a, v)) -> (v, Term' vt v a) -> ((a, v), Term' vt v a)
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first v -> (a, v)
annotatedVar ((v, Term' vt v a) -> ((a, v), Term' vt v a))
-> [(v, Term' vt v a)] -> [((a, v), Term' vt v a)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(v, Term' vt v a)]
cycle)
                  Term' vt v a
e
              mklet [] Term' vt v a
e = Term' vt v a
e
           in Maybe (Term' vt v a)
-> Either (NonEmpty (v, [a])) (Maybe (Term' vt v a))
forall a b. b -> Either a b
Right (Maybe (Term' vt v a)
 -> Either (NonEmpty (v, [a])) (Maybe (Term' vt v a)))
-> ([[(v, Term' vt v a)]] -> Maybe (Term' vt v a))
-> [[(v, Term' vt v a)]]
-> Either (NonEmpty (v, [a])) (Maybe (Term' vt v a))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Term' vt v a -> Maybe (Term' vt v a)
forall a. a -> Maybe a
Just (Term' vt v a -> Maybe (Term' vt v a))
-> ([[(v, Term' vt v a)]] -> Term' vt v a)
-> [[(v, Term' vt v a)]]
-> Maybe (Term' vt v a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([(v, Term' vt v a)] -> Term' vt v a -> Term' vt v a)
-> Term' vt v a -> [[(v, Term' vt v a)]] -> Term' vt v a
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr [(v, Term' vt v a)] -> Term' vt v a -> Term' vt v a
mklet Term' vt v a
e ([[(v, Term' vt v a)]]
 -> Either (NonEmpty (v, [a])) (Maybe (Term' vt v a)))
-> [[(v, Term' vt v a)]]
-> Either (NonEmpty (v, [a])) (Maybe (Term' vt v a))
forall a b. (a -> b) -> a -> b
$ [[(v, Term' vt v a)]]
cs
minimize Term' vt v a
_ = Maybe (Term' vt v a)
-> Either (NonEmpty (v, [a])) (Maybe (Term' vt v a))
forall a b. b -> Either a b
Right Maybe (Term' vt v a)
forall a. Maybe a
Nothing

minimize' ::
  (Var v) => Term' vt v a -> Either (NonEmpty (v, [a])) (Term' vt v a)
minimize' :: forall v vt a.
Var v =>
Term' vt v a -> Either (NonEmpty (v, [a])) (Term' vt v a)
minimize' Term' vt v a
term = Term' vt v a -> Maybe (Term' vt v a) -> Term' vt v a
forall a. a -> Maybe a -> a
fromMaybe Term' vt v a
term (Maybe (Term' vt v a) -> Term' vt v a)
-> Either (NonEmpty (v, [a])) (Maybe (Term' vt v a))
-> Either (NonEmpty (v, [a])) (Term' vt v a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Term' vt v a -> Either (NonEmpty (v, [a])) (Maybe (Term' vt v a))
forall v vt a.
Var v =>
Term' vt v a -> Either (NonEmpty (v, [a])) (Maybe (Term' vt v a))
minimize Term' vt v a
term