{-# LANGUAGE ExistentialQuantification #-}

module Unison.Runtime.Foreign.Dynamic where

import Control.Exception
import Control.Monad (unless, when)
import Foreign.ForeignPtr
import Foreign.LibFFI.FFITypes
import Foreign.LibFFI.Internal
import Foreign.Marshal
import Foreign.Ptr
import Foreign.Storable qualified as Store
import Unison.Runtime.FFI.DLL

data FFType
  = I8
  | I16
  | I32
  | I64
  | U8
  | U16
  | U32
  | U64
  | F32
  | D64
  | Void
  | MBArr
  | Ptr
  deriving (FFType -> FFType -> Bool
(FFType -> FFType -> Bool)
-> (FFType -> FFType -> Bool) -> Eq FFType
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: FFType -> FFType -> Bool
== :: FFType -> FFType -> Bool
$c/= :: FFType -> FFType -> Bool
/= :: FFType -> FFType -> Bool
Eq, Eq FFType
Eq FFType =>
(FFType -> FFType -> Ordering)
-> (FFType -> FFType -> Bool)
-> (FFType -> FFType -> Bool)
-> (FFType -> FFType -> Bool)
-> (FFType -> FFType -> Bool)
-> (FFType -> FFType -> FFType)
-> (FFType -> FFType -> FFType)
-> Ord FFType
FFType -> FFType -> Bool
FFType -> FFType -> Ordering
FFType -> FFType -> FFType
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: FFType -> FFType -> Ordering
compare :: FFType -> FFType -> Ordering
$c< :: FFType -> FFType -> Bool
< :: FFType -> FFType -> Bool
$c<= :: FFType -> FFType -> Bool
<= :: FFType -> FFType -> Bool
$c> :: FFType -> FFType -> Bool
> :: FFType -> FFType -> Bool
$c>= :: FFType -> FFType -> Bool
>= :: FFType -> FFType -> Bool
$cmax :: FFType -> FFType -> FFType
max :: FFType -> FFType -> FFType
$cmin :: FFType -> FFType -> FFType
min :: FFType -> FFType -> FFType
Ord, Int -> FFType -> ShowS
[FFType] -> ShowS
FFType -> [Char]
(Int -> FFType -> ShowS)
-> (FFType -> [Char]) -> ([FFType] -> ShowS) -> Show FFType
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> FFType -> ShowS
showsPrec :: Int -> FFType -> ShowS
$cshow :: FFType -> [Char]
show :: FFType -> [Char]
$cshowList :: [FFType] -> ShowS
showList :: [FFType] -> ShowS
Show)

-- arguments and return type
data FFSpec = FFSpec {FFSpec -> [FFType]
ffArgs :: ![FFType], FFSpec -> FFType
ffResult :: !FFType}
  deriving (FFSpec -> FFSpec -> Bool
(FFSpec -> FFSpec -> Bool)
-> (FFSpec -> FFSpec -> Bool) -> Eq FFSpec
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: FFSpec -> FFSpec -> Bool
== :: FFSpec -> FFSpec -> Bool
$c/= :: FFSpec -> FFSpec -> Bool
/= :: FFSpec -> FFSpec -> Bool
Eq, Eq FFSpec
Eq FFSpec =>
(FFSpec -> FFSpec -> Ordering)
-> (FFSpec -> FFSpec -> Bool)
-> (FFSpec -> FFSpec -> Bool)
-> (FFSpec -> FFSpec -> Bool)
-> (FFSpec -> FFSpec -> Bool)
-> (FFSpec -> FFSpec -> FFSpec)
-> (FFSpec -> FFSpec -> FFSpec)
-> Ord FFSpec
FFSpec -> FFSpec -> Bool
FFSpec -> FFSpec -> Ordering
FFSpec -> FFSpec -> FFSpec
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: FFSpec -> FFSpec -> Ordering
compare :: FFSpec -> FFSpec -> Ordering
$c< :: FFSpec -> FFSpec -> Bool
< :: FFSpec -> FFSpec -> Bool
$c<= :: FFSpec -> FFSpec -> Bool
<= :: FFSpec -> FFSpec -> Bool
$c> :: FFSpec -> FFSpec -> Bool
> :: FFSpec -> FFSpec -> Bool
$c>= :: FFSpec -> FFSpec -> Bool
>= :: FFSpec -> FFSpec -> Bool
$cmax :: FFSpec -> FFSpec -> FFSpec
max :: FFSpec -> FFSpec -> FFSpec
$cmin :: FFSpec -> FFSpec -> FFSpec
min :: FFSpec -> FFSpec -> FFSpec
Ord, Int -> FFSpec -> ShowS
[FFSpec] -> ShowS
FFSpec -> [Char]
(Int -> FFSpec -> ShowS)
-> (FFSpec -> [Char]) -> ([FFSpec] -> ShowS) -> Show FFSpec
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> FFSpec -> ShowS
showsPrec :: Int -> FFSpec -> ShowS
$cshow :: FFSpec -> [Char]
show :: FFSpec -> [Char]
$cshowList :: [FFSpec] -> ShowS
showList :: [FFSpec] -> ShowS
Show)

data CSpec = CSpec
  { CSpec -> ForeignPtr CIF
cInterface :: !(ForeignPtr CIF),
    CSpec -> Int
numArgs :: !Int,
    CSpec -> FFSpec
ffSpec :: !FFSpec
  }

data CDynFunc = forall a.
  CDynFunc
  { CDynFunc -> [Char]
cName :: String,
    CDynFunc -> CSpec
cSpec :: {-# UNPACK #-} !CSpec,
    ()
cFun :: !(FunPtr a)
  }

cffArgs :: CDynFunc -> [FFType]
cffArgs :: CDynFunc -> [FFType]
cffArgs = FFSpec -> [FFType]
ffArgs (FFSpec -> [FFType])
-> (CDynFunc -> FFSpec) -> CDynFunc -> [FFType]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CSpec -> FFSpec
ffSpec (CSpec -> FFSpec) -> (CDynFunc -> CSpec) -> CDynFunc -> FFSpec
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CDynFunc -> CSpec
cSpec

cffResult :: CDynFunc -> FFType
cffResult :: CDynFunc -> FFType
cffResult = FFSpec -> FFType
ffResult (FFSpec -> FFType) -> (CDynFunc -> FFSpec) -> CDynFunc -> FFType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CSpec -> FFSpec
ffSpec (CSpec -> FFSpec) -> (CDynFunc -> CSpec) -> CDynFunc -> FFSpec
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CDynFunc -> CSpec
cSpec

instance Show CDynFunc where
  show :: CDynFunc -> [Char]
show CDynFunc
f = [Char]
"<" [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ CDynFunc -> [Char]
cName CDynFunc
f [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
">"

encodeType :: FFType -> Ptr CType
encodeType :: FFType -> Ptr CType
encodeType FFType
I8 = Ptr CType
ffi_type_sint8
encodeType FFType
I16 = Ptr CType
ffi_type_sint16
encodeType FFType
I32 = Ptr CType
ffi_type_sint32
encodeType FFType
I64 = Ptr CType
ffi_type_sint64
encodeType FFType
U8 = Ptr CType
ffi_type_uint8
encodeType FFType
U16 = Ptr CType
ffi_type_uint16
encodeType FFType
U32 = Ptr CType
ffi_type_uint32
encodeType FFType
U64 = Ptr CType
ffi_type_uint64
encodeType FFType
D64 = Ptr CType
ffi_type_double
encodeType FFType
F32 = Ptr CType
ffi_type_float
encodeType FFType
Void = Ptr CType
ffi_type_void
encodeType FFType
MBArr = Ptr CType
ffi_type_pointer
encodeType FFType
Ptr = Ptr CType
ffi_type_pointer

encodeTypes :: [FFType] -> Ptr (Ptr CType) -> IO ()
encodeTypes :: [FFType] -> Ptr (Ptr CType) -> IO ()
encodeTypes [] !Ptr (Ptr CType)
_ = () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
encodeTypes (FFType
t : [FFType]
ts) !Ptr (Ptr CType)
p = do
  Ptr (Ptr CType) -> Ptr CType -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
Store.poke Ptr (Ptr CType)
p (Ptr CType -> IO ()) -> Ptr CType -> IO ()
forall a b. (a -> b) -> a -> b
$ FFType -> Ptr CType
encodeType FFType
t
  [FFType] -> Ptr (Ptr CType) -> IO ()
encodeTypes [FFType]
ts (Ptr (Ptr CType) -> Int -> Ptr (Ptr CType)
forall a b. Ptr a -> Int -> Ptr b
plusPtr Ptr (Ptr CType)
p Int
sz)
  where
    sz :: Int
sz = Ptr CType -> Int
forall a. Storable a => a -> Int
Store.sizeOf (Ptr CType
forall a. HasCallStack => a
undefined :: Ptr CType)

data PrepException = BadVoid | BadResult | BadInit deriving (Int -> PrepException -> ShowS
[PrepException] -> ShowS
PrepException -> [Char]
(Int -> PrepException -> ShowS)
-> (PrepException -> [Char])
-> ([PrepException] -> ShowS)
-> Show PrepException
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> PrepException -> ShowS
showsPrec :: Int -> PrepException -> ShowS
$cshow :: PrepException -> [Char]
show :: PrepException -> [Char]
$cshowList :: [PrepException] -> ShowS
showList :: [PrepException] -> ShowS
Show)

instance Exception PrepException

adjustSpec :: FFSpec -> IO FFSpec
adjustSpec :: FFSpec -> IO FFSpec
adjustSpec sp :: FFSpec
sp@(FFSpec [FFType]
as FFType
r)
  | [FFType
Void] <- [FFType]
as = FFSpec -> IO FFSpec
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (FFSpec -> IO FFSpec) -> FFSpec -> IO FFSpec
forall a b. (a -> b) -> a -> b
$ [FFType] -> FFType -> FFSpec
FFSpec [] FFType
r
  | (FFType -> Bool) -> [FFType] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (FFType -> FFType -> Bool
forall a. Eq a => a -> a -> Bool
== FFType
Void) [FFType]
as = PrepException -> IO FFSpec
forall e a. Exception e => e -> IO a
throwIO PrepException
BadVoid
  | Bool
otherwise = FFSpec -> IO FFSpec
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure FFSpec
sp

prepareSpec :: FFSpec -> IO CSpec
prepareSpec :: FFSpec -> IO CSpec
prepareSpec FFSpec
spec = do
  ffSpec :: FFSpec
ffSpec@(FFSpec [FFType]
args FFType
ret) <- FFSpec -> IO FFSpec
adjustSpec FFSpec
spec

  Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (FFType
ret FFType -> FFType -> Bool
forall a. Eq a => a -> a -> Bool
== FFType
MBArr) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
    PrepException -> IO ()
forall e a. Exception e => e -> IO a
throwIO PrepException
BadResult

  let numArgs :: Int
numArgs = [FFType] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [FFType]
args
      n :: CUInt
n = Int -> CUInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
numArgs

  ForeignPtr CIF
cInterface <- Int -> IO (ForeignPtr CIF)
forall a. Int -> IO (ForeignPtr a)
mallocForeignPtrBytes Int
sizeOf_cif
  ForeignPtr CIF -> (Ptr CIF -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CIF
cInterface \Ptr CIF
cif ->
    Int -> (Ptr (Ptr CType) -> IO ()) -> IO ()
forall a b. Storable a => Int -> (Ptr a -> IO b) -> IO b
allocaArray Int
numArgs \Ptr (Ptr CType)
argTys -> do
      let retTy :: Ptr CType
retTy = FFType -> Ptr CType
encodeType FFType
ret
      [FFType] -> Ptr (Ptr CType) -> IO ()
encodeTypes [FFType]
args Ptr (Ptr CType)
argTys
      C_ffi_status
status <- Ptr CIF
-> C_ffi_status
-> CUInt
-> Ptr CType
-> Ptr (Ptr CType)
-> IO C_ffi_status
ffi_prep_cif Ptr CIF
cif C_ffi_status
ffi_default_abi CUInt
n Ptr CType
retTy Ptr (Ptr CType)
argTys
      Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (C_ffi_status
status C_ffi_status -> C_ffi_status -> Bool
forall a. Eq a => a -> a -> Bool
== C_ffi_status
ffi_ok) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
        PrepException -> IO ()
forall e a. Exception e => e -> IO a
throwIO PrepException
BadInit

  pure $ CSpec {ForeignPtr CIF
$sel:cInterface:CSpec :: ForeignPtr CIF
cInterface :: ForeignPtr CIF
cInterface, Int
$sel:numArgs:CSpec :: Int
numArgs :: Int
numArgs, FFSpec
$sel:ffSpec:CSpec :: FFSpec
ffSpec :: FFSpec
ffSpec}

loadForeign :: DLL -> FFSpec -> String -> IO CDynFunc
loadForeign :: DLL -> FFSpec -> [Char] -> IO CDynFunc
loadForeign DLL
dll FFSpec
fspec [Char]
sym =
  [Char] -> CSpec -> FunPtr Any -> CDynFunc
forall a. [Char] -> CSpec -> FunPtr a -> CDynFunc
CDynFunc [Char]
name (CSpec -> FunPtr Any -> CDynFunc)
-> IO CSpec -> IO (FunPtr Any -> CDynFunc)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> FFSpec -> IO CSpec
prepareSpec FFSpec
fspec IO (FunPtr Any -> CDynFunc) -> IO (FunPtr Any) -> IO CDynFunc
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> DLL -> [Char] -> IO (FunPtr Any)
forall a. DLL -> [Char] -> IO (FunPtr a)
getDLLSym DLL
dll [Char]
sym
  where
    name :: [Char]
name = DLL -> [Char]
getDLLPath DLL
dll [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
"$" [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
sym

-- Calls a foreign function with arguments stored in memory pointed to by
-- the first pointer, and returning the result to the second pointer. The
-- argument pointer should be to a 64-bit type, and it should point to
-- memory with as many arguments as are taken by the specification of the
-- function argument.
--
-- If some of the function's arguments are smaller than 64-bits, they
-- should be written as individual 64-bit locations in the pointer, so
-- that casting the offset pointer to the smaller type gives the pointer
-- used to write the smaller value to the memory. E.G.
--
--     Store.poke (castPtr (plusPtr p i)) <smaller-value>
callForeign :: CDynFunc -> Ptr (Ptr a) -> Ptr r -> IO ()
callForeign :: forall a r. CDynFunc -> Ptr (Ptr a) -> Ptr r -> IO ()
callForeign (CDynFunc [Char]
_ (CSpec ForeignPtr CIF
cInterface Int
_ FFSpec
_) FunPtr a
fun) Ptr (Ptr a)
cArgs Ptr r
cRet =
  ForeignPtr CIF -> (Ptr CIF -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CIF
cInterface \Ptr CIF
cif ->
    Ptr CIF -> FunPtr a -> Ptr CValue -> Ptr (Ptr CValue) -> IO ()
forall a.
Ptr CIF -> FunPtr a -> Ptr CValue -> Ptr (Ptr CValue) -> IO ()
ffi_call Ptr CIF
cif FunPtr a
fun (Ptr r -> Ptr CValue
forall a b. Ptr a -> Ptr b
castPtr Ptr r
cRet) (Ptr (Ptr a) -> Ptr (Ptr CValue)
forall a b. Ptr a -> Ptr b
castPtr Ptr (Ptr a)
cArgs)