{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE MagicHash #-}
module Unison.Runtime.SparseVector where
import Control.Monad.ST (ST)
import Data.Bits ((.&.), (.|.))
import Data.Bits qualified as B
import Data.Vector.Unboxed qualified as UV
import Data.Vector.Unboxed.Mutable qualified as MUV
import GHC.Exts qualified as Exts
import Prelude hiding (unzip)
data SparseVector bits a = SparseVector
{ forall bits a. SparseVector bits a -> bits
indices :: !bits,
forall bits a. SparseVector bits a -> Vector a
elements :: !(UV.Vector a)
}
map :: (UV.Unbox a, UV.Unbox b) => (a -> b) -> SparseVector bits a -> SparseVector bits b
map :: forall a b bits.
(Unbox a, Unbox b) =>
(a -> b) -> SparseVector bits a -> SparseVector bits b
map a -> b
f SparseVector bits a
v = SparseVector bits a
v {elements = UV.map f (elements v)}
mask ::
forall a bits.
(UV.Unbox a, B.FiniteBits bits) =>
bits ->
SparseVector bits a ->
SparseVector bits a
mask :: forall a bits.
(Unbox a, FiniteBits bits) =>
bits -> SparseVector bits a -> SparseVector bits a
mask bits
bits SparseVector bits a
a =
if bits
indices' bits -> bits -> Bool
forall a. Eq a => a -> a -> Bool
== bits
bits
then SparseVector bits a
a
else bits -> Vector a -> SparseVector bits a
forall bits a. bits -> Vector a -> SparseVector bits a
SparseVector bits
indices' (Vector a -> SparseVector bits a)
-> Vector a -> SparseVector bits a
forall a b. (a -> b) -> a -> b
$
(forall s. ST s (MVector s a)) -> Vector a
forall a. Unbox a => (forall s. ST s (MVector s a)) -> Vector a
UV.create ((forall s. ST s (MVector s a)) -> Vector a)
-> (forall s. ST s (MVector s a)) -> Vector a
forall a b. (a -> b) -> a -> b
$ do
MVector s a
vec <- Int -> ST s (MVector (PrimState (ST s)) a)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
MUV.new (bits -> Int
forall a. Bits a => a -> Int
B.popCount bits
indices')
MVector s a -> bits -> bits -> Int -> Int -> ST s (MVector s a)
forall s.
STVector s a -> bits -> bits -> Int -> Int -> ST s (STVector s a)
go MVector s a
vec (SparseVector bits a -> bits
forall bits a. SparseVector bits a -> bits
indices SparseVector bits a
a) bits
bits Int
0 Int
0
where
indices' :: bits
indices' = SparseVector bits a -> bits
forall bits a. SparseVector bits a -> bits
indices SparseVector bits a
a bits -> bits -> bits
forall a. Bits a => a -> a -> a
.&. bits
bits
eas :: Vector a
eas = SparseVector bits a -> Vector a
forall bits a. SparseVector bits a -> Vector a
elements SparseVector bits a
a
go :: MUV.STVector s a -> bits -> bits -> Int -> Int -> ST s (MUV.STVector s a)
go :: forall s.
STVector s a -> bits -> bits -> Int -> Int -> ST s (STVector s a)
go !STVector s a
out !bits
indAs !bits
indBs !Int
i !Int
k =
if bits
indAs bits -> bits -> Bool
forall a. Eq a => a -> a -> Bool
== bits
forall a. Bits a => a
B.zeroBits Bool -> Bool -> Bool
|| bits
indBs bits -> bits -> Bool
forall a. Eq a => a -> a -> Bool
== bits
forall a. Bits a => a
B.zeroBits
then STVector s a -> ST s (STVector s a)
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure STVector s a
out
else
let (!Int
a1, !Int
b1) = (bits -> Int
forall b. FiniteBits b => b -> Int
B.countTrailingZeros bits
indAs, bits -> Int
forall b. FiniteBits b => b -> Int
B.countTrailingZeros bits
indBs)
in if Int
a1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
b1
then do
MVector (PrimState (ST s)) a -> Int -> a -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
MUV.write STVector s a
MVector (PrimState (ST s)) a
out Int
k (Vector a
eas Vector a -> Int -> a
forall a. Unbox a => Vector a -> Int -> a
UV.! (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
a1))
STVector s a -> bits -> bits -> Int -> Int -> ST s (STVector s a)
forall s.
STVector s a -> bits -> bits -> Int -> Int -> ST s (STVector s a)
go
STVector s a
out
(bits
indAs bits -> Int -> bits
forall a. Bits a => a -> Int -> a
`B.shiftR` (Int
a1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1))
(bits
indBs bits -> Int -> bits
forall a. Bits a => a -> Int -> a
`B.shiftR` (Int
b1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1))
(Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
(Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
else
if Int
a1 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
b1
then
STVector s a -> bits -> bits -> Int -> Int -> ST s (STVector s a)
forall s.
STVector s a -> bits -> bits -> Int -> Int -> ST s (STVector s a)
go
STVector s a
out
(bits
indAs bits -> Int -> bits
forall a. Bits a => a -> Int -> a
`B.shiftR` (Int
a1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1))
bits
indBs
(Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
Int
k
else STVector s a -> bits -> bits -> Int -> Int -> ST s (STVector s a)
forall s.
STVector s a -> bits -> bits -> Int -> Int -> ST s (STVector s a)
go STVector s a
out bits
indAs (bits
indBs bits -> Int -> bits
forall a. Bits a => a -> Int -> a
`B.shiftR` (Int
b1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)) Int
i Int
k
zipWith ::
(UV.Unbox a, UV.Unbox b, UV.Unbox c, B.FiniteBits bits) =>
(a -> b -> c) ->
SparseVector bits a ->
SparseVector bits b ->
SparseVector bits c
zipWith :: forall a b c bits.
(Unbox a, Unbox b, Unbox c, FiniteBits bits) =>
(a -> b -> c)
-> SparseVector bits a
-> SparseVector bits b
-> SparseVector bits c
zipWith a -> b -> c
f SparseVector bits a
a SparseVector bits b
b =
if SparseVector bits a -> bits
forall bits a. SparseVector bits a -> bits
indices SparseVector bits a
a bits -> bits -> Bool
forall a. a -> a -> Bool
`eq` SparseVector bits b -> bits
forall bits a. SparseVector bits a -> bits
indices SparseVector bits b
b Bool -> Bool -> Bool
|| SparseVector bits a -> bits
forall bits a. SparseVector bits a -> bits
indices SparseVector bits a
a bits -> bits -> Bool
forall a. Eq a => a -> a -> Bool
== SparseVector bits b -> bits
forall bits a. SparseVector bits a -> bits
indices SparseVector bits b
b
then bits -> Vector c -> SparseVector bits c
forall bits a. bits -> Vector a -> SparseVector bits a
SparseVector (SparseVector bits a -> bits
forall bits a. SparseVector bits a -> bits
indices SparseVector bits a
a) ((a -> b -> c) -> Vector a -> Vector b -> Vector c
forall a b c.
(Unbox a, Unbox b, Unbox c) =>
(a -> b -> c) -> Vector a -> Vector b -> Vector c
UV.zipWith a -> b -> c
f (SparseVector bits a -> Vector a
forall bits a. SparseVector bits a -> Vector a
elements SparseVector bits a
a) (SparseVector bits b -> Vector b
forall bits a. SparseVector bits a -> Vector a
elements SparseVector bits b
b))
else
let indices' :: bits
indices' = SparseVector bits a -> bits
forall bits a. SparseVector bits a -> bits
indices SparseVector bits a
a bits -> bits -> bits
forall a. Bits a => a -> a -> a
.&. SparseVector bits b -> bits
forall bits a. SparseVector bits a -> bits
indices SparseVector bits b
b
a' :: SparseVector bits a
a' = bits -> SparseVector bits a -> SparseVector bits a
forall a bits.
(Unbox a, FiniteBits bits) =>
bits -> SparseVector bits a -> SparseVector bits a
mask bits
indices' SparseVector bits a
a
b' :: SparseVector bits b
b' = bits -> SparseVector bits b -> SparseVector bits b
forall a bits.
(Unbox a, FiniteBits bits) =>
bits -> SparseVector bits a -> SparseVector bits a
mask bits
indices' SparseVector bits b
b
in bits -> Vector c -> SparseVector bits c
forall bits a. bits -> Vector a -> SparseVector bits a
SparseVector bits
indices' ((a -> b -> c) -> Vector a -> Vector b -> Vector c
forall a b c.
(Unbox a, Unbox b, Unbox c) =>
(a -> b -> c) -> Vector a -> Vector b -> Vector c
UV.zipWith a -> b -> c
f (SparseVector bits a -> Vector a
forall bits a. SparseVector bits a -> Vector a
elements SparseVector bits a
a') (SparseVector bits b -> Vector b
forall bits a. SparseVector bits a -> Vector a
elements SparseVector bits b
b'))
_1 :: (UV.Unbox a, UV.Unbox b) => SparseVector bits (a, b) -> SparseVector bits a
_1 :: forall a b bits.
(Unbox a, Unbox b) =>
SparseVector bits (a, b) -> SparseVector bits a
_1 = (SparseVector bits a, SparseVector bits b) -> SparseVector bits a
forall a b. (a, b) -> a
fst ((SparseVector bits a, SparseVector bits b) -> SparseVector bits a)
-> (SparseVector bits (a, b)
-> (SparseVector bits a, SparseVector bits b))
-> SparseVector bits (a, b)
-> SparseVector bits a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SparseVector bits (a, b)
-> (SparseVector bits a, SparseVector bits b)
forall a b bits.
(Unbox a, Unbox b) =>
SparseVector bits (a, b)
-> (SparseVector bits a, SparseVector bits b)
unzip
_2 :: (UV.Unbox a, UV.Unbox b) => SparseVector bits (a, b) -> SparseVector bits b
_2 :: forall a b bits.
(Unbox a, Unbox b) =>
SparseVector bits (a, b) -> SparseVector bits b
_2 = (SparseVector bits a, SparseVector bits b) -> SparseVector bits b
forall a b. (a, b) -> b
snd ((SparseVector bits a, SparseVector bits b) -> SparseVector bits b)
-> (SparseVector bits (a, b)
-> (SparseVector bits a, SparseVector bits b))
-> SparseVector bits (a, b)
-> SparseVector bits b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SparseVector bits (a, b)
-> (SparseVector bits a, SparseVector bits b)
forall a b bits.
(Unbox a, Unbox b) =>
SparseVector bits (a, b)
-> (SparseVector bits a, SparseVector bits b)
unzip
unzip ::
(UV.Unbox a, UV.Unbox b) =>
SparseVector bits (a, b) ->
(SparseVector bits a, SparseVector bits b)
unzip :: forall a b bits.
(Unbox a, Unbox b) =>
SparseVector bits (a, b)
-> (SparseVector bits a, SparseVector bits b)
unzip (SparseVector bits
inds Vector (a, b)
ps) =
let (Vector a
as, Vector b
bs) = Vector (a, b) -> (Vector a, Vector b)
forall a b.
(Unbox a, Unbox b) =>
Vector (a, b) -> (Vector a, Vector b)
UV.unzip Vector (a, b)
ps
in (bits -> Vector a -> SparseVector bits a
forall bits a. bits -> Vector a -> SparseVector bits a
SparseVector bits
inds Vector a
as, bits -> Vector b -> SparseVector bits b
forall bits a. bits -> Vector a -> SparseVector bits a
SparseVector bits
inds Vector b
bs)
choose ::
(B.FiniteBits bits, UV.Unbox a) =>
bits ->
SparseVector bits a ->
SparseVector bits a ->
SparseVector bits a
choose :: forall bits a.
(FiniteBits bits, Unbox a) =>
bits
-> SparseVector bits a
-> SparseVector bits a
-> SparseVector bits a
choose bits
bits SparseVector bits a
t SparseVector bits a
f
| bits
forall a. Bits a => a
B.zeroBits bits -> bits -> Bool
forall a. Eq a => a -> a -> Bool
== bits
bits = SparseVector bits a
f
| bits -> bits
forall a. Bits a => a -> a
B.complement bits
bits bits -> bits -> Bool
forall a. Eq a => a -> a -> Bool
== bits
forall a. Bits a => a
B.zeroBits = SparseVector bits a
t
| Bool
otherwise
=
SparseVector bits a -> SparseVector bits a -> SparseVector bits a
forall a bits.
(FiniteBits bits, Unbox a) =>
SparseVector bits a -> SparseVector bits a -> SparseVector bits a
merge (bits -> SparseVector bits a -> SparseVector bits a
forall a bits.
(Unbox a, FiniteBits bits) =>
bits -> SparseVector bits a -> SparseVector bits a
mask bits
bits SparseVector bits a
t) (bits -> SparseVector bits a -> SparseVector bits a
forall a bits.
(Unbox a, FiniteBits bits) =>
bits -> SparseVector bits a -> SparseVector bits a
mask (bits -> bits
forall a. Bits a => a -> a
B.complement bits
bits) SparseVector bits a
f)
merge ::
forall a bits.
(B.FiniteBits bits, UV.Unbox a) =>
SparseVector bits a ->
SparseVector bits a ->
SparseVector bits a
merge :: forall a bits.
(FiniteBits bits, Unbox a) =>
SparseVector bits a -> SparseVector bits a -> SparseVector bits a
merge SparseVector bits a
a SparseVector bits a
b = bits -> Vector a -> SparseVector bits a
forall bits a. bits -> Vector a -> SparseVector bits a
SparseVector bits
indices' Vector a
tricky
where
indices' :: bits
indices' = SparseVector bits a -> bits
forall bits a. SparseVector bits a -> bits
indices SparseVector bits a
a bits -> bits -> bits
forall a. Bits a => a -> a -> a
.|. SparseVector bits a -> bits
forall bits a. SparseVector bits a -> bits
indices SparseVector bits a
b
tricky :: Vector a
tricky = (forall s. ST s (MVector s a)) -> Vector a
forall a. Unbox a => (forall s. ST s (MVector s a)) -> Vector a
UV.create ((forall s. ST s (MVector s a)) -> Vector a)
-> (forall s. ST s (MVector s a)) -> Vector a
forall a b. (a -> b) -> a -> b
$ do
MVector s a
vec <- Int -> ST s (MVector (PrimState (ST s)) a)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
MUV.new (bits -> Int
forall a. Bits a => a -> Int
B.popCount bits
indices')
MVector s a
-> bits -> bits -> Int -> Int -> Int -> ST s (MVector s a)
forall s.
STVector s a
-> bits -> bits -> Int -> Int -> Int -> ST s (STVector s a)
go MVector s a
vec (SparseVector bits a -> bits
forall bits a. SparseVector bits a -> bits
indices SparseVector bits a
a) (SparseVector bits a -> bits
forall bits a. SparseVector bits a -> bits
indices SparseVector bits a
b) Int
0 Int
0 Int
0
(!Vector a
eas, !Vector a
ebs) = (SparseVector bits a -> Vector a
forall bits a. SparseVector bits a -> Vector a
elements SparseVector bits a
a, SparseVector bits a -> Vector a
forall bits a. SparseVector bits a -> Vector a
elements SparseVector bits a
b)
go :: MUV.STVector s a -> bits -> bits -> Int -> Int -> Int -> ST s (MUV.STVector s a)
go :: forall s.
STVector s a
-> bits -> bits -> Int -> Int -> Int -> ST s (STVector s a)
go !STVector s a
out !bits
indAs !bits
indBs !Int
i !Int
j !Int
k =
if bits
indAs bits -> bits -> Bool
forall a. Eq a => a -> a -> Bool
== bits
forall a. Bits a => a
B.zeroBits Bool -> Bool -> Bool
|| bits
indBs bits -> bits -> Bool
forall a. Eq a => a -> a -> Bool
== bits
forall a. Bits a => a
B.zeroBits
then STVector s a -> ST s (STVector s a)
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure STVector s a
out
else
let (!Int
a1, !Int
b1) = (bits -> Int
forall b. FiniteBits b => b -> Int
B.countTrailingZeros bits
indAs, bits -> Int
forall b. FiniteBits b => b -> Int
B.countTrailingZeros bits
indBs)
in if Int
a1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
b1
then do
MVector (PrimState (ST s)) a -> Int -> a -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
MUV.write STVector s a
MVector (PrimState (ST s)) a
out Int
k (Vector a
eas Vector a -> Int -> a
forall a. Unbox a => Vector a -> Int -> a
UV.! (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
a1))
STVector s a
-> bits -> bits -> Int -> Int -> Int -> ST s (STVector s a)
forall s.
STVector s a
-> bits -> bits -> Int -> Int -> Int -> ST s (STVector s a)
go
STVector s a
out
(bits
indAs bits -> Int -> bits
forall a. Bits a => a -> Int -> a
`B.shiftR` (Int
a1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1))
(bits
indBs bits -> Int -> bits
forall a. Bits a => a -> Int -> a
`B.shiftR` (Int
b1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1))
(Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
(Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
(Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
else
if Int
a1 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
b1
then do
MVector (PrimState (ST s)) a -> Int -> a -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
MUV.write STVector s a
MVector (PrimState (ST s)) a
out Int
k (Vector a
eas Vector a -> Int -> a
forall a. Unbox a => Vector a -> Int -> a
UV.! (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
a1))
STVector s a
-> bits -> bits -> Int -> Int -> Int -> ST s (STVector s a)
forall s.
STVector s a
-> bits -> bits -> Int -> Int -> Int -> ST s (STVector s a)
go
STVector s a
out
(bits
indAs bits -> Int -> bits
forall a. Bits a => a -> Int -> a
`B.shiftR` (Int
a1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1))
bits
indBs
(Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
Int
j
(Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
else do
MVector (PrimState (ST s)) a -> Int -> a -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
MUV.write STVector s a
MVector (PrimState (ST s)) a
out Int
k (Vector a
ebs Vector a -> Int -> a
forall a. Unbox a => Vector a -> Int -> a
UV.! (Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
a1))
STVector s a
-> bits -> bits -> Int -> Int -> Int -> ST s (STVector s a)
forall s.
STVector s a
-> bits -> bits -> Int -> Int -> Int -> ST s (STVector s a)
go STVector s a
out bits
indAs (bits
indBs bits -> Int -> bits
forall a. Bits a => a -> Int -> a
`B.shiftR` (Int
b1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)) Int
i (Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
eq :: a -> a -> Bool
eq :: forall a. a -> a -> Bool
eq a
x a
y = Int# -> Bool
Exts.isTrue# (a -> a -> Int#
forall a b. a -> b -> Int#
Exts.reallyUnsafePtrEquality# a
x a
y Int# -> Int# -> Int#
Exts.==# Int#
1#)
{-# INLINE eq #-}