{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE UndecidableInstances #-}

module Unison.Util.Recursion
  ( Algebra,
    Recursive (..),
    cataM,
    para,
    Fix (..),
  )
where

import Control.Arrow ((&&&))
import Control.Comonad.Cofree (Cofree ((:<)))
import Control.Comonad.Trans.Cofree (CofreeF)
import Control.Comonad.Trans.Cofree qualified as CofreeF
import Control.Monad ((<=<))

type Algebra f a = f a -> a

class Recursive t f | t -> f where
  cata :: (Algebra f a) -> t -> a
  default cata :: (Functor f) => (f a -> a) -> t -> a
  cata f a -> a
φ = f a -> a
φ (f a -> a) -> (t -> f a) -> t -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (t -> a) -> f t -> f a
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((f a -> a) -> t -> a
forall a. Algebra f a -> t -> a
forall t (f :: * -> *) a. Recursive t f => Algebra f a -> t -> a
cata f a -> a
φ) (f t -> f a) -> (t -> f t) -> t -> f a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. t -> f t
forall t (f :: * -> *). Recursive t f => t -> f t
project
  project :: t -> f t
  default project :: (Functor f) => t -> f t
  project = Algebra f (f t) -> t -> f t
forall a. Algebra f a -> t -> a
forall t (f :: * -> *) a. Recursive t f => Algebra f a -> t -> a
cata ((f t -> t) -> Algebra f (f t)
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap f t -> t
forall t (f :: * -> *). Recursive t f => f t -> t
embed)
  embed :: f t -> t
  {-# MINIMAL embed, (cata | project) #-}

cataM :: (Recursive t f, Traversable f, Monad m) => (f a -> m a) -> t -> m a
cataM :: forall t (f :: * -> *) (m :: * -> *) a.
(Recursive t f, Traversable f, Monad m) =>
(f a -> m a) -> t -> m a
cataM f a -> m a
φ = Algebra f (m a) -> t -> m a
forall a. Algebra f a -> t -> a
forall t (f :: * -> *) a. Recursive t f => Algebra f a -> t -> a
cata (Algebra f (m a) -> t -> m a) -> Algebra f (m a) -> t -> m a
forall a b. (a -> b) -> a -> b
$ f a -> m a
φ (f a -> m a) -> (f (m a) -> m (f a)) -> Algebra f (m a)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< f (m a) -> m (f a)
forall (t :: * -> *) (f :: * -> *) a.
(Traversable t, Applicative f) =>
t (f a) -> f (t a)
forall (f :: * -> *) a. Applicative f => f (f a) -> f (f a)
sequenceA

para :: (Recursive t f, Functor f) => (f (t, a) -> a) -> t -> a
para :: forall t (f :: * -> *) a.
(Recursive t f, Functor f) =>
(f (t, a) -> a) -> t -> a
para f (t, a) -> a
φ = (t, a) -> a
forall a b. (a, b) -> b
snd ((t, a) -> a) -> (t -> (t, a)) -> t -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Algebra f (t, a) -> t -> (t, a)
forall a. Algebra f a -> t -> a
forall t (f :: * -> *) a. Recursive t f => Algebra f a -> t -> a
cata (f t -> t
forall t (f :: * -> *). Recursive t f => f t -> t
embed (f t -> t) -> (f (t, a) -> f t) -> f (t, a) -> t
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((t, a) -> t) -> f (t, a) -> f t
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (t, a) -> t
forall a b. (a, b) -> a
fst (f (t, a) -> t) -> (f (t, a) -> a) -> Algebra f (t, a)
forall b c c'. (b -> c) -> (b -> c') -> b -> (c, c')
forall (a :: * -> * -> *) b c c'.
Arrow a =>
a b c -> a b c' -> a b (c, c')
&&& f (t, a) -> a
φ)

newtype Fix f = Fix (f (Fix f))

deriving instance (forall a. (Show a) => Show (f a)) => Show (Fix f)

deriving instance (forall a. (Eq a) => Eq (f a)) => Eq (Fix f)

deriving instance (Eq (Fix f), forall a. (Ord a) => Ord (f a)) => Ord (Fix f)

instance (Functor f) => Recursive (Fix f) f where
  embed :: f (Fix f) -> Fix f
embed = f (Fix f) -> Fix f
forall (f :: * -> *). f (Fix f) -> Fix f
Fix
  project :: Fix f -> f (Fix f)
project (Fix f (Fix f)
f) = f (Fix f)
f

-- |
--
--  __NB__: `Cofree` from “free” is lazy, so this instance is technically partial.
instance (Functor f) => Recursive (Cofree f a) (CofreeF f a) where
  embed :: CofreeF f a (Cofree f a) -> Cofree f a
embed (a
a CofreeF.:< f (Cofree f a)
fco) = a
a a -> f (Cofree f a) -> Cofree f a
forall (f :: * -> *) a. a -> f (Cofree f a) -> Cofree f a
:< f (Cofree f a)
fco
  project :: Cofree f a -> CofreeF f a (Cofree f a)
project (a
a :< f (Cofree f a)
fco) = a
a a -> f (Cofree f a) -> CofreeF f a (Cofree f a)
forall (f :: * -> *) a b. a -> f b -> CofreeF f a b
CofreeF.:< f (Cofree f a)
fco