Skip to content

Commit

Permalink
Merge pull request #584 from JuliaParallel/jps/stream-teardown
Browse files Browse the repository at this point in the history
streaming: Add DAG teardown option
  • Loading branch information
jpsamaroo authored Dec 9, 2024
2 parents a765cbe + 3c5c389 commit 3656030
Show file tree
Hide file tree
Showing 6 changed files with 223 additions and 25 deletions.
3 changes: 2 additions & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -427,4 +427,5 @@ wait(t)
The above example demonstrates a streaming region that generates random numbers
continuously and writes each random number to a file. The streaming region is
terminated when a random number less than 0.01 is generated, which is done by
calling `Dagger.finish_stream()` (this exits the current streaming task).
calling `Dagger.finish_stream()` (this terminates the current task, and will
also terminate all streaming tasks launched by `spawn_streaming`).
5 changes: 2 additions & 3 deletions docs/src/streaming.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,8 @@ end
```

If you want to stop the streaming DAG and tear it all down, you can call
`Dagger.cancel!.(all_vals)` and `Dagger.cancel!.(all_vals_written)` to
terminate each streaming task. In the future, a more convenient way to tear
down a full DAG will be added; for now, each task must be cancelled individually.
`Dagger.cancel!(all_vals[1])` (or with any other task in the streaming DAG) to
terminate all streaming tasks.

Alternatively, tasks can stop themselves from the inside with
`finish_stream`, optionally returning a value that can be `fetch`'d. Let's
Expand Down
26 changes: 26 additions & 0 deletions src/dtask.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,32 @@ function Base.fetch(t::DTask; raw=false)
end
return fetch(t.future; raw)
end
function waitany(tasks::Vector{DTask})
if isempty(tasks)
return
end
cond = Threads.Condition()
for task in tasks
Sch.errormonitor_tracked("waitany listener", Threads.@spawn begin
wait(task)
@lock cond notify(cond)
end)
end
@lock cond wait(cond)
return
end
function waitall(tasks::Vector{DTask})
if isempty(tasks)
return
end
@sync for task in tasks
Threads.@spawn begin
wait(task)
@lock cond notify(cond)
end
end
return
end
function Base.show(io::IO, t::DTask)
status = if istaskstarted(t)
isready(t) ? "finished" : "running"
Expand Down
27 changes: 26 additions & 1 deletion src/stream.jl
Original file line number Diff line number Diff line change
Expand Up @@ -426,12 +426,37 @@ function initialize_streaming!(self_streams, spec, task)
end
end

function spawn_streaming(f::Base.Callable)
"""
Starts a streaming region, within which all tasks run continuously and
concurrently. Any `DTask` argument that is itself a streaming task will be
treated as a streaming input/output. The streaming region will automatically
handle the buffering and synchronization of these tasks' values.
# Keyword Arguments
- `teardown::Bool=true`: If `true`, the streaming region will automatically
cancel all tasks if any task fails or is cancelled. Otherwise, a failing task
will not cancel the other tasks, which will continue running.
"""
function spawn_streaming(f::Base.Callable; teardown::Bool=true)
queue = StreamingTaskQueue()
result = with_options(f; task_queue=queue)
if length(queue.tasks) > 0
finalize_streaming!(queue.tasks, queue.self_streams)
enqueue!(queue.tasks)

if teardown
# Start teardown monitor
dtasks = map(last, queue.tasks)::Vector{DTask}
Sch.errormonitor_tracked("streaming teardown", Threads.@spawn begin
# Wait for any task to finish
waitany(dtasks)

# Cancel all tasks
for task in dtasks
cancel!(task; graceful=false)
end
end)
end
end
return result
end
Expand Down
112 changes: 112 additions & 0 deletions src/utils/tasks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,115 @@ function set_task_tid!(task::Task, tid::Integer)
end
@assert Threads.threadid(task) == tid "jl_set_task_tid failed!"
end

if isdefined(Base, :waitany)
import Base: waitany, waitall
else
# Vendored from Base
# License is MIT
waitany(tasks; throw=true) = _wait_multiple(tasks, throw)
waitall(tasks; failfast=true, throw=true) = _wait_multiple(tasks, throw, true, failfast)
function _wait_multiple(waiting_tasks, throwexc=false, all=false, failfast=false)
tasks = Task[]

for t in waiting_tasks
t isa Task || error("Expected an iterator of `Task` object")
push!(tasks, t)
end

if (all && !failfast) || length(tasks) <= 1
exception = false
# Force everything to finish synchronously for the case of waitall
# with failfast=false
for t in tasks
_wait(t)
exception |= istaskfailed(t)
end
if exception && throwexc
exceptions = [TaskFailedException(t) for t in tasks if istaskfailed(t)]
throw(CompositeException(exceptions))
else
return tasks, Task[]
end
end

exception = false
nremaining::Int = length(tasks)
done_mask = falses(nremaining)
for (i, t) in enumerate(tasks)
if istaskdone(t)
done_mask[i] = true
exception |= istaskfailed(t)
nremaining -= 1
else
done_mask[i] = false
end
end

if nremaining == 0
return tasks, Task[]
elseif any(done_mask) && (!all || (failfast && exception))
if throwexc && (!all || failfast) && exception
exceptions = [TaskFailedException(t) for t in tasks[done_mask] if istaskfailed(t)]
throw(CompositeException(exceptions))
else
return tasks[done_mask], tasks[.~done_mask]
end
end

chan = Channel{Int}(Inf)
sentinel = current_task()
waiter_tasks = fill(sentinel, length(tasks))

for (i, done) in enumerate(done_mask)
done && continue
t = tasks[i]
if istaskdone(t)
done_mask[i] = true
exception |= istaskfailed(t)
nremaining -= 1
exception && failfast && break
else
waiter = @task put!(chan, i)
waiter.sticky = false
_wait2(t, waiter)
waiter_tasks[i] = waiter
end
end

while nremaining > 0
i = take!(chan)
t = tasks[i]
waiter_tasks[i] = sentinel
done_mask[i] = true
exception |= istaskfailed(t)
nremaining -= 1

# stop early if requested, unless there is something immediately
# ready to consume from the channel (using a race-y check)
if (!all || (failfast && exception)) && !isready(chan)
break
end
end

close(chan)

if nremaining == 0
return tasks, Task[]
else
remaining_mask = .~done_mask
for i in findall(remaining_mask)
waiter = waiter_tasks[i]
donenotify = tasks[i].donenotify::ThreadSynchronizer
@lock donenotify Base.list_deletefirst!(donenotify.waitq, waiter)
end
done_tasks = tasks[done_mask]
if throwexc && exception
exceptions = [TaskFailedException(t) for t in done_tasks if istaskfailed(t)]
throw(CompositeException(exceptions))
else
return done_tasks, tasks[remaining_mask]
end
end
end
end
Loading

0 comments on commit 3656030

Please sign in to comment.