{-# LANGUAGE CPP #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE StandaloneKindSignatures #-}

-- This module wraps the operations in the primitive package so that
-- bounds checks can be toggled on during the build for debugging
-- purposes. It exports the entire API for the three array types
-- needed, and adds wrappers for the operations that are unchecked in
-- the base library.
--
-- Checking is toggled using the `arraychecks` flag.
module Unison.Runtime.Array
  ( module EPA,
    byteArrayToList,
    readArray,
    writeArray,
    copyArray,
    copyMutableArray,
    cloneMutableArray,
    readByteArray,
    writeByteArray,
    indexByteArray,
    copyByteArray,
    copyMutableByteArray,
    moveByteArray,
    readPrimArray,
    writePrimArray,
    indexPrimArray,
  )
where

import Control.Monad.Primitive
import Data.Kind (Constraint)
import Data.Primitive.Array as EPA hiding
  ( cloneMutableArray,
    copyArray,
    copyMutableArray,
    readArray,
    writeArray,
  )
import Data.Primitive.Array qualified as PA
import Data.Primitive.ByteArray as EPA hiding
  ( copyByteArray,
    copyMutableByteArray,
    indexByteArray,
    moveByteArray,
    readByteArray,
    writeByteArray,
  )
import Data.Primitive.ByteArray qualified as PA
import Data.Primitive.PrimArray as EPA hiding
  ( indexPrimArray,
    readPrimArray,
    writePrimArray,
  )
import Data.Primitive.PrimArray qualified as PA
import Data.Primitive.Types
import Data.Word (Word8)
import GHC.IsList (toList)

#ifdef ARRAY_CHECK
import GHC.Stack

type CheckCtx :: Constraint
type CheckCtx = HasCallStack

type MA = MutableArray
type MBA = MutableByteArray
type A = Array
type BA = ByteArray

-- check index mutable array
checkIMArray
  :: CheckCtx
  => String
  -> (MA s a -> Int -> r)
  -> MA s a -> Int -> r
checkIMArray name f arr i
  | i < 0 || sizeofMutableArray arr <= i
  = error $ name ++ " unsafe check out of bounds: " ++ show i
  | otherwise = f arr i
{-# inline checkIMArray #-}

-- check copy array
checkCArray
  :: CheckCtx
  => String
  -> (MA s a -> Int -> A a -> Int -> Int -> r)
  -> MA s a -> Int -> A a -> Int -> Int -> r
checkCArray name f dst d src s l
  | d < 0
  || s < 0
  || sizeofMutableArray dst < d + l
  || sizeofArray src < s + l
  = error $ name ++ " unsafe check out of bounds: " ++ show (d, s, l)
  | otherwise = f dst d src s l
{-# inline checkCArray #-}

-- check copy mutable array
checkCMArray
  :: CheckCtx
  => String
  -> (MA s a -> Int -> MA s a -> Int -> Int -> r)
  -> MA s a -> Int -> MA s a -> Int -> Int -> r
checkCMArray name f dst d src s l
  | d < 0
  || s < 0
  || sizeofMutableArray dst < d + l
  || sizeofMutableArray src < s + l
  = error $ name ++ " unsafe check out of bounds: " ++ show (d, s, l)
  | otherwise = f dst d src s l
{-# inline checkCMArray #-}

-- check range mutable array
checkRMArray
  :: CheckCtx
  => String
  -> (MA s a -> Int -> Int -> r)
  -> MA s a -> Int -> Int -> r
checkRMArray name f arr o l
  | o < 0 || sizeofMutableArray arr < o+l
  = error $ name ++ "unsafe check out of bounds: " ++ show (o, l)
  | otherwise = f arr o l
{-# inline checkRMArray #-}

-- check index byte array
checkIBArray
  :: CheckCtx
  => Prim a
  => String
  -> a
  -> (ByteArray -> Int -> r)
  -> ByteArray -> Int -> r
checkIBArray name a f arr i
  | i < 0 || sizeofByteArray arr `quot` sizeOf a <= i
  = error $ name ++ " unsafe check out of bounds: " ++ show i
  | otherwise = f arr i
{-# inline checkIBArray #-}

-- check index mutable byte array
checkIMBArray
  :: CheckCtx
  => Prim a
  => String
  -> a
  -> (MutableByteArray s -> Int -> r)
  -> MutableByteArray s -> Int -> r
checkIMBArray name a f arr i
  | i < 0 || sizeofMutableByteArray arr `quot` sizeOf a <= i
  = error $ name ++ " unsafe check out of bounds: " ++ show i
  | otherwise = f arr i
{-# inline checkIMBArray #-}

-- check copy byte array
checkCBArray
  :: CheckCtx
  => String
  -> (MBA s -> Int -> BA -> Int -> Int -> r)
  -> MBA s -> Int -> BA -> Int -> Int -> r
checkCBArray name f dst d src s l
  | d < 0
  || s < 0
  || sizeofMutableByteArray dst < d + l
  || sizeofByteArray src < s + l
  = error $ name ++ " unsafe check out of bounds: " ++ show (d, s, l)
  | otherwise = f dst d src s l
{-# inline checkCBArray #-}

-- check copy mutable byte array
checkCMBArray
  :: CheckCtx
  => String
  -> (MBA s -> Int -> MBA s -> Int -> Int -> r)
  -> MBA s -> Int -> MBA s -> Int -> Int -> r
checkCMBArray name f dst d src s l
  | d < 0
  || s < 0
  || sizeofMutableByteArray dst < d + l
  || sizeofMutableByteArray src < s + l
  = error $ name ++ " unsafe check out of bounds: " ++ show (d, s, l)
  | otherwise = f dst d src s l
{-# inline checkCMBArray #-}

-- check index prim array
checkIPArray
  :: CheckCtx
  => Prim a
  => String
  -> (PrimArray a -> Int -> r)
  -> PrimArray a -> Int -> r
checkIPArray name f arr i
  | i < 0 || sizeofPrimArray arr <= i
  = error $ name ++ " unsafe check out of bounds: " ++ show i
  | otherwise = f arr i
{-# inline checkIPArray #-}

-- check index mutable prim array
checkIMPArray
  :: CheckCtx
  => Prim a
  => String
  -> (MutablePrimArray s a -> Int -> r)
  -> MutablePrimArray s a -> Int -> r
checkIMPArray name f arr i
  | i < 0 || sizeofMutablePrimArray arr <= i
  = error $ name ++ " unsafe check out of bounds: " ++ show i
  | otherwise = f arr i
{-# inline checkIMPArray #-}

#else
type CheckCtx :: Constraint
type CheckCtx = ()

checkIMArray, checkIMPArray, checkIPArray :: String -> r -> r
checkCArray, checkCMArray, checkRMArray :: String -> r -> r
checkIMArray :: forall r. String -> r -> r
checkIMArray String
_ = r -> r
forall a. a -> a
id
checkIMPArray :: forall r. String -> r -> r
checkIMPArray String
_ = r -> r
forall a. a -> a
id
checkCArray :: forall r. String -> r -> r
checkCArray String
_ = r -> r
forall a. a -> a
id
checkCMArray :: forall r. String -> r -> r
checkCMArray String
_ = r -> r
forall a. a -> a
id
checkRMArray :: forall r. String -> r -> r
checkRMArray String
_ = r -> r
forall a. a -> a
id
checkIPArray :: forall r. String -> r -> r
checkIPArray String
_ = r -> r
forall a. a -> a
id

checkIBArray, checkIMBArray :: String -> a -> r -> r
checkCBArray, checkCMBArray :: String -> r -> r
checkIBArray :: forall a r. String -> a -> r -> r
checkIBArray String
_ a
_ = r -> r
forall a. a -> a
id
checkIMBArray :: forall a r. String -> a -> r -> r
checkIMBArray String
_ a
_ = r -> r
forall a. a -> a
id
checkCBArray :: forall r. String -> r -> r
checkCBArray String
_ = r -> r
forall a. a -> a
id
checkCMBArray :: forall r. String -> r -> r
checkCMBArray String
_ = r -> r
forall a. a -> a
id
#endif

readArray ::
  (CheckCtx) =>
  (PrimMonad m) =>
  MutableArray (PrimState m) a ->
  Int ->
  m a
readArray :: forall (m :: * -> *) a.
(CheckCtx, PrimMonad m) =>
MutableArray (PrimState m) a -> Int -> m a
readArray = String
-> (MutableArray (PrimState m) a -> Int -> m a)
-> MutableArray (PrimState m) a
-> Int
-> m a
forall r. String -> r -> r
checkIMArray String
"readArray" MutableArray (PrimState m) a -> Int -> m a
forall (m :: * -> *) a.
PrimMonad m =>
MutableArray (PrimState m) a -> Int -> m a
PA.readArray
{-# INLINE readArray #-}

writeArray ::
  (CheckCtx) =>
  (PrimMonad m) =>
  MutableArray (PrimState m) a ->
  Int ->
  a ->
  m ()
writeArray :: forall (m :: * -> *) a.
(CheckCtx, PrimMonad m) =>
MutableArray (PrimState m) a -> Int -> a -> m ()
writeArray = String
-> (MutableArray (PrimState m) a -> Int -> a -> m ())
-> MutableArray (PrimState m) a
-> Int
-> a
-> m ()
forall r. String -> r -> r
checkIMArray String
"writeArray" MutableArray (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
MutableArray (PrimState m) a -> Int -> a -> m ()
PA.writeArray
{-# INLINE writeArray #-}

copyArray ::
  (CheckCtx) =>
  (PrimMonad m) =>
  MutableArray (PrimState m) a ->
  Int ->
  Array a ->
  Int ->
  Int ->
  m ()
copyArray :: forall (m :: * -> *) a.
(CheckCtx, PrimMonad m) =>
MutableArray (PrimState m) a
-> Int -> Array a -> Int -> Int -> m ()
copyArray = String
-> (MutableArray (PrimState m) a
    -> Int -> Array a -> Int -> Int -> m ())
-> MutableArray (PrimState m) a
-> Int
-> Array a
-> Int
-> Int
-> m ()
forall r. String -> r -> r
checkCArray String
"copyArray" MutableArray (PrimState m) a
-> Int -> Array a -> Int -> Int -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
MutableArray (PrimState m) a
-> Int -> Array a -> Int -> Int -> m ()
PA.copyArray
{-# INLINE copyArray #-}

cloneMutableArray ::
  (CheckCtx) =>
  (PrimMonad m) =>
  MutableArray (PrimState m) a ->
  Int ->
  Int ->
  m (MutableArray (PrimState m) a)
cloneMutableArray :: forall (m :: * -> *) a.
(CheckCtx, PrimMonad m) =>
MutableArray (PrimState m) a
-> Int -> Int -> m (MutableArray (PrimState m) a)
cloneMutableArray = String
-> (MutableArray (PrimState m) a
    -> Int -> Int -> m (MutableArray (PrimState m) a))
-> MutableArray (PrimState m) a
-> Int
-> Int
-> m (MutableArray (PrimState m) a)
forall r. String -> r -> r
checkRMArray String
"cloneMutableArray" MutableArray (PrimState m) a
-> Int -> Int -> m (MutableArray (PrimState m) a)
forall (m :: * -> *) a.
PrimMonad m =>
MutableArray (PrimState m) a
-> Int -> Int -> m (MutableArray (PrimState m) a)
PA.cloneMutableArray
{-# INLINE cloneMutableArray #-}

copyMutableArray ::
  (CheckCtx) =>
  (PrimMonad m) =>
  MutableArray (PrimState m) a ->
  Int ->
  MutableArray (PrimState m) a ->
  Int ->
  Int ->
  m ()
copyMutableArray :: forall (m :: * -> *) a.
(CheckCtx, PrimMonad m) =>
MutableArray (PrimState m) a
-> Int -> MutableArray (PrimState m) a -> Int -> Int -> m ()
copyMutableArray = String
-> (MutableArray (PrimState m) a
    -> Int -> MutableArray (PrimState m) a -> Int -> Int -> m ())
-> MutableArray (PrimState m) a
-> Int
-> MutableArray (PrimState m) a
-> Int
-> Int
-> m ()
forall r. String -> r -> r
checkCMArray String
"copyMutableArray" MutableArray (PrimState m) a
-> Int -> MutableArray (PrimState m) a -> Int -> Int -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
MutableArray (PrimState m) a
-> Int -> MutableArray (PrimState m) a -> Int -> Int -> m ()
PA.copyMutableArray
{-# INLINE copyMutableArray #-}

readByteArray ::
  forall a m.
  (CheckCtx) =>
  (PrimMonad m) =>
  (Prim a) =>
  MutableByteArray (PrimState m) ->
  Int ->
  m a
readByteArray :: forall a (m :: * -> *).
(CheckCtx, PrimMonad m, Prim a) =>
MutableByteArray (PrimState m) -> Int -> m a
readByteArray = forall a r. String -> a -> r -> r
checkIMBArray @a String
"readByteArray" a
forall a. HasCallStack => a
undefined MutableByteArray (PrimState m) -> Int -> m a
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> m a
PA.readByteArray
{-# INLINE readByteArray #-}

writeByteArray ::
  forall a m.
  (CheckCtx) =>
  (PrimMonad m) =>
  (Prim a) =>
  MutableByteArray (PrimState m) ->
  Int ->
  a ->
  m ()
writeByteArray :: forall a (m :: * -> *).
(CheckCtx, PrimMonad m, Prim a) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
writeByteArray = forall a r. String -> a -> r -> r
checkIMBArray @a String
"writeByteArray" a
forall a. HasCallStack => a
undefined MutableByteArray (PrimState m) -> Int -> a -> m ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PA.writeByteArray
{-# INLINE writeByteArray #-}

indexByteArray ::
  forall a.
  (CheckCtx) =>
  (Prim a) =>
  ByteArray ->
  Int ->
  a
indexByteArray :: forall a. (CheckCtx, Prim a) => ByteArray -> Int -> a
indexByteArray = forall a r. String -> a -> r -> r
checkIBArray @a String
"indexByteArray" a
forall a. HasCallStack => a
undefined ByteArray -> Int -> a
forall a. Prim a => ByteArray -> Int -> a
PA.indexByteArray
{-# INLINE indexByteArray #-}

copyByteArray ::
  (CheckCtx) =>
  (PrimMonad m) =>
  MutableByteArray (PrimState m) ->
  Int ->
  ByteArray ->
  Int ->
  Int ->
  m ()
copyByteArray :: forall (m :: * -> *).
(CheckCtx, PrimMonad m) =>
MutableByteArray (PrimState m)
-> Int -> ByteArray -> Int -> Int -> m ()
copyByteArray = String
-> (MutableByteArray (PrimState m)
    -> Int -> ByteArray -> Int -> Int -> m ())
-> MutableByteArray (PrimState m)
-> Int
-> ByteArray
-> Int
-> Int
-> m ()
forall r. String -> r -> r
checkCBArray String
"copyByteArray" MutableByteArray (PrimState m)
-> Int -> ByteArray -> Int -> Int -> m ()
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m)
-> Int -> ByteArray -> Int -> Int -> m ()
PA.copyByteArray
{-# INLINE copyByteArray #-}

copyMutableByteArray ::
  (CheckCtx) =>
  (PrimMonad m) =>
  MutableByteArray (PrimState m) ->
  Int ->
  MutableByteArray (PrimState m) ->
  Int ->
  Int ->
  m ()
copyMutableByteArray :: forall (m :: * -> *).
(CheckCtx, PrimMonad m) =>
MutableByteArray (PrimState m)
-> Int -> MutableByteArray (PrimState m) -> Int -> Int -> m ()
copyMutableByteArray = String
-> (MutableByteArray (PrimState m)
    -> Int -> MutableByteArray (PrimState m) -> Int -> Int -> m ())
-> MutableByteArray (PrimState m)
-> Int
-> MutableByteArray (PrimState m)
-> Int
-> Int
-> m ()
forall r. String -> r -> r
checkCMBArray String
"copyMutableByteArray" MutableByteArray (PrimState m)
-> Int -> MutableByteArray (PrimState m) -> Int -> Int -> m ()
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m)
-> Int -> MutableByteArray (PrimState m) -> Int -> Int -> m ()
PA.copyMutableByteArray
{-# INLINE copyMutableByteArray #-}

moveByteArray ::
  (CheckCtx) =>
  (PrimMonad m) =>
  MutableByteArray (PrimState m) ->
  Int ->
  MutableByteArray (PrimState m) ->
  Int ->
  Int ->
  m ()
moveByteArray :: forall (m :: * -> *).
(CheckCtx, PrimMonad m) =>
MutableByteArray (PrimState m)
-> Int -> MutableByteArray (PrimState m) -> Int -> Int -> m ()
moveByteArray = String
-> (MutableByteArray (PrimState m)
    -> Int -> MutableByteArray (PrimState m) -> Int -> Int -> m ())
-> MutableByteArray (PrimState m)
-> Int
-> MutableByteArray (PrimState m)
-> Int
-> Int
-> m ()
forall r. String -> r -> r
checkCMBArray String
"moveByteArray" MutableByteArray (PrimState m)
-> Int -> MutableByteArray (PrimState m) -> Int -> Int -> m ()
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m)
-> Int -> MutableByteArray (PrimState m) -> Int -> Int -> m ()
PA.moveByteArray
{-# INLINE moveByteArray #-}

readPrimArray ::
  (CheckCtx) =>
  (PrimMonad m) =>
  (Prim a) =>
  MutablePrimArray (PrimState m) a ->
  Int ->
  m a
readPrimArray :: forall (m :: * -> *) a.
(CheckCtx, PrimMonad m, Prim a) =>
MutablePrimArray (PrimState m) a -> Int -> m a
readPrimArray = String
-> (MutablePrimArray (PrimState m) a -> Int -> m a)
-> MutablePrimArray (PrimState m) a
-> Int
-> m a
forall r. String -> r -> r
checkIMPArray String
"readPrimArray" MutablePrimArray (PrimState m) a -> Int -> m a
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> m a
PA.readPrimArray
{-# INLINE readPrimArray #-}

writePrimArray ::
  (CheckCtx) =>
  (PrimMonad m) =>
  (Prim a) =>
  MutablePrimArray (PrimState m) a ->
  Int ->
  a ->
  m ()
writePrimArray :: forall (m :: * -> *) a.
(CheckCtx, PrimMonad m, Prim a) =>
MutablePrimArray (PrimState m) a -> Int -> a -> m ()
writePrimArray = String
-> (MutablePrimArray (PrimState m) a -> Int -> a -> m ())
-> MutablePrimArray (PrimState m) a
-> Int
-> a
-> m ()
forall r. String -> r -> r
checkIMPArray String
"writePrimArray" MutablePrimArray (PrimState m) a -> Int -> a -> m ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> a -> m ()
PA.writePrimArray
{-# INLINE writePrimArray #-}

indexPrimArray ::
  (CheckCtx) =>
  (Prim a) =>
  PrimArray a ->
  Int ->
  a
indexPrimArray :: forall a. (CheckCtx, Prim a) => PrimArray a -> Int -> a
indexPrimArray = String -> (PrimArray a -> Int -> a) -> PrimArray a -> Int -> a
forall r. String -> r -> r
checkIPArray String
"indexPrimArray" PrimArray a -> Int -> a
forall a. Prim a => PrimArray a -> Int -> a
PA.indexPrimArray
{-# INLINE indexPrimArray #-}

byteArrayToList :: ByteArray -> [Word8]
byteArrayToList :: ByteArray -> [Word8]
byteArrayToList = ByteArray -> [Word8]
ByteArray -> [Item ByteArray]
forall l. IsList l => l -> [Item l]
toList