{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}

-- |
-- Module      : Data.Massiv.Core.Operations
-- Copyright   : (c) Alexey Kuleshevich 2019-2022
-- License     : BSD3
-- Maintainer  : Alexey Kuleshevich <lehins@yandex.ru>
-- Stability   : experimental
-- Portability : non-portable
module Data.Massiv.Core.Operations (
  FoldNumeric (..),
  defaultPowerSumArray,
  defaultUnsafeDotProduct,
  defaultFoldArray,
  Numeric (..),
  defaultUnsafeLiftArray,
  defaultUnsafeLiftArray2,
  NumericFloat (..),
) where

import Data.Massiv.Core.Common

class (Size r, Num e) => FoldNumeric r e where
  {-# MINIMAL foldArray, powerSumArray, unsafeDotProduct #-}

  -- | Compute sum of all elements in the array
  --
  -- @since 0.5.6
  sumArray :: Index ix => Array r ix e -> e
  sumArray = (e -> e -> e) -> e -> Array r ix e -> e
forall ix. Index ix => (e -> e -> e) -> e -> Array r ix e -> e
forall r e ix.
(FoldNumeric r e, Index ix) =>
(e -> e -> e) -> e -> Array r ix e -> e
foldArray e -> e -> e
forall a. Num a => a -> a -> a
(+) e
0
  {-# INLINE sumArray #-}

  -- | Compute product of all elements in the array
  --
  -- @since 0.5.6
  productArray :: Index ix => Array r ix e -> e
  productArray = (e -> e -> e) -> e -> Array r ix e -> e
forall ix. Index ix => (e -> e -> e) -> e -> Array r ix e -> e
forall r e ix.
(FoldNumeric r e, Index ix) =>
(e -> e -> e) -> e -> Array r ix e -> e
foldArray e -> e -> e
forall a. Num a => a -> a -> a
(*) e
1
  {-# INLINE productArray #-}

  -- | Raise each element in the array to some non-negative power and sum the results
  --
  -- @since 0.5.7
  powerSumArray :: Index ix => Array r ix e -> Int -> e

  -- | Compute dot product without any extraneous checks
  --
  -- @since 0.5.6
  unsafeDotProduct :: Index ix => Array r ix e -> Array r ix e -> e

  -- | Fold over an array
  --
  -- @since 0.5.6
  foldArray :: Index ix => (e -> e -> e) -> e -> Array r ix e -> e

defaultUnsafeDotProduct
  :: (Num e, Index ix, Source r e) => Array r ix e -> Array r ix e -> e
defaultUnsafeDotProduct :: forall e ix r.
(Num e, Index ix, Source r e) =>
Array r ix e -> Array r ix e -> e
defaultUnsafeDotProduct Array r ix e
a1 Array r ix e
a2 = e -> Int -> e
go e
0 Int
0
  where
    !len :: Int
len = Sz ix -> Int
forall ix. Index ix => Sz ix -> Int
totalElem (Array r ix e -> Sz ix
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array r ix e -> Sz ix
size Array r ix e
a1)
    go :: e -> Int -> e
go !e
acc Int
i
      | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
len = e -> Int -> e
go (e
acc e -> e -> e
forall a. Num a => a -> a -> a
+ Array r ix e -> Int -> e
forall ix. Index ix => Array r ix e -> Int -> e
forall r e ix. (Source r e, Index ix) => Array r ix e -> Int -> e
unsafeLinearIndex Array r ix e
a1 Int
i e -> e -> e
forall a. Num a => a -> a -> a
* Array r ix e -> Int -> e
forall ix. Index ix => Array r ix e -> Int -> e
forall r e ix. (Source r e, Index ix) => Array r ix e -> Int -> e
unsafeLinearIndex Array r ix e
a2 Int
i) (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
      | Bool
otherwise = e
acc
{-# INLINE defaultUnsafeDotProduct #-}

defaultPowerSumArray :: (Index ix, Source r e, Num e) => Array r ix e -> Int -> e
defaultPowerSumArray :: forall ix r e.
(Index ix, Source r e, Num e) =>
Array r ix e -> Int -> e
defaultPowerSumArray Array r ix e
arr Int
p = e -> Int -> e
go e
0 Int
0
  where
    !len :: Int
len = Sz ix -> Int
forall ix. Index ix => Sz ix -> Int
totalElem (Array r ix e -> Sz ix
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array r ix e -> Sz ix
size Array r ix e
arr)
    go :: e -> Int -> e
go !e
acc Int
i
      | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
len = e -> Int -> e
go (e
acc e -> e -> e
forall a. Num a => a -> a -> a
+ Array r ix e -> Int -> e
forall ix. Index ix => Array r ix e -> Int -> e
forall r e ix. (Source r e, Index ix) => Array r ix e -> Int -> e
unsafeLinearIndex Array r ix e
arr Int
i e -> Int -> e
forall a b. (Num a, Integral b) => a -> b -> a
^ Int
p) (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
      | Bool
otherwise = e
acc
{-# INLINE defaultPowerSumArray #-}

defaultFoldArray :: (Index ix, Source r e) => (e -> e -> e) -> e -> Array r ix e -> e
defaultFoldArray :: forall ix r e.
(Index ix, Source r e) =>
(e -> e -> e) -> e -> Array r ix e -> e
defaultFoldArray e -> e -> e
f !e
initAcc Array r ix e
arr = e -> Int -> e
go e
initAcc Int
0
  where
    !len :: Int
len = Sz ix -> Int
forall ix. Index ix => Sz ix -> Int
totalElem (Array r ix e -> Sz ix
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array r ix e -> Sz ix
size Array r ix e
arr)
    go :: e -> Int -> e
go !e
acc Int
i
      | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
len = e -> Int -> e
go (e -> e -> e
f e
acc (Array r ix e -> Int -> e
forall ix. Index ix => Array r ix e -> Int -> e
forall r e ix. (Source r e, Index ix) => Array r ix e -> Int -> e
unsafeLinearIndex Array r ix e
arr Int
i)) (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
      | Bool
otherwise = e
acc
{-# INLINE defaultFoldArray #-}

class FoldNumeric r e => Numeric r e where
  {-# MINIMAL unsafeLiftArray, unsafeLiftArray2 #-}

  plusScalar :: Index ix => Array r ix e -> e -> Array r ix e
  plusScalar Array r ix e
arr e
e = (e -> e) -> Array r ix e -> Array r ix e
forall ix. Index ix => (e -> e) -> Array r ix e -> Array r ix e
forall r e ix.
(Numeric r e, Index ix) =>
(e -> e) -> Array r ix e -> Array r ix e
unsafeLiftArray (e -> e -> e
forall a. Num a => a -> a -> a
+ e
e) Array r ix e
arr
  {-# INLINE plusScalar #-}

  minusScalar :: Index ix => Array r ix e -> e -> Array r ix e
  minusScalar Array r ix e
arr e
e = (e -> e) -> Array r ix e -> Array r ix e
forall ix. Index ix => (e -> e) -> Array r ix e -> Array r ix e
forall r e ix.
(Numeric r e, Index ix) =>
(e -> e) -> Array r ix e -> Array r ix e
unsafeLiftArray (e -> e -> e
forall a. Num a => a -> a -> a
subtract e
e) Array r ix e
arr
  {-# INLINE minusScalar #-}

  scalarMinus :: Index ix => e -> Array r ix e -> Array r ix e
  scalarMinus e
e Array r ix e
arr = (e -> e) -> Array r ix e -> Array r ix e
forall ix. Index ix => (e -> e) -> Array r ix e -> Array r ix e
forall r e ix.
(Numeric r e, Index ix) =>
(e -> e) -> Array r ix e -> Array r ix e
unsafeLiftArray (e
e e -> e -> e
forall a. Num a => a -> a -> a
-) Array r ix e
arr
  {-# INLINE scalarMinus #-}

  multiplyScalar :: Index ix => Array r ix e -> e -> Array r ix e
  multiplyScalar Array r ix e
arr e
e = (e -> e) -> Array r ix e -> Array r ix e
forall ix. Index ix => (e -> e) -> Array r ix e -> Array r ix e
forall r e ix.
(Numeric r e, Index ix) =>
(e -> e) -> Array r ix e -> Array r ix e
unsafeLiftArray (e -> e -> e
forall a. Num a => a -> a -> a
* e
e) Array r ix e
arr
  {-# INLINE multiplyScalar #-}

  absPointwise :: Index ix => Array r ix e -> Array r ix e
  absPointwise = (e -> e) -> Array r ix e -> Array r ix e
forall ix. Index ix => (e -> e) -> Array r ix e -> Array r ix e
forall r e ix.
(Numeric r e, Index ix) =>
(e -> e) -> Array r ix e -> Array r ix e
unsafeLiftArray e -> e
forall a. Num a => a -> a
abs
  {-# INLINE absPointwise #-}

  additionPointwise :: Index ix => Array r ix e -> Array r ix e -> Array r ix e
  additionPointwise = (e -> e -> e) -> Array r ix e -> Array r ix e -> Array r ix e
forall ix.
Index ix =>
(e -> e -> e) -> Array r ix e -> Array r ix e -> Array r ix e
forall r e ix.
(Numeric r e, Index ix) =>
(e -> e -> e) -> Array r ix e -> Array r ix e -> Array r ix e
unsafeLiftArray2 e -> e -> e
forall a. Num a => a -> a -> a
(+)
  {-# INLINE additionPointwise #-}

  subtractionPointwise :: Index ix => Array r ix e -> Array r ix e -> Array r ix e
  subtractionPointwise = (e -> e -> e) -> Array r ix e -> Array r ix e -> Array r ix e
forall ix.
Index ix =>
(e -> e -> e) -> Array r ix e -> Array r ix e -> Array r ix e
forall r e ix.
(Numeric r e, Index ix) =>
(e -> e -> e) -> Array r ix e -> Array r ix e -> Array r ix e
unsafeLiftArray2 (-)
  {-# INLINE subtractionPointwise #-}

  multiplicationPointwise :: Index ix => Array r ix e -> Array r ix e -> Array r ix e
  multiplicationPointwise = (e -> e -> e) -> Array r ix e -> Array r ix e -> Array r ix e
forall ix.
Index ix =>
(e -> e -> e) -> Array r ix e -> Array r ix e -> Array r ix e
forall r e ix.
(Numeric r e, Index ix) =>
(e -> e -> e) -> Array r ix e -> Array r ix e -> Array r ix e
unsafeLiftArray2 e -> e -> e
forall a. Num a => a -> a -> a
(*)
  {-# INLINE multiplicationPointwise #-}

  -- TODO:
  --  - rename to powerScalar
  --  - add? powerPointwise :: Array r ix e -> Array r ix Int -> Array r ix e

  -- | Raise each element of the array to the power
  powerPointwise :: Index ix => Array r ix e -> Int -> Array r ix e
  powerPointwise Array r ix e
arr Int
pow = (e -> e) -> Array r ix e -> Array r ix e
forall ix. Index ix => (e -> e) -> Array r ix e -> Array r ix e
forall r e ix.
(Numeric r e, Index ix) =>
(e -> e) -> Array r ix e -> Array r ix e
unsafeLiftArray (e -> Int -> e
forall a b. (Num a, Integral b) => a -> b -> a
^ Int
pow) Array r ix e
arr
  {-# INLINE powerPointwise #-}

  unsafeLiftArray :: Index ix => (e -> e) -> Array r ix e -> Array r ix e

  unsafeLiftArray2 :: Index ix => (e -> e -> e) -> Array r ix e -> Array r ix e -> Array r ix e

defaultUnsafeLiftArray
  :: (Load r ix e, Source r e) => (e -> e) -> Array r ix e -> Array r ix e
defaultUnsafeLiftArray :: forall r ix e.
(Load r ix e, Source r e) =>
(e -> e) -> Array r ix e -> Array r ix e
defaultUnsafeLiftArray e -> e
f Array r ix e
arr = Comp -> Sz ix -> (Int -> e) -> Array r ix e
forall r ix e.
Load r ix e =>
Comp -> Sz ix -> (Int -> e) -> Array r ix e
makeArrayLinear (Array r ix e -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array r ix e -> Comp
getComp Array r ix e
arr) (Array r ix e -> Sz ix
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array r ix e -> Sz ix
size Array r ix e
arr) (e -> e
f (e -> e) -> (Int -> e) -> Int -> e
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Array r ix e -> Int -> e
forall ix. Index ix => Array r ix e -> Int -> e
forall r e ix. (Source r e, Index ix) => Array r ix e -> Int -> e
unsafeLinearIndex Array r ix e
arr)
{-# INLINE defaultUnsafeLiftArray #-}

defaultUnsafeLiftArray2
  :: (Load r ix e, Source r e)
  => (e -> e -> e)
  -> Array r ix e
  -> Array r ix e
  -> Array r ix e
defaultUnsafeLiftArray2 :: forall r ix e.
(Load r ix e, Source r e) =>
(e -> e -> e) -> Array r ix e -> Array r ix e -> Array r ix e
defaultUnsafeLiftArray2 e -> e -> e
f Array r ix e
a1 Array r ix e
a2 =
  Comp -> Sz ix -> (Int -> e) -> Array r ix e
forall r ix e.
Load r ix e =>
Comp -> Sz ix -> (Int -> e) -> Array r ix e
makeArrayLinear (Array r ix e -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array r ix e -> Comp
getComp Array r ix e
a1 Comp -> Comp -> Comp
forall a. Semigroup a => a -> a -> a
<> Array r ix e -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array r ix e -> Comp
getComp Array r ix e
a2) (Array r ix e -> Sz ix
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array r ix e -> Sz ix
size Array r ix e
a1) ((Int -> e) -> Array r ix e) -> (Int -> e) -> Array r ix e
forall a b. (a -> b) -> a -> b
$ \ !Int
i ->
    e -> e -> e
f (Array r ix e -> Int -> e
forall ix. Index ix => Array r ix e -> Int -> e
forall r e ix. (Source r e, Index ix) => Array r ix e -> Int -> e
unsafeLinearIndex Array r ix e
a1 Int
i) (Array r ix e -> Int -> e
forall ix. Index ix => Array r ix e -> Int -> e
forall r e ix. (Source r e, Index ix) => Array r ix e -> Int -> e
unsafeLinearIndex Array r ix e
a2 Int
i)
{-# INLINE defaultUnsafeLiftArray2 #-}

class (Numeric r e, Floating e) => NumericFloat r e where
  divideScalar :: Index ix => Array r ix e -> e -> Array r ix e
  divideScalar Array r ix e
arr e
e = (e -> e) -> Array r ix e -> Array r ix e
forall ix. Index ix => (e -> e) -> Array r ix e -> Array r ix e
forall r e ix.
(Numeric r e, Index ix) =>
(e -> e) -> Array r ix e -> Array r ix e
unsafeLiftArray (e -> e -> e
forall a. Fractional a => a -> a -> a
/ e
e) Array r ix e
arr
  {-# INLINE divideScalar #-}

  scalarDivide :: Index ix => e -> Array r ix e -> Array r ix e
  scalarDivide e
e Array r ix e
arr = (e -> e) -> Array r ix e -> Array r ix e
forall ix. Index ix => (e -> e) -> Array r ix e -> Array r ix e
forall r e ix.
(Numeric r e, Index ix) =>
(e -> e) -> Array r ix e -> Array r ix e
unsafeLiftArray (e
e e -> e -> e
forall a. Fractional a => a -> a -> a
/) Array r ix e
arr
  {-# INLINE scalarDivide #-}

  divisionPointwise :: Index ix => Array r ix e -> Array r ix e -> Array r ix e
  divisionPointwise = (e -> e -> e) -> Array r ix e -> Array r ix e -> Array r ix e
forall ix.
Index ix =>
(e -> e -> e) -> Array r ix e -> Array r ix e -> Array r ix e
forall r e ix.
(Numeric r e, Index ix) =>
(e -> e -> e) -> Array r ix e -> Array r ix e -> Array r ix e
unsafeLiftArray2 e -> e -> e
forall a. Fractional a => a -> a -> a
(/)
  {-# INLINE divisionPointwise #-}

  recipPointwise :: Index ix => Array r ix e -> Array r ix e
  recipPointwise = (e -> e) -> Array r ix e -> Array r ix e
forall ix. Index ix => (e -> e) -> Array r ix e -> Array r ix e
forall r e ix.
(Numeric r e, Index ix) =>
(e -> e) -> Array r ix e -> Array r ix e
unsafeLiftArray e -> e
forall a. Fractional a => a -> a
recip
  {-# INLINE recipPointwise #-}

  sqrtPointwise :: Index ix => Array r ix e -> Array r ix e
  sqrtPointwise = (e -> e) -> Array r ix e -> Array r ix e
forall ix. Index ix => (e -> e) -> Array r ix e -> Array r ix e
forall r e ix.
(Numeric r e, Index ix) =>
(e -> e) -> Array r ix e -> Array r ix e
unsafeLiftArray e -> e
forall a. Floating a => a -> a
sqrt
  {-# INLINE sqrtPointwise #-}

-- floorPointwise :: (Index ix, Integral a) => Array r ix e -> Array r ix a
-- floorPointwise = unsafeLiftArray floor
-- {-# INLINE floorPointwise #-}

-- ceilingPointwise :: (Index ix, Integral a) => Array r ix e -> Array r ix a
-- ceilingPointwise = unsafeLiftArray ceiling
-- {-# INLINE ceilingPointwise #-}

-- class Equality r e where

--   unsafeEq :: Index ix => Array r ix e -> Array r ix e -> Bool

--   unsafeEqPointwise :: Index ix => Array r ix e -> Array r ix e -> Array r ix Bool

-- class Relation r e where

--   unsafePointwiseLT :: Array r ix e -> Array r ix e -> Array r ix Bool
--   unsafePointwiseLTE :: Array r ix e -> Array r ix e -> Array r ix Bool

--   unsafePointwiseGT :: Array r ix e -> Array r ix e -> Array r ix Bool
--   unsafePointwiseGTE :: Array r ix e -> Array r ix e -> Array r ix Bool

--   unsafePointwiseMin :: Array r ix e -> Array r ix e -> Array r ix e
--   unsafePointwiseMax :: Array r ix e -> Array r ix e -> Array r ix e

--   unsafeMinimum :: Array r ix e -> e

--   unsafeMaximum :: Array r ix e -> e