{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
module Data.Massiv.Core.Operations (
FoldNumeric (..),
defaultPowerSumArray,
defaultUnsafeDotProduct,
defaultFoldArray,
Numeric (..),
defaultUnsafeLiftArray,
defaultUnsafeLiftArray2,
NumericFloat (..),
) where
import Data.Massiv.Core.Common
class (Size r, Num e) => FoldNumeric r e where
{-# MINIMAL foldArray, powerSumArray, unsafeDotProduct #-}
sumArray :: Index ix => Array r ix e -> e
sumArray = (e -> e -> e) -> e -> Array r ix e -> e
forall ix. Index ix => (e -> e -> e) -> e -> Array r ix e -> e
forall r e ix.
(FoldNumeric r e, Index ix) =>
(e -> e -> e) -> e -> Array r ix e -> e
foldArray e -> e -> e
forall a. Num a => a -> a -> a
(+) e
0
{-# INLINE sumArray #-}
productArray :: Index ix => Array r ix e -> e
productArray = (e -> e -> e) -> e -> Array r ix e -> e
forall ix. Index ix => (e -> e -> e) -> e -> Array r ix e -> e
forall r e ix.
(FoldNumeric r e, Index ix) =>
(e -> e -> e) -> e -> Array r ix e -> e
foldArray e -> e -> e
forall a. Num a => a -> a -> a
(*) e
1
{-# INLINE productArray #-}
powerSumArray :: Index ix => Array r ix e -> Int -> e
unsafeDotProduct :: Index ix => Array r ix e -> Array r ix e -> e
foldArray :: Index ix => (e -> e -> e) -> e -> Array r ix e -> e
defaultUnsafeDotProduct
:: (Num e, Index ix, Source r e) => Array r ix e -> Array r ix e -> e
defaultUnsafeDotProduct :: forall e ix r.
(Num e, Index ix, Source r e) =>
Array r ix e -> Array r ix e -> e
defaultUnsafeDotProduct Array r ix e
a1 Array r ix e
a2 = e -> Int -> e
go e
0 Int
0
where
!len :: Int
len = Sz ix -> Int
forall ix. Index ix => Sz ix -> Int
totalElem (Array r ix e -> Sz ix
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array r ix e -> Sz ix
size Array r ix e
a1)
go :: e -> Int -> e
go !e
acc Int
i
| Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
len = e -> Int -> e
go (e
acc e -> e -> e
forall a. Num a => a -> a -> a
+ Array r ix e -> Int -> e
forall ix. Index ix => Array r ix e -> Int -> e
forall r e ix. (Source r e, Index ix) => Array r ix e -> Int -> e
unsafeLinearIndex Array r ix e
a1 Int
i e -> e -> e
forall a. Num a => a -> a -> a
* Array r ix e -> Int -> e
forall ix. Index ix => Array r ix e -> Int -> e
forall r e ix. (Source r e, Index ix) => Array r ix e -> Int -> e
unsafeLinearIndex Array r ix e
a2 Int
i) (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
| Bool
otherwise = e
acc
{-# INLINE defaultUnsafeDotProduct #-}
defaultPowerSumArray :: (Index ix, Source r e, Num e) => Array r ix e -> Int -> e
defaultPowerSumArray :: forall ix r e.
(Index ix, Source r e, Num e) =>
Array r ix e -> Int -> e
defaultPowerSumArray Array r ix e
arr Int
p = e -> Int -> e
go e
0 Int
0
where
!len :: Int
len = Sz ix -> Int
forall ix. Index ix => Sz ix -> Int
totalElem (Array r ix e -> Sz ix
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array r ix e -> Sz ix
size Array r ix e
arr)
go :: e -> Int -> e
go !e
acc Int
i
| Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
len = e -> Int -> e
go (e
acc e -> e -> e
forall a. Num a => a -> a -> a
+ Array r ix e -> Int -> e
forall ix. Index ix => Array r ix e -> Int -> e
forall r e ix. (Source r e, Index ix) => Array r ix e -> Int -> e
unsafeLinearIndex Array r ix e
arr Int
i e -> Int -> e
forall a b. (Num a, Integral b) => a -> b -> a
^ Int
p) (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
| Bool
otherwise = e
acc
{-# INLINE defaultPowerSumArray #-}
defaultFoldArray :: (Index ix, Source r e) => (e -> e -> e) -> e -> Array r ix e -> e
defaultFoldArray :: forall ix r e.
(Index ix, Source r e) =>
(e -> e -> e) -> e -> Array r ix e -> e
defaultFoldArray e -> e -> e
f !e
initAcc Array r ix e
arr = e -> Int -> e
go e
initAcc Int
0
where
!len :: Int
len = Sz ix -> Int
forall ix. Index ix => Sz ix -> Int
totalElem (Array r ix e -> Sz ix
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array r ix e -> Sz ix
size Array r ix e
arr)
go :: e -> Int -> e
go !e
acc Int
i
| Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
len = e -> Int -> e
go (e -> e -> e
f e
acc (Array r ix e -> Int -> e
forall ix. Index ix => Array r ix e -> Int -> e
forall r e ix. (Source r e, Index ix) => Array r ix e -> Int -> e
unsafeLinearIndex Array r ix e
arr Int
i)) (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
| Bool
otherwise = e
acc
{-# INLINE defaultFoldArray #-}
class FoldNumeric r e => Numeric r e where
{-# MINIMAL unsafeLiftArray, unsafeLiftArray2 #-}
plusScalar :: Index ix => Array r ix e -> e -> Array r ix e
plusScalar Array r ix e
arr e
e = (e -> e) -> Array r ix e -> Array r ix e
forall ix. Index ix => (e -> e) -> Array r ix e -> Array r ix e
forall r e ix.
(Numeric r e, Index ix) =>
(e -> e) -> Array r ix e -> Array r ix e
unsafeLiftArray (e -> e -> e
forall a. Num a => a -> a -> a
+ e
e) Array r ix e
arr
{-# INLINE plusScalar #-}
minusScalar :: Index ix => Array r ix e -> e -> Array r ix e
minusScalar Array r ix e
arr e
e = (e -> e) -> Array r ix e -> Array r ix e
forall ix. Index ix => (e -> e) -> Array r ix e -> Array r ix e
forall r e ix.
(Numeric r e, Index ix) =>
(e -> e) -> Array r ix e -> Array r ix e
unsafeLiftArray (e -> e -> e
forall a. Num a => a -> a -> a
subtract e
e) Array r ix e
arr
{-# INLINE minusScalar #-}
scalarMinus :: Index ix => e -> Array r ix e -> Array r ix e
scalarMinus e
e Array r ix e
arr = (e -> e) -> Array r ix e -> Array r ix e
forall ix. Index ix => (e -> e) -> Array r ix e -> Array r ix e
forall r e ix.
(Numeric r e, Index ix) =>
(e -> e) -> Array r ix e -> Array r ix e
unsafeLiftArray (e
e e -> e -> e
forall a. Num a => a -> a -> a
-) Array r ix e
arr
{-# INLINE scalarMinus #-}
multiplyScalar :: Index ix => Array r ix e -> e -> Array r ix e
multiplyScalar Array r ix e
arr e
e = (e -> e) -> Array r ix e -> Array r ix e
forall ix. Index ix => (e -> e) -> Array r ix e -> Array r ix e
forall r e ix.
(Numeric r e, Index ix) =>
(e -> e) -> Array r ix e -> Array r ix e
unsafeLiftArray (e -> e -> e
forall a. Num a => a -> a -> a
* e
e) Array r ix e
arr
{-# INLINE multiplyScalar #-}
absPointwise :: Index ix => Array r ix e -> Array r ix e
absPointwise = (e -> e) -> Array r ix e -> Array r ix e
forall ix. Index ix => (e -> e) -> Array r ix e -> Array r ix e
forall r e ix.
(Numeric r e, Index ix) =>
(e -> e) -> Array r ix e -> Array r ix e
unsafeLiftArray e -> e
forall a. Num a => a -> a
abs
{-# INLINE absPointwise #-}
additionPointwise :: Index ix => Array r ix e -> Array r ix e -> Array r ix e
additionPointwise = (e -> e -> e) -> Array r ix e -> Array r ix e -> Array r ix e
forall ix.
Index ix =>
(e -> e -> e) -> Array r ix e -> Array r ix e -> Array r ix e
forall r e ix.
(Numeric r e, Index ix) =>
(e -> e -> e) -> Array r ix e -> Array r ix e -> Array r ix e
unsafeLiftArray2 e -> e -> e
forall a. Num a => a -> a -> a
(+)
{-# INLINE additionPointwise #-}
subtractionPointwise :: Index ix => Array r ix e -> Array r ix e -> Array r ix e
subtractionPointwise = (e -> e -> e) -> Array r ix e -> Array r ix e -> Array r ix e
forall ix.
Index ix =>
(e -> e -> e) -> Array r ix e -> Array r ix e -> Array r ix e
forall r e ix.
(Numeric r e, Index ix) =>
(e -> e -> e) -> Array r ix e -> Array r ix e -> Array r ix e
unsafeLiftArray2 (-)
{-# INLINE subtractionPointwise #-}
multiplicationPointwise :: Index ix => Array r ix e -> Array r ix e -> Array r ix e
multiplicationPointwise = (e -> e -> e) -> Array r ix e -> Array r ix e -> Array r ix e
forall ix.
Index ix =>
(e -> e -> e) -> Array r ix e -> Array r ix e -> Array r ix e
forall r e ix.
(Numeric r e, Index ix) =>
(e -> e -> e) -> Array r ix e -> Array r ix e -> Array r ix e
unsafeLiftArray2 e -> e -> e
forall a. Num a => a -> a -> a
(*)
{-# INLINE multiplicationPointwise #-}
powerPointwise :: Index ix => Array r ix e -> Int -> Array r ix e
powerPointwise Array r ix e
arr Int
pow = (e -> e) -> Array r ix e -> Array r ix e
forall ix. Index ix => (e -> e) -> Array r ix e -> Array r ix e
forall r e ix.
(Numeric r e, Index ix) =>
(e -> e) -> Array r ix e -> Array r ix e
unsafeLiftArray (e -> Int -> e
forall a b. (Num a, Integral b) => a -> b -> a
^ Int
pow) Array r ix e
arr
{-# INLINE powerPointwise #-}
unsafeLiftArray :: Index ix => (e -> e) -> Array r ix e -> Array r ix e
unsafeLiftArray2 :: Index ix => (e -> e -> e) -> Array r ix e -> Array r ix e -> Array r ix e
defaultUnsafeLiftArray
:: (Load r ix e, Source r e) => (e -> e) -> Array r ix e -> Array r ix e
defaultUnsafeLiftArray :: forall r ix e.
(Load r ix e, Source r e) =>
(e -> e) -> Array r ix e -> Array r ix e
defaultUnsafeLiftArray e -> e
f Array r ix e
arr = Comp -> Sz ix -> (Int -> e) -> Array r ix e
forall r ix e.
Load r ix e =>
Comp -> Sz ix -> (Int -> e) -> Array r ix e
makeArrayLinear (Array r ix e -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array r ix e -> Comp
getComp Array r ix e
arr) (Array r ix e -> Sz ix
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array r ix e -> Sz ix
size Array r ix e
arr) (e -> e
f (e -> e) -> (Int -> e) -> Int -> e
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Array r ix e -> Int -> e
forall ix. Index ix => Array r ix e -> Int -> e
forall r e ix. (Source r e, Index ix) => Array r ix e -> Int -> e
unsafeLinearIndex Array r ix e
arr)
{-# INLINE defaultUnsafeLiftArray #-}
defaultUnsafeLiftArray2
:: (Load r ix e, Source r e)
=> (e -> e -> e)
-> Array r ix e
-> Array r ix e
-> Array r ix e
defaultUnsafeLiftArray2 :: forall r ix e.
(Load r ix e, Source r e) =>
(e -> e -> e) -> Array r ix e -> Array r ix e -> Array r ix e
defaultUnsafeLiftArray2 e -> e -> e
f Array r ix e
a1 Array r ix e
a2 =
Comp -> Sz ix -> (Int -> e) -> Array r ix e
forall r ix e.
Load r ix e =>
Comp -> Sz ix -> (Int -> e) -> Array r ix e
makeArrayLinear (Array r ix e -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array r ix e -> Comp
getComp Array r ix e
a1 Comp -> Comp -> Comp
forall a. Semigroup a => a -> a -> a
<> Array r ix e -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array r ix e -> Comp
getComp Array r ix e
a2) (Array r ix e -> Sz ix
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array r ix e -> Sz ix
size Array r ix e
a1) ((Int -> e) -> Array r ix e) -> (Int -> e) -> Array r ix e
forall a b. (a -> b) -> a -> b
$ \ !Int
i ->
e -> e -> e
f (Array r ix e -> Int -> e
forall ix. Index ix => Array r ix e -> Int -> e
forall r e ix. (Source r e, Index ix) => Array r ix e -> Int -> e
unsafeLinearIndex Array r ix e
a1 Int
i) (Array r ix e -> Int -> e
forall ix. Index ix => Array r ix e -> Int -> e
forall r e ix. (Source r e, Index ix) => Array r ix e -> Int -> e
unsafeLinearIndex Array r ix e
a2 Int
i)
{-# INLINE defaultUnsafeLiftArray2 #-}
class (Numeric r e, Floating e) => NumericFloat r e where
divideScalar :: Index ix => Array r ix e -> e -> Array r ix e
divideScalar Array r ix e
arr e
e = (e -> e) -> Array r ix e -> Array r ix e
forall ix. Index ix => (e -> e) -> Array r ix e -> Array r ix e
forall r e ix.
(Numeric r e, Index ix) =>
(e -> e) -> Array r ix e -> Array r ix e
unsafeLiftArray (e -> e -> e
forall a. Fractional a => a -> a -> a
/ e
e) Array r ix e
arr
{-# INLINE divideScalar #-}
scalarDivide :: Index ix => e -> Array r ix e -> Array r ix e
scalarDivide e
e Array r ix e
arr = (e -> e) -> Array r ix e -> Array r ix e
forall ix. Index ix => (e -> e) -> Array r ix e -> Array r ix e
forall r e ix.
(Numeric r e, Index ix) =>
(e -> e) -> Array r ix e -> Array r ix e
unsafeLiftArray (e
e e -> e -> e
forall a. Fractional a => a -> a -> a
/) Array r ix e
arr
{-# INLINE scalarDivide #-}
divisionPointwise :: Index ix => Array r ix e -> Array r ix e -> Array r ix e
divisionPointwise = (e -> e -> e) -> Array r ix e -> Array r ix e -> Array r ix e
forall ix.
Index ix =>
(e -> e -> e) -> Array r ix e -> Array r ix e -> Array r ix e
forall r e ix.
(Numeric r e, Index ix) =>
(e -> e -> e) -> Array r ix e -> Array r ix e -> Array r ix e
unsafeLiftArray2 e -> e -> e
forall a. Fractional a => a -> a -> a
(/)
{-# INLINE divisionPointwise #-}
recipPointwise :: Index ix => Array r ix e -> Array r ix e
recipPointwise = (e -> e) -> Array r ix e -> Array r ix e
forall ix. Index ix => (e -> e) -> Array r ix e -> Array r ix e
forall r e ix.
(Numeric r e, Index ix) =>
(e -> e) -> Array r ix e -> Array r ix e
unsafeLiftArray e -> e
forall a. Fractional a => a -> a
recip
{-# INLINE recipPointwise #-}
sqrtPointwise :: Index ix => Array r ix e -> Array r ix e
sqrtPointwise = (e -> e) -> Array r ix e -> Array r ix e
forall ix. Index ix => (e -> e) -> Array r ix e -> Array r ix e
forall r e ix.
(Numeric r e, Index ix) =>
(e -> e) -> Array r ix e -> Array r ix e
unsafeLiftArray e -> e
forall a. Floating a => a -> a
sqrt
{-# INLINE sqrtPointwise #-}