{-# LANGUAGE BangPatterns #-}
-- used for unsafe pointer equality
{-# LANGUAGE MagicHash #-}

module Unison.Runtime.SparseVector where

import Control.Monad.ST (ST)
import Data.Bits ((.&.), (.|.))
import Data.Bits qualified as B
import Data.Vector.Unboxed qualified as UV
import Data.Vector.Unboxed.Mutable qualified as MUV
import GHC.Exts qualified as Exts
import Prelude hiding (unzip)

-- Denotes a `Nat -> Maybe a`.
-- Representation is a `Vector a` along with a bitset
-- that encodes the index of each element.
-- Ex: `[(1,a), (5,b)]` is encoded as (100010, [a,b])
data SparseVector bits a = SparseVector
  { forall bits a. SparseVector bits a -> bits
indices :: !bits,
    forall bits a. SparseVector bits a -> Vector a
elements :: !(UV.Vector a)
  }

-- todo: instance (UV.Unbox a, B.FiniteBits bits, Num n)
--   => Num (SparseVector bits n)

-- Denotationally: `map f v n = f <$> v n`
map :: (UV.Unbox a, UV.Unbox b) => (a -> b) -> SparseVector bits a -> SparseVector bits b
map :: forall a b bits.
(Unbox a, Unbox b) =>
(a -> b) -> SparseVector bits a -> SparseVector bits b
map a -> b
f SparseVector bits a
v = SparseVector bits a
v {elements = UV.map f (elements v)}

-- Denotationally, a mask is a `Nat -> Bool`, so this implementation
-- means: `mask ok v n = if ok n then v n else Nothing`
mask ::
  forall a bits.
  (UV.Unbox a, B.FiniteBits bits) =>
  bits ->
  SparseVector bits a ->
  SparseVector bits a
mask :: forall a bits.
(Unbox a, FiniteBits bits) =>
bits -> SparseVector bits a -> SparseVector bits a
mask bits
bits SparseVector bits a
a =
  if bits
indices' bits -> bits -> Bool
forall a. Eq a => a -> a -> Bool
== bits
bits
    then SparseVector bits a
a -- check if mask is a superset
    else bits -> Vector a -> SparseVector bits a
forall bits a. bits -> Vector a -> SparseVector bits a
SparseVector bits
indices' (Vector a -> SparseVector bits a)
-> Vector a -> SparseVector bits a
forall a b. (a -> b) -> a -> b
$
      (forall s. ST s (MVector s a)) -> Vector a
forall a. Unbox a => (forall s. ST s (MVector s a)) -> Vector a
UV.create ((forall s. ST s (MVector s a)) -> Vector a)
-> (forall s. ST s (MVector s a)) -> Vector a
forall a b. (a -> b) -> a -> b
$ do
        MVector s a
vec <- Int -> ST s (MVector (PrimState (ST s)) a)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
MUV.new (bits -> Int
forall a. Bits a => a -> Int
B.popCount bits
indices')
        MVector s a -> bits -> bits -> Int -> Int -> ST s (MVector s a)
forall s.
STVector s a -> bits -> bits -> Int -> Int -> ST s (STVector s a)
go MVector s a
vec (SparseVector bits a -> bits
forall bits a. SparseVector bits a -> bits
indices SparseVector bits a
a) bits
bits Int
0 Int
0
  where
    indices' :: bits
indices' = SparseVector bits a -> bits
forall bits a. SparseVector bits a -> bits
indices SparseVector bits a
a bits -> bits -> bits
forall a. Bits a => a -> a -> a
.&. bits
bits
    eas :: Vector a
eas = SparseVector bits a -> Vector a
forall bits a. SparseVector bits a -> Vector a
elements SparseVector bits a
a
    go :: MUV.STVector s a -> bits -> bits -> Int -> Int -> ST s (MUV.STVector s a)
    go :: forall s.
STVector s a -> bits -> bits -> Int -> Int -> ST s (STVector s a)
go !STVector s a
out !bits
indAs !bits
indBs !Int
i !Int
k =
      if bits
indAs bits -> bits -> Bool
forall a. Eq a => a -> a -> Bool
== bits
forall a. Bits a => a
B.zeroBits Bool -> Bool -> Bool
|| bits
indBs bits -> bits -> Bool
forall a. Eq a => a -> a -> Bool
== bits
forall a. Bits a => a
B.zeroBits
        then STVector s a -> ST s (STVector s a)
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure STVector s a
out
        else
          let (!Int
a1, !Int
b1) = (bits -> Int
forall b. FiniteBits b => b -> Int
B.countTrailingZeros bits
indAs, bits -> Int
forall b. FiniteBits b => b -> Int
B.countTrailingZeros bits
indBs)
           in if Int
a1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
b1
                then do
                  MVector (PrimState (ST s)) a -> Int -> a -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
MUV.write STVector s a
MVector (PrimState (ST s)) a
out Int
k (Vector a
eas Vector a -> Int -> a
forall a. Unbox a => Vector a -> Int -> a
UV.! (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
a1))
                  STVector s a -> bits -> bits -> Int -> Int -> ST s (STVector s a)
forall s.
STVector s a -> bits -> bits -> Int -> Int -> ST s (STVector s a)
go
                    STVector s a
out
                    (bits
indAs bits -> Int -> bits
forall a. Bits a => a -> Int -> a
`B.shiftR` (Int
a1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1))
                    (bits
indBs bits -> Int -> bits
forall a. Bits a => a -> Int -> a
`B.shiftR` (Int
b1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1))
                    (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
                    (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
                else
                  if Int
a1 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
b1
                    then
                      STVector s a -> bits -> bits -> Int -> Int -> ST s (STVector s a)
forall s.
STVector s a -> bits -> bits -> Int -> Int -> ST s (STVector s a)
go
                        STVector s a
out
                        (bits
indAs bits -> Int -> bits
forall a. Bits a => a -> Int -> a
`B.shiftR` (Int
a1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1))
                        bits
indBs
                        (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
                        Int
k
                    else STVector s a -> bits -> bits -> Int -> Int -> ST s (STVector s a)
forall s.
STVector s a -> bits -> bits -> Int -> Int -> ST s (STVector s a)
go STVector s a
out bits
indAs (bits
indBs bits -> Int -> bits
forall a. Bits a => a -> Int -> a
`B.shiftR` (Int
b1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)) Int
i Int
k

-- Denotationally: `zipWith f a b n = f <$> a n <*> b n`, in other words,
-- this takes the intersection of the two shapes.
zipWith ::
  (UV.Unbox a, UV.Unbox b, UV.Unbox c, B.FiniteBits bits) =>
  (a -> b -> c) ->
  SparseVector bits a ->
  SparseVector bits b ->
  SparseVector bits c
zipWith :: forall a b c bits.
(Unbox a, Unbox b, Unbox c, FiniteBits bits) =>
(a -> b -> c)
-> SparseVector bits a
-> SparseVector bits b
-> SparseVector bits c
zipWith a -> b -> c
f SparseVector bits a
a SparseVector bits b
b =
  if SparseVector bits a -> bits
forall bits a. SparseVector bits a -> bits
indices SparseVector bits a
a bits -> bits -> Bool
forall a. a -> a -> Bool
`eq` SparseVector bits b -> bits
forall bits a. SparseVector bits a -> bits
indices SparseVector bits b
b Bool -> Bool -> Bool
|| SparseVector bits a -> bits
forall bits a. SparseVector bits a -> bits
indices SparseVector bits a
a bits -> bits -> Bool
forall a. Eq a => a -> a -> Bool
== SparseVector bits b -> bits
forall bits a. SparseVector bits a -> bits
indices SparseVector bits b
b
    then bits -> Vector c -> SparseVector bits c
forall bits a. bits -> Vector a -> SparseVector bits a
SparseVector (SparseVector bits a -> bits
forall bits a. SparseVector bits a -> bits
indices SparseVector bits a
a) ((a -> b -> c) -> Vector a -> Vector b -> Vector c
forall a b c.
(Unbox a, Unbox b, Unbox c) =>
(a -> b -> c) -> Vector a -> Vector b -> Vector c
UV.zipWith a -> b -> c
f (SparseVector bits a -> Vector a
forall bits a. SparseVector bits a -> Vector a
elements SparseVector bits a
a) (SparseVector bits b -> Vector b
forall bits a. SparseVector bits a -> Vector a
elements SparseVector bits b
b))
    else
      let indices' :: bits
indices' = SparseVector bits a -> bits
forall bits a. SparseVector bits a -> bits
indices SparseVector bits a
a bits -> bits -> bits
forall a. Bits a => a -> a -> a
.&. SparseVector bits b -> bits
forall bits a. SparseVector bits a -> bits
indices SparseVector bits b
b
          a' :: SparseVector bits a
a' = bits -> SparseVector bits a -> SparseVector bits a
forall a bits.
(Unbox a, FiniteBits bits) =>
bits -> SparseVector bits a -> SparseVector bits a
mask bits
indices' SparseVector bits a
a
          b' :: SparseVector bits b
b' = bits -> SparseVector bits b -> SparseVector bits b
forall a bits.
(Unbox a, FiniteBits bits) =>
bits -> SparseVector bits a -> SparseVector bits a
mask bits
indices' SparseVector bits b
b
       in bits -> Vector c -> SparseVector bits c
forall bits a. bits -> Vector a -> SparseVector bits a
SparseVector bits
indices' ((a -> b -> c) -> Vector a -> Vector b -> Vector c
forall a b c.
(Unbox a, Unbox b, Unbox c) =>
(a -> b -> c) -> Vector a -> Vector b -> Vector c
UV.zipWith a -> b -> c
f (SparseVector bits a -> Vector a
forall bits a. SparseVector bits a -> Vector a
elements SparseVector bits a
a') (SparseVector bits b -> Vector b
forall bits a. SparseVector bits a -> Vector a
elements SparseVector bits b
b'))

_1 :: (UV.Unbox a, UV.Unbox b) => SparseVector bits (a, b) -> SparseVector bits a
_1 :: forall a b bits.
(Unbox a, Unbox b) =>
SparseVector bits (a, b) -> SparseVector bits a
_1 = (SparseVector bits a, SparseVector bits b) -> SparseVector bits a
forall a b. (a, b) -> a
fst ((SparseVector bits a, SparseVector bits b) -> SparseVector bits a)
-> (SparseVector bits (a, b)
    -> (SparseVector bits a, SparseVector bits b))
-> SparseVector bits (a, b)
-> SparseVector bits a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SparseVector bits (a, b)
-> (SparseVector bits a, SparseVector bits b)
forall a b bits.
(Unbox a, Unbox b) =>
SparseVector bits (a, b)
-> (SparseVector bits a, SparseVector bits b)
unzip

_2 :: (UV.Unbox a, UV.Unbox b) => SparseVector bits (a, b) -> SparseVector bits b
_2 :: forall a b bits.
(Unbox a, Unbox b) =>
SparseVector bits (a, b) -> SparseVector bits b
_2 = (SparseVector bits a, SparseVector bits b) -> SparseVector bits b
forall a b. (a, b) -> b
snd ((SparseVector bits a, SparseVector bits b) -> SparseVector bits b)
-> (SparseVector bits (a, b)
    -> (SparseVector bits a, SparseVector bits b))
-> SparseVector bits (a, b)
-> SparseVector bits b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SparseVector bits (a, b)
-> (SparseVector bits a, SparseVector bits b)
forall a b bits.
(Unbox a, Unbox b) =>
SparseVector bits (a, b)
-> (SparseVector bits a, SparseVector bits b)
unzip

-- Denotationally: `unzip p = (\n -> fst <$> p n, \n -> snd <$> p n)`
unzip ::
  (UV.Unbox a, UV.Unbox b) =>
  SparseVector bits (a, b) ->
  (SparseVector bits a, SparseVector bits b)
unzip :: forall a b bits.
(Unbox a, Unbox b) =>
SparseVector bits (a, b)
-> (SparseVector bits a, SparseVector bits b)
unzip (SparseVector bits
inds Vector (a, b)
ps) =
  let (Vector a
as, Vector b
bs) = Vector (a, b) -> (Vector a, Vector b)
forall a b.
(Unbox a, Unbox b) =>
Vector (a, b) -> (Vector a, Vector b)
UV.unzip Vector (a, b)
ps
   in (bits -> Vector a -> SparseVector bits a
forall bits a. bits -> Vector a -> SparseVector bits a
SparseVector bits
inds Vector a
as, bits -> Vector b -> SparseVector bits b
forall bits a. bits -> Vector a -> SparseVector bits a
SparseVector bits
inds Vector b
bs)

-- Denotationally: `choose bs a b n = if bs n then a n else b n`
choose ::
  (B.FiniteBits bits, UV.Unbox a) =>
  bits ->
  SparseVector bits a ->
  SparseVector bits a ->
  SparseVector bits a
choose :: forall bits a.
(FiniteBits bits, Unbox a) =>
bits
-> SparseVector bits a
-> SparseVector bits a
-> SparseVector bits a
choose bits
bits SparseVector bits a
t SparseVector bits a
f
  | bits
forall a. Bits a => a
B.zeroBits bits -> bits -> Bool
forall a. Eq a => a -> a -> Bool
== bits
bits = SparseVector bits a
f
  | bits -> bits
forall a. Bits a => a -> a
B.complement bits
bits bits -> bits -> Bool
forall a. Eq a => a -> a -> Bool
== bits
forall a. Bits a => a
B.zeroBits = SparseVector bits a
t
  | Bool
otherwise -- it's a mix of true and false
    =
      SparseVector bits a -> SparseVector bits a -> SparseVector bits a
forall a bits.
(FiniteBits bits, Unbox a) =>
SparseVector bits a -> SparseVector bits a -> SparseVector bits a
merge (bits -> SparseVector bits a -> SparseVector bits a
forall a bits.
(Unbox a, FiniteBits bits) =>
bits -> SparseVector bits a -> SparseVector bits a
mask bits
bits SparseVector bits a
t) (bits -> SparseVector bits a -> SparseVector bits a
forall a bits.
(Unbox a, FiniteBits bits) =>
bits -> SparseVector bits a -> SparseVector bits a
mask (bits -> bits
forall a. Bits a => a -> a
B.complement bits
bits) SparseVector bits a
f)

-- Denotationally: `merge a b n = a n <|> b n`
merge ::
  forall a bits.
  (B.FiniteBits bits, UV.Unbox a) =>
  SparseVector bits a ->
  SparseVector bits a ->
  SparseVector bits a
merge :: forall a bits.
(FiniteBits bits, Unbox a) =>
SparseVector bits a -> SparseVector bits a -> SparseVector bits a
merge SparseVector bits a
a SparseVector bits a
b = bits -> Vector a -> SparseVector bits a
forall bits a. bits -> Vector a -> SparseVector bits a
SparseVector bits
indices' Vector a
tricky
  where
    indices' :: bits
indices' = SparseVector bits a -> bits
forall bits a. SparseVector bits a -> bits
indices SparseVector bits a
a bits -> bits -> bits
forall a. Bits a => a -> a -> a
.|. SparseVector bits a -> bits
forall bits a. SparseVector bits a -> bits
indices SparseVector bits a
b
    tricky :: Vector a
tricky = (forall s. ST s (MVector s a)) -> Vector a
forall a. Unbox a => (forall s. ST s (MVector s a)) -> Vector a
UV.create ((forall s. ST s (MVector s a)) -> Vector a)
-> (forall s. ST s (MVector s a)) -> Vector a
forall a b. (a -> b) -> a -> b
$ do
      MVector s a
vec <- Int -> ST s (MVector (PrimState (ST s)) a)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
MUV.new (bits -> Int
forall a. Bits a => a -> Int
B.popCount bits
indices')
      MVector s a
-> bits -> bits -> Int -> Int -> Int -> ST s (MVector s a)
forall s.
STVector s a
-> bits -> bits -> Int -> Int -> Int -> ST s (STVector s a)
go MVector s a
vec (SparseVector bits a -> bits
forall bits a. SparseVector bits a -> bits
indices SparseVector bits a
a) (SparseVector bits a -> bits
forall bits a. SparseVector bits a -> bits
indices SparseVector bits a
b) Int
0 Int
0 Int
0
    (!Vector a
eas, !Vector a
ebs) = (SparseVector bits a -> Vector a
forall bits a. SparseVector bits a -> Vector a
elements SparseVector bits a
a, SparseVector bits a -> Vector a
forall bits a. SparseVector bits a -> Vector a
elements SparseVector bits a
b)
    go :: MUV.STVector s a -> bits -> bits -> Int -> Int -> Int -> ST s (MUV.STVector s a)
    go :: forall s.
STVector s a
-> bits -> bits -> Int -> Int -> Int -> ST s (STVector s a)
go !STVector s a
out !bits
indAs !bits
indBs !Int
i !Int
j !Int
k =
      if bits
indAs bits -> bits -> Bool
forall a. Eq a => a -> a -> Bool
== bits
forall a. Bits a => a
B.zeroBits Bool -> Bool -> Bool
|| bits
indBs bits -> bits -> Bool
forall a. Eq a => a -> a -> Bool
== bits
forall a. Bits a => a
B.zeroBits
        then STVector s a -> ST s (STVector s a)
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure STVector s a
out
        else
          let (!Int
a1, !Int
b1) = (bits -> Int
forall b. FiniteBits b => b -> Int
B.countTrailingZeros bits
indAs, bits -> Int
forall b. FiniteBits b => b -> Int
B.countTrailingZeros bits
indBs)
           in if Int
a1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
b1
                then do
                  MVector (PrimState (ST s)) a -> Int -> a -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
MUV.write STVector s a
MVector (PrimState (ST s)) a
out Int
k (Vector a
eas Vector a -> Int -> a
forall a. Unbox a => Vector a -> Int -> a
UV.! (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
a1))
                  STVector s a
-> bits -> bits -> Int -> Int -> Int -> ST s (STVector s a)
forall s.
STVector s a
-> bits -> bits -> Int -> Int -> Int -> ST s (STVector s a)
go
                    STVector s a
out
                    (bits
indAs bits -> Int -> bits
forall a. Bits a => a -> Int -> a
`B.shiftR` (Int
a1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1))
                    (bits
indBs bits -> Int -> bits
forall a. Bits a => a -> Int -> a
`B.shiftR` (Int
b1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1))
                    (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
                    (Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
                    (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
                else
                  if Int
a1 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
b1
                    then do
                      MVector (PrimState (ST s)) a -> Int -> a -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
MUV.write STVector s a
MVector (PrimState (ST s)) a
out Int
k (Vector a
eas Vector a -> Int -> a
forall a. Unbox a => Vector a -> Int -> a
UV.! (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
a1))
                      STVector s a
-> bits -> bits -> Int -> Int -> Int -> ST s (STVector s a)
forall s.
STVector s a
-> bits -> bits -> Int -> Int -> Int -> ST s (STVector s a)
go
                        STVector s a
out
                        (bits
indAs bits -> Int -> bits
forall a. Bits a => a -> Int -> a
`B.shiftR` (Int
a1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1))
                        bits
indBs
                        (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
                        Int
j
                        (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
                    else do
                      MVector (PrimState (ST s)) a -> Int -> a -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
MUV.write STVector s a
MVector (PrimState (ST s)) a
out Int
k (Vector a
ebs Vector a -> Int -> a
forall a. Unbox a => Vector a -> Int -> a
UV.! (Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
a1))
                      STVector s a
-> bits -> bits -> Int -> Int -> Int -> ST s (STVector s a)
forall s.
STVector s a
-> bits -> bits -> Int -> Int -> Int -> ST s (STVector s a)
go STVector s a
out bits
indAs (bits
indBs bits -> Int -> bits
forall a. Bits a => a -> Int -> a
`B.shiftR` (Int
b1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)) Int
i (Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)

-- Pointer equality a la Scala.
eq :: a -> a -> Bool
eq :: forall a. a -> a -> Bool
eq a
x a
y = Int# -> Bool
Exts.isTrue# (a -> a -> Int#
forall a b. a -> b -> Int#
Exts.reallyUnsafePtrEquality# a
x a
y Int# -> Int# -> Int#
Exts.==# Int#
1#)
{-# INLINE eq #-}