-- | This module contains various utilities related to the implementation of record types.
module Unison.DataDeclaration.Records
  ( generateRecordAccessors,
  )
where

import Data.List.NonEmpty (pattern (:|))
import Data.List.NonEmpty qualified as List (NonEmpty)
import Data.Set qualified as Set
import Unison.ABT qualified as ABT
import Unison.ConstructorReference (GConstructorReference (..))
import Unison.Pattern qualified as Pattern
import Unison.Prelude
import Unison.Reference (TypeReference)
import Unison.Term (Term)
import Unison.Term qualified as Term
import Unison.Var (Var)
import Unison.Var qualified as Var

generateRecordAccessors ::
  (Semigroup a, Var v) =>
  (List.NonEmpty v -> v) ->
  (a -> a) ->
  [(v, a)] ->
  v ->
  TypeReference ->
  [(v, a, Term v a)]
generateRecordAccessors :: forall a v.
(Semigroup a, Var v) =>
(NonEmpty v -> v)
-> (a -> a) -> [(v, a)] -> v -> TypeReference -> [(v, a, Term v a)]
generateRecordAccessors NonEmpty v -> v
namespaced a -> a
generatedAnn [(v, a)]
fields v
typename TypeReference
typ =
  [[(v, a, Term v a)]] -> [(v, a, Term v a)]
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join [(v, a) -> Int -> [(v, a, Term v a)]
tm (v, a)
t Int
i | ((v, a)
t, Int
i) <- [(v, a)]
fields [(v, a)] -> [Int] -> [((v, a), Int)]
forall a b. [a] -> [b] -> [(a, b)]
`zip` [(Int
0 :: Int) ..]]
  where
    argname :: v
argname = v -> v
forall v. Var v => v -> v
Var.uncapitalize v
typename
    tm :: (v, a) -> Int -> [(v, a, Term v a)]
tm (v
fname, a
fieldAnn) Int
i =
      [ (NonEmpty v -> v
namespaced (v
typename v -> [v] -> NonEmpty v
forall a. a -> [a] -> NonEmpty a
:| [v
fname]), a
ann, Term v a
get),
        (NonEmpty v -> v
namespaced (v
typename v -> [v] -> NonEmpty v
forall a. a -> [a] -> NonEmpty a
:| [v
fname, Text -> v
forall v. Var v => Text -> v
Var.named Text
"set"]), a
ann, Term v a
set),
        (NonEmpty v -> v
namespaced (v
typename v -> [v] -> NonEmpty v
forall a. a -> [a] -> NonEmpty a
:| [v
fname, Text -> v
forall v. Var v => Text -> v
Var.named Text
"modify"]), a
ann, Term v a
modify)
      ]
      where
        ann :: a
ann = a -> a
generatedAnn a
fieldAnn
        conref :: GConstructorReference TypeReference
conref = TypeReference
-> ConstructorId -> GConstructorReference TypeReference
forall r. r -> ConstructorId -> GConstructorReference r
ConstructorReference TypeReference
typ ConstructorId
0
        pat :: [Pattern a] -> Pattern a
pat = a
-> GConstructorReference TypeReference -> [Pattern a] -> Pattern a
forall loc.
loc
-> GConstructorReference TypeReference
-> [Pattern loc]
-> Pattern loc
Pattern.Constructor a
ann GConstructorReference TypeReference
conref

        -- point -> case point of Point _ y _ -> y
        get :: Term v a
get =
          a -> (a, v) -> Term v a -> Term v a
forall v a vt at ap.
Ord v =>
a -> (a, v) -> Term2 vt at ap v a -> Term2 vt at ap v a
Term.lam a
ann (a
ann, v
argname) (Term v a -> Term v a) -> Term v a -> Term v a
forall a b. (a -> b) -> a -> b
$
            a -> Term v a -> [MatchCase a (Term v a)] -> Term v a
forall v a vt at.
Ord v =>
a
-> Term2 vt at a v a
-> [MatchCase a (Term2 vt at a v a)]
-> Term2 vt at a v a
Term.match
              a
ann
              (a -> v -> Term v a
forall a v vt at ap. a -> v -> Term2 vt at ap v a
Term.var a
ann v
argname)
              [Pattern a -> Maybe (Term v a) -> Term v a -> MatchCase a (Term v a)
forall loc a. Pattern loc -> Maybe a -> a -> MatchCase loc a
Term.MatchCase ([Pattern a] -> Pattern a
pat [Pattern a]
cargs) Maybe (Term v a)
forall a. Maybe a
Nothing Term v a
rhs]
          where
            -- [_, y, _]
            cargs :: [Pattern a]
cargs =
              [ if Int
j Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
i then a -> Pattern a
forall loc. loc -> Pattern loc
Pattern.Var a
ann else a -> Pattern a
forall loc. loc -> Pattern loc
Pattern.Unbound a
ann
                | ((v, a)
_, Int
j) <- [(v, a)]
fields [(v, a)] -> [Int] -> [((v, a), Int)]
forall a b. [a] -> [b] -> [(a, b)]
`zip` [Int
0 ..]
              ]
            -- y -> y
            rhs :: Term v a
rhs = a -> v -> Term v a -> Term v a
forall v a (f :: * -> *).
Ord v =>
a -> v -> Term f v a -> Term f v a
ABT.abs' a
ann v
fname (a -> v -> Term v a
forall a v vt at ap. a -> v -> Term2 vt at ap v a
Term.var a
ann v
fname)

        -- y' point -> case point of Point x _ z -> Point x y' z
        set :: Term v a
set =
          a -> [(a, v)] -> Term v a -> Term v a
forall v a vt at ap.
Ord v =>
a -> [(a, v)] -> Term2 vt at ap v a -> Term2 vt at ap v a
Term.lam' a
ann [(a
ann, v
fname'), (a
ann, v
argname)] (Term v a -> Term v a) -> Term v a -> Term v a
forall a b. (a -> b) -> a -> b
$
            a -> Term v a -> [MatchCase a (Term v a)] -> Term v a
forall v a vt at.
Ord v =>
a
-> Term2 vt at a v a
-> [MatchCase a (Term2 vt at a v a)]
-> Term2 vt at a v a
Term.match
              a
ann
              (a -> v -> Term v a
forall a v vt at ap. a -> v -> Term2 vt at ap v a
Term.var a
ann v
argname)
              [Pattern a -> Maybe (Term v a) -> Term v a -> MatchCase a (Term v a)
forall loc a. Pattern loc -> Maybe a -> a -> MatchCase loc a
Term.MatchCase ([Pattern a] -> Pattern a
pat [Pattern a]
cargs) Maybe (Term v a)
forall a. Maybe a
Nothing Term v a
rhs]
          where
            -- y'
            fname' :: v
fname' =
              Text -> v
forall v. Var v => Text -> v
Var.named (Text -> v) -> (v -> Text) -> v -> v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. v -> Text
forall v. Var v => v -> Text
Var.name (v -> v) -> v -> v
forall a b. (a -> b) -> a -> b
$
                Set v -> v -> v
forall v. Var v => Set v -> v -> v
Var.freshIn ([v] -> Set v
forall a. Ord a => [a] -> Set a
Set.fromList ([v] -> Set v) -> [v] -> Set v
forall a b. (a -> b) -> a -> b
$ [v
argname] [v] -> [v] -> [v]
forall a. Semigroup a => a -> a -> a
<> ((v, a) -> v
forall a b. (a, b) -> a
fst ((v, a) -> v) -> [(v, a)] -> [v]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(v, a)]
fields)) v
fname
            -- [x, _, z]
            cargs :: [Pattern a]
cargs =
              [ if Int
j Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
i then a -> Pattern a
forall loc. loc -> Pattern loc
Pattern.Unbound a
ann else a -> Pattern a
forall loc. loc -> Pattern loc
Pattern.Var a
ann
                | ((v, a)
_, Int
j) <- [(v, a)]
fields [(v, a)] -> [Int] -> [((v, a), Int)]
forall a b. [a] -> [b] -> [(a, b)]
`zip` [Int
0 ..]
              ]
            -- x z -> Point x y' z
            rhs :: Term v a
rhs =
              (v -> Term v a -> Term v a) -> Term v a -> [v] -> Term v a
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr
                (a -> v -> Term v a -> Term v a
forall v a (f :: * -> *).
Ord v =>
a -> v -> Term f v a -> Term f v a
ABT.abs' a
ann)
                (a -> GConstructorReference TypeReference -> Term v a
forall v a vt at ap.
Ord v =>
a -> GConstructorReference TypeReference -> Term2 vt at ap v a
Term.constructor a
ann GConstructorReference TypeReference
conref Term v a -> [Term v a] -> Term v a
forall v a vt at ap.
(Ord v, Semigroup a) =>
Term2 vt at ap v a -> [Term2 vt at ap v a] -> Term2 vt at ap v a
`Term.apps'` [Term v a]
vargs)
                [v
v | ((v
v, a
_), Int
j) <- [(v, a)]
fields [(v, a)] -> [Int] -> [((v, a), Int)]
forall a b. [a] -> [b] -> [(a, b)]
`zip` [Int
0 ..], Int
j Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
i]
            -- [x, y', z]
            vargs :: [Term v a]
vargs =
              [ if Int
j Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
i then a -> v -> Term v a
forall a v vt at ap. a -> v -> Term2 vt at ap v a
Term.var a
ann v
fname' else a -> v -> Term v a
forall a v vt at ap. a -> v -> Term2 vt at ap v a
Term.var a
ann v
v
                | ((v
v, a
_), Int
j) <- [(v, a)]
fields [(v, a)] -> [Int] -> [((v, a), Int)]
forall a b. [a] -> [b] -> [(a, b)]
`zip` [Int
0 ..]
              ]

        -- example: `f point -> case point of Point x y z -> Point x (f y) z`
        modify :: Term v a
modify =
          a -> [(a, v)] -> Term v a -> Term v a
forall v a vt at ap.
Ord v =>
a -> [(a, v)] -> Term2 vt at ap v a -> Term2 vt at ap v a
Term.lam' a
ann [(a
ann, v
fname'), (a
ann, v
argname)] (Term v a -> Term v a) -> Term v a -> Term v a
forall a b. (a -> b) -> a -> b
$
            a -> Term v a -> [MatchCase a (Term v a)] -> Term v a
forall v a vt at.
Ord v =>
a
-> Term2 vt at a v a
-> [MatchCase a (Term2 vt at a v a)]
-> Term2 vt at a v a
Term.match
              a
ann
              (a -> v -> Term v a
forall a v vt at ap. a -> v -> Term2 vt at ap v a
Term.var a
ann v
argname)
              [Pattern a -> Maybe (Term v a) -> Term v a -> MatchCase a (Term v a)
forall loc a. Pattern loc -> Maybe a -> a -> MatchCase loc a
Term.MatchCase ([Pattern a] -> Pattern a
pat [Pattern a]
cargs) Maybe (Term v a)
forall a. Maybe a
Nothing Term v a
rhs]
          where
            fname' :: v
fname' =
              Text -> v
forall v. Var v => Text -> v
Var.named (Text -> v) -> (v -> Text) -> v -> v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. v -> Text
forall v. Var v => v -> Text
Var.name (v -> v) -> v -> v
forall a b. (a -> b) -> a -> b
$
                Set v -> v -> v
forall v. Var v => Set v -> v -> v
Var.freshIn
                  ([v] -> Set v
forall a. Ord a => [a] -> Set a
Set.fromList ([v] -> Set v) -> [v] -> Set v
forall a b. (a -> b) -> a -> b
$ [v
argname] [v] -> [v] -> [v]
forall a. Semigroup a => a -> a -> a
<> ((v, a) -> v
forall a b. (a, b) -> a
fst ((v, a) -> v) -> [(v, a)] -> [v]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(v, a)]
fields))
                  (Text -> v
forall v. Var v => Text -> v
Var.named Text
"f")
            cargs :: [Pattern a]
cargs = [a -> Pattern a
forall loc. loc -> Pattern loc
Pattern.Var a
ann | (v, a)
_ <- [(v, a)]
fields]
            rhs :: Term v a
rhs =
              (v -> Term v a -> Term v a) -> Term v a -> [v] -> Term v a
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr
                (a -> v -> Term v a -> Term v a
forall v a (f :: * -> *).
Ord v =>
a -> v -> Term f v a -> Term f v a
ABT.abs' a
ann)
                (a -> GConstructorReference TypeReference -> Term v a
forall v a vt at ap.
Ord v =>
a -> GConstructorReference TypeReference -> Term2 vt at ap v a
Term.constructor a
ann GConstructorReference TypeReference
conref Term v a -> [Term v a] -> Term v a
forall v a vt at ap.
(Ord v, Semigroup a) =>
Term2 vt at ap v a -> [Term2 vt at ap v a] -> Term2 vt at ap v a
`Term.apps'` [Term v a]
vargs)
                ((v, a) -> v
forall a b. (a, b) -> a
fst ((v, a) -> v) -> [(v, a)] -> [v]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(v, a)]
fields)
            vargs :: [Term v a]
vargs =
              [ if Int
j Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
i
                  then Term v a -> [Term v a] -> Term v a
forall v a vt at ap.
(Ord v, Semigroup a) =>
Term2 vt at ap v a -> [Term2 vt at ap v a] -> Term2 vt at ap v a
Term.apps' (a -> v -> Term v a
forall a v vt at ap. a -> v -> Term2 vt at ap v a
Term.var a
ann v
fname') [a -> v -> Term v a
forall a v vt at ap. a -> v -> Term2 vt at ap v a
Term.var a
ann v
v]
                  else a -> v -> Term v a
forall a v vt at ap. a -> v -> Term2 vt at ap v a
Term.var a
ann v
v
                | ((v
v, a
_), Int
j) <- [(v, a)]
fields [(v, a)] -> [Int] -> [((v, a), Int)]
forall a b. [a] -> [b] -> [(a, b)]
`zip` [Int
0 ..]
              ]