{-# LANGUAGE MagicHash #-}
{-# LANGUAGE UnboxedTuples #-}

-- | Miscellaneous IO utilities
module Ki.Internal.IO
  ( -- * Unexceptional IO
    UnexceptionalIO (..),
    IOResult (..),
    unexceptionalTry,
    unexceptionalTryEither,

    -- * Exception utils
    isAsyncException,
    interruptiblyMasked,
    uninterruptiblyMasked,
    tryEitherSTM,

    -- * Fork utils
    forkIO,
    forkOn,
  )
where

import Control.Exception
import Control.Monad (join)
import Data.Coerce (coerce)
import GHC.Base (maskAsyncExceptions#, maskUninterruptible#)
import GHC.Conc (STM, ThreadId (ThreadId), catchSTM)
import GHC.Exts (Int (I#), fork#, forkOn#)
import GHC.IO (IO (IO))
import Prelude

-- A little promise that this IO action cannot throw an exception.
--
-- Yeah it's verbose, and maybe not that necessary, but the code that bothers to use it really does require
-- un-exceptiony IO actions for correctness, so here we are.
newtype UnexceptionalIO a = UnexceptionalIO
  {forall a. UnexceptionalIO a -> IO a
runUnexceptionalIO :: IO a}
  deriving newtype (Functor UnexceptionalIO
Functor UnexceptionalIO =>
(forall a. a -> UnexceptionalIO a)
-> (forall a b.
    UnexceptionalIO (a -> b) -> UnexceptionalIO a -> UnexceptionalIO b)
-> (forall a b c.
    (a -> b -> c)
    -> UnexceptionalIO a -> UnexceptionalIO b -> UnexceptionalIO c)
-> (forall a b.
    UnexceptionalIO a -> UnexceptionalIO b -> UnexceptionalIO b)
-> (forall a b.
    UnexceptionalIO a -> UnexceptionalIO b -> UnexceptionalIO a)
-> Applicative UnexceptionalIO
forall a. a -> UnexceptionalIO a
forall a b.
UnexceptionalIO a -> UnexceptionalIO b -> UnexceptionalIO a
forall a b.
UnexceptionalIO a -> UnexceptionalIO b -> UnexceptionalIO b
forall a b.
UnexceptionalIO (a -> b) -> UnexceptionalIO a -> UnexceptionalIO b
forall a b c.
(a -> b -> c)
-> UnexceptionalIO a -> UnexceptionalIO b -> UnexceptionalIO c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall a. a -> UnexceptionalIO a
pure :: forall a. a -> UnexceptionalIO a
$c<*> :: forall a b.
UnexceptionalIO (a -> b) -> UnexceptionalIO a -> UnexceptionalIO b
<*> :: forall a b.
UnexceptionalIO (a -> b) -> UnexceptionalIO a -> UnexceptionalIO b
$cliftA2 :: forall a b c.
(a -> b -> c)
-> UnexceptionalIO a -> UnexceptionalIO b -> UnexceptionalIO c
liftA2 :: forall a b c.
(a -> b -> c)
-> UnexceptionalIO a -> UnexceptionalIO b -> UnexceptionalIO c
$c*> :: forall a b.
UnexceptionalIO a -> UnexceptionalIO b -> UnexceptionalIO b
*> :: forall a b.
UnexceptionalIO a -> UnexceptionalIO b -> UnexceptionalIO b
$c<* :: forall a b.
UnexceptionalIO a -> UnexceptionalIO b -> UnexceptionalIO a
<* :: forall a b.
UnexceptionalIO a -> UnexceptionalIO b -> UnexceptionalIO a
Applicative, (forall a b. (a -> b) -> UnexceptionalIO a -> UnexceptionalIO b)
-> (forall a b. a -> UnexceptionalIO b -> UnexceptionalIO a)
-> Functor UnexceptionalIO
forall a b. a -> UnexceptionalIO b -> UnexceptionalIO a
forall a b. (a -> b) -> UnexceptionalIO a -> UnexceptionalIO b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall a b. (a -> b) -> UnexceptionalIO a -> UnexceptionalIO b
fmap :: forall a b. (a -> b) -> UnexceptionalIO a -> UnexceptionalIO b
$c<$ :: forall a b. a -> UnexceptionalIO b -> UnexceptionalIO a
<$ :: forall a b. a -> UnexceptionalIO b -> UnexceptionalIO a
Functor, Applicative UnexceptionalIO
Applicative UnexceptionalIO =>
(forall a b.
 UnexceptionalIO a -> (a -> UnexceptionalIO b) -> UnexceptionalIO b)
-> (forall a b.
    UnexceptionalIO a -> UnexceptionalIO b -> UnexceptionalIO b)
-> (forall a. a -> UnexceptionalIO a)
-> Monad UnexceptionalIO
forall a. a -> UnexceptionalIO a
forall a b.
UnexceptionalIO a -> UnexceptionalIO b -> UnexceptionalIO b
forall a b.
UnexceptionalIO a -> (a -> UnexceptionalIO b) -> UnexceptionalIO b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall a b.
UnexceptionalIO a -> (a -> UnexceptionalIO b) -> UnexceptionalIO b
>>= :: forall a b.
UnexceptionalIO a -> (a -> UnexceptionalIO b) -> UnexceptionalIO b
$c>> :: forall a b.
UnexceptionalIO a -> UnexceptionalIO b -> UnexceptionalIO b
>> :: forall a b.
UnexceptionalIO a -> UnexceptionalIO b -> UnexceptionalIO b
$creturn :: forall a. a -> UnexceptionalIO a
return :: forall a. a -> UnexceptionalIO a
Monad)

data IOResult a
  = Failure !SomeException -- sync or async exception
  | Success a

unexceptionalTry :: forall a. IO a -> UnexceptionalIO (IOResult a)
unexceptionalTry :: forall a. IO a -> UnexceptionalIO (IOResult a)
unexceptionalTry IO a
action =
  IO (IOResult a) -> UnexceptionalIO (IOResult a)
forall a. IO a -> UnexceptionalIO a
UnexceptionalIO do
    (a -> IOResult a
forall a. a -> IOResult a
Success (a -> IOResult a) -> IO a -> IO (IOResult a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO a
action) IO (IOResult a)
-> (SomeException -> IO (IOResult a)) -> IO (IOResult a)
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` \SomeException
exception ->
      IOResult a -> IO (IOResult a)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SomeException -> IOResult a
forall a. SomeException -> IOResult a
Failure SomeException
exception)

-- Like try, but with continuations. Also, catches all exceptions, because that's the only flavor we need.
unexceptionalTryEither ::
  forall a b.
  (SomeException -> UnexceptionalIO b) ->
  (a -> UnexceptionalIO b) ->
  IO a ->
  UnexceptionalIO b
unexceptionalTryEither :: forall a b.
(SomeException -> UnexceptionalIO b)
-> (a -> UnexceptionalIO b) -> IO a -> UnexceptionalIO b
unexceptionalTryEither SomeException -> UnexceptionalIO b
onFailure a -> UnexceptionalIO b
onSuccess IO a
action =
  IO b -> UnexceptionalIO b
forall a. IO a -> UnexceptionalIO a
UnexceptionalIO do
    IO (IO b) -> IO b
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join do
      IO (IO b) -> (SomeException -> IO (IO b)) -> IO (IO b)
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
catch
        (forall a b. Coercible a b => a -> b
forall a b. Coercible a b => a -> b
coerce @_ @(a -> IO b) a -> UnexceptionalIO b
onSuccess (a -> IO b) -> IO a -> IO (IO b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO a
action)
        (IO b -> IO (IO b)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IO b -> IO (IO b))
-> (SomeException -> IO b) -> SomeException -> IO (IO b)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. Coercible a b => a -> b
forall a b. Coercible a b => a -> b
coerce @_ @(SomeException -> IO b) SomeException -> UnexceptionalIO b
onFailure)

isAsyncException :: SomeException -> Bool
isAsyncException :: SomeException -> Bool
isAsyncException SomeException
exception =
  case forall e. Exception e => SomeException -> Maybe e
fromException @SomeAsyncException SomeException
exception of
    Maybe SomeAsyncException
Nothing -> Bool
False
    Just SomeAsyncException
_ -> Bool
True

-- | Call an action with asynchronous exceptions interruptibly masked.
interruptiblyMasked :: IO a -> IO a
interruptiblyMasked :: forall a. IO a -> IO a
interruptiblyMasked (IO State# RealWorld -> (# State# RealWorld, a #)
io) =
  (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO ((State# RealWorld -> (# State# RealWorld, a #))
-> State# RealWorld -> (# State# RealWorld, a #)
forall a.
(State# RealWorld -> (# State# RealWorld, a #))
-> State# RealWorld -> (# State# RealWorld, a #)
maskAsyncExceptions# State# RealWorld -> (# State# RealWorld, a #)
io)

-- | Call an action with asynchronous exceptions uninterruptibly masked.
uninterruptiblyMasked :: IO a -> IO a
uninterruptiblyMasked :: forall a. IO a -> IO a
uninterruptiblyMasked (IO State# RealWorld -> (# State# RealWorld, a #)
io) =
  (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO ((State# RealWorld -> (# State# RealWorld, a #))
-> State# RealWorld -> (# State# RealWorld, a #)
forall a.
(State# RealWorld -> (# State# RealWorld, a #))
-> State# RealWorld -> (# State# RealWorld, a #)
maskUninterruptible# State# RealWorld -> (# State# RealWorld, a #)
io)

-- Like try, but with continuations
tryEitherSTM :: Exception e => (e -> STM b) -> (a -> STM b) -> STM a -> STM b
tryEitherSTM :: forall e b a.
Exception e =>
(e -> STM b) -> (a -> STM b) -> STM a -> STM b
tryEitherSTM e -> STM b
onFailure a -> STM b
onSuccess STM a
action =
  STM (STM b) -> STM b
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (STM (STM b) -> (e -> STM (STM b)) -> STM (STM b)
forall e a. Exception e => STM a -> (e -> STM a) -> STM a
catchSTM (a -> STM b
onSuccess (a -> STM b) -> STM a -> STM (STM b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> STM a
action) (STM b -> STM (STM b)
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (STM b -> STM (STM b)) -> (e -> STM b) -> e -> STM (STM b)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. e -> STM b
onFailure))

-- Control.Concurrent.forkIO without the exception handler
forkIO :: IO () -> IO ThreadId
forkIO :: IO () -> IO ThreadId
forkIO (IO State# RealWorld -> (# State# RealWorld, () #)
action) =
  (State# RealWorld -> (# State# RealWorld, ThreadId #))
-> IO ThreadId
forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO \State# RealWorld
s0 ->
    case (State# RealWorld -> (# State# RealWorld, () #))
-> State# RealWorld -> (# State# RealWorld, ThreadId# #)
forall a.
(State# RealWorld -> (# State# RealWorld, a #))
-> State# RealWorld -> (# State# RealWorld, ThreadId# #)
fork# State# RealWorld -> (# State# RealWorld, () #)
action State# RealWorld
s0 of
      (# State# RealWorld
s1, ThreadId#
tid #) -> (# State# RealWorld
s1, ThreadId# -> ThreadId
ThreadId ThreadId#
tid #)

-- Control.Concurrent.forkOn without the exception handler
forkOn :: Int -> IO () -> IO ThreadId
forkOn :: Int -> IO () -> IO ThreadId
forkOn (I# Int#
cap) (IO State# RealWorld -> (# State# RealWorld, () #)
action) =
  (State# RealWorld -> (# State# RealWorld, ThreadId #))
-> IO ThreadId
forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO \State# RealWorld
s0 ->
    case Int#
-> (State# RealWorld -> (# State# RealWorld, () #))
-> State# RealWorld
-> (# State# RealWorld, ThreadId# #)
forall a.
Int#
-> (State# RealWorld -> (# State# RealWorld, a #))
-> State# RealWorld
-> (# State# RealWorld, ThreadId# #)
forkOn# Int#
cap State# RealWorld -> (# State# RealWorld, () #)
action State# RealWorld
s0 of
      (# State# RealWorld
s1, ThreadId#
tid #) -> (# State# RealWorld
s1, ThreadId# -> ThreadId
ThreadId ThreadId#
tid #)