module Unison.LSP.Util.IntersectionMap
  ( -- * Intersection map
    intersectionsFromList,
    intersectionsSingleton,
    IntersectionRange (..),
    IntersectionMap,
    smallestIntersection,

    -- * Keyed intersection map
    KeyedIntersectionMap,
    keyedFromList,
    keyedSingleton,
    keyedSmallestIntersection,
  )
where

import Data.List qualified as List
import Data.Map qualified as Map
import Language.LSP.Protocol.Types qualified as LSP
import Unison.Prelude
import Unison.Util.List (safeHead)

-- | An intersection map where intersections are partitioned by a key.
newtype KeyedIntersectionMap k pos a = KeyedIntersectionMap (Map k (IntersectionMap pos a))
  deriving stock (Int -> KeyedIntersectionMap k pos a -> ShowS
[KeyedIntersectionMap k pos a] -> ShowS
KeyedIntersectionMap k pos a -> String
(Int -> KeyedIntersectionMap k pos a -> ShowS)
-> (KeyedIntersectionMap k pos a -> String)
-> ([KeyedIntersectionMap k pos a] -> ShowS)
-> Show (KeyedIntersectionMap k pos a)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall k pos a.
(Show k, Show pos, Show a) =>
Int -> KeyedIntersectionMap k pos a -> ShowS
forall k pos a.
(Show k, Show pos, Show a) =>
[KeyedIntersectionMap k pos a] -> ShowS
forall k pos a.
(Show k, Show pos, Show a) =>
KeyedIntersectionMap k pos a -> String
$cshowsPrec :: forall k pos a.
(Show k, Show pos, Show a) =>
Int -> KeyedIntersectionMap k pos a -> ShowS
showsPrec :: Int -> KeyedIntersectionMap k pos a -> ShowS
$cshow :: forall k pos a.
(Show k, Show pos, Show a) =>
KeyedIntersectionMap k pos a -> String
show :: KeyedIntersectionMap k pos a -> String
$cshowList :: forall k pos a.
(Show k, Show pos, Show a) =>
[KeyedIntersectionMap k pos a] -> ShowS
showList :: [KeyedIntersectionMap k pos a] -> ShowS
Show, KeyedIntersectionMap k pos a
-> KeyedIntersectionMap k pos a -> Bool
(KeyedIntersectionMap k pos a
 -> KeyedIntersectionMap k pos a -> Bool)
-> (KeyedIntersectionMap k pos a
    -> KeyedIntersectionMap k pos a -> Bool)
-> Eq (KeyedIntersectionMap k pos a)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall k pos a.
(Eq k, Eq pos, Eq a) =>
KeyedIntersectionMap k pos a
-> KeyedIntersectionMap k pos a -> Bool
$c== :: forall k pos a.
(Eq k, Eq pos, Eq a) =>
KeyedIntersectionMap k pos a
-> KeyedIntersectionMap k pos a -> Bool
== :: KeyedIntersectionMap k pos a
-> KeyedIntersectionMap k pos a -> Bool
$c/= :: forall k pos a.
(Eq k, Eq pos, Eq a) =>
KeyedIntersectionMap k pos a
-> KeyedIntersectionMap k pos a -> Bool
/= :: KeyedIntersectionMap k pos a
-> KeyedIntersectionMap k pos a -> Bool
Eq)

instance (Ord k, Ord pos) => Semigroup (KeyedIntersectionMap k pos a) where
  KeyedIntersectionMap Map k (IntersectionMap pos a)
a <> :: KeyedIntersectionMap k pos a
-> KeyedIntersectionMap k pos a -> KeyedIntersectionMap k pos a
<> KeyedIntersectionMap Map k (IntersectionMap pos a)
b = Map k (IntersectionMap pos a) -> KeyedIntersectionMap k pos a
forall k pos a.
Map k (IntersectionMap pos a) -> KeyedIntersectionMap k pos a
KeyedIntersectionMap ((IntersectionMap pos a
 -> IntersectionMap pos a -> IntersectionMap pos a)
-> Map k (IntersectionMap pos a)
-> Map k (IntersectionMap pos a)
-> Map k (IntersectionMap pos a)
forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
Map.unionWith IntersectionMap pos a
-> IntersectionMap pos a -> IntersectionMap pos a
forall a. Semigroup a => a -> a -> a
(<>) Map k (IntersectionMap pos a)
a Map k (IntersectionMap pos a)
b)

instance (Ord k, Ord pos) => Monoid (KeyedIntersectionMap k pos a) where
  mempty :: KeyedIntersectionMap k pos a
mempty = Map k (IntersectionMap pos a) -> KeyedIntersectionMap k pos a
forall k pos a.
Map k (IntersectionMap pos a) -> KeyedIntersectionMap k pos a
KeyedIntersectionMap Map k (IntersectionMap pos a)
forall k a. Map k a
Map.empty

keyedFromList :: (Ord k, IntersectionRange pos) => [(k, ((pos, pos), a))] -> KeyedIntersectionMap k pos a
keyedFromList :: forall k pos a.
(Ord k, IntersectionRange pos) =>
[(k, ((pos, pos), a))] -> KeyedIntersectionMap k pos a
keyedFromList [(k, ((pos, pos), a))]
elems =
  Map k (IntersectionMap pos a) -> KeyedIntersectionMap k pos a
forall k pos a.
Map k (IntersectionMap pos a) -> KeyedIntersectionMap k pos a
KeyedIntersectionMap (Map k (IntersectionMap pos a) -> KeyedIntersectionMap k pos a)
-> Map k (IntersectionMap pos a) -> KeyedIntersectionMap k pos a
forall a b. (a -> b) -> a -> b
$
    [(k, ((pos, pos), a))]
elems
      [(k, ((pos, pos), a))]
-> ([(k, ((pos, pos), a))] -> [(k, IntersectionMap pos a)])
-> [(k, IntersectionMap pos a)]
forall a b. a -> (a -> b) -> b
& ((k, ((pos, pos), a)) -> (k, IntersectionMap pos a))
-> [(k, ((pos, pos), a))] -> [(k, IntersectionMap pos a)]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\(k
k, ((pos, pos)
range, a
v)) -> (k
k, (pos, pos) -> a -> IntersectionMap pos a
forall pos a. (pos, pos) -> a -> IntersectionMap pos a
intersectionsSingleton (pos, pos)
range a
v))
      [(k, IntersectionMap pos a)]
-> ([(k, IntersectionMap pos a)] -> Map k (IntersectionMap pos a))
-> Map k (IntersectionMap pos a)
forall a b. a -> (a -> b) -> b
& (IntersectionMap pos a
 -> IntersectionMap pos a -> IntersectionMap pos a)
-> [(k, IntersectionMap pos a)] -> Map k (IntersectionMap pos a)
forall k a. Ord k => (a -> a -> a) -> [(k, a)] -> Map k a
Map.fromListWith IntersectionMap pos a
-> IntersectionMap pos a -> IntersectionMap pos a
forall a. Semigroup a => a -> a -> a
(<>)

keyedSingleton :: (Ord k, IntersectionRange pos) => k -> (pos, pos) -> a -> KeyedIntersectionMap k pos a
keyedSingleton :: forall k pos a.
(Ord k, IntersectionRange pos) =>
k -> (pos, pos) -> a -> KeyedIntersectionMap k pos a
keyedSingleton k
k (pos, pos)
range a
a = [(k, ((pos, pos), a))] -> KeyedIntersectionMap k pos a
forall k pos a.
(Ord k, IntersectionRange pos) =>
[(k, ((pos, pos), a))] -> KeyedIntersectionMap k pos a
keyedFromList [(k
k, ((pos, pos)
range, a
a))]

-- | NOTE: Assumes that ranges only NEST and never overlap, which is an invariant that should
-- be maintained by the ABT annotations.
--
-- Returns the value associated with the tightest span which intersects with the given position.
keyedSmallestIntersection :: (Ord k, IntersectionRange pos) => k -> pos -> KeyedIntersectionMap k pos a -> Maybe ((pos, pos), a)
keyedSmallestIntersection :: forall k pos a.
(Ord k, IntersectionRange pos) =>
k -> pos -> KeyedIntersectionMap k pos a -> Maybe ((pos, pos), a)
keyedSmallestIntersection k
k pos
p (KeyedIntersectionMap Map k (IntersectionMap pos a)
m) = do
  IntersectionMap pos a
intersections <- k -> Map k (IntersectionMap pos a) -> Maybe (IntersectionMap pos a)
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup k
k Map k (IntersectionMap pos a)
m
  pos -> IntersectionMap pos a -> Maybe ((pos, pos), a)
forall pos a.
IntersectionRange pos =>
pos -> IntersectionMap pos a -> Maybe ((pos, pos), a)
smallestIntersection pos
p IntersectionMap pos a
intersections

newtype IntersectionMap pos a = IntersectionMap (Map (pos, pos) a)
  deriving stock (Int -> IntersectionMap pos a -> ShowS
[IntersectionMap pos a] -> ShowS
IntersectionMap pos a -> String
(Int -> IntersectionMap pos a -> ShowS)
-> (IntersectionMap pos a -> String)
-> ([IntersectionMap pos a] -> ShowS)
-> Show (IntersectionMap pos a)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall pos a.
(Show pos, Show a) =>
Int -> IntersectionMap pos a -> ShowS
forall pos a.
(Show pos, Show a) =>
[IntersectionMap pos a] -> ShowS
forall pos a. (Show pos, Show a) => IntersectionMap pos a -> String
$cshowsPrec :: forall pos a.
(Show pos, Show a) =>
Int -> IntersectionMap pos a -> ShowS
showsPrec :: Int -> IntersectionMap pos a -> ShowS
$cshow :: forall pos a. (Show pos, Show a) => IntersectionMap pos a -> String
show :: IntersectionMap pos a -> String
$cshowList :: forall pos a.
(Show pos, Show a) =>
[IntersectionMap pos a] -> ShowS
showList :: [IntersectionMap pos a] -> ShowS
Show, IntersectionMap pos a -> IntersectionMap pos a -> Bool
(IntersectionMap pos a -> IntersectionMap pos a -> Bool)
-> (IntersectionMap pos a -> IntersectionMap pos a -> Bool)
-> Eq (IntersectionMap pos a)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall pos a.
(Eq pos, Eq a) =>
IntersectionMap pos a -> IntersectionMap pos a -> Bool
$c== :: forall pos a.
(Eq pos, Eq a) =>
IntersectionMap pos a -> IntersectionMap pos a -> Bool
== :: IntersectionMap pos a -> IntersectionMap pos a -> Bool
$c/= :: forall pos a.
(Eq pos, Eq a) =>
IntersectionMap pos a -> IntersectionMap pos a -> Bool
/= :: IntersectionMap pos a -> IntersectionMap pos a -> Bool
Eq)

instance (Ord pos) => Semigroup (IntersectionMap pos a) where
  IntersectionMap Map (pos, pos) a
a <> :: IntersectionMap pos a
-> IntersectionMap pos a -> IntersectionMap pos a
<> IntersectionMap Map (pos, pos) a
b = Map (pos, pos) a -> IntersectionMap pos a
forall pos a. Map (pos, pos) a -> IntersectionMap pos a
IntersectionMap (Map (pos, pos) a
a Map (pos, pos) a -> Map (pos, pos) a -> Map (pos, pos) a
forall a. Semigroup a => a -> a -> a
<> Map (pos, pos) a
b)

instance (Ord pos) => Monoid (IntersectionMap pos a) where
  mempty :: IntersectionMap pos a
mempty = Map (pos, pos) a -> IntersectionMap pos a
forall pos a. Map (pos, pos) a -> IntersectionMap pos a
IntersectionMap Map (pos, pos) a
forall a. Monoid a => a
mempty

-- | Class for types that can be used as ranges for intersection maps.
class Ord pos => IntersectionRange pos where
  intersects :: pos -> (pos, pos) -> Bool

  -- Returns true if the first bound is tighter than the second.
  isTighterThan :: (pos, pos) -> (pos, pos) -> Bool

instance IntersectionRange LSP.Position where
  intersects :: Position -> (Position, Position) -> Bool
intersects (LSP.Position UInt
l UInt
c) ((LSP.Position UInt
lStart UInt
cStart), (LSP.Position UInt
lEnd UInt
cEnd)) =
    (UInt
l UInt -> UInt -> Bool
forall a. Ord a => a -> a -> Bool
>= UInt
lStart Bool -> Bool -> Bool
&& UInt
l UInt -> UInt -> Bool
forall a. Ord a => a -> a -> Bool
<= UInt
lEnd)
      Bool -> Bool -> Bool
&& if
          | UInt
l UInt -> UInt -> Bool
forall a. Eq a => a -> a -> Bool
== UInt
lStart Bool -> Bool -> Bool
&& UInt
l UInt -> UInt -> Bool
forall a. Eq a => a -> a -> Bool
== UInt
lEnd -> UInt
c UInt -> UInt -> Bool
forall a. Ord a => a -> a -> Bool
>= UInt
cStart Bool -> Bool -> Bool
&& UInt
c UInt -> UInt -> Bool
forall a. Ord a => a -> a -> Bool
<= UInt
cEnd
          | UInt
l UInt -> UInt -> Bool
forall a. Eq a => a -> a -> Bool
== UInt
lStart -> UInt
c UInt -> UInt -> Bool
forall a. Ord a => a -> a -> Bool
>= UInt
cStart
          | UInt
l UInt -> UInt -> Bool
forall a. Eq a => a -> a -> Bool
== UInt
lEnd -> UInt
c UInt -> UInt -> Bool
forall a. Ord a => a -> a -> Bool
<= UInt
cEnd
          | Bool
otherwise -> Bool
True

  ((LSP.Position UInt
lStartA UInt
cStartA), (LSP.Position UInt
lEndA UInt
cEndA)) isTighterThan :: (Position, Position) -> (Position, Position) -> Bool
`isTighterThan` ((LSP.Position UInt
lStartB UInt
cStartB), (LSP.Position UInt
lEndB UInt
cEndB)) =
    if UInt
lStartA UInt -> UInt -> Bool
forall a. Eq a => a -> a -> Bool
== UInt
lStartB Bool -> Bool -> Bool
&& UInt
lEndA UInt -> UInt -> Bool
forall a. Eq a => a -> a -> Bool
== UInt
lEndB
      then UInt
cStartA UInt -> UInt -> Bool
forall a. Ord a => a -> a -> Bool
>= UInt
cStartB Bool -> Bool -> Bool
&& UInt
cEndA UInt -> UInt -> Bool
forall a. Ord a => a -> a -> Bool
<= UInt
cEndB
      else UInt
lStartA UInt -> UInt -> Bool
forall a. Ord a => a -> a -> Bool
>= UInt
lStartB Bool -> Bool -> Bool
&& UInt
lEndA UInt -> UInt -> Bool
forall a. Ord a => a -> a -> Bool
<= UInt
lEndB

-- | Construct an intersection map from a list of ranges and values.
-- Duplicates are dropped.
intersectionsFromList :: (Ord pos) => [((pos, pos), a)] -> IntersectionMap pos a
intersectionsFromList :: forall pos a. Ord pos => [((pos, pos), a)] -> IntersectionMap pos a
intersectionsFromList [((pos, pos), a)]
elems =
  Map (pos, pos) a -> IntersectionMap pos a
forall pos a. Map (pos, pos) a -> IntersectionMap pos a
IntersectionMap (Map (pos, pos) a -> IntersectionMap pos a)
-> Map (pos, pos) a -> IntersectionMap pos a
forall a b. (a -> b) -> a -> b
$ [((pos, pos), a)] -> Map (pos, pos) a
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList [((pos, pos), a)]
elems

intersectionsSingleton :: (pos, pos) -> a -> IntersectionMap pos a
intersectionsSingleton :: forall pos a. (pos, pos) -> a -> IntersectionMap pos a
intersectionsSingleton (pos, pos)
range a
a = Map (pos, pos) a -> IntersectionMap pos a
forall pos a. Map (pos, pos) a -> IntersectionMap pos a
IntersectionMap (Map (pos, pos) a -> IntersectionMap pos a)
-> Map (pos, pos) a -> IntersectionMap pos a
forall a b. (a -> b) -> a -> b
$ (pos, pos) -> a -> Map (pos, pos) a
forall k a. k -> a -> Map k a
Map.singleton (pos, pos)
range a
a

-- | NOTE: Assumes that ranges only NEST and never overlap, which is an invariant that should
-- be maintained by the ABT annotations.
--
-- Returns the value associated with the tightest span which intersects with the given position.
--
-- >>> smallestIntersection (LSP.Position 5 1) (intersectionsFromList [((LSP.Position 1 1, LSP.Position 3 1), "a"), ((LSP.Position 2 1, LSP.Position 8 1), "b"), ((LSP.Position 4 1, LSP.Position 6 1), "c")])
-- Just ((Position {_line = 4, _character = 1},Position {_line = 6, _character = 1}),"c")
-- >>> smallestIntersection (LSP.Position 5 3) (intersectionsFromList [((LSP.Position 1 1, LSP.Position 3 1), "a"), ((LSP.Position 4 2, LSP.Position 6 5), "b"), ((LSP.Position 4 1, LSP.Position 6 6), "c"), ((LSP.Position 7 1, LSP.Position 9 1), "d")])
-- Just ((Position {_line = 4, _character = 2},Position {_line = 6, _character = 5}),"b")
smallestIntersection :: IntersectionRange pos => pos -> IntersectionMap pos a -> Maybe ((pos, pos), a)
smallestIntersection :: forall pos a.
IntersectionRange pos =>
pos -> IntersectionMap pos a -> Maybe ((pos, pos), a)
smallestIntersection pos
p (IntersectionMap Map (pos, pos) a
bounds) =
  Map (pos, pos) a
bounds
    Map (pos, pos) a
-> (Map (pos, pos) a -> Map (pos, pos) a) -> Map (pos, pos) a
forall a b. a -> (a -> b) -> b
& ((pos, pos) -> a -> Bool) -> Map (pos, pos) a -> Map (pos, pos) a
forall k a. (k -> a -> Bool) -> Map k a -> Map k a
Map.filterWithKey (\(pos, pos)
b a
_ -> pos
p pos -> (pos, pos) -> Bool
forall pos. IntersectionRange pos => pos -> (pos, pos) -> Bool
`intersects` (pos, pos)
b)
    Map (pos, pos) a
-> (Map (pos, pos) a -> [((pos, pos), a)]) -> [((pos, pos), a)]
forall a b. a -> (a -> b) -> b
& Map (pos, pos) a -> [((pos, pos), a)]
forall k a. Map k a -> [(k, a)]
Map.toList
    [((pos, pos), a)]
-> ([((pos, pos), a)] -> [((pos, pos), a)]) -> [((pos, pos), a)]
forall a b. a -> (a -> b) -> b
& (((pos, pos), a) -> ((pos, pos), a) -> Ordering)
-> [((pos, pos), a)] -> [((pos, pos), a)]
forall a. (a -> a -> Ordering) -> [a] -> [a]
List.sortBy ((pos, pos), a) -> ((pos, pos), a) -> Ordering
forall {pos} {b} {b}.
IntersectionRange pos =>
((pos, pos), b) -> ((pos, pos), b) -> Ordering
cmp
    [((pos, pos), a)]
-> ([((pos, pos), a)] -> Maybe ((pos, pos), a))
-> Maybe ((pos, pos), a)
forall a b. a -> (a -> b) -> b
& [((pos, pos), a)] -> Maybe ((pos, pos), a)
forall (f :: * -> *) a. Foldable f => f a -> Maybe a
safeHead
  where
    cmp :: ((pos, pos), b) -> ((pos, pos), b) -> Ordering
cmp ((pos, pos)
a, b
_) ((pos, pos)
b, b
_) =
      if (pos, pos)
a (pos, pos) -> (pos, pos) -> Bool
forall pos.
IntersectionRange pos =>
(pos, pos) -> (pos, pos) -> Bool
`isTighterThan` (pos, pos)
b
        then Ordering
LT
        else Ordering
GT