module Unison.Util.Set
  ( asSingleton,
    difference1,
    intersects,
    mapMaybe,
    symmetricDifference,
    Unison.Util.Set.traverse,
    flatMap,
    filterM,
    forMaybe,
  )
where

import Data.Function ((&))
import Data.Functor ((<&>))
import Data.Maybe qualified as Maybe
import Data.Set (Set)
import Data.Set qualified as Set
import Unison.Util.Monoid (foldMapM)

-- | Get the only member of a set, iff it's a singleton.
asSingleton :: Set a -> Maybe a
asSingleton :: forall a. Set a -> Maybe a
asSingleton Set a
xs =
  if Set a -> Int
forall a. Set a -> Int
Set.size Set a
xs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 then a -> Maybe a
forall a. a -> Maybe a
Just (Set a -> a
forall a. Set a -> a
Set.findMin Set a
xs) else Maybe a
forall a. Maybe a
Nothing

-- | Set difference, but return @Nothing@ if the difference is empty.
difference1 :: (Ord a) => Set a -> Set a -> Maybe (Set a)
difference1 :: forall a. Ord a => Set a -> Set a -> Maybe (Set a)
difference1 Set a
xs Set a
ys =
  if Set a -> Bool
forall a. Set a -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Set a
zs then Maybe (Set a)
forall a. Maybe a
Nothing else Set a -> Maybe (Set a)
forall a. a -> Maybe a
Just Set a
zs
  where
    zs :: Set a
zs = Set a -> Set a -> Set a
forall a. Ord a => Set a -> Set a -> Set a
Set.difference Set a
xs Set a
ys

-- | Get whether two sets intersect.
intersects :: (Ord a) => Set a -> Set a -> Bool
intersects :: forall a. Ord a => Set a -> Set a -> Bool
intersects Set a
xs Set a
ys =
  Bool -> Bool
not (Set a -> Set a -> Bool
forall a. Ord a => Set a -> Set a -> Bool
Set.disjoint Set a
xs Set a
ys)

symmetricDifference :: (Ord a) => Set a -> Set a -> Set a
symmetricDifference :: forall a. Ord a => Set a -> Set a -> Set a
symmetricDifference Set a
a Set a
b = (Set a
a Set a -> Set a -> Set a
forall a. Ord a => Set a -> Set a -> Set a
`Set.difference` Set a
b) Set a -> Set a -> Set a
forall a. Ord a => Set a -> Set a -> Set a
`Set.union` (Set a
b Set a -> Set a -> Set a
forall a. Ord a => Set a -> Set a -> Set a
`Set.difference` Set a
a)

mapMaybe :: (Ord b) => (a -> Maybe b) -> Set a -> Set b
mapMaybe :: forall b a. Ord b => (a -> Maybe b) -> Set a -> Set b
mapMaybe a -> Maybe b
f = [b] -> Set b
forall a. Ord a => [a] -> Set a
Set.fromList ([b] -> Set b) -> (Set a -> [b]) -> Set a -> Set b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> Maybe b) -> [a] -> [b]
forall a b. (a -> Maybe b) -> [a] -> [b]
Maybe.mapMaybe a -> Maybe b
f ([a] -> [b]) -> (Set a -> [a]) -> Set a -> [b]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Set a -> [a]
forall a. Set a -> [a]
Set.toList

forMaybe :: (Ord b, Applicative f) => Set a -> (a -> f (Maybe b)) -> f (Set b)
forMaybe :: forall b (f :: * -> *) a.
(Ord b, Applicative f) =>
Set a -> (a -> f (Maybe b)) -> f (Set b)
forMaybe Set a
xs a -> f (Maybe b)
f =
  (a -> f (Maybe b)) -> [a] -> f [Maybe b]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
Prelude.traverse a -> f (Maybe b)
f (Set a -> [a]
forall a. Set a -> [a]
Set.toList Set a
xs) f [Maybe b] -> ([Maybe b] -> Set b) -> f (Set b)
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> \[Maybe b]
ys ->
    [Maybe b]
ys
      [Maybe b] -> ([Maybe b] -> [b]) -> [b]
forall a b. a -> (a -> b) -> b
& [Maybe b] -> [b]
forall a. [Maybe a] -> [a]
Maybe.catMaybes
      [b] -> ([b] -> Set b) -> Set b
forall a b. a -> (a -> b) -> b
& [b] -> Set b
forall a. Ord a => [a] -> Set a
Set.fromList

traverse :: (Applicative f, Ord b) => (a -> f b) -> Set a -> f (Set b)
traverse :: forall (f :: * -> *) b a.
(Applicative f, Ord b) =>
(a -> f b) -> Set a -> f (Set b)
traverse a -> f b
f = ([b] -> Set b) -> f [b] -> f (Set b)
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [b] -> Set b
forall a. Ord a => [a] -> Set a
Set.fromList (f [b] -> f (Set b)) -> (Set a -> f [b]) -> Set a -> f (Set b)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> f b) -> [a] -> f [b]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
Prelude.traverse a -> f b
f ([a] -> f [b]) -> (Set a -> [a]) -> Set a -> f [b]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Set a -> [a]
forall a. Set a -> [a]
Set.toList

flatMap :: (Ord b) => (a -> Set b) -> Set a -> Set b
flatMap :: forall b a. Ord b => (a -> Set b) -> Set a -> Set b
flatMap a -> Set b
f = [Set b] -> Set b
forall (f :: * -> *) a. (Foldable f, Ord a) => f (Set a) -> Set a
Set.unions ([Set b] -> Set b) -> (Set a -> [Set b]) -> Set a -> Set b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> Set b) -> [a] -> [Set b]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> Set b
f ([a] -> [Set b]) -> (Set a -> [a]) -> Set a -> [Set b]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Set a -> [a]
forall a. Set a -> [a]
Set.toList

filterM :: (Ord a, Monad m) => (a -> m Bool) -> Set a -> m (Set a)
filterM :: forall a (m :: * -> *).
(Ord a, Monad m) =>
(a -> m Bool) -> Set a -> m (Set a)
filterM a -> m Bool
p =
  (a -> m (Set a)) -> Set a -> m (Set a)
forall (m :: * -> *) (f :: * -> *) b a.
(Monad m, Foldable f, Monoid b) =>
(a -> m b) -> f a -> m b
foldMapM \a
x ->
    a -> m Bool
p a
x m Bool -> (Bool -> Set a) -> m (Set a)
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> \case
      Bool
False -> Set a
forall a. Set a
Set.empty
      Bool
True -> a -> Set a
forall a. a -> Set a
Set.singleton a
x