module Unison.KindInference.Solve.Monad
  ( Solve (..),
    Env (..),
    SolveState (..),
    Descriptor (..),
    ConstraintMap,
    run,
    emptyState,
    find,
    genStateL,
    runGen,
    addUnconstrainedVar,
  )
where

import Control.Lens (Lens', (%%~))
import Control.Monad.Fix (MonadFix (..))
import Control.Monad.Reader qualified as M
import Control.Monad.State.Strict qualified as M
import Data.Functor.Identity
import Data.List.NonEmpty (NonEmpty)
import Data.Map.Strict qualified as M
import Data.Set qualified as Set
import Unison.KindInference.Constraint.Solved (Constraint (..))
import Unison.KindInference.Generate.Monad (Gen (..))
import Unison.KindInference.Generate.Monad qualified as Gen
import Unison.KindInference.UVar (UVar (..))
import Unison.PatternMatchCoverage.UFMap qualified as U
import Unison.Prelude
import Unison.PrettyPrintEnv (PrettyPrintEnv)
import Unison.Symbol
import Unison.Type qualified as T
import Unison.Var

data Env = Env {Env -> PrettyPrintEnv
prettyPrintEnv :: PrettyPrintEnv}

type ConstraintMap v loc = U.UFMap (UVar v loc) (Descriptor v loc)

-- | The @SolveState@ holds all kind constraints gathered for each
-- type. For example, after processing data and effect decls the
-- @typeMap@ will hold entries for every decl, and looking up the
-- corresponding @UVar@ in @constraints@ will return its kind.
--
-- The other fields, @unifVars@ and @newUnifVars@, are relevant when
-- interleaving constraint generation with solving. Constraint
-- generation needs to create fresh unification variables, so it needs
-- the set of bound unification variables from
-- @unifVars@. @newUnifVars@ holds the uvars that are candidates for
-- kind defaulting (see
-- 'Unison.KindInference.Solve.defaultUnconstrainedVars').
data SolveState v loc = SolveState
  { forall v loc. SolveState v loc -> Set Symbol
unifVars :: !(Set Symbol),
    forall v loc. SolveState v loc -> [UVar v loc]
newUnifVars :: [UVar v loc],
    forall v loc.
SolveState v loc -> UFMap (UVar v loc) (Descriptor v loc)
constraints :: !(U.UFMap (UVar v loc) (Descriptor v loc)),
    forall v loc.
SolveState v loc -> Map (Type v loc) (NonEmpty (UVar v loc))
typeMap :: !(Map (T.Type v loc) (NonEmpty (UVar v loc)))
  }

-- | Constraints associated with a unification variable
data Descriptor v loc = Descriptor
  { forall v loc.
Descriptor v loc -> Maybe (Constraint (UVar v loc) v loc)
descriptorConstraint :: Maybe (Constraint (UVar v loc) v loc)
  }

newtype Solve v loc a = Solve {forall v loc a.
Solve v loc a -> Env -> SolveState v loc -> (a, SolveState v loc)
unSolve :: Env -> SolveState v loc -> (a, SolveState v loc)}
  deriving
    ( (forall a b. (a -> b) -> Solve v loc a -> Solve v loc b)
-> (forall a b. a -> Solve v loc b -> Solve v loc a)
-> Functor (Solve v loc)
forall a b. a -> Solve v loc b -> Solve v loc a
forall a b. (a -> b) -> Solve v loc a -> Solve v loc b
forall v loc a b. a -> Solve v loc b -> Solve v loc a
forall v loc a b. (a -> b) -> Solve v loc a -> Solve v loc b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall v loc a b. (a -> b) -> Solve v loc a -> Solve v loc b
fmap :: forall a b. (a -> b) -> Solve v loc a -> Solve v loc b
$c<$ :: forall v loc a b. a -> Solve v loc b -> Solve v loc a
<$ :: forall a b. a -> Solve v loc b -> Solve v loc a
Functor,
      Functor (Solve v loc)
Functor (Solve v loc) =>
(forall a. a -> Solve v loc a)
-> (forall a b.
    Solve v loc (a -> b) -> Solve v loc a -> Solve v loc b)
-> (forall a b c.
    (a -> b -> c) -> Solve v loc a -> Solve v loc b -> Solve v loc c)
-> (forall a b. Solve v loc a -> Solve v loc b -> Solve v loc b)
-> (forall a b. Solve v loc a -> Solve v loc b -> Solve v loc a)
-> Applicative (Solve v loc)
forall a. a -> Solve v loc a
forall v loc. Functor (Solve v loc)
forall a b. Solve v loc a -> Solve v loc b -> Solve v loc a
forall a b. Solve v loc a -> Solve v loc b -> Solve v loc b
forall a b. Solve v loc (a -> b) -> Solve v loc a -> Solve v loc b
forall v loc a. a -> Solve v loc a
forall a b c.
(a -> b -> c) -> Solve v loc a -> Solve v loc b -> Solve v loc c
forall v loc a b. Solve v loc a -> Solve v loc b -> Solve v loc a
forall v loc a b. Solve v loc a -> Solve v loc b -> Solve v loc b
forall v loc a b.
Solve v loc (a -> b) -> Solve v loc a -> Solve v loc b
forall v loc a b c.
(a -> b -> c) -> Solve v loc a -> Solve v loc b -> Solve v loc c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall v loc a. a -> Solve v loc a
pure :: forall a. a -> Solve v loc a
$c<*> :: forall v loc a b.
Solve v loc (a -> b) -> Solve v loc a -> Solve v loc b
<*> :: forall a b. Solve v loc (a -> b) -> Solve v loc a -> Solve v loc b
$cliftA2 :: forall v loc a b c.
(a -> b -> c) -> Solve v loc a -> Solve v loc b -> Solve v loc c
liftA2 :: forall a b c.
(a -> b -> c) -> Solve v loc a -> Solve v loc b -> Solve v loc c
$c*> :: forall v loc a b. Solve v loc a -> Solve v loc b -> Solve v loc b
*> :: forall a b. Solve v loc a -> Solve v loc b -> Solve v loc b
$c<* :: forall v loc a b. Solve v loc a -> Solve v loc b -> Solve v loc a
<* :: forall a b. Solve v loc a -> Solve v loc b -> Solve v loc a
Applicative,
      Applicative (Solve v loc)
Applicative (Solve v loc) =>
(forall a b.
 Solve v loc a -> (a -> Solve v loc b) -> Solve v loc b)
-> (forall a b. Solve v loc a -> Solve v loc b -> Solve v loc b)
-> (forall a. a -> Solve v loc a)
-> Monad (Solve v loc)
forall a. a -> Solve v loc a
forall v loc. Applicative (Solve v loc)
forall a b. Solve v loc a -> Solve v loc b -> Solve v loc b
forall a b. Solve v loc a -> (a -> Solve v loc b) -> Solve v loc b
forall v loc a. a -> Solve v loc a
forall v loc a b. Solve v loc a -> Solve v loc b -> Solve v loc b
forall v loc a b.
Solve v loc a -> (a -> Solve v loc b) -> Solve v loc b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall v loc a b.
Solve v loc a -> (a -> Solve v loc b) -> Solve v loc b
>>= :: forall a b. Solve v loc a -> (a -> Solve v loc b) -> Solve v loc b
$c>> :: forall v loc a b. Solve v loc a -> Solve v loc b -> Solve v loc b
>> :: forall a b. Solve v loc a -> Solve v loc b -> Solve v loc b
$creturn :: forall v loc a. a -> Solve v loc a
return :: forall a. a -> Solve v loc a
Monad,
      Monad (Solve v loc)
Monad (Solve v loc) =>
(forall a. (a -> Solve v loc a) -> Solve v loc a)
-> MonadFix (Solve v loc)
forall a. (a -> Solve v loc a) -> Solve v loc a
forall v loc. Monad (Solve v loc)
forall v loc a. (a -> Solve v loc a) -> Solve v loc a
forall (m :: * -> *).
Monad m =>
(forall a. (a -> m a) -> m a) -> MonadFix m
$cmfix :: forall v loc a. (a -> Solve v loc a) -> Solve v loc a
mfix :: forall a. (a -> Solve v loc a) -> Solve v loc a
MonadFix,
      M.MonadReader Env,
      M.MonadState (SolveState v loc)
    )
    via M.ReaderT Env (M.State (SolveState v loc))

-- | Helper for inteleaving constraint generation and solving
genStateL :: Lens' (SolveState v loc) (Gen.GenState v loc)
genStateL :: forall v loc (f :: * -> *).
Functor f =>
(GenState v loc -> f (GenState v loc))
-> SolveState v loc -> f (SolveState v loc)
genStateL GenState v loc -> f (GenState v loc)
f SolveState v loc
st =
  ( \GenState v loc
genState ->
      SolveState v loc
st
        { unifVars = Gen.unifVars genState,
          typeMap = Gen.typeMap genState
        }
  )
    (GenState v loc -> SolveState v loc)
-> f (GenState v loc) -> f (SolveState v loc)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> GenState v loc -> f (GenState v loc)
f
      Gen.GenState
        { $sel:unifVars:GenState :: Set Symbol
unifVars = SolveState v loc -> Set Symbol
forall v loc. SolveState v loc -> Set Symbol
unifVars SolveState v loc
st,
          $sel:typeMap:GenState :: Map (Type v loc) (NonEmpty (UVar v loc))
typeMap = SolveState v loc -> Map (Type v loc) (NonEmpty (UVar v loc))
forall v loc.
SolveState v loc -> Map (Type v loc) (NonEmpty (UVar v loc))
typeMap SolveState v loc
st,
          $sel:newVars:GenState :: [UVar v loc]
newVars = []
        }

-- | Interleave constraint generation into constraint solving
runGen :: (Var v) => Gen v loc a -> Solve v loc a
runGen :: forall v loc a. Var v => Gen v loc a -> Solve v loc a
runGen Gen v loc a
gena = do
  SolveState v loc
st <- Solve v loc (SolveState v loc)
forall s (m :: * -> *). MonadState s m => m s
M.get
  let gena' :: Gen v loc (a, [UVar v loc])
gena' = do
        a
res <- Gen v loc a
gena
        GenState v loc
st <- Gen v loc (GenState v loc)
forall s (m :: * -> *). MonadState s m => m s
M.get
        pure (a
res, GenState v loc -> [UVar v loc]
forall v loc. GenState v loc -> [UVar v loc]
Gen.newVars GenState v loc
st)
  let ((a
cs, [UVar v loc]
vs), SolveState v loc
st') = SolveState v loc
st SolveState v loc
-> (SolveState v loc -> ((a, [UVar v loc]), SolveState v loc))
-> ((a, [UVar v loc]), SolveState v loc)
forall a b. a -> (a -> b) -> b
& (GenState v loc -> ((a, [UVar v loc]), GenState v loc))
-> SolveState v loc -> ((a, [UVar v loc]), SolveState v loc)
forall v loc (f :: * -> *).
Functor f =>
(GenState v loc -> f (GenState v loc))
-> SolveState v loc -> f (SolveState v loc)
genStateL ((GenState v loc -> ((a, [UVar v loc]), GenState v loc))
 -> SolveState v loc -> ((a, [UVar v loc]), SolveState v loc))
-> (GenState v loc -> ((a, [UVar v loc]), GenState v loc))
-> SolveState v loc
-> ((a, [UVar v loc]), SolveState v loc)
forall {k} (f :: k -> *) s (t :: k) a (b :: k).
LensLike f s t a b -> LensLike f s t a b
%%~ Gen v loc (a, [UVar v loc])
-> GenState v loc -> ((a, [UVar v loc]), GenState v loc)
forall v loc a.
Gen v loc a -> GenState v loc -> (a, GenState v loc)
Gen.run Gen v loc (a, [UVar v loc])
gena'
  SolveState v loc -> Solve v loc ()
forall s (m :: * -> *). MonadState s m => s -> m ()
M.put SolveState v loc
st'
  (UVar v loc -> Solve v loc ()) -> [UVar v loc] -> Solve v loc ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ UVar v loc -> Solve v loc ()
forall v loc. Var v => UVar v loc -> Solve v loc ()
addUnconstrainedVar [UVar v loc]
vs
  (SolveState v loc -> SolveState v loc) -> Solve v loc ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
M.modify \SolveState v loc
st -> SolveState v loc
st {newUnifVars = vs ++ newUnifVars st}
  pure a
cs

-- | Add a unification variable to the constarint mapping with no
-- constraints. This is done on uvars created during constraint
-- generation to initialize the new uvars (see 'runGen').
addUnconstrainedVar :: (Var v) => UVar v loc -> Solve v loc ()
addUnconstrainedVar :: forall v loc. Var v => UVar v loc -> Solve v loc ()
addUnconstrainedVar UVar v loc
uvar = do
  st :: SolveState v loc
st@SolveState {UFMap (UVar v loc) (Descriptor v loc)
$sel:constraints:SolveState :: forall v loc.
SolveState v loc -> UFMap (UVar v loc) (Descriptor v loc)
constraints :: UFMap (UVar v loc) (Descriptor v loc)
constraints} <- Solve v loc (SolveState v loc)
forall s (m :: * -> *). MonadState s m => m s
M.get
  let constraints' :: UFMap (UVar v loc) (Descriptor v loc)
constraints' = UVar v loc
-> Descriptor v loc
-> UFMap (UVar v loc) (Descriptor v loc)
-> UFMap (UVar v loc) (Descriptor v loc)
forall k v. Ord k => k -> v -> UFMap k v -> UFMap k v
U.insert UVar v loc
uvar Descriptor {$sel:descriptorConstraint:Descriptor :: Maybe (Constraint (UVar v loc) v loc)
descriptorConstraint = Maybe (Constraint (UVar v loc) v loc)
forall a. Maybe a
Nothing} UFMap (UVar v loc) (Descriptor v loc)
constraints
  SolveState v loc -> Solve v loc ()
forall s (m :: * -> *). MonadState s m => s -> m ()
M.put SolveState v loc
st {constraints = constraints'}

-- | Runner for the @Solve@ monad
run :: Env -> SolveState v loc -> Solve v loc a -> (a, SolveState v loc)
run :: forall v loc a.
Env -> SolveState v loc -> Solve v loc a -> (a, SolveState v loc)
run Env
e SolveState v loc
st Solve v loc a
action = Solve v loc a -> Env -> SolveState v loc -> (a, SolveState v loc)
forall v loc a.
Solve v loc a -> Env -> SolveState v loc -> (a, SolveState v loc)
unSolve Solve v loc a
action Env
e SolveState v loc
st

-- | Initial solve state
emptyState :: SolveState v loc
emptyState :: forall v loc. SolveState v loc
emptyState =
  SolveState
    { $sel:unifVars:SolveState :: Set Symbol
unifVars = Set Symbol
forall a. Set a
Set.empty,
      $sel:newUnifVars:SolveState :: [UVar v loc]
newUnifVars = [],
      $sel:constraints:SolveState :: UFMap (UVar v loc) (Descriptor v loc)
constraints = UFMap (UVar v loc) (Descriptor v loc)
forall k v. UFMap k v
U.empty,
      $sel:typeMap:SolveState :: Map (Type v loc) (NonEmpty (UVar v loc))
typeMap = Map (Type v loc) (NonEmpty (UVar v loc))
forall k a. Map k a
M.empty
    }

-- | Lookup the constraints associated with a unification variable
find :: (Var v) => UVar v loc -> Solve v loc (Maybe (Constraint (UVar v loc) v loc))
find :: forall v loc.
Var v =>
UVar v loc -> Solve v loc (Maybe (Constraint (UVar v loc) v loc))
find UVar v loc
k = do
  st :: SolveState v loc
st@SolveState {UFMap (UVar v loc) (Descriptor v loc)
$sel:constraints:SolveState :: forall v loc.
SolveState v loc -> UFMap (UVar v loc) (Descriptor v loc)
constraints :: UFMap (UVar v loc) (Descriptor v loc)
constraints} <- Solve v loc (SolveState v loc)
forall s (m :: * -> *). MonadState s m => m s
M.get
  case UVar v loc
-> UFMap (UVar v loc) (Descriptor v loc)
-> Maybe
     (UVar v loc, Int, Descriptor v loc,
      UFMap (UVar v loc) (Descriptor v loc))
forall k v. Ord k => k -> UFMap k v -> Maybe (k, Int, v, UFMap k v)
U.lookupCanon UVar v loc
k UFMap (UVar v loc) (Descriptor v loc)
constraints of
    Just (UVar v loc
_canon, Int
_size, Descriptor {Maybe (Constraint (UVar v loc) v loc)
$sel:descriptorConstraint:Descriptor :: forall v loc.
Descriptor v loc -> Maybe (Constraint (UVar v loc) v loc)
descriptorConstraint :: Maybe (Constraint (UVar v loc) v loc)
descriptorConstraint}, UFMap (UVar v loc) (Descriptor v loc)
constraints') -> do
      SolveState v loc -> Solve v loc ()
forall s (m :: * -> *). MonadState s m => s -> m ()
M.put SolveState v loc
st {constraints = constraints'}
      pure Maybe (Constraint (UVar v loc) v loc)
descriptorConstraint
    Maybe
  (UVar v loc, Int, Descriptor v loc,
   UFMap (UVar v loc) (Descriptor v loc))
Nothing -> [Char] -> Solve v loc (Maybe (Constraint (UVar v loc) v loc))
forall a. HasCallStack => [Char] -> a
error [Char]
"find: Nothing"