{-# LANGUAGE FunctionalDependencies #-}

module Unison.PatternMatchCoverage.Class
  ( Pmc (..),
    EnumeratedConstructors (..),
    traverseConstructorTypes,
  )
where

import Control.Monad.Fix (MonadFix)
import Data.Map (Map)
import Data.Map qualified as Map
import Unison.ConstructorReference (ConstructorReference)
import Unison.PatternMatchCoverage.ListPat (ListPat)
import Unison.PrettyPrintEnv (PrettyPrintEnv)
import Unison.Type (Type)
import Unison.Var (Var)

-- | A typeclass for the queries required to perform pattern match
-- coverage checking.
class (Ord loc, Var vt, Var v, MonadFix m) => Pmc vt v loc m | m -> vt v loc where
  -- | Get the constructors of a type
  getConstructors :: Type vt loc -> m (EnumeratedConstructors vt v loc)

  -- | Get the types of the arguments of a specific constructor
  getConstructorVarTypes :: Type vt loc -> ConstructorReference -> m [Type vt loc]

  -- | Get a fresh variable
  fresh :: m v

  getPrettyPrintEnv :: m PrettyPrintEnv

data EnumeratedConstructors vt v loc
  = ConstructorType [(v, ConstructorReference, Type vt loc)]
  | AbilityType (Type vt loc) (Map ConstructorReference (v, Type vt loc))
  | SequenceType [(ListPat, [Type vt loc])]
  | BooleanType
  | OtherType
  deriving stock (Int -> EnumeratedConstructors vt v loc -> ShowS
[EnumeratedConstructors vt v loc] -> ShowS
EnumeratedConstructors vt v loc -> String
(Int -> EnumeratedConstructors vt v loc -> ShowS)
-> (EnumeratedConstructors vt v loc -> String)
-> ([EnumeratedConstructors vt v loc] -> ShowS)
-> Show (EnumeratedConstructors vt v loc)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall vt v loc.
(Show v, Show vt) =>
Int -> EnumeratedConstructors vt v loc -> ShowS
forall vt v loc.
(Show v, Show vt) =>
[EnumeratedConstructors vt v loc] -> ShowS
forall vt v loc.
(Show v, Show vt) =>
EnumeratedConstructors vt v loc -> String
$cshowsPrec :: forall vt v loc.
(Show v, Show vt) =>
Int -> EnumeratedConstructors vt v loc -> ShowS
showsPrec :: Int -> EnumeratedConstructors vt v loc -> ShowS
$cshow :: forall vt v loc.
(Show v, Show vt) =>
EnumeratedConstructors vt v loc -> String
show :: EnumeratedConstructors vt v loc -> String
$cshowList :: forall vt v loc.
(Show v, Show vt) =>
[EnumeratedConstructors vt v loc] -> ShowS
showList :: [EnumeratedConstructors vt v loc] -> ShowS
Show)

traverseConstructorTypes ::
  (Applicative f) =>
  (v -> ConstructorReference -> Type vt loc -> f (Type vt loc)) ->
  EnumeratedConstructors vt v loc ->
  f (EnumeratedConstructors vt v loc)
traverseConstructorTypes :: forall (f :: * -> *) v vt loc.
Applicative f =>
(v -> ConstructorReference -> Type vt loc -> f (Type vt loc))
-> EnumeratedConstructors vt v loc
-> f (EnumeratedConstructors vt v loc)
traverseConstructorTypes v -> ConstructorReference -> Type vt loc -> f (Type vt loc)
f = \case
  ConstructorType [(v, ConstructorReference, Type vt loc)]
xs -> [(v, ConstructorReference, Type vt loc)]
-> EnumeratedConstructors vt v loc
forall vt v loc.
[(v, ConstructorReference, Type vt loc)]
-> EnumeratedConstructors vt v loc
ConstructorType ([(v, ConstructorReference, Type vt loc)]
 -> EnumeratedConstructors vt v loc)
-> f [(v, ConstructorReference, Type vt loc)]
-> f (EnumeratedConstructors vt v loc)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((v, ConstructorReference, Type vt loc)
 -> f (v, ConstructorReference, Type vt loc))
-> [(v, ConstructorReference, Type vt loc)]
-> f [(v, ConstructorReference, Type vt loc)]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (\(v
a, ConstructorReference
b, Type vt loc
c) -> (v
a,ConstructorReference
b,) (Type vt loc -> (v, ConstructorReference, Type vt loc))
-> f (Type vt loc) -> f (v, ConstructorReference, Type vt loc)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> v -> ConstructorReference -> Type vt loc -> f (Type vt loc)
f v
a ConstructorReference
b Type vt loc
c) [(v, ConstructorReference, Type vt loc)]
xs
  AbilityType Type vt loc
resultType Map ConstructorReference (v, Type vt loc)
m ->
    Type vt loc
-> Map ConstructorReference (v, Type vt loc)
-> EnumeratedConstructors vt v loc
forall vt v loc.
Type vt loc
-> Map ConstructorReference (v, Type vt loc)
-> EnumeratedConstructors vt v loc
AbilityType Type vt loc
resultType
      (Map ConstructorReference (v, Type vt loc)
 -> EnumeratedConstructors vt v loc)
-> f (Map ConstructorReference (v, Type vt loc))
-> f (EnumeratedConstructors vt v loc)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (ConstructorReference
 -> (v, Type vt loc)
 -> f (Map ConstructorReference (v, Type vt loc))
 -> f (Map ConstructorReference (v, Type vt loc)))
-> f (Map ConstructorReference (v, Type vt loc))
-> Map ConstructorReference (v, Type vt loc)
-> f (Map ConstructorReference (v, Type vt loc))
forall k a b. (k -> a -> b -> b) -> b -> Map k a -> b
Map.foldrWithKey
        ( \ConstructorReference
cr (v
v, Type vt loc
t) f (Map ConstructorReference (v, Type vt loc))
b ->
            let t' :: f (Type vt loc)
t' = v -> ConstructorReference -> Type vt loc -> f (Type vt loc)
f v
v ConstructorReference
cr Type vt loc
t
                newValue :: f (v, Type vt loc)
newValue = (v
v,) (Type vt loc -> (v, Type vt loc))
-> f (Type vt loc) -> f (v, Type vt loc)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f (Type vt loc)
t'
             in ConstructorReference
-> (v, Type vt loc)
-> Map ConstructorReference (v, Type vt loc)
-> Map ConstructorReference (v, Type vt loc)
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert ConstructorReference
cr ((v, Type vt loc)
 -> Map ConstructorReference (v, Type vt loc)
 -> Map ConstructorReference (v, Type vt loc))
-> f (v, Type vt loc)
-> f (Map ConstructorReference (v, Type vt loc)
      -> Map ConstructorReference (v, Type vt loc))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f (v, Type vt loc)
newValue f (Map ConstructorReference (v, Type vt loc)
   -> Map ConstructorReference (v, Type vt loc))
-> f (Map ConstructorReference (v, Type vt loc))
-> f (Map ConstructorReference (v, Type vt loc))
forall a b. f (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> f (Map ConstructorReference (v, Type vt loc))
b
        )
        (Map ConstructorReference (v, Type vt loc)
-> f (Map ConstructorReference (v, Type vt loc))
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Map ConstructorReference (v, Type vt loc)
forall a. Monoid a => a
mempty)
        Map ConstructorReference (v, Type vt loc)
m
  SequenceType [(ListPat, [Type vt loc])]
x -> EnumeratedConstructors vt v loc
-> f (EnumeratedConstructors vt v loc)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([(ListPat, [Type vt loc])] -> EnumeratedConstructors vt v loc
forall vt v loc.
[(ListPat, [Type vt loc])] -> EnumeratedConstructors vt v loc
SequenceType [(ListPat, [Type vt loc])]
x)
  EnumeratedConstructors vt v loc
BooleanType -> EnumeratedConstructors vt v loc
-> f (EnumeratedConstructors vt v loc)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure EnumeratedConstructors vt v loc
forall vt v loc. EnumeratedConstructors vt v loc
BooleanType
  EnumeratedConstructors vt v loc
OtherType -> EnumeratedConstructors vt v loc
-> f (EnumeratedConstructors vt v loc)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure EnumeratedConstructors vt v loc
forall vt v loc. EnumeratedConstructors vt v loc
OtherType