-- | Kind inference for Unison
--
-- Unison has Type, ->, and Ability kinds
--
-- An algorithm sketch: First break all decls into strongly connected
-- components in reverse topological order. Then, for each component,
-- generate kind constraints that arise from the constructors in the
-- decl to discover constraints on the decl vars. These constraints
-- are then given to a constraint solver that determines a unique kind
-- for each type variable. Unconstrained variables are defaulted to
-- kind Type (just like Haskell 98). This is done by 'inferDecls'.
--
-- Afterwards, the 'SolveState' holds the kinds of all decls and we
-- can check that type annotations in terms that may mention the
-- decls are well-kinded with 'kindCheckAnnotations'.
module Unison.KindInference
  ( inferDecls,
    kindCheckAnnotations,
    KindError,
  )
where

import Data.Foldable (foldlM)
import Data.Graph (flattenSCC, stronglyConnCompR)
import Data.List.NonEmpty (NonEmpty (..))
import Data.Map.Strict qualified as Map
import Unison.Codebase.BuiltinAnnotation (BuiltinAnnotation)
import Unison.DataDeclaration
import Unison.KindInference.Generate (declComponentConstraints, termConstraints)
import Unison.KindInference.Solve (KindError, defaultUnconstrainedVars, initialState, step, verify)
import Unison.KindInference.Solve.Monad (Env (..), SolveState, run, runGen)
import Unison.Prelude
import Unison.PrettyPrintEnv qualified as PrettyPrintEnv
import Unison.Reference
import Unison.Term qualified as Term
import Unison.Var qualified as Var

-- | Check that all annotations in a term are well-kinded
kindCheckAnnotations ::
  forall v loc.
  (Var.Var v, Ord loc, Show loc, BuiltinAnnotation loc) =>
  PrettyPrintEnv.PrettyPrintEnv ->
  SolveState v loc ->
  Term.Term v loc ->
  Either (NonEmpty (KindError v loc)) ()
kindCheckAnnotations :: forall v loc.
(Var v, Ord loc, Show loc, BuiltinAnnotation loc) =>
PrettyPrintEnv
-> SolveState v loc
-> Term v loc
-> Either (NonEmpty (KindError v loc)) ()
kindCheckAnnotations PrettyPrintEnv
ppe SolveState v loc
st Term v loc
t =
  let ([GeneratedConstraint v loc]
cs, SolveState v loc
st') = Env
-> SolveState v loc
-> Solve v loc [GeneratedConstraint v loc]
-> ([GeneratedConstraint v loc], SolveState v loc)
forall v loc a.
Env -> SolveState v loc -> Solve v loc a -> (a, SolveState v loc)
run Env
env SolveState v loc
st (Gen v loc [GeneratedConstraint v loc]
-> Solve v loc [GeneratedConstraint v loc]
forall v loc a. Var v => Gen v loc a -> Solve v loc a
runGen (Gen v loc [GeneratedConstraint v loc]
 -> Solve v loc [GeneratedConstraint v loc])
-> Gen v loc [GeneratedConstraint v loc]
-> Solve v loc [GeneratedConstraint v loc]
forall a b. (a -> b) -> a -> b
$ Term v loc -> Gen v loc [GeneratedConstraint v loc]
forall v loc.
(Var v, Ord loc) =>
Term v loc -> Gen v loc [GeneratedConstraint v loc]
termConstraints Term v loc
t)
      env :: Env
env = PrettyPrintEnv -> Env
Env PrettyPrintEnv
ppe
   in Env
-> SolveState v loc
-> [GeneratedConstraint v loc]
-> Either (NonEmpty (KindError v loc)) (SolveState v loc)
forall v loc.
(Var v, Ord loc, Show loc) =>
Env
-> SolveState v loc
-> [GeneratedConstraint v loc]
-> Either (NonEmpty (KindError v loc)) (SolveState v loc)
step Env
env SolveState v loc
st' [GeneratedConstraint v loc]
cs Either (NonEmpty (KindError v loc)) (SolveState v loc)
-> () -> Either (NonEmpty (KindError v loc)) ()
forall (f :: * -> *) a b. Functor f => f a -> b -> f b
$> ()

-- | Infer the kinds of all decl vars
inferDecls ::
  forall v loc.
  (Var.Var v, BuiltinAnnotation loc, Ord loc, Show loc) =>
  PrettyPrintEnv.PrettyPrintEnv ->
  Map Reference (Decl v loc) ->
  Either (NonEmpty (KindError v loc)) (SolveState v loc)
inferDecls :: forall v loc.
(Var v, BuiltinAnnotation loc, Ord loc, Show loc) =>
PrettyPrintEnv
-> Map Reference (Decl v loc)
-> Either (NonEmpty (KindError v loc)) (SolveState v loc)
inferDecls PrettyPrintEnv
ppe Map Reference (Decl v loc)
declMap =
  let components :: [[(Reference, Decl v loc)]]
      components :: [[(Reference, Decl v loc)]]
components = Map Reference (Decl v loc) -> [[(Reference, Decl v loc)]]
forall v a.
Ord v =>
Map Reference (Decl v a) -> [[(Reference, Decl v a)]]
intoComponents Map Reference (Decl v loc)
declMap

      env :: Env
env = PrettyPrintEnv -> Env
Env PrettyPrintEnv
ppe

      handleComponent ::
        SolveState v loc ->
        [(Reference, Decl v loc)] ->
        Either (NonEmpty (KindError v loc)) (SolveState v loc)
      handleComponent :: SolveState v loc
-> [(Reference, Decl v loc)]
-> Either (NonEmpty (KindError v loc)) (SolveState v loc)
handleComponent SolveState v loc
s [(Reference, Decl v loc)]
c =
        let ([GeneratedConstraint v loc]
cs, SolveState v loc
st) = Env
-> SolveState v loc
-> Solve v loc [GeneratedConstraint v loc]
-> ([GeneratedConstraint v loc], SolveState v loc)
forall v loc a.
Env -> SolveState v loc -> Solve v loc a -> (a, SolveState v loc)
run Env
env SolveState v loc
s (Gen v loc [GeneratedConstraint v loc]
-> Solve v loc [GeneratedConstraint v loc]
forall v loc a. Var v => Gen v loc a -> Solve v loc a
runGen (Gen v loc [GeneratedConstraint v loc]
 -> Solve v loc [GeneratedConstraint v loc])
-> Gen v loc [GeneratedConstraint v loc]
-> Solve v loc [GeneratedConstraint v loc]
forall a b. (a -> b) -> a -> b
$ [(Reference, Decl v loc)] -> Gen v loc [GeneratedConstraint v loc]
forall v loc.
(Var v, Ord loc) =>
[(Reference, Decl v loc)] -> Gen v loc [GeneratedConstraint v loc]
declComponentConstraints [(Reference, Decl v loc)]
c)
         in Env
-> SolveState v loc
-> [GeneratedConstraint v loc]
-> Either (NonEmpty (KindError v loc)) (SolveState v loc)
forall v loc.
(Var v, Ord loc, Show loc) =>
Env
-> SolveState v loc
-> [GeneratedConstraint v loc]
-> Either (NonEmpty (KindError v loc)) (SolveState v loc)
step Env
env SolveState v loc
st [GeneratedConstraint v loc]
cs

      handleComponents ::
        [[(Reference, Decl v loc)]] ->
        Either (NonEmpty (KindError v loc)) (SolveState v loc)
      handleComponents :: [[(Reference, Decl v loc)]]
-> Either (NonEmpty (KindError v loc)) (SolveState v loc)
handleComponents = SolveState v loc
-> Either (NonEmpty (KindError v loc)) (SolveState v loc)
forall v loc.
Var v =>
SolveState v loc
-> Either (NonEmpty (KindError v loc)) (SolveState v loc)
verify (SolveState v loc
 -> Either (NonEmpty (KindError v loc)) (SolveState v loc))
-> ([[(Reference, Decl v loc)]]
    -> Either (NonEmpty (KindError v loc)) (SolveState v loc))
-> [[(Reference, Decl v loc)]]
-> Either (NonEmpty (KindError v loc)) (SolveState v loc)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< (SolveState v loc
 -> [(Reference, Decl v loc)]
 -> Either (NonEmpty (KindError v loc)) (SolveState v loc))
-> SolveState v loc
-> [[(Reference, Decl v loc)]]
-> Either (NonEmpty (KindError v loc)) (SolveState v loc)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldlM SolveState v loc
-> [(Reference, Decl v loc)]
-> Either (NonEmpty (KindError v loc)) (SolveState v loc)
phi (Env -> SolveState v loc
forall v loc.
(BuiltinAnnotation loc, Show loc, Ord loc, Var v) =>
Env -> SolveState v loc
initialState Env
env)
        where
          phi :: SolveState v loc
-> [(Reference, Decl v loc)]
-> Either (NonEmpty (KindError v loc)) (SolveState v loc)
phi SolveState v loc
b [(Reference, Decl v loc)]
a = SolveState v loc
-> [(Reference, Decl v loc)]
-> Either (NonEmpty (KindError v loc)) (SolveState v loc)
handleComponent SolveState v loc
b [(Reference, Decl v loc)]
a
   in SolveState v loc -> SolveState v loc
forall v loc. Var v => SolveState v loc -> SolveState v loc
defaultUnconstrainedVars (SolveState v loc -> SolveState v loc)
-> Either (NonEmpty (KindError v loc)) (SolveState v loc)
-> Either (NonEmpty (KindError v loc)) (SolveState v loc)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [[(Reference, Decl v loc)]]
-> Either (NonEmpty (KindError v loc)) (SolveState v loc)
handleComponents [[(Reference, Decl v loc)]]
components

-- | Break the decls into strongly connected components in reverse
-- topological order
intoComponents :: forall v a. (Ord v) => Map Reference (Decl v a) -> [[(Reference, Decl v a)]]
intoComponents :: forall v a.
Ord v =>
Map Reference (Decl v a) -> [[(Reference, Decl v a)]]
intoComponents Map Reference (Decl v a)
declMap =
  let graphInput :: [(Decl v a, Reference, [Reference])]
      graphInput :: [(Decl v a, Reference, [Reference])]
graphInput = (Reference
 -> Decl v a
 -> [(Decl v a, Reference, [Reference])]
 -> [(Decl v a, Reference, [Reference])])
-> [(Decl v a, Reference, [Reference])]
-> Map Reference (Decl v a)
-> [(Decl v a, Reference, [Reference])]
forall k a b. (k -> a -> b -> b) -> b -> Map k a -> b
Map.foldrWithKey (\Reference
k Decl v a
a [(Decl v a, Reference, [Reference])]
b -> (Decl v a
a, Reference
k, Decl v a -> [Reference]
declReferences Decl v a
a) (Decl v a, Reference, [Reference])
-> [(Decl v a, Reference, [Reference])]
-> [(Decl v a, Reference, [Reference])]
forall a. a -> [a] -> [a]
: [(Decl v a, Reference, [Reference])]
b) [] Map Reference (Decl v a)
declMap
   in ((Decl v a, Reference, [Reference]) -> (Reference, Decl v a))
-> [(Decl v a, Reference, [Reference])] -> [(Reference, Decl v a)]
forall a b. (a -> b) -> [a] -> [b]
map (\(Decl v a
a, Reference
b, [Reference]
_) -> (Reference
b, Decl v a
a)) ([(Decl v a, Reference, [Reference])] -> [(Reference, Decl v a)])
-> (SCC (Decl v a, Reference, [Reference])
    -> [(Decl v a, Reference, [Reference])])
-> SCC (Decl v a, Reference, [Reference])
-> [(Reference, Decl v a)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SCC (Decl v a, Reference, [Reference])
-> [(Decl v a, Reference, [Reference])]
forall vertex. SCC vertex -> [vertex]
flattenSCC (SCC (Decl v a, Reference, [Reference]) -> [(Reference, Decl v a)])
-> [SCC (Decl v a, Reference, [Reference])]
-> [[(Reference, Decl v a)]]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(Decl v a, Reference, [Reference])]
-> [SCC (Decl v a, Reference, [Reference])]
forall key node.
Ord key =>
[(node, key, [key])] -> [SCC (node, key, [key])]
stronglyConnCompR [(Decl v a, Reference, [Reference])]
graphInput
  where
    declReferences :: Decl v a -> [Reference]
    declReferences :: Decl v a -> [Reference]
declReferences = Set Reference -> [Reference]
forall a. Set a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList (Set Reference -> [Reference])
-> (Decl v a -> Set Reference) -> Decl v a -> [Reference]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DataDeclaration v a -> Set Reference
forall v a. Ord v => DataDeclaration v a -> Set Reference
typeDependencies (DataDeclaration v a -> Set Reference)
-> (Decl v a -> DataDeclaration v a) -> Decl v a -> Set Reference
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Decl v a -> DataDeclaration v a
forall v a. Decl v a -> DataDeclaration v a
asDataDecl