module Unison.Typechecker.Variance where

import Control.Monad.State.Strict
import Data.Foldable (foldl', traverse_)
import Data.Graph (flattenSCC, stronglyConnComp)
import Data.Map.Strict (Map)
import Data.Map.Strict qualified as Map
import Data.Set qualified as Set
import Unison.DataDeclaration
import Unison.Reference
import Unison.Type
import Unison.Typechecker.TypeLookup (TypeLookup (..))
import Unison.Var (Var, freshIn)

-- Polarity for variable occurrences during checking. This is used both
-- for tracking the ambient polarity as we walk down the type, and
-- recording information about the occurrences so that we can later solve
-- the overall variance of a parameter from its occurrences.
data Polarity v = Positive | Negative | Exact | As v | Op v
  deriving (Polarity v -> Polarity v -> Bool
(Polarity v -> Polarity v -> Bool)
-> (Polarity v -> Polarity v -> Bool) -> Eq (Polarity v)
forall v. Eq v => Polarity v -> Polarity v -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall v. Eq v => Polarity v -> Polarity v -> Bool
== :: Polarity v -> Polarity v -> Bool
$c/= :: forall v. Eq v => Polarity v -> Polarity v -> Bool
/= :: Polarity v -> Polarity v -> Bool
Eq, Eq (Polarity v)
Eq (Polarity v) =>
(Polarity v -> Polarity v -> Ordering)
-> (Polarity v -> Polarity v -> Bool)
-> (Polarity v -> Polarity v -> Bool)
-> (Polarity v -> Polarity v -> Bool)
-> (Polarity v -> Polarity v -> Bool)
-> (Polarity v -> Polarity v -> Polarity v)
-> (Polarity v -> Polarity v -> Polarity v)
-> Ord (Polarity v)
Polarity v -> Polarity v -> Bool
Polarity v -> Polarity v -> Ordering
Polarity v -> Polarity v -> Polarity v
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall v. Ord v => Eq (Polarity v)
forall v. Ord v => Polarity v -> Polarity v -> Bool
forall v. Ord v => Polarity v -> Polarity v -> Ordering
forall v. Ord v => Polarity v -> Polarity v -> Polarity v
$ccompare :: forall v. Ord v => Polarity v -> Polarity v -> Ordering
compare :: Polarity v -> Polarity v -> Ordering
$c< :: forall v. Ord v => Polarity v -> Polarity v -> Bool
< :: Polarity v -> Polarity v -> Bool
$c<= :: forall v. Ord v => Polarity v -> Polarity v -> Bool
<= :: Polarity v -> Polarity v -> Bool
$c> :: forall v. Ord v => Polarity v -> Polarity v -> Bool
> :: Polarity v -> Polarity v -> Bool
$c>= :: forall v. Ord v => Polarity v -> Polarity v -> Bool
>= :: Polarity v -> Polarity v -> Bool
$cmax :: forall v. Ord v => Polarity v -> Polarity v -> Polarity v
max :: Polarity v -> Polarity v -> Polarity v
$cmin :: forall v. Ord v => Polarity v -> Polarity v -> Polarity v
min :: Polarity v -> Polarity v -> Polarity v
Ord, Int -> Polarity v -> ShowS
[Polarity v] -> ShowS
Polarity v -> String
(Int -> Polarity v -> ShowS)
-> (Polarity v -> String)
-> ([Polarity v] -> ShowS)
-> Show (Polarity v)
forall v. Show v => Int -> Polarity v -> ShowS
forall v. Show v => [Polarity v] -> ShowS
forall v. Show v => Polarity v -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall v. Show v => Int -> Polarity v -> ShowS
showsPrec :: Int -> Polarity v -> ShowS
$cshow :: forall v. Show v => Polarity v -> String
show :: Polarity v -> String
$cshowList :: forall v. Show v => [Polarity v] -> ShowS
showList :: [Polarity v] -> ShowS
Show)

-- Reverse polarity, for moving into negative positions.
inv :: Polarity v -> Polarity v
inv :: forall v. Polarity v -> Polarity v
inv Polarity v
Positive = Polarity v
forall v. Polarity v
Negative
inv Polarity v
Negative = Polarity v
forall v. Polarity v
Positive
inv (As v
v) = v -> Polarity v
forall v. v -> Polarity v
Op v
v
inv (Op v
v) = v -> Polarity v
forall v. v -> Polarity v
As v
v
-- Reverse of invariant is invariant
inv Polarity v
Exact = Polarity v
forall v. Polarity v
Exact

act :: Polarity v -> Polarity v -> Polarity v
act :: forall v. Polarity v -> Polarity v -> Polarity v
act Polarity v
Positive Polarity v
p = Polarity v
p
act Polarity v
Negative Polarity v
p = Polarity v -> Polarity v
forall v. Polarity v -> Polarity v
inv Polarity v
p
act Polarity v
Exact Polarity v
_ = Polarity v
forall v. Polarity v
Exact
act Polarity v
_ Polarity v
_ = Polarity v
forall v. Polarity v
Exact -- TODO: revisit

-- Concrete variance information for a parameter.
data Variance = Any | Pos | Neg | Inv
  deriving (Variance -> Variance -> Bool
(Variance -> Variance -> Bool)
-> (Variance -> Variance -> Bool) -> Eq Variance
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Variance -> Variance -> Bool
== :: Variance -> Variance -> Bool
$c/= :: Variance -> Variance -> Bool
/= :: Variance -> Variance -> Bool
Eq, Eq Variance
Eq Variance =>
(Variance -> Variance -> Ordering)
-> (Variance -> Variance -> Bool)
-> (Variance -> Variance -> Bool)
-> (Variance -> Variance -> Bool)
-> (Variance -> Variance -> Bool)
-> (Variance -> Variance -> Variance)
-> (Variance -> Variance -> Variance)
-> Ord Variance
Variance -> Variance -> Bool
Variance -> Variance -> Ordering
Variance -> Variance -> Variance
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: Variance -> Variance -> Ordering
compare :: Variance -> Variance -> Ordering
$c< :: Variance -> Variance -> Bool
< :: Variance -> Variance -> Bool
$c<= :: Variance -> Variance -> Bool
<= :: Variance -> Variance -> Bool
$c> :: Variance -> Variance -> Bool
> :: Variance -> Variance -> Bool
$c>= :: Variance -> Variance -> Bool
>= :: Variance -> Variance -> Bool
$cmax :: Variance -> Variance -> Variance
max :: Variance -> Variance -> Variance
$cmin :: Variance -> Variance -> Variance
min :: Variance -> Variance -> Variance
Ord, Int -> Variance -> ShowS
[Variance] -> ShowS
Variance -> String
(Int -> Variance -> ShowS)
-> (Variance -> String) -> ([Variance] -> ShowS) -> Show Variance
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Variance -> ShowS
showsPrec :: Int -> Variance -> ShowS
$cshow :: Variance -> String
show :: Variance -> String
$cshowList :: [Variance] -> ShowS
showList :: [Variance] -> ShowS
Show)

defaultVariances :: Map Reference [Variance]
defaultVariances :: Map Reference [Variance]
defaultVariances =
  [(Reference, [Variance])] -> Map Reference [Variance]
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList
    [ (Reference
listRef, [Variance
Pos]),
      (Reference
iarrayRef, [Variance
Pos])
    ]

lookupVariance :: Map Reference [Variance] -> Type v a -> Maybe [Variance]
lookupVariance :: forall v a.
Map Reference [Variance] -> Type v a -> Maybe [Variance]
lookupVariance Map Reference [Variance]
vs (Ref' Reference
r) = Reference -> Map Reference [Variance] -> Maybe [Variance]
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Reference
r Map Reference [Variance]
vs
lookupVariance Map Reference [Variance]
_ Term F v a
_ = Maybe [Variance]
forall a. Maybe a
Nothing

combine :: (Ord v) => [Map v [Polarity v]] -> Map v [Polarity v]
combine :: forall v. Ord v => [Map v [Polarity v]] -> Map v [Polarity v]
combine [] = Map v [Polarity v]
forall k a. Map k a
Map.empty
combine (Map v [Polarity v]
m : [Map v [Polarity v]]
ms) = (Map v [Polarity v] -> Map v [Polarity v] -> Map v [Polarity v])
-> Map v [Polarity v] -> [Map v [Polarity v]] -> Map v [Polarity v]
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (([Polarity v] -> [Polarity v] -> [Polarity v])
-> Map v [Polarity v] -> Map v [Polarity v] -> Map v [Polarity v]
forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
Map.unionWith [Polarity v] -> [Polarity v] -> [Polarity v]
forall a. [a] -> [a] -> [a]
(++)) Map v [Polarity v]
m [Map v [Polarity v]]
ms

collectVariance ::
  (Var v) =>
  Map Reference [Variance] ->
  Map Reference [v] ->
  Type v a ->
  Map v [Polarity v]
collectVariance :: forall v a.
Var v =>
Map Reference [Variance]
-> Map Reference [v] -> Type v a -> Map v [Polarity v]
collectVariance Map Reference [Variance]
prev Map Reference [v]
group = Polarity v -> Term F v a -> Map v [Polarity v]
descend Polarity v
forall v. Polarity v
Positive
  where
    descend :: Polarity v -> Term F v a -> Map v [Polarity v]
descend Polarity v
pol = \case
      Arrow' Term F v a
i Term F v a
o ->
        ([Polarity v] -> [Polarity v] -> [Polarity v])
-> Map v [Polarity v] -> Map v [Polarity v] -> Map v [Polarity v]
forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
Map.unionWith [Polarity v] -> [Polarity v] -> [Polarity v]
forall a. [a] -> [a] -> [a]
(++) (Polarity v -> Term F v a -> Map v [Polarity v]
descend (Polarity v -> Polarity v
forall v. Polarity v -> Polarity v
inv Polarity v
pol) Term F v a
i) (Polarity v -> Term F v a -> Map v [Polarity v]
descend Polarity v
pol Term F v a
o)
      Effect1' Term F v a
e Term F v a
r ->
        ([Polarity v] -> [Polarity v] -> [Polarity v])
-> Map v [Polarity v] -> Map v [Polarity v] -> Map v [Polarity v]
forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
Map.unionWith [Polarity v] -> [Polarity v] -> [Polarity v]
forall a. [a] -> [a] -> [a]
(++) (Polarity v -> Term F v a -> Map v [Polarity v]
descend Polarity v
pol Term F v a
e) (Polarity v -> Term F v a -> Map v [Polarity v]
descend Polarity v
pol Term F v a
r)
      Apps' Term F v a
f [Term F v a]
xs
        | Ref' Reference
r <- Term F v a
f,
          Just [v]
bnd <- Reference -> Map Reference [v] -> Maybe [v]
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Reference
r Map Reference [v]
group ->
            [Map v [Polarity v]] -> Map v [Polarity v]
forall v. Ord v => [Map v [Polarity v]] -> Map v [Polarity v]
combine ([Map v [Polarity v]] -> Map v [Polarity v])
-> [Map v [Polarity v]] -> Map v [Polarity v]
forall a b. (a -> b) -> a -> b
$ (v -> Term F v a -> Map v [Polarity v])
-> [v] -> [Term F v a] -> [Map v [Polarity v]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (Polarity v -> Term F v a -> Map v [Polarity v]
descend (Polarity v -> Term F v a -> Map v [Polarity v])
-> (v -> Polarity v) -> v -> Term F v a -> Map v [Polarity v]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Polarity v -> Polarity v -> Polarity v
forall v. Polarity v -> Polarity v -> Polarity v
act Polarity v
pol (Polarity v -> Polarity v) -> (v -> Polarity v) -> v -> Polarity v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. v -> Polarity v
forall v. v -> Polarity v
As) [v]
bnd [Term F v a]
xs
        | Just [Variance]
vs <- Map Reference [Variance] -> Term F v a -> Maybe [Variance]
forall v a.
Map Reference [Variance] -> Type v a -> Maybe [Variance]
lookupVariance Map Reference [Variance]
prev Term F v a
f ->
            [Map v [Polarity v]] -> Map v [Polarity v]
forall v. Ord v => [Map v [Polarity v]] -> Map v [Polarity v]
combine ([Map v [Polarity v]] -> Map v [Polarity v])
-> [Map v [Polarity v]] -> Map v [Polarity v]
forall a b. (a -> b) -> a -> b
$ Polarity v -> Term F v a -> Map v [Polarity v]
descend Polarity v
pol Term F v a
f Map v [Polarity v] -> [Map v [Polarity v]] -> [Map v [Polarity v]]
forall a. a -> [a] -> [a]
: (Variance -> Term F v a -> Map v [Polarity v])
-> [Variance] -> [Term F v a] -> [Map v [Polarity v]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Variance -> Term F v a -> Map v [Polarity v]
h [Variance]
vs [Term F v a]
xs
        -- if it's not in the info we have, assume invariant
        | Bool
otherwise -> [Map v [Polarity v]] -> Map v [Polarity v]
forall v. Ord v => [Map v [Polarity v]] -> Map v [Polarity v]
combine ([Map v [Polarity v]] -> Map v [Polarity v])
-> [Map v [Polarity v]] -> Map v [Polarity v]
forall a b. (a -> b) -> a -> b
$ Polarity v -> Term F v a -> Map v [Polarity v]
descend Polarity v
pol Term F v a
f Map v [Polarity v] -> [Map v [Polarity v]] -> [Map v [Polarity v]]
forall a. a -> [a] -> [a]
: (Term F v a -> Map v [Polarity v])
-> [Term F v a] -> [Map v [Polarity v]]
forall a b. (a -> b) -> [a] -> [b]
map (Polarity v -> Term F v a -> Map v [Polarity v]
descend Polarity v
forall v. Polarity v
Exact) [Term F v a]
xs
        where
          -- for 'any variant' positions, we don't need to keep inferring
          h :: Variance -> Term F v a -> Map v [Polarity v]
h Variance
Any Term F v a
_ = Map v [Polarity v]
forall k a. Map k a
Map.empty
          h Variance
Neg Term F v a
t = Polarity v -> Term F v a -> Map v [Polarity v]
descend (Polarity v -> Polarity v
forall v. Polarity v -> Polarity v
inv Polarity v
pol) Term F v a
t
          h Variance
Pos Term F v a
t = Polarity v -> Term F v a -> Map v [Polarity v]
descend Polarity v
pol Term F v a
t
          h Variance
Inv Term F v a
t = Polarity v -> Term F v a -> Map v [Polarity v]
descend Polarity v
forall v. Polarity v
Exact Term F v a
t
      Ann' Term F v a
t Kind
_ -> Polarity v -> Term F v a -> Map v [Polarity v]
descend Polarity v
pol Term F v a
t
      Effects' [Term F v a]
ts -> [Map v [Polarity v]] -> Map v [Polarity v]
forall v. Ord v => [Map v [Polarity v]] -> Map v [Polarity v]
combine ([Map v [Polarity v]] -> Map v [Polarity v])
-> [Map v [Polarity v]] -> Map v [Polarity v]
forall a b. (a -> b) -> a -> b
$ (Term F v a -> Map v [Polarity v])
-> [Term F v a] -> [Map v [Polarity v]]
forall a b. (a -> b) -> [a] -> [b]
map (Polarity v -> Term F v a -> Map v [Polarity v]
descend Polarity v
pol) [Term F v a]
ts
      ForallsNamed' [v]
_ Term F v a
t -> Polarity v -> Term F v a -> Map v [Polarity v]
descend Polarity v
pol Term F v a
t
      IntroOuterNamed' v
_ Term F v a
t -> Polarity v -> Term F v a -> Map v [Polarity v]
descend Polarity v
pol Term F v a
t
      Var' v
v -> v -> [Polarity v] -> Map v [Polarity v]
forall k a. k -> a -> Map k a
Map.singleton v
v [Polarity v
pol]
      Term F v a
_ -> Map v [Polarity v]
forall k a. Map k a
Map.empty

collectDeclVariance ::
  (Var v, Show a) =>
  Map Reference [Variance] ->
  Map Reference [v] ->
  DataDeclaration v a ->
  Map v [Polarity v]
collectDeclVariance :: forall v a.
(Var v, Show a) =>
Map Reference [Variance]
-> Map Reference [v] -> DataDeclaration v a -> Map v [Polarity v]
collectDeclVariance Map Reference [Variance]
vars Map Reference [v]
group DataDeclaration v a
decl =
  [Map v [Polarity v]] -> Map v [Polarity v]
forall v. Ord v => [Map v [Polarity v]] -> Map v [Polarity v]
combine ([Map v [Polarity v]] -> Map v [Polarity v])
-> [Map v [Polarity v]] -> Map v [Polarity v]
forall a b. (a -> b) -> a -> b
$
    (Type v a -> Map v [Polarity v])
-> [Type v a] -> [Map v [Polarity v]]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Map Reference [Variance]
-> Map Reference [v] -> Type v a -> Map v [Polarity v]
forall v a.
Var v =>
Map Reference [Variance]
-> Map Reference [v] -> Type v a -> Map v [Polarity v]
collectVariance Map Reference [Variance]
vars Map Reference [v]
group)
      ([Type v a] -> [Map v [Polarity v]])
-> ((v, Type v a) -> [Type v a])
-> (v, Type v a)
-> [Map v [Polarity v]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (v, Type v a) -> [Type v a]
forall {a} {v} {a}. (a, Type v a) -> [Type v a]
split
      ((v, Type v a) -> [Map v [Polarity v]])
-> [(v, Type v a)] -> [Map v [Polarity v]]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< DataDeclaration v a -> [(v, Type v a)]
forall v a. DataDeclaration v a -> [(v, Type v a)]
constructors DataDeclaration v a
decl
  where
    split :: (a, Type v a) -> [Type v a]
split (a
_, ForallsNamedOpt' [v]
_vs (Arrows' [Type v a]
ts)) = [Type v a]
ts
    split (a
_, Type v a
t) = [Type v a
t]

-- Simplifies some polarities
simplify :: (Var v) => v -> [Polarity v] -> [Polarity v]
simplify :: forall v. Var v => v -> [Polarity v] -> [Polarity v]
simplify v
v = Set (Polarity v) -> [Polarity v]
reduce (Set (Polarity v) -> [Polarity v])
-> ([Polarity v] -> Set (Polarity v))
-> [Polarity v]
-> [Polarity v]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Polarity v -> Set (Polarity v) -> Set (Polarity v)
forall a. Ord a => a -> Set a -> Set a
Set.delete (v -> Polarity v
forall v. v -> Polarity v
As v
v) (Set (Polarity v) -> Set (Polarity v))
-> ([Polarity v] -> Set (Polarity v))
-> [Polarity v]
-> Set (Polarity v)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Polarity v] -> Set (Polarity v)
forall a. Ord a => [a] -> Set a
Set.fromList
  where
    reduce :: Set (Polarity v) -> [Polarity v]
reduce Set (Polarity v)
s
      -- invariant overrides everything
      | Polarity v
forall v. Polarity v
Exact Polarity v -> Set (Polarity v) -> Bool
forall a. Ord a => a -> Set a -> Bool
`Set.member` Set (Polarity v)
s = [Polarity v
forall v. Polarity v
Exact]
      -- both positive and negative is invariant
      | Polarity v
forall v. Polarity v
Positive Polarity v -> Set (Polarity v) -> Bool
forall a. Ord a => a -> Set a -> Bool
`Set.member` Set (Polarity v)
s,
        Polarity v
forall v. Polarity v
Negative Polarity v -> Set (Polarity v) -> Bool
forall a. Ord a => a -> Set a -> Bool
`Set.member` Set (Polarity v)
s =
          [Polarity v
forall v. Polarity v
Exact]
      -- a variable that must be its own opposite is invariant
      | v -> Polarity v
forall v. v -> Polarity v
Op v
v Polarity v -> Set (Polarity v) -> Bool
forall a. Ord a => a -> Set a -> Bool
`Set.member` Set (Polarity v)
s = [Polarity v
forall v. Polarity v
Exact]
      | Bool
otherwise = Set (Polarity v) -> [Polarity v]
forall a. Set a -> [a]
Set.toList Set (Polarity v)
s

chain :: (Var v) => Map v [Polarity v] -> [Polarity v] -> [Polarity v]
chain :: forall v.
Var v =>
Map v [Polarity v] -> [Polarity v] -> [Polarity v]
chain Map v [Polarity v]
m = (Polarity v -> [Polarity v]) -> [Polarity v] -> [Polarity v]
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Polarity v -> [Polarity v]
f
  where
    -- If an `As` or `Op` is not in the map, we will never be able to
    -- find it. All variables should have been initialized to at least
    -- `x -> As x` by the result types of constructors, so if a
    -- variable isn't in the map, assume the worst and use invariant.
    f :: Polarity v -> [Polarity v]
f (As v
v) = [Polarity v] -> v -> Map v [Polarity v] -> [Polarity v]
forall k a. Ord k => a -> k -> Map k a -> a
Map.findWithDefault [Polarity v
forall v. Polarity v
Exact] v
v Map v [Polarity v]
m
    f (Op v
v) = Polarity v -> Polarity v
forall v. Polarity v -> Polarity v
inv (Polarity v -> Polarity v) -> [Polarity v] -> [Polarity v]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Polarity v] -> v -> Map v [Polarity v] -> [Polarity v]
forall k a. Ord k => a -> k -> Map k a -> a
Map.findWithDefault [Polarity v
forall v. Polarity v
Exact] v
v Map v [Polarity v]
m
    f Polarity v
p = [Polarity v
p]

checkFinished :: Map v [Polarity v] -> Maybe (Map v Variance)
checkFinished :: forall v. Map v [Polarity v] -> Maybe (Map v Variance)
checkFinished = ([Polarity v] -> Maybe Variance)
-> Map v [Polarity v] -> Maybe (Map v Variance)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Map v a -> f (Map v b)
traverse [Polarity v] -> Maybe Variance
forall {v}. [Polarity v] -> Maybe Variance
f
  where
    f :: [Polarity v] -> Maybe Variance
f [] = Variance -> Maybe Variance
forall a. a -> Maybe a
Just Variance
Any
    f [Polarity v
Exact] = Variance -> Maybe Variance
forall a. a -> Maybe a
Just Variance
Inv
    f [Polarity v
Positive] = Variance -> Maybe Variance
forall a. a -> Maybe a
Just Variance
Pos
    f [Polarity v
Negative] = Variance -> Maybe Variance
forall a. a -> Maybe a
Just Variance
Neg
    f [Polarity v]
_ = Maybe Variance
forall a. Maybe a
Nothing

solve :: (Var v) => Map v [Polarity v] -> Map v Variance
solve :: forall v. Var v => Map v [Polarity v] -> Map v Variance
solve Map v [Polarity v]
map0
  | Just Map v Variance
m <- Map v [Polarity v] -> Maybe (Map v Variance)
forall v. Map v [Polarity v] -> Maybe (Map v Variance)
checkFinished Map v [Polarity v]
map0 = Map v Variance
m
  | Bool
otherwise = Map v [Polarity v] -> Map v Variance
forall v. Var v => Map v [Polarity v] -> Map v Variance
solve (Map v [Polarity v] -> Map v Variance)
-> (Map v [Polarity v] -> Map v [Polarity v])
-> Map v [Polarity v]
-> Map v Variance
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (v -> [Polarity v] -> [Polarity v])
-> Map v [Polarity v] -> Map v [Polarity v]
forall k a b. (k -> a -> b) -> Map k a -> Map k b
Map.mapWithKey v -> [Polarity v] -> [Polarity v]
forall v. Var v => v -> [Polarity v] -> [Polarity v]
simplify (Map v [Polarity v] -> Map v Variance)
-> Map v [Polarity v] -> Map v Variance
forall a b. (a -> b) -> a -> b
$ Map v [Polarity v] -> [Polarity v] -> [Polarity v]
forall v.
Var v =>
Map v [Polarity v] -> [Polarity v] -> [Polarity v]
chain Map v [Polarity v]
map0 ([Polarity v] -> [Polarity v])
-> Map v [Polarity v] -> Map v [Polarity v]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Map v [Polarity v]
map0

inferDeclGroupVariance ::
  (Var v, Show a) =>
  Map Reference [Variance] ->
  Map Reference (DataDeclaration v a) ->
  Map Reference [Variance]
inferDeclGroupVariance :: forall v a.
(Var v, Show a) =>
Map Reference [Variance]
-> Map Reference (DataDeclaration v a) -> Map Reference [Variance]
inferDeclGroupVariance Map Reference [Variance]
vars (Map Reference (DataDeclaration v a)
-> Map Reference ([v], DataDeclaration v a)
forall v a.
Var v =>
Map Reference (DataDeclaration v a)
-> Map Reference ([v], DataDeclaration v a)
freshenGroup -> Map Reference ([v], DataDeclaration v a)
group) =
  Map v Variance -> Map Reference [Variance]
resolveGroup (Map v Variance -> Map Reference [Variance])
-> (Map v [Polarity v] -> Map v Variance)
-> Map v [Polarity v]
-> Map Reference [Variance]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map v [Polarity v] -> Map v Variance
forall v. Var v => Map v [Polarity v] -> Map v Variance
solve (Map v [Polarity v] -> Map Reference [Variance])
-> Map v [Polarity v] -> Map Reference [Variance]
forall a b. (a -> b) -> a -> b
$
    (([v], DataDeclaration v a) -> Map v [Polarity v])
-> Map Reference ([v], DataDeclaration v a) -> Map v [Polarity v]
forall m a. Monoid m => (a -> m) -> Map Reference a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Map Reference [Variance]
-> Map Reference [v] -> DataDeclaration v a -> Map v [Polarity v]
forall v a.
(Var v, Show a) =>
Map Reference [Variance]
-> Map Reference [v] -> DataDeclaration v a -> Map v [Polarity v]
collectDeclVariance Map Reference [Variance]
vars Map Reference [v]
groupVars (DataDeclaration v a -> Map v [Polarity v])
-> (([v], DataDeclaration v a) -> DataDeclaration v a)
-> ([v], DataDeclaration v a)
-> Map v [Polarity v]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([v], DataDeclaration v a) -> DataDeclaration v a
forall a b. (a, b) -> b
snd) Map Reference ([v], DataDeclaration v a)
group
  where
    groupVars :: Map Reference [v]
groupVars = ([v], DataDeclaration v a) -> [v]
forall a b. (a, b) -> a
fst (([v], DataDeclaration v a) -> [v])
-> Map Reference ([v], DataDeclaration v a) -> Map Reference [v]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Map Reference ([v], DataDeclaration v a)
group
    resolveGroup :: Map v Variance -> Map Reference [Variance]
resolveGroup Map v Variance
m = (([v], DataDeclaration v a) -> Maybe [Variance])
-> Map Reference ([v], DataDeclaration v a)
-> Map Reference [Variance]
forall a b k. (a -> Maybe b) -> Map k a -> Map k b
Map.mapMaybe (Map v Variance -> [v] -> Maybe [Variance]
forall {t :: * -> *} {k} {b}.
(Traversable t, Ord k) =>
Map k b -> t k -> Maybe (t b)
resolve Map v Variance
m ([v] -> Maybe [Variance])
-> (([v], DataDeclaration v a) -> [v])
-> ([v], DataDeclaration v a)
-> Maybe [Variance]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([v], DataDeclaration v a) -> [v]
forall a b. (a, b) -> a
fst) Map Reference ([v], DataDeclaration v a)
group
    resolve :: Map k b -> t k -> Maybe (t b)
resolve Map k b
m t k
bound = (k -> Maybe b) -> t k -> Maybe (t b)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> t a -> f (t b)
traverse (\k
v -> k -> Map k b -> Maybe b
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup k
v Map k b
m) t k
bound

freshenGroup ::
  (Var v) =>
  Map Reference (DataDeclaration v a) ->
  Map Reference ([v], DataDeclaration v a)
freshenGroup :: forall v a.
Var v =>
Map Reference (DataDeclaration v a)
-> Map Reference ([v], DataDeclaration v a)
freshenGroup Map Reference (DataDeclaration v a)
group = State (Set v) (Map Reference ([v], DataDeclaration v a))
-> Set v -> Map Reference ([v], DataDeclaration v a)
forall s a. State s a -> s -> a
evalState ((DataDeclaration v a
 -> StateT (Set v) Identity ([v], DataDeclaration v a))
-> Map Reference (DataDeclaration v a)
-> State (Set v) (Map Reference ([v], DataDeclaration v a))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Map Reference a -> f (Map Reference b)
traverse DataDeclaration v a
-> StateT (Set v) Identity ([v], DataDeclaration v a)
forall v a.
Var v =>
DataDeclaration v a -> State (Set v) ([v], DataDeclaration v a)
freshDecl Map Reference (DataDeclaration v a)
group) Set v
forall a. Set a
Set.empty

freshDecl ::
  (Var v) =>
  DataDeclaration v a ->
  State (Set.Set v) ([v], DataDeclaration v a)
freshDecl :: forall v a.
Var v =>
DataDeclaration v a -> State (Set v) ([v], DataDeclaration v a)
freshDecl DataDeclaration v a
dd = do
  [v]
vs <- (v -> StateT (Set v) Identity v)
-> [v] -> StateT (Set v) Identity [v]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse v -> StateT (Set v) Identity v
forall {a} {m :: * -> *}. (MonadState (Set a) m, Var a) => a -> m a
fv (DataDeclaration v a -> [v]
forall v a. DataDeclaration v a -> [v]
bound DataDeclaration v a
dd)
  let frvs :: Map v v
frvs = [(v, v)] -> Map v v
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList ([(v, v)] -> Map v v) -> [(v, v)] -> Map v v
forall a b. (a -> b) -> a -> b
$ [v] -> [v] -> [(v, v)]
forall a b. [a] -> [b] -> [(a, b)]
zip (DataDeclaration v a -> [v]
forall v a. DataDeclaration v a -> [v]
bound DataDeclaration v a
dd) [v]
vs
      f :: v -> v
f v
v = v -> v -> Map v v -> v
forall k a. Ord k => a -> k -> Map k a -> a
Map.findWithDefault v
v v
v Map v v
frvs
  ([v], DataDeclaration v a)
-> State (Set v) ([v], DataDeclaration v a)
forall a. a -> StateT (Set v) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([v]
vs, (v -> v) -> DataDeclaration v a -> DataDeclaration v a
forall v' v a.
Ord v' =>
(v -> v') -> DataDeclaration v a -> DataDeclaration v' a
vmap' v -> v
f DataDeclaration v a
dd)
  where
    fv :: a -> m a
fv a
u = (Set a -> (a, Set a)) -> m a
forall a. (Set a -> (a, Set a)) -> m a
forall s (m :: * -> *) a. MonadState s m => (s -> (a, s)) -> m a
state \Set a
avoid ->
      let v :: a
v = Set a -> a -> a
forall v. Var v => Set v -> v -> v
freshIn Set a
avoid a
u
       in (a
v, a -> Set a -> Set a
forall a. Ord a => a -> Set a -> Set a
Set.insert a
v Set a
avoid)

inferDeclVariances ::
  (Var v, Show a) =>
  Map Reference [Variance] ->
  Map Reference (DataDeclaration v a) ->
  Map Reference [Variance]
inferDeclVariances :: forall v a.
(Var v, Show a) =>
Map Reference [Variance]
-> Map Reference (DataDeclaration v a) -> Map Reference [Variance]
inferDeclVariances Map Reference [Variance]
boot (Map Reference (DataDeclaration v a)
-> [(Reference, DataDeclaration v a)]
forall k a. Map k a -> [(k, a)]
Map.toList -> [(Reference, DataDeclaration v a)]
rdds) =
  State (Map Reference [Variance]) ()
-> Map Reference [Variance] -> Map Reference [Variance]
forall s a. State s a -> s -> s
execState ((SCC (Reference, DataDeclaration v a)
 -> State (Map Reference [Variance]) ())
-> [SCC (Reference, DataDeclaration v a)]
-> State (Map Reference [Variance]) ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ SCC (Reference, DataDeclaration v a)
-> State (Map Reference [Variance]) ()
forall {m :: * -> *} {v} {a}.
(MonadState (Map Reference [Variance]) m, Var v, Show a) =>
SCC (Reference, DataDeclaration v a) -> m ()
inf [SCC (Reference, DataDeclaration v a)]
sccs) Map Reference [Variance]
boot
  where
    inf :: SCC (Reference, DataDeclaration v a) -> m ()
inf ([(Reference, DataDeclaration v a)]
-> Map Reference (DataDeclaration v a)
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList ([(Reference, DataDeclaration v a)]
 -> Map Reference (DataDeclaration v a))
-> (SCC (Reference, DataDeclaration v a)
    -> [(Reference, DataDeclaration v a)])
-> SCC (Reference, DataDeclaration v a)
-> Map Reference (DataDeclaration v a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SCC (Reference, DataDeclaration v a)
-> [(Reference, DataDeclaration v a)]
forall vertex. SCC vertex -> [vertex]
flattenSCC -> Map Reference (DataDeclaration v a)
ddm) = do
      Map Reference [Variance]
vs <- m (Map Reference [Variance])
forall s (m :: * -> *). MonadState s m => m s
get
      Map Reference [Variance] -> m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (Map Reference [Variance] -> m ())
-> (Map Reference [Variance] -> Map Reference [Variance])
-> Map Reference [Variance]
-> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map Reference [Variance]
-> Map Reference [Variance] -> Map Reference [Variance]
forall k a. Ord k => Map k a -> Map k a -> Map k a
Map.union Map Reference [Variance]
vs (Map Reference [Variance] -> m ())
-> Map Reference [Variance] -> m ()
forall a b. (a -> b) -> a -> b
$ Map Reference [Variance]
-> Map Reference (DataDeclaration v a) -> Map Reference [Variance]
forall v a.
(Var v, Show a) =>
Map Reference [Variance]
-> Map Reference (DataDeclaration v a) -> Map Reference [Variance]
inferDeclGroupVariance Map Reference [Variance]
vs Map Reference (DataDeclaration v a)
ddm

    trc :: (b, DataDeclaration v a)
-> ((b, DataDeclaration v a), b, [Reference])
trc p :: (b, DataDeclaration v a)
p@(b
r, DataDeclaration v a
dd) = ((b, DataDeclaration v a)
p, b
r, Set Reference -> [Reference]
forall a. Set a -> [a]
Set.toList (Set Reference -> [Reference]) -> Set Reference -> [Reference]
forall a b. (a -> b) -> a -> b
$ DataDeclaration v a -> Set Reference
forall v a. Ord v => DataDeclaration v a -> Set Reference
typeDependencies DataDeclaration v a
dd)
    sccs :: [SCC (Reference, DataDeclaration v a)]
sccs = [((Reference, DataDeclaration v a), Reference, [Reference])]
-> [SCC (Reference, DataDeclaration v a)]
forall key node. Ord key => [(node, key, [key])] -> [SCC node]
stronglyConnComp ([((Reference, DataDeclaration v a), Reference, [Reference])]
 -> [SCC (Reference, DataDeclaration v a)])
-> [((Reference, DataDeclaration v a), Reference, [Reference])]
-> [SCC (Reference, DataDeclaration v a)]
forall a b. (a -> b) -> a -> b
$ ((Reference, DataDeclaration v a)
 -> ((Reference, DataDeclaration v a), Reference, [Reference]))
-> [(Reference, DataDeclaration v a)]
-> [((Reference, DataDeclaration v a), Reference, [Reference])]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Reference, DataDeclaration v a)
-> ((Reference, DataDeclaration v a), Reference, [Reference])
forall {v} {b} {a}.
Ord v =>
(b, DataDeclaration v a)
-> ((b, DataDeclaration v a), b, [Reference])
trc [(Reference, DataDeclaration v a)]
rdds

fromTypeLookup ::
  (Var v, Show a) => TypeLookup v a -> Map Reference [Variance]
fromTypeLookup :: forall v a.
(Var v, Show a) =>
TypeLookup v a -> Map Reference [Variance]
fromTypeLookup = Map Reference [Variance]
-> Map Reference (DataDeclaration v a) -> Map Reference [Variance]
forall v a.
(Var v, Show a) =>
Map Reference [Variance]
-> Map Reference (DataDeclaration v a) -> Map Reference [Variance]
inferDeclVariances Map Reference [Variance]
defaultVariances (Map Reference (DataDeclaration v a) -> Map Reference [Variance])
-> (TypeLookup v a -> Map Reference (DataDeclaration v a))
-> TypeLookup v a
-> Map Reference [Variance]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypeLookup v a -> Map Reference (DataDeclaration v a)
forall v a. TypeLookup v a -> Map Reference (DataDeclaration v a)
dataDecls