module Unison.Util.Timing
  ( time,
    startTiming,
    stopTiming,
  )
where

import Data.Word (Word64)
import GHC.Clock (getMonotonicTimeNSec)
import System.CPUTime (getCPUTime)
import Text.Printf (printf)
import Unison.Debug qualified as Debug
import UnliftIO (MonadIO, liftIO)

time :: (MonadIO m) => String -> m a -> m a
time :: forall (m :: * -> *) a. MonadIO m => String -> m a -> m a
time String
label m a
action =
  if DebugFlag -> Bool
Debug.shouldDebug DebugFlag
Debug.Timing
    then do
      (Word64, Integer)
startTime <- IO (Word64, Integer) -> m (Word64, Integer)
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO (Word64, Integer)
startTiming
      a
result <- m a
action
      IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (String -> (Word64, Integer) -> IO ()
stopTiming String
label (Word64, Integer)
startTime)
      pure a
result
    else m a
action

startTiming :: IO (Word64, Integer)
startTiming :: IO (Word64, Integer)
startTiming = (,) (Word64 -> Integer -> (Word64, Integer))
-> IO Word64 -> IO (Integer -> (Word64, Integer))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO Word64
getMonotonicTimeNSec IO (Integer -> (Word64, Integer))
-> IO Integer -> IO (Word64, Integer)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IO Integer
getCPUTime

stopTiming :: String -> (Word64, Integer) -> IO ()
stopTiming :: String -> (Word64, Integer) -> IO ()
stopTiming String
label (Word64
systemTimeStart, Integer
cpuTimeStart) = do
  (Word64
systemTimeEnd, Integer
cpuTimeEnd) <- IO (Word64, Integer)
startTiming
  let systemDiff :: Double
systemDiff = forall a b. (Real a, Fractional b) => a -> b
realToFrac @Word64 @Double (Word64
systemTimeEnd Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
- Word64
systemTimeStart)
  let cpuDiff :: Double
cpuDiff = forall a b. (Real a, Fractional b) => a -> b
realToFrac @Integer @Double (Integer
cpuTimeEnd Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
cpuTimeStart) Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
1_000
  String -> String -> String -> String -> IO ()
forall r. PrintfType r => String -> r
printf String
"%s: %s (cpu), %s (system)\n" String
label (Double -> String
renderNanos Double
cpuDiff) (Double -> String
renderNanos Double
systemDiff)
  where
    -- Render nanoseconds, trying to fit into 4 characters.
    renderNanos :: Double -> String
    renderNanos :: Double -> String
renderNanos Double
ns
      | Double
ns Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
0.5 = String
"0 ns"
      | Double
ns Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
995 = String -> Double -> String
forall r. PrintfType r => String -> r
printf String
"%.0f ns" Double
ns
      | Double
ns Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
9_950 = String -> Double -> String
forall r. PrintfType r => String -> r
printf String
"%.2f µs" Double
us
      | Double
ns Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
99_500 = String -> Double -> String
forall r. PrintfType r => String -> r
printf String
"%.1f µs" Double
us
      | Double
ns Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
995_000 = String -> Double -> String
forall r. PrintfType r => String -> r
printf String
"%.0f µs" Double
us
      | Double
ns Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
9_950_000 = String -> Double -> String
forall r. PrintfType r => String -> r
printf String
"%.2f ms" Double
ms
      | Double
ns Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
99_500_000 = String -> Double -> String
forall r. PrintfType r => String -> r
printf String
"%.1f ms" Double
ms
      | Double
ns Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
995_000_000 = String -> Double -> String
forall r. PrintfType r => String -> r
printf String
"%.0f ms" Double
ms
      | Double
ns Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
9_950_000_000 = String -> Double -> String
forall r. PrintfType r => String -> r
printf String
"%.2f s" Double
s
      | Double
ns Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
99_500_000_000 = String -> Double -> String
forall r. PrintfType r => String -> r
printf String
"%.1f s" Double
s
      | Bool
otherwise = String -> Double -> String
forall r. PrintfType r => String -> r
printf String
"%.0f s" Double
s
      where
        us :: Double
us = Double
ns Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
1_000
        ms :: Double
ms = Double
ns Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
1_000_000
        s :: Double
s = Double
ns Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
1_000_000_000