{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleContexts #-}
module Data.Massiv.Array.Stencil.Convolution (
makeConvolutionStencil,
makeConvolutionStencilFromKernel,
makeCorrelationStencil,
makeCorrelationStencilFromKernel,
) where
import Data.Massiv.Array.Ops.Fold (ifoldlS)
import Data.Massiv.Array.Stencil.Internal
import Data.Massiv.Core.Common
import GHC.Exts (inline)
makeConvolutionStencil
:: (Index ix, Num e)
=> Sz ix
-> ix
-> ((ix -> e -> e -> e) -> e -> e)
-> Stencil ix e e
makeConvolutionStencil :: forall ix e.
(Index ix, Num e) =>
Sz ix -> ix -> ((ix -> e -> e -> e) -> e -> e) -> Stencil ix e e
makeConvolutionStencil !Sz ix
sz !ix
sCenter (ix -> e -> e -> e) -> e -> e
relStencil =
Sz ix
-> ix -> ((ix -> e) -> (ix -> e) -> ix -> e) -> Stencil ix e e
forall ix e a.
Sz ix
-> ix -> ((ix -> e) -> (ix -> e) -> ix -> a) -> Stencil ix e a
Stencil Sz ix
sz ix
sInvertCenter (ix -> e) -> (ix -> e) -> ix -> e
forall {p}. p -> (ix -> e) -> ix -> e
stencil
where
!sInvertCenter :: ix
sInvertCenter = (Int -> Int -> Int) -> ix -> ix -> ix
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 (-) ((Int -> Int) -> ix -> ix
forall ix. Index ix => (Int -> Int) -> ix -> ix
liftIndex (Int -> Int -> Int
forall a. Num a => a -> a -> a
subtract Int
1) (Sz ix -> ix
forall ix. Sz ix -> ix
unSz Sz ix
sz)) ix
sCenter
stencil :: p -> (ix -> e) -> ix -> e
stencil p
_ ix -> e
getVal !ix
ix =
(((ix -> e -> e -> e) -> e -> e) -> (ix -> e -> e -> e) -> e -> e
forall a. a -> a
inline (ix -> e -> e -> e) -> e -> e
relStencil ((ix -> e -> e -> e) -> e -> e) -> (ix -> e -> e -> e) -> e -> e
forall a b. (a -> b) -> a -> b
$ \ !ix
ixD !e
kVal !e
acc -> ix -> e
getVal ((Int -> Int -> Int) -> ix -> ix -> ix
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 (-) ix
ix ix
ixD) e -> e -> e
forall a. Num a => a -> a -> a
* e
kVal e -> e -> e
forall a. Num a => a -> a -> a
+ e
acc) e
0
{-# INLINE stencil #-}
{-# INLINE makeConvolutionStencil #-}
makeConvolutionStencilFromKernel
:: (Manifest r e, Index ix, Num e)
=> Array r ix e
-> Stencil ix e e
makeConvolutionStencilFromKernel :: forall r e ix.
(Manifest r e, Index ix, Num e) =>
Array r ix e -> Stencil ix e e
makeConvolutionStencilFromKernel Array r ix e
kArr = Sz ix
-> ix -> ((ix -> e) -> (ix -> e) -> ix -> e) -> Stencil ix e e
forall ix e a.
Sz ix
-> ix -> ((ix -> e) -> (ix -> e) -> ix -> a) -> Stencil ix e a
Stencil Sz ix
sz ix
sInvertCenter (ix -> e) -> (ix -> e) -> ix -> e
forall {p}. (ix -> e) -> p -> ix -> e
stencil
where
!sz :: Sz ix
sz@(Sz ix
szi) = Array r ix e -> Sz ix
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array r ix e -> Sz ix
size Array r ix e
kArr
!szi1 :: ix
szi1 = (Int -> Int) -> ix -> ix
forall ix. Index ix => (Int -> Int) -> ix -> ix
liftIndex (Int -> Int -> Int
forall a. Num a => a -> a -> a
subtract Int
1) ix
szi
!sInvertCenter :: ix
sInvertCenter = (Int -> Int -> Int) -> ix -> ix -> ix
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 (-) ix
szi1 ix
sCenter
!sCenter :: ix
sCenter = (Int -> Int) -> ix -> ix
forall ix. Index ix => (Int -> Int) -> ix -> ix
liftIndex (Int -> Int -> Int
forall a. Integral a => a -> a -> a
`quot` Int
2) ix
szi
stencil :: (ix -> e) -> p -> ix -> e
stencil ix -> e
uget p
_ !ix
ix = (e -> ix -> e -> e) -> e -> Array r ix e -> e
forall ix r e a.
(Index ix, Source r e) =>
(a -> ix -> e -> a) -> a -> Array r ix e -> a
ifoldlS e -> ix -> e -> e
accum e
0 Array r ix e
kArr
where
!ixOff :: ix
ixOff = (Int -> Int -> Int) -> ix -> ix -> ix
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) ix
ix ix
sCenter
accum :: e -> ix -> e -> e
accum !e
acc !ix
kIx !e
kVal = ix -> e
uget ((Int -> Int -> Int) -> ix -> ix -> ix
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 (-) ix
ixOff ix
kIx) e -> e -> e
forall a. Num a => a -> a -> a
* e
kVal e -> e -> e
forall a. Num a => a -> a -> a
+ e
acc
{-# INLINE accum #-}
{-# INLINE stencil #-}
{-# INLINE makeConvolutionStencilFromKernel #-}
makeCorrelationStencil
:: (Index ix, Num e)
=> Sz ix
-> ix
-> ((ix -> e -> e -> e) -> e -> e)
-> Stencil ix e e
makeCorrelationStencil :: forall ix e.
(Index ix, Num e) =>
Sz ix -> ix -> ((ix -> e -> e -> e) -> e -> e) -> Stencil ix e e
makeCorrelationStencil !Sz ix
sSz !ix
sCenter (ix -> e -> e -> e) -> e -> e
relStencil = Sz ix
-> ix -> ((ix -> e) -> (ix -> e) -> ix -> e) -> Stencil ix e e
forall ix e a.
Sz ix
-> ix -> ((ix -> e) -> (ix -> e) -> ix -> a) -> Stencil ix e a
Stencil Sz ix
sSz ix
sCenter (ix -> e) -> (ix -> e) -> ix -> e
forall {p}. p -> (ix -> e) -> ix -> e
stencil
where
stencil :: p -> (ix -> e) -> ix -> e
stencil p
_ ix -> e
getVal !ix
ix =
(((ix -> e -> e -> e) -> e -> e) -> (ix -> e -> e -> e) -> e -> e
forall a. a -> a
inline (ix -> e -> e -> e) -> e -> e
relStencil ((ix -> e -> e -> e) -> e -> e) -> (ix -> e -> e -> e) -> e -> e
forall a b. (a -> b) -> a -> b
$ \ !ix
ixD !e
kVal !e
acc -> ix -> e
getVal ((Int -> Int -> Int) -> ix -> ix -> ix
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) ix
ix ix
ixD) e -> e -> e
forall a. Num a => a -> a -> a
* e
kVal e -> e -> e
forall a. Num a => a -> a -> a
+ e
acc) e
0
{-# INLINE stencil #-}
{-# INLINE makeCorrelationStencil #-}
makeCorrelationStencilFromKernel
:: (Manifest r e, Index ix, Num e)
=> Array r ix e
-> Stencil ix e e
makeCorrelationStencilFromKernel :: forall r e ix.
(Manifest r e, Index ix, Num e) =>
Array r ix e -> Stencil ix e e
makeCorrelationStencilFromKernel Array r ix e
kArr = Sz ix
-> ix -> ((ix -> e) -> (ix -> e) -> ix -> e) -> Stencil ix e e
forall ix e a.
Sz ix
-> ix -> ((ix -> e) -> (ix -> e) -> ix -> a) -> Stencil ix e a
Stencil Sz ix
sz ix
sCenter (ix -> e) -> (ix -> e) -> ix -> e
forall {p}. (ix -> e) -> p -> ix -> e
stencil
where
!sz :: Sz ix
sz = Array r ix e -> Sz ix
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array r ix e -> Sz ix
size Array r ix e
kArr
!sCenter :: ix
sCenter = (Int -> Int) -> ix -> ix
forall ix. Index ix => (Int -> Int) -> ix -> ix
liftIndex (Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2) (ix -> ix) -> ix -> ix
forall a b. (a -> b) -> a -> b
$ Sz ix -> ix
forall ix. Sz ix -> ix
unSz Sz ix
sz
stencil :: (ix -> e) -> p -> ix -> e
stencil ix -> e
uget p
_ !ix
ix = (e -> ix -> e -> e) -> e -> Array r ix e -> e
forall ix r e a.
(Index ix, Source r e) =>
(a -> ix -> e -> a) -> a -> Array r ix e -> a
ifoldlS e -> ix -> e -> e
accum e
0 Array r ix e
kArr
where
!ixOff :: ix
ixOff = (Int -> Int -> Int) -> ix -> ix -> ix
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 (-) ix
ix ix
sCenter
accum :: e -> ix -> e -> e
accum !e
acc !ix
kIx !e
kVal = ix -> e
uget ((Int -> Int -> Int) -> ix -> ix -> ix
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) ix
ixOff ix
kIx) e -> e -> e
forall a. Num a => a -> a -> a
* e
kVal e -> e -> e
forall a. Num a => a -> a -> a
+ e
acc
{-# INLINE accum #-}
{-# INLINE stencil #-}
{-# INLINE makeCorrelationStencilFromKernel #-}