{-# LANGUAGE CPP #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE PatternSynonyms #-}
module Data.Massiv.Core.Index.Stride (
Stride (SafeStride),
pattern Stride,
unStride,
oneStride,
toLinearIndexStride,
strideStart,
strideSize,
) where
import Control.DeepSeq (NFData)
import Data.Massiv.Core.Index.Internal
import System.Random.Stateful (Random, Uniform (..), UniformRange (..))
newtype Stride ix = SafeStride ix deriving (Stride ix -> Stride ix -> Bool
(Stride ix -> Stride ix -> Bool)
-> (Stride ix -> Stride ix -> Bool) -> Eq (Stride ix)
forall ix. Eq ix => Stride ix -> Stride ix -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall ix. Eq ix => Stride ix -> Stride ix -> Bool
== :: Stride ix -> Stride ix -> Bool
$c/= :: forall ix. Eq ix => Stride ix -> Stride ix -> Bool
/= :: Stride ix -> Stride ix -> Bool
Eq, Eq (Stride ix)
Eq (Stride ix) =>
(Stride ix -> Stride ix -> Ordering)
-> (Stride ix -> Stride ix -> Bool)
-> (Stride ix -> Stride ix -> Bool)
-> (Stride ix -> Stride ix -> Bool)
-> (Stride ix -> Stride ix -> Bool)
-> (Stride ix -> Stride ix -> Stride ix)
-> (Stride ix -> Stride ix -> Stride ix)
-> Ord (Stride ix)
Stride ix -> Stride ix -> Bool
Stride ix -> Stride ix -> Ordering
Stride ix -> Stride ix -> Stride ix
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall ix. Ord ix => Eq (Stride ix)
forall ix. Ord ix => Stride ix -> Stride ix -> Bool
forall ix. Ord ix => Stride ix -> Stride ix -> Ordering
forall ix. Ord ix => Stride ix -> Stride ix -> Stride ix
$ccompare :: forall ix. Ord ix => Stride ix -> Stride ix -> Ordering
compare :: Stride ix -> Stride ix -> Ordering
$c< :: forall ix. Ord ix => Stride ix -> Stride ix -> Bool
< :: Stride ix -> Stride ix -> Bool
$c<= :: forall ix. Ord ix => Stride ix -> Stride ix -> Bool
<= :: Stride ix -> Stride ix -> Bool
$c> :: forall ix. Ord ix => Stride ix -> Stride ix -> Bool
> :: Stride ix -> Stride ix -> Bool
$c>= :: forall ix. Ord ix => Stride ix -> Stride ix -> Bool
>= :: Stride ix -> Stride ix -> Bool
$cmax :: forall ix. Ord ix => Stride ix -> Stride ix -> Stride ix
max :: Stride ix -> Stride ix -> Stride ix
$cmin :: forall ix. Ord ix => Stride ix -> Stride ix -> Stride ix
min :: Stride ix -> Stride ix -> Stride ix
Ord, Stride ix -> ()
(Stride ix -> ()) -> NFData (Stride ix)
forall ix. NFData ix => Stride ix -> ()
forall a. (a -> ()) -> NFData a
$crnf :: forall ix. NFData ix => Stride ix -> ()
rnf :: Stride ix -> ()
NFData)
pattern Stride :: Index ix => ix -> Stride ix
pattern $mStride :: forall {r} {ix}.
Index ix =>
Stride ix -> (ix -> r) -> ((# #) -> r) -> r
$bStride :: forall ix. Index ix => ix -> Stride ix
Stride ix <- SafeStride ix
where
Stride ix
ix = ix -> Stride ix
forall ix. ix -> Stride ix
SafeStride ((Int -> Int) -> ix -> ix
forall ix. Index ix => (Int -> Int) -> ix -> ix
liftIndex (Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
1) ix
ix)
{-# COMPLETE Stride #-}
instance Index ix => Show (Stride ix) where
showsPrec :: Int -> Stride ix -> ShowS
showsPrec Int
n (SafeStride ix
ix) = Int -> ShowS -> ShowS
showsPrecWrapped Int
n ((String
"Stride " String -> ShowS
forall a. [a] -> [a] -> [a]
++) ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> ix -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec Int
1 ix
ix)
instance (UniformRange ix, Index ix) => Uniform (Stride ix) where
uniformM :: forall g (m :: * -> *). StatefulGen g m => g -> m (Stride ix)
uniformM g
g = ix -> Stride ix
forall ix. ix -> Stride ix
SafeStride (ix -> Stride ix) -> m ix -> m (Stride ix)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (ix, ix) -> g -> m ix
forall a g (m :: * -> *).
(UniformRange a, StatefulGen g m) =>
(a, a) -> g -> m a
forall g (m :: * -> *). StatefulGen g m => (ix, ix) -> g -> m ix
uniformRM (Int -> ix
forall ix. Index ix => Int -> ix
pureIndex Int
1, Int -> ix
forall ix. Index ix => Int -> ix
pureIndex Int
forall a. Bounded a => a
maxBound) g
g
{-# INLINE uniformM #-}
instance UniformRange ix => UniformRange (Stride ix) where
uniformRM :: forall g (m :: * -> *).
StatefulGen g m =>
(Stride ix, Stride ix) -> g -> m (Stride ix)
uniformRM (SafeStride ix
l, SafeStride ix
u) g
g = ix -> Stride ix
forall ix. ix -> Stride ix
SafeStride (ix -> Stride ix) -> m ix -> m (Stride ix)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (ix, ix) -> g -> m ix
forall a g (m :: * -> *).
(UniformRange a, StatefulGen g m) =>
(a, a) -> g -> m a
forall g (m :: * -> *). StatefulGen g m => (ix, ix) -> g -> m ix
uniformRM (ix
l, ix
u) g
g
{-# INLINE uniformRM #-}
instance (UniformRange ix, Index ix) => Random (Stride ix)
unStride :: Stride ix -> ix
unStride :: forall ix. Stride ix -> ix
unStride (SafeStride ix
ix) = ix
ix
{-# INLINE unStride #-}
strideStart :: Index ix => Stride ix -> ix -> ix
strideStart :: forall ix. Index ix => Stride ix -> ix -> ix
strideStart (SafeStride ix
stride) ix
ix =
(Int -> Int -> Int) -> ix -> ix -> ix
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2
Int -> Int -> Int
forall a. Num a => a -> a -> a
(+)
ix
ix
((Int -> Int -> Int) -> ix -> ix -> ix
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 Int -> Int -> Int
forall a. Integral a => a -> a -> a
mod ((Int -> Int -> Int) -> ix -> ix -> ix
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 Int -> Int -> Int
forall a. Num a => a -> a -> a
subtract ((Int -> Int -> Int) -> ix -> ix -> ix
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 Int -> Int -> Int
forall a. Integral a => a -> a -> a
mod ix
ix ix
stride) ix
stride) ix
stride)
{-# INLINE strideStart #-}
strideSize :: Index ix => Stride ix -> Sz ix -> Sz ix
strideSize :: forall ix. Index ix => Stride ix -> Sz ix -> Sz ix
strideSize (SafeStride ix
stride) (SafeSz ix
sz) =
ix -> Sz ix
forall ix. ix -> Sz ix
SafeSz ((Int -> Int) -> ix -> ix
forall ix. Index ix => (Int -> Int) -> ix -> ix
liftIndex (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (ix -> ix) -> ix -> ix
forall a b. (a -> b) -> a -> b
$ (Int -> Int -> Int) -> ix -> ix -> ix
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 Int -> Int -> Int
forall a. Integral a => a -> a -> a
div ((Int -> Int) -> ix -> ix
forall ix. Index ix => (Int -> Int) -> ix -> ix
liftIndex (Int -> Int -> Int
forall a. Num a => a -> a -> a
subtract Int
1) ix
sz) ix
stride)
{-# INLINE strideSize #-}
toLinearIndexStride
:: Index ix
=> Stride ix
-> Sz ix
-> ix
-> Int
toLinearIndexStride :: forall ix. Index ix => Stride ix -> Sz ix -> ix -> Int
toLinearIndexStride (SafeStride ix
stride) Sz ix
sz ix
ix = Sz ix -> ix -> Int
forall ix. Index ix => Sz ix -> ix -> Int
toLinearIndex Sz ix
sz ((Int -> Int -> Int) -> ix -> ix -> ix
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 Int -> Int -> Int
forall a. Integral a => a -> a -> a
div ix
ix ix
stride)
{-# INLINE toLinearIndexStride #-}
oneStride :: Index ix => Stride ix
oneStride :: forall ix. Index ix => Stride ix
oneStride = ix -> Stride ix
forall ix. ix -> Stride ix
SafeStride (Int -> ix
forall ix. Index ix => Int -> ix
pureIndex Int
1)
{-# INLINE oneStride #-}