Skip to content

Commit

Permalink
feat(dfir_lang): add repeat_n windowing operator
Browse files Browse the repository at this point in the history
  • Loading branch information
MingweiSamuel committed Jan 17, 2025
1 parent 5e8e2f1 commit 59fa7e0
Show file tree
Hide file tree
Showing 14 changed files with 529 additions and 32 deletions.
4 changes: 3 additions & 1 deletion dfir_lang/src/graph/ops/batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ pub const BATCH: OperatorConstraints = OperatorConstraints {
let input = &inputs[0];
quote_spanned! {op_span=>
let mut #vec_ident = #context.state_ref(#singleton_output_ident).borrow_mut();
*#vec_ident = #input.collect::<::std::vec::Vec<_>>();
if #context.is_first_run_this_tick() {
*#vec_ident = #input.collect::<::std::vec::Vec<_>>();
}
let #ident = ::std::iter::once(::std::clone::Clone::clone(&*#vec_ident));
}
} else if let Some(_output) = outputs.first() {
Expand Down
1 change: 1 addition & 0 deletions dfir_lang/src/graph/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ declare_ops![
persist_mut_keyed::PERSIST_MUT_KEYED,
py_udf::PY_UDF,
reduce::REDUCE,
repeat_n::REPEAT_N,
spin::SPIN,
sort::SORT,
sort_by_key::SORT_BY_KEY,
Expand Down
57 changes: 57 additions & 0 deletions dfir_lang/src/graph/ops/repeat_n.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
use quote::quote_spanned;

use super::{OperatorConstraints, OperatorWriteOutput, WriteContextArgs};

/// TODO(mingwei): docs
pub const REPEAT_N: OperatorConstraints = OperatorConstraints {
name: "repeat_n",
num_args: 1,
write_fn: |wc @ &WriteContextArgs {
context,
hydroflow,
op_span,
arguments,
..
},
diagnostics| {
let OperatorWriteOutput {
write_prologue,
write_iterator,
write_iterator_after,
} = (super::all_once::ALL_ONCE.write_fn)(wc, diagnostics)?;

let count_ident = wc.make_ident("count");

let write_prologue = quote_spanned! {op_span=>
#write_prologue

let #count_ident = #hydroflow.add_state(::std::cell::Cell::new(0_usize));
#hydroflow.set_state_tick_hook(#count_ident, move |cell| { cell.take(); });
};

// Reschedule, to repeat.
let count_arg = &arguments[0];
let write_iterator_after = quote_spanned! {op_span=>
#write_iterator_after

{
let count_ref = #context.state_ref(#count_ident);
if #context.is_first_loop_iteration() {
count_ref.set(0);
}
let count = count_ref.get() + 1;
if count < #count_arg {
count_ref.set(count);
#context.reschedule_current_subgraph();
}
}
};

Ok(OperatorWriteOutput {
write_prologue,
write_iterator,
write_iterator_after,
})
},
..super::all_once::ALL_ONCE
};
16 changes: 15 additions & 1 deletion dfir_rs/src/scheduled/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
//! Provides APIs for state and scheduling.
use std::any::Any;
use std::cell::Cell;
use std::future::Future;
use std::marker::PhantomData;
use std::ops::DerefMut;
Expand Down Expand Up @@ -41,6 +42,9 @@ pub struct Context {
// Second field (bool) is for if the event is an external "important" event (true).
pub(super) event_queue_send: UnboundedSender<(SubgraphId, bool)>,

/// If the current subgraph wants to reschedule in the current tick+stratum.
pub(super) reschedule_current_subgraph: Cell<bool>,

pub(super) current_tick: TickInstant,
pub(super) current_stratum: usize,

Expand Down Expand Up @@ -92,11 +96,20 @@ impl Context {
self.subgraph_id
}

/// Schedules a subgraph.
/// Schedules a subgraph for the next tick.
///
/// If `is_external` is `true`, the scheduling will trigger the next tick to begin. If it is
/// `false` then scheduling will be lazy and the next tick will not begin unless there is other
/// reason to.
pub fn schedule_subgraph(&self, sg_id: SubgraphId, is_external: bool) {
self.event_queue_send.send((sg_id, is_external)).unwrap()
}

/// Schedules the current subgraph to run again _this tick_.
pub fn reschedule_current_subgraph(&self) {
self.reschedule_current_subgraph.set(true);
}

/// Returns a `Waker` for interacting with async Rust.
/// Waker events are considered to be extenral.
pub fn waker(&self) -> std::task::Waker {
Expand Down Expand Up @@ -238,6 +251,7 @@ impl Default for Context {
events_received_tick: false,

event_queue_send,
reschedule_current_subgraph: Cell::new(false),

current_stratum: 0,
current_tick: TickInstant::default(),
Expand Down
10 changes: 10 additions & 0 deletions dfir_rs/src/scheduled/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,8 @@ impl<'a> Dfir<'a> {

{
let sg_data = &mut self.subgraphs[sg_id];
debug_assert_eq!(self.context.current_stratum, sg_data.stratum);

// This must be true for the subgraph to be enqueued.
assert!(sg_data.is_scheduled.take());
tracing::trace!(
Expand Down Expand Up @@ -324,6 +326,14 @@ impl<'a> Dfir<'a> {
}
}
}

// Check if subgraph wants rescheduling
if self.context.reschedule_current_subgraph.take() {
// Add subgraph to stratum queue if it is not already scheduled.
if !sg_data.is_scheduled.replace(true) {
self.context.stratum_queues[sg_data.stratum].push(sg_data.loop_depth, sg_id);
}
}
}
work_done
}
Expand Down
19 changes: 11 additions & 8 deletions dfir_rs/tests/snapshots/surface_loop__flo_nested@graphvis_dot.snap
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,23 @@ digraph {
n6v1 [label="(n6v1) flatten()", shape=invhouse, fillcolor="#88aaff"]
n7v1 [label="(n7v1) cross_join::<'static, 'tick>()", shape=invhouse, fillcolor="#88aaff"]
n8v1 [label="(n8v1) all_once()", shape=invhouse, fillcolor="#88aaff"]
n9v1 [label="(n9v1) for_each(|all| println!(\"{}: {:?}\", context.current_tick(), all))", shape=house, fillcolor="#ffff88"]
n10v1 [label="(n10v1) handoff", shape=parallelogram, fillcolor="#ddddff"]
n9v1 [label="(n9v1) map(|vec| (context.current_tick().0, vec))", shape=invhouse, fillcolor="#88aaff"]
n10v1 [label="(n10v1) assert_eq([\l (\l 0,\l vec![\l (\"alice\", 0),\l (\"alice\", 1),\l (\"alice\", 2),\l (\"bob\", 0),\l (\"bob\", 1),\l (\"bob\", 2),\l ],\l ),\l (\l 1,\l vec![\l (\"alice\", 3),\l (\"alice\", 4),\l (\"alice\", 5),\l (\"bob\", 3),\l (\"bob\", 4),\l (\"bob\", 5),\l ],\l ),\l (\l 2,\l vec![\l (\"alice\", 6),\l (\"alice\", 7),\l (\"alice\", 8),\l (\"bob\", 6),\l (\"bob\", 7),\l (\"bob\", 8),\l ],\l ),\l (\l 3,\l vec![\l (\"alice\", 9),\l (\"alice\", 10),\l (\"alice\", 11),\l (\"bob\", 9),\l (\"bob\", 10),\l (\"bob\", 11),\l ],\l ),\l])\l", shape=house, fillcolor="#ffff88"]
n11v1 [label="(n11v1) handoff", shape=parallelogram, fillcolor="#ddddff"]
n12v1 [label="(n12v1) handoff", shape=parallelogram, fillcolor="#ddddff"]
n13v1 [label="(n13v1) handoff", shape=parallelogram, fillcolor="#ddddff"]
n4v1 -> n7v1 [label="0"]
n3v1 -> n4v1
n1v1 -> n10v1
n1v1 -> n11v1
n6v1 -> n7v1 [label="1"]
n5v1 -> n6v1
n2v1 -> n11v1
n2v1 -> n12v1
n9v1 -> n10v1
n8v1 -> n9v1
n7v1 -> n12v1
n10v1 -> n3v1
n11v1 -> n5v1
n12v1 -> n8v1 [color=red]
n7v1 -> n13v1
n11v1 -> n3v1
n12v1 -> n5v1
n13v1 -> n8v1 [color=red]
subgraph "cluster n1v1" {
fillcolor="#dddddd"
style=filled
Expand Down Expand Up @@ -68,5 +70,6 @@ digraph {
label = "sg_4v1\nstratum 1"
n8v1
n9v1
n10v1
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,23 @@ linkStyle default stroke:#aaa
6v1[\"(6v1) <code>flatten()</code>"/]:::pullClass
7v1[\"(7v1) <code>cross_join::&lt;'static, 'tick&gt;()</code>"/]:::pullClass
8v1[\"(8v1) <code>all_once()</code>"/]:::pullClass
9v1[/"(9v1) <code>for_each(|all| println!(&quot;{}: {:?}&quot;, context.current_tick(), all))</code>"\]:::pushClass
10v1["(10v1) <code>handoff</code>"]:::otherClass
9v1[\"(9v1) <code>map(|vec| (context.current_tick().0, vec))</code>"/]:::pullClass
10v1[/"<div style=text-align:center>(10v1)</div> <code>assert_eq([<br> (<br> 0,<br> vec![<br> (&quot;alice&quot;, 0),<br> (&quot;alice&quot;, 1),<br> (&quot;alice&quot;, 2),<br> (&quot;bob&quot;, 0),<br> (&quot;bob&quot;, 1),<br> (&quot;bob&quot;, 2),<br> ],<br> ),<br> (<br> 1,<br> vec![<br> (&quot;alice&quot;, 3),<br> (&quot;alice&quot;, 4),<br> (&quot;alice&quot;, 5),<br> (&quot;bob&quot;, 3),<br> (&quot;bob&quot;, 4),<br> (&quot;bob&quot;, 5),<br> ],<br> ),<br> (<br> 2,<br> vec![<br> (&quot;alice&quot;, 6),<br> (&quot;alice&quot;, 7),<br> (&quot;alice&quot;, 8),<br> (&quot;bob&quot;, 6),<br> (&quot;bob&quot;, 7),<br> (&quot;bob&quot;, 8),<br> ],<br> ),<br> (<br> 3,<br> vec![<br> (&quot;alice&quot;, 9),<br> (&quot;alice&quot;, 10),<br> (&quot;alice&quot;, 11),<br> (&quot;bob&quot;, 9),<br> (&quot;bob&quot;, 10),<br> (&quot;bob&quot;, 11),<br> ],<br> ),<br>])</code>"\]:::pushClass
11v1["(11v1) <code>handoff</code>"]:::otherClass
12v1["(12v1) <code>handoff</code>"]:::otherClass
13v1["(13v1) <code>handoff</code>"]:::otherClass
4v1-->|0|7v1
3v1-->4v1
1v1-->10v1
1v1-->11v1
6v1-->|1|7v1
5v1-->6v1
2v1-->11v1
2v1-->12v1
9v1-->10v1
8v1-->9v1
7v1-->12v1
10v1-->3v1
11v1-->5v1
12v1--x8v1; linkStyle 10 stroke:red
7v1-->13v1
11v1-->3v1
12v1-->5v1
13v1--x8v1; linkStyle 11 stroke:red
subgraph sg_1v1 ["sg_1v1 stratum 0"]
1v1
subgraph sg_1v1_var_users ["var <tt>users</tt>"]
Expand All @@ -56,4 +58,5 @@ end
subgraph sg_4v1 ["sg_4v1 stratum 1"]
8v1
9v1
10v1
end
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
---
source: dfir_rs/tests/surface_loop.rs
expression: "df.meta_graph().unwrap().to_dot(& Default :: default())"
---
digraph {
node [fontname="Monaco,Menlo,Consolas,&quot;Droid Sans Mono&quot;,Inconsolata,&quot;Courier New&quot;,monospace", style=filled];
edge [fontname="Monaco,Menlo,Consolas,&quot;Droid Sans Mono&quot;,Inconsolata,&quot;Courier New&quot;,monospace"];
n1v1 [label="(n1v1) source_iter([\"alice\", \"bob\"])", shape=invhouse, fillcolor="#88aaff"]
n2v1 [label="(n2v1) source_stream(iter_batches_stream(0..12, 3))", shape=invhouse, fillcolor="#88aaff"]
n3v1 [label="(n3v1) batch()", shape=invhouse, fillcolor="#88aaff"]
n4v1 [label="(n4v1) flatten()", shape=invhouse, fillcolor="#88aaff"]
n5v1 [label="(n5v1) batch()", shape=invhouse, fillcolor="#88aaff"]
n6v1 [label="(n6v1) flatten()", shape=invhouse, fillcolor="#88aaff"]
n7v1 [label="(n7v1) cross_join::<'static, 'tick>()", shape=invhouse, fillcolor="#88aaff"]
n8v1 [label="(n8v1) repeat_n(3)", shape=invhouse, fillcolor="#88aaff"]
n9v1 [label="(n9v1) map(|vec| (context.current_tick().0, vec))", shape=invhouse, fillcolor="#88aaff"]
n10v1 [label="(n10v1) inspect(|x| println!(\"{:?}\", x))", shape=invhouse, fillcolor="#88aaff"]
n11v1 [label="(n11v1) assert_eq([\l (\l 0,\l vec![\l (\"alice\", 0),\l (\"alice\", 1),\l (\"alice\", 2),\l (\"bob\", 0),\l (\"bob\", 1),\l (\"bob\", 2),\l ],\l ),\l (\l 0,\l vec![\l (\"alice\", 0),\l (\"alice\", 1),\l (\"alice\", 2),\l (\"bob\", 0),\l (\"bob\", 1),\l (\"bob\", 2),\l ],\l ),\l (\l 0,\l vec![\l (\"alice\", 0),\l (\"alice\", 1),\l (\"alice\", 2),\l (\"bob\", 0),\l (\"bob\", 1),\l (\"bob\", 2),\l ],\l ),\l (\l 1,\l vec![\l (\"alice\", 3),\l (\"alice\", 4),\l (\"alice\", 5),\l (\"bob\", 3),\l (\"bob\", 4),\l (\"bob\", 5),\l ],\l ),\l (\l 1,\l vec![\l (\"alice\", 3),\l (\"alice\", 4),\l (\"alice\", 5),\l (\"bob\", 3),\l (\"bob\", 4),\l (\"bob\", 5),\l ],\l ),\l (\l 1,\l vec![\l (\"alice\", 3),\l (\"alice\", 4),\l (\"alice\", 5),\l (\"bob\", 3),\l (\"bob\", 4),\l (\"bob\", 5),\l ],\l ),\l (\l 2,\l vec![\l (\"alice\", 6),\l (\"alice\", 7),\l (\"alice\", 8),\l (\"bob\", 6),\l (\"bob\", 7),\l (\"bob\", 8),\l ],\l ),\l (\l 2,\l vec![\l (\"alice\", 6),\l (\"alice\", 7),\l (\"alice\", 8),\l (\"bob\", 6),\l (\"bob\", 7),\l (\"bob\", 8),\l ],\l ),\l (\l 2,\l vec![\l (\"alice\", 6),\l (\"alice\", 7),\l (\"alice\", 8),\l (\"bob\", 6),\l (\"bob\", 7),\l (\"bob\", 8),\l ],\l ),\l (\l 3,\l vec![\l (\"alice\", 9),\l (\"alice\", 10),\l (\"alice\", 11),\l (\"bob\", 9),\l (\"bob\", 10),\l (\"bob\", 11),\l ],\l ),\l (\l 3,\l vec![\l (\"alice\", 9),\l (\"alice\", 10),\l (\"alice\", 11),\l (\"bob\", 9),\l (\"bob\", 10),\l (\"bob\", 11),\l ],\l ),\l (\l 3,\l vec![\l (\"alice\", 9),\l (\"alice\", 10),\l (\"alice\", 11),\l (\"bob\", 9),\l (\"bob\", 10),\l (\"bob\", 11),\l ],\l ),\l])\l", shape=house, fillcolor="#ffff88"]
n12v1 [label="(n12v1) handoff", shape=parallelogram, fillcolor="#ddddff"]
n13v1 [label="(n13v1) handoff", shape=parallelogram, fillcolor="#ddddff"]
n14v1 [label="(n14v1) handoff", shape=parallelogram, fillcolor="#ddddff"]
n4v1 -> n7v1 [label="0"]
n3v1 -> n4v1
n1v1 -> n12v1
n6v1 -> n7v1 [label="1"]
n5v1 -> n6v1
n2v1 -> n13v1
n10v1 -> n11v1
n9v1 -> n10v1
n8v1 -> n9v1
n7v1 -> n14v1
n12v1 -> n3v1
n13v1 -> n5v1
n14v1 -> n8v1 [color=red]
subgraph "cluster n1v1" {
fillcolor="#dddddd"
style=filled
label = "sg_1v1\nstratum 0"
n1v1
subgraph "cluster_sg_1v1_var_users" {
label="var users"
n1v1
}
}
subgraph "cluster n2v1" {
fillcolor="#dddddd"
style=filled
label = "sg_2v1\nstratum 0"
n2v1
subgraph "cluster_sg_2v1_var_messages" {
label="var messages"
n2v1
}
}
subgraph "cluster n3v1" {
fillcolor="#dddddd"
style=filled
label = "sg_3v1\nstratum 0"
n3v1
n4v1
n5v1
n6v1
n7v1
subgraph "cluster_sg_3v1_var_cp" {
label="var cp"
n7v1
}
}
subgraph "cluster n4v1" {
fillcolor="#dddddd"
style=filled
label = "sg_4v1\nstratum 1"
n8v1
n9v1
n10v1
n11v1
}
}
Loading

0 comments on commit 59fa7e0

Please sign in to comment.