{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE CPP #-}

{-|

The 'Thunk' API provides a way to defer potentially recursive computations:

* 'thunk' is lazy in its argument, and does not run it directly
* the first 'force' triggers execution of the action passed to thunk
* that action is run at most once, and returns a list of other thunks
* 'force' forces these thunks as well, and does not return before all of them have executed
* Cycles are allowed: The action passed to 'thunk' may return a thunk whose action returns the first thunk.

The implementation is hopefully thread safe: Even if multiple threads force or
kick related thunks, all actions are still run at most once, and all calls to
force terminate (no deadlock).

>>> :set -XRecursiveDo
>>> :{
  mdo t1 <- thunk $ putStrLn "Hello" >> pure [t1, t2]
      t2 <- thunk $ putStrLn "World" >> pure [t1, t2]
      putStrLn "Nothing happened so far, but now:"
      force t1
      putStrLn "No more will happen now:"
      force t1
      putStrLn "That's it"
:}
Nothing happened so far, but now:
Hello
World
No more will happen now:
That's it

-}
module System.IO.RecThunk
    ( Thunk
    , thunk
    , doneThunk
    , force
    )
where


-- I want to test this code with dejafu, without carrying it as a dependency
-- of the main library. So here is a bit of CPP to care for that.

#ifdef DEJAFU

#define Ctxt   (MonadConc m, MonadIO m) =>
#define Thunk_  (Thunk m)
#define ResolvingState_  (ResolvingState m)
#define KickedThunk_  (KickedThunk m)
#define MVar_  MVar m
#define M      m

import Control.Concurrent.Classy hiding (wait)
import Data.Unique
import Control.Monad.IO.Class

#else

#define Ctxt
#define Thunk_  Thunk
#define ResolvingState_  ResolvingState
#define KickedThunk_  KickedThunk
#define MVar_  MVar
#define M      IO

import Control.Concurrent.MVar
import Data.Unique
import Control.Monad.IO.Class

#endif

-- | An @IO@ action that is to be run at most once
newtype Thunk_ = Thunk (MVar_ (Either (M [Thunk_]) KickedThunk_))
data ResolvingState_ = NotStarted | ProcessedBy Unique (MVar_ ()) | Done
-- | A 'Thunk' that is being evaluated
data KickedThunk_ = KickedThunk (MVar_ [KickedThunk_]) (MVar_ ResolvingState_)

-- | Create a new 'Thunk' from an 'IO' action.
--
-- The 'IO' action may return other thunks that should be forced together
-- whenver this thunk is forced (in arbitrary order)
thunk :: Ctxt M [Thunk_] -> M Thunk_
thunk :: IO [Thunk] -> IO Thunk
thunk IO [Thunk]
act = MVar (Either (IO [Thunk]) KickedThunk) -> Thunk
Thunk (MVar (Either (IO [Thunk]) KickedThunk) -> Thunk)
-> IO (MVar (Either (IO [Thunk]) KickedThunk)) -> IO Thunk
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Either (IO [Thunk]) KickedThunk
-> IO (MVar (Either (IO [Thunk]) KickedThunk))
forall a. a -> IO (MVar a)
newMVar (IO [Thunk] -> Either (IO [Thunk]) KickedThunk
forall a b. a -> Either a b
Left IO [Thunk]
act)

-- | A Thunk that that already is done.
--
-- Equivalent to @do {t <- thunk (pure []); force t; pure t }@
doneThunk :: Ctxt M Thunk_
doneThunk :: IO Thunk
doneThunk = do
    MVar [KickedThunk]
mv_ts <- [KickedThunk] -> IO (MVar [KickedThunk])
forall a. a -> IO (MVar a)
newMVar []
    MVar ResolvingState
mv_s <- ResolvingState -> IO (MVar ResolvingState)
forall a. a -> IO (MVar a)
newMVar ResolvingState
Done
    MVar (Either (IO [Thunk]) KickedThunk) -> Thunk
Thunk (MVar (Either (IO [Thunk]) KickedThunk) -> Thunk)
-> IO (MVar (Either (IO [Thunk]) KickedThunk)) -> IO Thunk
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Either (IO [Thunk]) KickedThunk
-> IO (MVar (Either (IO [Thunk]) KickedThunk))
forall a. a -> IO (MVar a)
newMVar (KickedThunk -> Either (IO [Thunk]) KickedThunk
forall a b. b -> Either a b
Right (MVar [KickedThunk] -> MVar ResolvingState -> KickedThunk
KickedThunk MVar [KickedThunk]
mv_ts MVar ResolvingState
mv_s))

-- Recursively explores the thunk, and kicks the execution
-- May return before before execution is done (if started by another thread)
kick :: Ctxt Thunk_ -> M KickedThunk_
kick :: Thunk -> IO KickedThunk
kick (Thunk MVar (Either (IO [Thunk]) KickedThunk)
t) = MVar (Either (IO [Thunk]) KickedThunk)
-> IO (Either (IO [Thunk]) KickedThunk)
forall a. MVar a -> IO a
takeMVar MVar (Either (IO [Thunk]) KickedThunk)
t IO (Either (IO [Thunk]) KickedThunk)
-> (Either (IO [Thunk]) KickedThunk -> IO KickedThunk)
-> IO KickedThunk
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Left IO [Thunk]
act -> do
        MVar [KickedThunk]
mv_thunks <- IO (MVar [KickedThunk])
forall a. IO (MVar a)
newEmptyMVar
        MVar ResolvingState
mv_state <- ResolvingState -> IO (MVar ResolvingState)
forall a. a -> IO (MVar a)
newMVar ResolvingState
NotStarted
        let kt :: KickedThunk
kt = MVar [KickedThunk] -> MVar ResolvingState -> KickedThunk
KickedThunk MVar [KickedThunk]
mv_thunks MVar ResolvingState
mv_state
        MVar (Either (IO [Thunk]) KickedThunk)
-> Either (IO [Thunk]) KickedThunk -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar (Either (IO [Thunk]) KickedThunk)
t (KickedThunk -> Either (IO [Thunk]) KickedThunk
forall a b. b -> Either a b
Right KickedThunk
kt)

        [Thunk]
ts <- IO [Thunk]
act
        [KickedThunk]
kts <- (Thunk -> IO KickedThunk) -> [Thunk] -> IO [KickedThunk]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Thunk -> IO KickedThunk
kick [Thunk]
ts
        MVar [KickedThunk] -> [KickedThunk] -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar [KickedThunk]
mv_thunks [KickedThunk]
kts
        KickedThunk -> IO KickedThunk
forall (f :: * -> *) a. Applicative f => a -> f a
pure KickedThunk
kt

    -- Thread was already kicked, nothing to do
    Right KickedThunk
kt -> do
        MVar (Either (IO [Thunk]) KickedThunk)
-> Either (IO [Thunk]) KickedThunk -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar (Either (IO [Thunk]) KickedThunk)
t (KickedThunk -> Either (IO [Thunk]) KickedThunk
forall a b. b -> Either a b
Right KickedThunk
kt)
        KickedThunk -> IO KickedThunk
forall (f :: * -> *) a. Applicative f => a -> f a
pure KickedThunk
kt

wait :: Ctxt Unique -> KickedThunk_ -> M ()
wait :: Unique -> KickedThunk -> IO ()
wait Unique
my_id (KickedThunk MVar [KickedThunk]
mv_deps MVar ResolvingState
mv_s) = do
    ResolvingState
s <- MVar ResolvingState -> IO ResolvingState
forall a. MVar a -> IO a
takeMVar MVar ResolvingState
mv_s
    case ResolvingState
s of
        -- Thunk and all dependences are done
        ResolvingState
Done -> MVar ResolvingState -> ResolvingState -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar ResolvingState
mv_s ResolvingState
s
        -- Thunk is being processed by a higher priority thread, so simply wait
        ProcessedBy Unique
other_id MVar ()
done_mv | Unique
other_id Unique -> Unique -> Bool
forall a. Ord a => a -> a -> Bool
< Unique
my_id -> do
            MVar ResolvingState -> ResolvingState -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar ResolvingState
mv_s ResolvingState
s
            MVar () -> IO ()
forall a. MVar a -> IO a
readMVar MVar ()
done_mv
        -- Thunk is already being processed by this thread, ignore
        ProcessedBy Unique
other_id MVar ()
_done_mv | Unique
other_id Unique -> Unique -> Bool
forall a. Eq a => a -> a -> Bool
== Unique
my_id -> do
            MVar ResolvingState -> ResolvingState -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar ResolvingState
mv_s ResolvingState
s
            () -> IO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
        -- Thunk is not yet processed, or processed by a lower priority thread, so process now
        ResolvingState
_ -> do
            MVar ()
done_mv <- IO (MVar ())
forall a. IO (MVar a)
newEmptyMVar
            MVar ResolvingState -> ResolvingState -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar ResolvingState
mv_s (Unique -> MVar () -> ResolvingState
ProcessedBy Unique
my_id MVar ()
done_mv)
            [KickedThunk]
ts <- MVar [KickedThunk] -> IO [KickedThunk]
forall a. MVar a -> IO a
readMVar MVar [KickedThunk]
mv_deps
            (KickedThunk -> IO ()) -> [KickedThunk] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Unique -> KickedThunk -> IO ()
wait Unique
my_id) [KickedThunk]
ts
            -- Mark kicked thunk as done
            ResolvingState
_ <- MVar ResolvingState -> ResolvingState -> IO ResolvingState
forall a. MVar a -> a -> IO a
swapMVar MVar ResolvingState
mv_s ResolvingState
Done
            -- Wake up waiting threads
            MVar () -> () -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar ()
done_mv ()

-- | Force the execution of the thunk. If it has been forced already, it will
-- do nothing. Else it will run the action passed to 'thunk', force thunks
-- returned by that action, and not return until all of them are forced.
force :: Ctxt Thunk_ -> M ()
force :: Thunk -> IO ()
force Thunk
t = do
    KickedThunk
rt <- Thunk -> IO KickedThunk
kick Thunk
t
    Unique
my_id <- IO Unique -> IO Unique
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO Unique
newUnique
    Unique -> KickedThunk -> IO ()
wait Unique
my_id KickedThunk
rt