Skip to content

Commit

Permalink
don't need per-parent io node relations
Browse files Browse the repository at this point in the history
  • Loading branch information
doug-q committed Jun 10, 2024
1 parent 81b49a3 commit aeac385
Showing 1 changed file with 29 additions and 31 deletions.
60 changes: 29 additions & 31 deletions hugr-passes/src/const_fold2/datalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@ pub use utils::{TailLoopTermination, ValueRow, IO, PV};
ascent::ascent! {
struct AscentProgram<C: DFContext>;
relation context(C);
relation out_wire_value_proto(Node, OutgoingPort, PV);

relation node(C, Node);
relation in_wire(C, Node, IncomingPort);
relation out_wire(C, Node, OutgoingPort);
relation parent_of_node(C, Node, Node);
relation out_wire_value_proto(Node, OutgoingPort, PV);
relation io_node(C, Node, Node, IO);
lattice out_wire_value(C, Node, OutgoingPort, PV);
lattice node_in_value_row(C, Node, ValueRow);
lattice in_wire_value(C, Node, IncomingPort, PV);
Expand All @@ -38,12 +40,15 @@ ascent::ascent! {
parent_of_node(c, parent, child) <--
node(c, child), if let Some(parent) = c.hugr().get_parent(*child);

// All out wire values are initialised to Bottom. If any value is Bottom after
// running we can infer that execution never reaches that value.
io_node(c, parent, child, io) <-- node(c, parent),
if let Some([i,o]) = c.hugr().get_io(*parent),
for (child,io) in [(i,IO::Input),(o,IO::Output)];
// We support prepopulating out_wire_value via out_wire_value_proto.
//
// out wires that do not have prepopulation values are initialised to bottom.
out_wire_value(c, n,p, PV::bottom()) <-- out_wire(c, n,p);
out_wire_value(c, n, p, v) <-- out_wire(c,n,p) , out_wire_value_proto(n, p, v);


in_wire_value(c, n, ip, v) <-- in_wire(c, n, ip),
if let Some((m,op)) = c.hugr().single_linked_output(*n, *ip),
out_wire_value(c, m, op, v);
Expand All @@ -52,20 +57,23 @@ ascent::ascent! {
node_in_value_row(c, n, utils::bottom_row(c, *n)) <-- node(c, n);
node_in_value_row(c, n, utils::singleton_in_row(c, n, p, v.clone())) <-- in_wire_value(c, n, p, v);


// LoadConstant
relation load_constant_node(C, Node);
load_constant_node(c, n) <-- node(c, n), if c.hugr().get_optype(*n).is_load_constant();

out_wire_value(c, n, 0.into(), utils::partial_value_from_load_constant(c, *n)) <--
load_constant_node(c, n);


// MakeTuple
relation make_tuple_node(C, Node);
make_tuple_node(c, n) <-- node(c, n), if c.hugr().get_optype(*n).is_make_tuple();

out_wire_value(c, n, 0.into(), utils::partial_value_tuple_from_value_row(vs.clone())) <--
make_tuple_node(c, n), node_in_value_row(c, n, vs);


// UnpackTuple
relation unpack_tuple_node(C, Node);
unpack_tuple_node(c,n) <-- node(c, n), if c.hugr().get_optype(*n).is_unpack_tuple();
Expand All @@ -75,45 +83,38 @@ ascent::ascent! {
in_wire_value(c, n, IncomingPort::from(0), v),
out_wire(c, n, p);


// DFG
relation dfg_node(C, Node);
dfg_node(c,n) <-- node(c, n), if c.hugr().get_optype(*n).is_dfg();
relation dfg_io_node(C, Node, Node, IO);
dfg_io_node(c,dfg,n,io) <-- dfg_node(c,dfg),
if let Some([i,o]) = c.hugr().get_io(*dfg),
for (n, io) in [(i, IO::Input), (o, IO::Output)];

out_wire_value(c, i, OutgoingPort::from(p.index()), v) <--
dfg_io_node(c,dfg,i, IO::Input), in_wire_value(c, dfg, p, v);
out_wire_value(c, dfg, OutgoingPort::from(p.index()), v) <--
dfg_io_node(c,dfg,o, IO::Output), in_wire_value(c, o, p, v);
out_wire_value(c, i, OutgoingPort::from(p.index()), v) <-- dfg_node(c,dfg),
io_node(c, dfg, i, IO::Input), in_wire_value(c, dfg, p, v);

out_wire_value(c, dfg, OutgoingPort::from(p.index()), v) <-- dfg_node(c,dfg),
io_node(c,dfg,o, IO::Output), in_wire_value(c, o, p, v);


// TailLoop
relation tail_loop_node(C, Node);
tail_loop_node(c,n) <-- node(c, n), if c.hugr().get_optype(*n).is_tail_loop();
relation tail_loop_io_node(C, Node, Node, IO);
tail_loop_io_node(c,tl,n, io) <-- tail_loop_node(c,tl),
if let Some([i,o]) = c.hugr().get_io(*tl),
for (n,io) in [(i,IO::Input), (o, IO::Output)];

// inputs of tail loop propagate to Input node of child region
out_wire_value(c, i, OutgoingPort::from(p.index()), v) <--
tail_loop_io_node(c,tl,i, IO::Input), in_wire_value(c, tl, p, v);

out_wire_value(c, i, OutgoingPort::from(p.index()), v) <-- tail_loop_node(c, tl),
io_node(c,tl,i, IO::Input), in_wire_value(c, tl, p, v);

// Output node of child region propagate to Input node of child region
out_wire_value(c, i, input_p, v) <--
tail_loop_io_node(c,tl,i, IO::Input),
tail_loop_io_node(c,tl,o, IO::Output),
out_wire_value(c, i, input_p, v) <-- tail_loop_node(c, tl),
io_node(c,tl,i, IO::Input),
io_node(c,tl,o, IO::Output),
in_wire_value(c, o, output_p, output_v),
if let Some(tailloop) = c.hugr().get_optype(*tl).as_tail_loop(),
let variant_len = tailloop.just_inputs.len(),
for (input_p, v) in utils::tail_loop_worker(*output_p, 0, variant_len, output_v);

// Output node of child region propagate to outputs of tail loop
out_wire_value(c, tl, p, v) <--
tail_loop_io_node(c,tl,o, IO::Output),
out_wire_value(c, tl, p, v) <-- tail_loop_node(c, tl),
io_node(c,tl,o, IO::Output),
in_wire_value(c, o, output_p, output_v),
if let Some(tailloop) = c.hugr().get_optype(*tl).as_tail_loop(),
let variant_len = tailloop.just_outputs.len(),
Expand All @@ -124,26 +125,23 @@ ascent::ascent! {
tail_loop_node(c,tl);
tail_loop_termination(c,tl,TailLoopTermination::from_control_value(v).into()) <--
tail_loop_node(c,tl),
tail_loop_io_node(c,tl,o, IO::Output),
io_node(c,tl,o, IO::Output),
in_wire_value(c, o, Into::<IncomingPort>::into(0usize), v);


// Conditional
relation conditional_node(C, Node);
relation case_node(C,Node,usize, Node);
relation case_io_node(C, Node, Node, IO);

conditional_node (c,n)<-- node(c, n), if c.hugr().get_optype(*n).is_conditional();
case_node(c,cond,i, case) <-- conditional_node(c,cond),
for (i, case) in c.hugr().children(*cond).enumerate(),
if c.hugr().get_optype(case).is_case();
case_io_node(c,case, n, io) <-- case_node(c, _, _, case),
if let Some([i,o]) = c.hugr().get_io(*case),
for (n,io) in [(i,IO::Input), (o, IO::Output)];

// inputs of conditional propagate into case nodes
out_wire_value(c, i_node, i_p, v) <--
case_node(c, cond, case_index, case),
case_io_node(c, case, i_node, IO::Input),
io_node(c, case, i_node, IO::Input),
in_wire_value(c, cond, cond_in_p, cond_in_v),
if let Some(conditional) = c.hugr().get_optype(*cond).as_conditional(),
let variant_len = conditional.sum_rows[*case_index].len(),
Expand All @@ -152,7 +150,7 @@ ascent::ascent! {
// outputs of case nodes propagate to outputs of conditional
out_wire_value(c, cond, OutgoingPort::from(o_p.index()), v) <--
case_node(c, cond, _, case),
case_io_node(c, case, o, IO::Output),
io_node(c, case, o, IO::Output),
in_wire_value(c, o, o_p, v);

lattice case_reachable(C, Node, Node, bool);
Expand Down

0 comments on commit aeac385

Please sign in to comment.