{-# LANGUAGE OverloadedStrings #-}

module Network.UDP.Recv (
    recv
  , recvFrom
  , recvMsg
  ) where

import Data.ByteString (ByteString)
import Data.ByteString.Internal (create, ByteString(..), createUptoN)
import Foreign.ForeignPtr (withForeignPtr)
import Network.Socket (Socket, SockAddr, Cmsg, MsgFlag, recvBuf, recvBufFrom, recvBufMsg)
import Network.Socket.Internal (zeroMemory)
import System.IO.Error (ioeSetErrorString, mkIOError)
import GHC.IO.Exception (IOErrorType(..))

mkInvalidRecvArgError :: String -> IOError
mkInvalidRecvArgError :: String -> IOError
mkInvalidRecvArgError String
loc = IOError -> String -> IOError
ioeSetErrorString IOError
err String
"non-positive length"
  where
    err :: IOError
err = IOErrorType -> String -> Maybe Handle -> Maybe String -> IOError
mkIOError IOErrorType
InvalidArgument String
loc Maybe Handle
forall a. Maybe a
Nothing Maybe String
forall a. Maybe a
Nothing

recv :: Socket -> Int -> IO ByteString
recv :: Socket -> Int -> IO ByteString
recv Socket
s Int
siz
  | Int
siz Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0   = IOError -> IO ByteString
forall a. IOError -> IO a
ioError (String -> IOError
mkInvalidRecvArgError String
"Network.UDP.Recv.recv")
  | Bool
otherwise = Int -> (Ptr Word8 -> IO Int) -> IO ByteString
createUptoN Int
siz ((Ptr Word8 -> IO Int) -> IO ByteString)
-> (Ptr Word8 -> IO Int) -> IO ByteString
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
ptr -> Socket -> Ptr Word8 -> Int -> IO Int
recvBuf Socket
s Ptr Word8
ptr Int
siz

recvFrom :: Socket -> Int -> IO (ByteString, SockAddr)
recvFrom :: Socket -> Int -> IO (ByteString, SockAddr)
recvFrom Socket
s Int
siz = do
    bs :: ByteString
bs@(PS ForeignPtr Word8
fptr Int
_ Int
_) <- Int -> (Ptr Word8 -> IO ()) -> IO ByteString
create Int
siz ((Ptr Word8 -> IO ()) -> IO ByteString)
-> (Ptr Word8 -> IO ()) -> IO ByteString
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
ptr -> Ptr Word8 -> CSize -> IO ()
forall a. Ptr a -> CSize -> IO ()
zeroMemory Ptr Word8
ptr (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
siz)
    ForeignPtr Word8
-> (Ptr Word8 -> IO (ByteString, SockAddr))
-> IO (ByteString, SockAddr)
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Word8
fptr ((Ptr Word8 -> IO (ByteString, SockAddr))
 -> IO (ByteString, SockAddr))
-> (Ptr Word8 -> IO (ByteString, SockAddr))
-> IO (ByteString, SockAddr)
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
ptr -> do
        (Int
len, SockAddr
sa) <- Socket -> Ptr Word8 -> Int -> IO (Int, SockAddr)
forall a. Socket -> Ptr a -> Int -> IO (Int, SockAddr)
recvBufFrom Socket
s Ptr Word8
ptr Int
siz
        let bs' :: ByteString
bs' | Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
siz = ForeignPtr Word8 -> Int -> Int -> ByteString
PS ForeignPtr Word8
fptr Int
0 Int
len
                | Bool
otherwise = ByteString
bs
        (ByteString, SockAddr) -> IO (ByteString, SockAddr)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
bs', SockAddr
sa)

recvMsg :: Socket -> Int -> Int -> MsgFlag
        -> IO (ByteString, SockAddr, [Cmsg], MsgFlag)
recvMsg :: Socket
-> Int
-> Int
-> MsgFlag
-> IO (ByteString, SockAddr, [Cmsg], MsgFlag)
recvMsg Socket
s Int
siz Int
clen MsgFlag
flags = do
    bs :: ByteString
bs@(PS ForeignPtr Word8
fptr Int
_ Int
_) <- Int -> (Ptr Word8 -> IO ()) -> IO ByteString
create Int
siz ((Ptr Word8 -> IO ()) -> IO ByteString)
-> (Ptr Word8 -> IO ()) -> IO ByteString
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
ptr -> Ptr Word8 -> CSize -> IO ()
forall a. Ptr a -> CSize -> IO ()
zeroMemory Ptr Word8
ptr (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
siz)
    ForeignPtr Word8
-> (Ptr Word8 -> IO (ByteString, SockAddr, [Cmsg], MsgFlag))
-> IO (ByteString, SockAddr, [Cmsg], MsgFlag)
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Word8
fptr ((Ptr Word8 -> IO (ByteString, SockAddr, [Cmsg], MsgFlag))
 -> IO (ByteString, SockAddr, [Cmsg], MsgFlag))
-> (Ptr Word8 -> IO (ByteString, SockAddr, [Cmsg], MsgFlag))
-> IO (ByteString, SockAddr, [Cmsg], MsgFlag)
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
ptr -> do
        (SockAddr
addr,Int
len,[Cmsg]
cmsgs,MsgFlag
flags') <- Socket
-> [(Ptr Word8, Int)]
-> Int
-> MsgFlag
-> IO (SockAddr, Int, [Cmsg], MsgFlag)
recvBufMsg Socket
s [(Ptr Word8
ptr,Int
siz)] Int
clen MsgFlag
flags
        let bs' :: ByteString
bs' | Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
siz = ForeignPtr Word8 -> Int -> Int -> ByteString
PS ForeignPtr Word8
fptr Int
0 Int
len
                | Bool
otherwise = ByteString
bs
        (ByteString, SockAddr, [Cmsg], MsgFlag)
-> IO (ByteString, SockAddr, [Cmsg], MsgFlag)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
bs', SockAddr
addr, [Cmsg]
cmsgs, MsgFlag
flags')