module Unison.KindInference.Generate.Monad
  ( Gen (..),
    GenState (..),
    GeneratedConstraint,
    run,
    freshVar,
    pushType,
    popType,
    scopedType,
    lookupType,
  )
where

import Control.Monad.State.Strict
import Data.Functor.Compose
import Data.List.NonEmpty (NonEmpty ((:|)))
import Data.List.NonEmpty qualified as NonEmpty
import Data.Map.Strict qualified as Map
import Data.Set qualified as Set
import Unison.KindInference.Constraint.Provenance (Provenance)
import Unison.KindInference.Constraint.Unsolved (Constraint (..))
import Unison.KindInference.UVar (UVar (..))
import Unison.Prelude
import Unison.Symbol
import Unison.Type qualified as T
import Unison.Var

-- | A generated constraint
type GeneratedConstraint v loc = Constraint (UVar v loc) v loc Provenance

-- | The @Gen@ monad state
data GenState v loc = GenState
  { forall v loc. GenState v loc -> Set Symbol
unifVars :: !(Set Symbol),
    forall v loc.
GenState v loc -> Map (Type v loc) (NonEmpty (UVar v loc))
typeMap :: !(Map (T.Type v loc) (NonEmpty (UVar v loc))),
    forall v loc. GenState v loc -> [UVar v loc]
newVars :: [UVar v loc]
  }
  deriving stock ((forall x. GenState v loc -> Rep (GenState v loc) x)
-> (forall x. Rep (GenState v loc) x -> GenState v loc)
-> Generic (GenState v loc)
forall x. Rep (GenState v loc) x -> GenState v loc
forall x. GenState v loc -> Rep (GenState v loc) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall v loc x. Rep (GenState v loc) x -> GenState v loc
forall v loc x. GenState v loc -> Rep (GenState v loc) x
$cfrom :: forall v loc x. GenState v loc -> Rep (GenState v loc) x
from :: forall x. GenState v loc -> Rep (GenState v loc) x
$cto :: forall v loc x. Rep (GenState v loc) x -> GenState v loc
to :: forall x. Rep (GenState v loc) x -> GenState v loc
Generic)

newtype Gen v loc a = Gen
  { forall v loc a.
Gen v loc a -> GenState v loc -> (a, GenState v loc)
unGen :: GenState v loc -> (a, GenState v loc)
  }
  deriving
    ( (forall a b. (a -> b) -> Gen v loc a -> Gen v loc b)
-> (forall a b. a -> Gen v loc b -> Gen v loc a)
-> Functor (Gen v loc)
forall a b. a -> Gen v loc b -> Gen v loc a
forall a b. (a -> b) -> Gen v loc a -> Gen v loc b
forall v loc a b. a -> Gen v loc b -> Gen v loc a
forall v loc a b. (a -> b) -> Gen v loc a -> Gen v loc b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall v loc a b. (a -> b) -> Gen v loc a -> Gen v loc b
fmap :: forall a b. (a -> b) -> Gen v loc a -> Gen v loc b
$c<$ :: forall v loc a b. a -> Gen v loc b -> Gen v loc a
<$ :: forall a b. a -> Gen v loc b -> Gen v loc a
Functor,
      Functor (Gen v loc)
Functor (Gen v loc) =>
(forall a. a -> Gen v loc a)
-> (forall a b. Gen v loc (a -> b) -> Gen v loc a -> Gen v loc b)
-> (forall a b c.
    (a -> b -> c) -> Gen v loc a -> Gen v loc b -> Gen v loc c)
-> (forall a b. Gen v loc a -> Gen v loc b -> Gen v loc b)
-> (forall a b. Gen v loc a -> Gen v loc b -> Gen v loc a)
-> Applicative (Gen v loc)
forall a. a -> Gen v loc a
forall v loc. Functor (Gen v loc)
forall a b. Gen v loc a -> Gen v loc b -> Gen v loc a
forall a b. Gen v loc a -> Gen v loc b -> Gen v loc b
forall a b. Gen v loc (a -> b) -> Gen v loc a -> Gen v loc b
forall v loc a. a -> Gen v loc a
forall a b c.
(a -> b -> c) -> Gen v loc a -> Gen v loc b -> Gen v loc c
forall v loc a b. Gen v loc a -> Gen v loc b -> Gen v loc a
forall v loc a b. Gen v loc a -> Gen v loc b -> Gen v loc b
forall v loc a b. Gen v loc (a -> b) -> Gen v loc a -> Gen v loc b
forall v loc a b c.
(a -> b -> c) -> Gen v loc a -> Gen v loc b -> Gen v loc c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall v loc a. a -> Gen v loc a
pure :: forall a. a -> Gen v loc a
$c<*> :: forall v loc a b. Gen v loc (a -> b) -> Gen v loc a -> Gen v loc b
<*> :: forall a b. Gen v loc (a -> b) -> Gen v loc a -> Gen v loc b
$cliftA2 :: forall v loc a b c.
(a -> b -> c) -> Gen v loc a -> Gen v loc b -> Gen v loc c
liftA2 :: forall a b c.
(a -> b -> c) -> Gen v loc a -> Gen v loc b -> Gen v loc c
$c*> :: forall v loc a b. Gen v loc a -> Gen v loc b -> Gen v loc b
*> :: forall a b. Gen v loc a -> Gen v loc b -> Gen v loc b
$c<* :: forall v loc a b. Gen v loc a -> Gen v loc b -> Gen v loc a
<* :: forall a b. Gen v loc a -> Gen v loc b -> Gen v loc a
Applicative,
      Applicative (Gen v loc)
Applicative (Gen v loc) =>
(forall a b. Gen v loc a -> (a -> Gen v loc b) -> Gen v loc b)
-> (forall a b. Gen v loc a -> Gen v loc b -> Gen v loc b)
-> (forall a. a -> Gen v loc a)
-> Monad (Gen v loc)
forall a. a -> Gen v loc a
forall v loc. Applicative (Gen v loc)
forall a b. Gen v loc a -> Gen v loc b -> Gen v loc b
forall a b. Gen v loc a -> (a -> Gen v loc b) -> Gen v loc b
forall v loc a. a -> Gen v loc a
forall v loc a b. Gen v loc a -> Gen v loc b -> Gen v loc b
forall v loc a b. Gen v loc a -> (a -> Gen v loc b) -> Gen v loc b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall v loc a b. Gen v loc a -> (a -> Gen v loc b) -> Gen v loc b
>>= :: forall a b. Gen v loc a -> (a -> Gen v loc b) -> Gen v loc b
$c>> :: forall v loc a b. Gen v loc a -> Gen v loc b -> Gen v loc b
>> :: forall a b. Gen v loc a -> Gen v loc b -> Gen v loc b
$creturn :: forall v loc a. a -> Gen v loc a
return :: forall a. a -> Gen v loc a
Monad,
      MonadState (GenState v loc)
    )
    via State (GenState v loc)

-- | @Gen@ monad runner
run :: Gen v loc a -> GenState v loc -> (a, GenState v loc)
run :: forall v loc a.
Gen v loc a -> GenState v loc -> (a, GenState v loc)
run (Gen GenState v loc -> (a, GenState v loc)
ma) GenState v loc
st0 = GenState v loc -> (a, GenState v loc)
ma GenState v loc
st0

-- | Create a unique @UVar@ associated with @typ@
freshVar :: (Var v) => T.Type v loc -> Gen v loc (UVar v loc)
freshVar :: forall v loc. Var v => Type v loc -> Gen v loc (UVar v loc)
freshVar Type v loc
typ = do
  st :: GenState v loc
st@GenState {Set Symbol
$sel:unifVars:GenState :: forall v loc. GenState v loc -> Set Symbol
unifVars :: Set Symbol
unifVars, [UVar v loc]
$sel:newVars:GenState :: forall v loc. GenState v loc -> [UVar v loc]
newVars :: [UVar v loc]
newVars} <- Gen v loc (GenState v loc)
forall s (m :: * -> *). MonadState s m => m s
get
  let var :: Symbol
      var :: Symbol
var = Set Symbol -> Symbol -> Symbol
forall v. Var v => Set v -> v -> v
freshIn Set Symbol
unifVars (Type -> Symbol
forall v. Var v => Type -> v
typed (InferenceType -> Type
Inference InferenceType
Other))
      uvar :: UVar v loc
uvar = Symbol -> Type v loc -> UVar v loc
forall v loc. Symbol -> Type v loc -> UVar v loc
UVar Symbol
var Type v loc
typ
      unifVars' :: Set Symbol
unifVars' = Symbol -> Set Symbol -> Set Symbol
forall a. Ord a => a -> Set a -> Set a
Set.insert Symbol
var Set Symbol
unifVars
  GenState v loc -> Gen v loc ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put GenState v loc
st {unifVars = unifVars', newVars = uvar : newVars}
  pure UVar v loc
uvar

-- | Associate a fresh @UVar@ with @t@, push onto context
pushType :: (Var v) => T.Type v loc -> Gen v loc (UVar v loc)
pushType :: forall v loc. Var v => Type v loc -> Gen v loc (UVar v loc)
pushType Type v loc
t = do
  GenState {Map (Type v loc) (NonEmpty (UVar v loc))
$sel:typeMap:GenState :: forall v loc.
GenState v loc -> Map (Type v loc) (NonEmpty (UVar v loc))
typeMap :: Map (Type v loc) (NonEmpty (UVar v loc))
typeMap} <- Gen v loc (GenState v loc)
forall s (m :: * -> *). MonadState s m => m s
get
  (UVar v loc
var, Map (Type v loc) (NonEmpty (UVar v loc))
newTypeMap) <-
    let f :: Maybe (NonEmpty (UVar v loc))
-> Compose
     (Gen v loc) ((,) (UVar v loc)) (Maybe (NonEmpty (UVar v loc)))
f = \case
          Maybe (NonEmpty (UVar v loc))
Nothing -> Gen v loc (UVar v loc, Maybe (NonEmpty (UVar v loc)))
-> Compose
     (Gen v loc) ((,) (UVar v loc)) (Maybe (NonEmpty (UVar v loc)))
forall {k} {k1} (f :: k -> *) (g :: k1 -> k) (a :: k1).
f (g a) -> Compose f g a
Compose (Gen v loc (UVar v loc, Maybe (NonEmpty (UVar v loc)))
 -> Compose
      (Gen v loc) ((,) (UVar v loc)) (Maybe (NonEmpty (UVar v loc))))
-> Gen v loc (UVar v loc, Maybe (NonEmpty (UVar v loc)))
-> Compose
     (Gen v loc) ((,) (UVar v loc)) (Maybe (NonEmpty (UVar v loc)))
forall a b. (a -> b) -> a -> b
$ (\UVar v loc
v -> (UVar v loc
v, NonEmpty (UVar v loc) -> Maybe (NonEmpty (UVar v loc))
forall a. a -> Maybe a
Just (UVar v loc
v UVar v loc -> [UVar v loc] -> NonEmpty (UVar v loc)
forall a. a -> [a] -> NonEmpty a
:| []))) (UVar v loc -> (UVar v loc, Maybe (NonEmpty (UVar v loc))))
-> Gen v loc (UVar v loc)
-> Gen v loc (UVar v loc, Maybe (NonEmpty (UVar v loc)))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Type v loc -> Gen v loc (UVar v loc)
forall v loc. Var v => Type v loc -> Gen v loc (UVar v loc)
freshVar Type v loc
t
          Just NonEmpty (UVar v loc)
xs -> Gen v loc (UVar v loc, Maybe (NonEmpty (UVar v loc)))
-> Compose
     (Gen v loc) ((,) (UVar v loc)) (Maybe (NonEmpty (UVar v loc)))
forall {k} {k1} (f :: k -> *) (g :: k1 -> k) (a :: k1).
f (g a) -> Compose f g a
Compose (Gen v loc (UVar v loc, Maybe (NonEmpty (UVar v loc)))
 -> Compose
      (Gen v loc) ((,) (UVar v loc)) (Maybe (NonEmpty (UVar v loc))))
-> Gen v loc (UVar v loc, Maybe (NonEmpty (UVar v loc)))
-> Compose
     (Gen v loc) ((,) (UVar v loc)) (Maybe (NonEmpty (UVar v loc)))
forall a b. (a -> b) -> a -> b
$ (\UVar v loc
v -> (UVar v loc
v, NonEmpty (UVar v loc) -> Maybe (NonEmpty (UVar v loc))
forall a. a -> Maybe a
Just (UVar v loc -> NonEmpty (UVar v loc) -> NonEmpty (UVar v loc)
forall a. a -> NonEmpty a -> NonEmpty a
NonEmpty.cons UVar v loc
v NonEmpty (UVar v loc)
xs))) (UVar v loc -> (UVar v loc, Maybe (NonEmpty (UVar v loc))))
-> Gen v loc (UVar v loc)
-> Gen v loc (UVar v loc, Maybe (NonEmpty (UVar v loc)))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Type v loc -> Gen v loc (UVar v loc)
forall v loc. Var v => Type v loc -> Gen v loc (UVar v loc)
freshVar Type v loc
t
     in Compose
  (Gen v loc)
  ((,) (UVar v loc))
  (Map (Type v loc) (NonEmpty (UVar v loc)))
-> Gen v loc (UVar v loc, Map (Type v loc) (NonEmpty (UVar v loc)))
forall {k1} {k2} (f :: k1 -> *) (g :: k2 -> k1) (a :: k2).
Compose f g a -> f (g a)
getCompose (Compose
   (Gen v loc)
   ((,) (UVar v loc))
   (Map (Type v loc) (NonEmpty (UVar v loc)))
 -> Gen
      v loc (UVar v loc, Map (Type v loc) (NonEmpty (UVar v loc))))
-> Compose
     (Gen v loc)
     ((,) (UVar v loc))
     (Map (Type v loc) (NonEmpty (UVar v loc)))
-> Gen v loc (UVar v loc, Map (Type v loc) (NonEmpty (UVar v loc)))
forall a b. (a -> b) -> a -> b
$ (Maybe (NonEmpty (UVar v loc))
 -> Compose
      (Gen v loc) ((,) (UVar v loc)) (Maybe (NonEmpty (UVar v loc))))
-> Type v loc
-> Map (Type v loc) (NonEmpty (UVar v loc))
-> Compose
     (Gen v loc)
     ((,) (UVar v loc))
     (Map (Type v loc) (NonEmpty (UVar v loc)))
forall (f :: * -> *) k a.
(Functor f, Ord k) =>
(Maybe a -> f (Maybe a)) -> k -> Map k a -> f (Map k a)
Map.alterF Maybe (NonEmpty (UVar v loc))
-> Compose
     (Gen v loc) ((,) (UVar v loc)) (Maybe (NonEmpty (UVar v loc)))
f Type v loc
t Map (Type v loc) (NonEmpty (UVar v loc))
typeMap
  (GenState v loc -> GenState v loc) -> Gen v loc ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify \GenState v loc
st -> GenState v loc
st {typeMap = newTypeMap}
  pure UVar v loc
var

-- | Lookup the @UVar@ associated with a @Type@
lookupType :: (Var v) => T.Type v loc -> Gen v loc (Maybe (UVar v loc))
lookupType :: forall v loc. Var v => Type v loc -> Gen v loc (Maybe (UVar v loc))
lookupType Type v loc
t = do
  GenState {Map (Type v loc) (NonEmpty (UVar v loc))
$sel:typeMap:GenState :: forall v loc.
GenState v loc -> Map (Type v loc) (NonEmpty (UVar v loc))
typeMap :: Map (Type v loc) (NonEmpty (UVar v loc))
typeMap} <- Gen v loc (GenState v loc)
forall s (m :: * -> *). MonadState s m => m s
get
  Maybe (UVar v loc) -> Gen v loc (Maybe (UVar v loc))
forall a. a -> Gen v loc a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (NonEmpty (UVar v loc) -> UVar v loc
forall a. NonEmpty a -> a
NonEmpty.head (NonEmpty (UVar v loc) -> UVar v loc)
-> Maybe (NonEmpty (UVar v loc)) -> Maybe (UVar v loc)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Type v loc
-> Map (Type v loc) (NonEmpty (UVar v loc))
-> Maybe (NonEmpty (UVar v loc))
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Type v loc
t Map (Type v loc) (NonEmpty (UVar v loc))
typeMap)

-- | Remove a @Type@ from the context
popType :: (Var v) => T.Type v loc -> Gen v loc ()
popType :: forall v loc. Var v => Type v loc -> Gen v loc ()
popType Type v loc
t = do
  (GenState v loc -> GenState v loc) -> Gen v loc ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify \GenState v loc
st -> GenState v loc
st {typeMap = del (typeMap st)}
  where
    del :: Map (Type v loc) (NonEmpty (UVar v loc))
-> Map (Type v loc) (NonEmpty (UVar v loc))
del Map (Type v loc) (NonEmpty (UVar v loc))
m =
      let f :: Maybe (NonEmpty a) -> Maybe (NonEmpty a)
f = \case
            Maybe (NonEmpty a)
Nothing -> Maybe (NonEmpty a)
forall a. Maybe a
Nothing
            Just (a
_ :| [a]
ys) -> case [a]
ys of
              [] -> Maybe (NonEmpty a)
forall a. Maybe a
Nothing
              a
x : [a]
xs -> NonEmpty a -> Maybe (NonEmpty a)
forall a. a -> Maybe a
Just (a
x a -> [a] -> NonEmpty a
forall a. a -> [a] -> NonEmpty a
:| [a]
xs)
       in (Maybe (NonEmpty (UVar v loc)) -> Maybe (NonEmpty (UVar v loc)))
-> Type v loc
-> Map (Type v loc) (NonEmpty (UVar v loc))
-> Map (Type v loc) (NonEmpty (UVar v loc))
forall k a.
Ord k =>
(Maybe a -> Maybe a) -> k -> Map k a -> Map k a
Map.alter Maybe (NonEmpty (UVar v loc)) -> Maybe (NonEmpty (UVar v loc))
forall {a}. Maybe (NonEmpty a) -> Maybe (NonEmpty a)
f Type v loc
t Map (Type v loc) (NonEmpty (UVar v loc))
m

-- | Helper to run an action with the given @Type@ in the context
scopedType :: (Var v) => T.Type v loc -> (UVar v loc -> Gen v loc r) -> Gen v loc r
scopedType :: forall v loc r.
Var v =>
Type v loc -> (UVar v loc -> Gen v loc r) -> Gen v loc r
scopedType Type v loc
t UVar v loc -> Gen v loc r
m = do
  UVar v loc
s <- Type v loc -> Gen v loc (UVar v loc)
forall v loc. Var v => Type v loc -> Gen v loc (UVar v loc)
pushType Type v loc
t
  r
r <- UVar v loc -> Gen v loc r
m UVar v loc
s
  Type v loc -> Gen v loc ()
forall v loc. Var v => Type v loc -> Gen v loc ()
popType Type v loc
t
  pure r
r