module Unison.Util.Set
  ( asSingleton,
    difference1,
    differenceMap,
    foldCommutativeM,
    insertMaybe,
    intersects,
    mapMaybe,
    symmetricDifference,
    Unison.Util.Set.traverse,
    Unison.Util.Set.for,
    flatMap,
    filterM,
    forMaybe,
    thenInsert,
    thenInsertMaybe,
  )
where

import Data.Function ((&))
import Data.Functor ((<&>))
import Data.Map.Internal qualified as Map.Internal (Map (..))
import Data.Map.Strict (Map)
import Data.Maybe qualified as Maybe
import Data.Set (Set)
import Data.Set qualified as Set
import Data.Set.Internal qualified as Set.Internal (Set (..), merge)
import Unison.Util.Monoid (foldMapM)

-- | Get the only member of a set, iff it's a singleton.
asSingleton :: Set a -> Maybe a
asSingleton :: forall a. Set a -> Maybe a
asSingleton Set a
xs =
  if Set a -> Int
forall a. Set a -> Int
Set.size Set a
xs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 then a -> Maybe a
forall a. a -> Maybe a
Just (Set a -> a
forall a. Set a -> a
Set.findMin Set a
xs) else Maybe a
forall a. Maybe a
Nothing

-- | Set difference, but return @Nothing@ if the difference is empty.
difference1 :: (Ord a) => Set a -> Set a -> Maybe (Set a)
difference1 :: forall a. Ord a => Set a -> Set a -> Maybe (Set a)
difference1 Set a
xs Set a
ys =
  if Set a -> Bool
forall a. Set a -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Set a
zs then Maybe (Set a)
forall a. Maybe a
Nothing else Set a -> Maybe (Set a)
forall a. a -> Maybe a
Just Set a
zs
  where
    zs :: Set a
zs = Set a -> Set a -> Set a
forall a. Ord a => Set a -> Set a -> Set a
Set.difference Set a
xs Set a
ys

-- | Like 'Set.difference', but the second argument is a map.
differenceMap :: (Ord k) => Set k -> Map k a -> Set k
differenceMap :: forall k a. Ord k => Set k -> Map k a -> Set k
differenceMap Set k
Set.Internal.Tip Map k a
_ = Set k
forall a. Set a
Set.Internal.Tip
differenceMap Set k
x Map k a
Map.Internal.Tip = Set k
x
differenceMap Set k
x (Map.Internal.Bin Int
_ k
k a
_ Map k a
yl Map k a
yr)
  | Set k -> Int
forall a. Set a -> Int
Set.size Set k
zl Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Set k -> Int
forall a. Set a -> Int
Set.size Set k
zr Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Set k -> Int
forall a. Set a -> Int
Set.size Set k
x = Set k
x
  | Bool
otherwise = Set k -> Set k -> Set k
forall a. Set a -> Set a -> Set a
Set.Internal.merge Set k
zl Set k
zr
  where
    (Set k
xl, Set k
xr) = k -> Set k -> (Set k, Set k)
forall a. Ord a => a -> Set a -> (Set a, Set a)
Set.split k
k Set k
x
    !zl :: Set k
zl = Set k -> Map k a -> Set k
forall k a. Ord k => Set k -> Map k a -> Set k
differenceMap Set k
xl Map k a
yl
    !zr :: Set k
zr = Set k -> Map k a -> Set k
forall k a. Ord k => Set k -> Map k a -> Set k
differenceMap Set k
xr Map k a
yr

-- | Fold a set strictly with a monadic "commutative" combining function that doesn't receive the elements in any
-- particular order.
foldCommutativeM :: (Monad m) => (a -> b -> m b) -> b -> Set a -> m b
foldCommutativeM :: forall (m :: * -> *) a b.
Monad m =>
(a -> b -> m b) -> b -> Set a -> m b
foldCommutativeM a -> b -> m b
f =
  let go :: b -> [Set a] -> m b
go !b
acc = \case
        Set.Internal.Bin Int
_ a
x Set a
l Set a
r : [Set a]
xs -> do
          !b
acc1 <- a -> b -> m b
f a
x b
acc
          b -> [Set a] -> m b
go b
acc1 (Set a
l Set a -> [Set a] -> [Set a]
forall a. a -> [a] -> [a]
: Set a
r Set a -> [Set a] -> [Set a]
forall a. a -> [a] -> [a]
: [Set a]
xs)
        Set a
Set.Internal.Tip : [Set a]
xs -> b -> [Set a] -> m b
go b
acc [Set a]
xs
        [] -> b -> m b
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure b
acc
   in \b
z Set a
xs -> b -> [Set a] -> m b
go b
z [Set a
xs]

insertMaybe :: (Ord a) => Maybe a -> Set a -> Set a
insertMaybe :: forall a. Ord a => Maybe a -> Set a -> Set a
insertMaybe Maybe a
mx Set a
xs =
  case Maybe a
mx of
    Just a
x -> a -> Set a -> Set a
forall a. Ord a => a -> Set a -> Set a
Set.insert a
x Set a
xs
    Maybe a
Nothing -> Set a
xs

-- | Get whether two sets intersect.
intersects :: (Ord a) => Set a -> Set a -> Bool
intersects :: forall a. Ord a => Set a -> Set a -> Bool
intersects Set a
xs Set a
ys =
  Bool -> Bool
not (Set a -> Set a -> Bool
forall a. Ord a => Set a -> Set a -> Bool
Set.disjoint Set a
xs Set a
ys)

symmetricDifference :: (Ord a) => Set a -> Set a -> Set a
symmetricDifference :: forall a. Ord a => Set a -> Set a -> Set a
symmetricDifference Set a
a Set a
b = (Set a
a Set a -> Set a -> Set a
forall a. Ord a => Set a -> Set a -> Set a
`Set.difference` Set a
b) Set a -> Set a -> Set a
forall a. Ord a => Set a -> Set a -> Set a
`Set.union` (Set a
b Set a -> Set a -> Set a
forall a. Ord a => Set a -> Set a -> Set a
`Set.difference` Set a
a)

mapMaybe :: (Ord b) => (a -> Maybe b) -> Set a -> Set b
mapMaybe :: forall b a. Ord b => (a -> Maybe b) -> Set a -> Set b
mapMaybe a -> Maybe b
f = [b] -> Set b
forall a. Ord a => [a] -> Set a
Set.fromList ([b] -> Set b) -> (Set a -> [b]) -> Set a -> Set b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> Maybe b) -> [a] -> [b]
forall a b. (a -> Maybe b) -> [a] -> [b]
Maybe.mapMaybe a -> Maybe b
f ([a] -> [b]) -> (Set a -> [a]) -> Set a -> [b]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Set a -> [a]
forall a. Set a -> [a]
Set.toList

forMaybe :: (Ord b, Applicative f) => Set a -> (a -> f (Maybe b)) -> f (Set b)
forMaybe :: forall b (f :: * -> *) a.
(Ord b, Applicative f) =>
Set a -> (a -> f (Maybe b)) -> f (Set b)
forMaybe Set a
xs a -> f (Maybe b)
f =
  (a -> f (Maybe b)) -> [a] -> f [Maybe b]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
Prelude.traverse a -> f (Maybe b)
f (Set a -> [a]
forall a. Set a -> [a]
Set.toList Set a
xs) f [Maybe b] -> ([Maybe b] -> Set b) -> f (Set b)
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> \[Maybe b]
ys ->
    [Maybe b]
ys
      [Maybe b] -> ([Maybe b] -> [b]) -> [b]
forall a b. a -> (a -> b) -> b
& [Maybe b] -> [b]
forall a. [Maybe a] -> [a]
Maybe.catMaybes
      [b] -> ([b] -> Set b) -> Set b
forall a b. a -> (a -> b) -> b
& [b] -> Set b
forall a. Ord a => [a] -> Set a
Set.fromList

traverse :: (Applicative f, Ord b) => (a -> f b) -> Set a -> f (Set b)
traverse :: forall (f :: * -> *) b a.
(Applicative f, Ord b) =>
(a -> f b) -> Set a -> f (Set b)
traverse a -> f b
f = ([b] -> Set b) -> f [b] -> f (Set b)
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [b] -> Set b
forall a. Ord a => [a] -> Set a
Set.fromList (f [b] -> f (Set b)) -> (Set a -> f [b]) -> Set a -> f (Set b)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> f b) -> [a] -> f [b]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
Prelude.traverse a -> f b
f ([a] -> f [b]) -> (Set a -> [a]) -> Set a -> f [b]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Set a -> [a]
forall a. Set a -> [a]
Set.toList

for :: (Ord b, Applicative f) => Set a -> (a -> f b) -> f (Set b)
for :: forall b (f :: * -> *) a.
(Ord b, Applicative f) =>
Set a -> (a -> f b) -> f (Set b)
for = ((a -> f b) -> Set a -> f (Set b))
-> Set a -> (a -> f b) -> f (Set b)
forall a b c. (a -> b -> c) -> b -> a -> c
flip (a -> f b) -> Set a -> f (Set b)
forall (f :: * -> *) b a.
(Applicative f, Ord b) =>
(a -> f b) -> Set a -> f (Set b)
Unison.Util.Set.traverse

flatMap :: (Ord b) => (a -> Set b) -> Set a -> Set b
flatMap :: forall b a. Ord b => (a -> Set b) -> Set a -> Set b
flatMap a -> Set b
f = [Set b] -> Set b
forall (f :: * -> *) a. (Foldable f, Ord a) => f (Set a) -> Set a
Set.unions ([Set b] -> Set b) -> (Set a -> [Set b]) -> Set a -> Set b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> Set b) -> [a] -> [Set b]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> Set b
f ([a] -> [Set b]) -> (Set a -> [a]) -> Set a -> [Set b]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Set a -> [a]
forall a. Set a -> [a]
Set.toList

filterM :: (Ord a, Monad m) => (a -> m Bool) -> Set a -> m (Set a)
filterM :: forall a (m :: * -> *).
(Ord a, Monad m) =>
(a -> m Bool) -> Set a -> m (Set a)
filterM a -> m Bool
p =
  (a -> m (Set a)) -> Set a -> m (Set a)
forall (m :: * -> *) (f :: * -> *) b a.
(Monad m, Foldable f, Monoid b) =>
(a -> m b) -> f a -> m b
foldMapM \a
x ->
    a -> m Bool
p a
x m Bool -> (Bool -> Set a) -> m (Set a)
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> \case
      Bool
False -> Set a
forall a. Set a
Set.empty
      Bool
True -> a -> Set a
forall a. a -> Set a
Set.singleton a
x

thenInsert :: (Ord a) => Set a -> a -> Set a
thenInsert :: forall a. Ord a => Set a -> a -> Set a
thenInsert Set a
xs a
x =
  a -> Set a -> Set a
forall a. Ord a => a -> Set a -> Set a
Set.insert a
x Set a
xs

thenInsertMaybe :: (Ord a) => Set a -> Maybe a -> Set a
thenInsertMaybe :: forall a. Ord a => Set a -> Maybe a -> Set a
thenInsertMaybe Set a
xs = \case
  Just a
x -> a -> Set a -> Set a
forall a. Ord a => a -> Set a -> Set a
Set.insert a
x Set a
xs
  Maybe a
Nothing -> Set a
xs