Skip to content

Commit

Permalink
tidying
Browse files Browse the repository at this point in the history
  • Loading branch information
doug-q committed Jun 10, 2024
1 parent aeac385 commit 067bf8d
Show file tree
Hide file tree
Showing 5 changed files with 216 additions and 54 deletions.
8 changes: 6 additions & 2 deletions hugr-core/src/partial_value/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -316,15 +316,19 @@ proptest! {

#[test]
fn bounded_lattice(v in any_partial_value()) {
prop_assert!(v <= PartialValue::Top);
prop_assert!(v >= PartialValue::Bottom);
prop_assert!(v <= PartialValue::top());
prop_assert!(v >= PartialValue::bottom());
}

#[test]
fn meet_join_self_noop(v1 in any_partial_value()) {
let mut subject = v1.clone();

assert_eq!(v1.clone(), v1.clone().join(v1.clone()));
assert!(!subject.join_mut(v1.clone()));
assert_eq!(subject, v1);

assert_eq!(v1.clone(), v1.clone().meet(v1.clone()));
assert!(!subject.meet_mut(v1.clone()));
assert_eq!(subject, v1);
}
Expand Down
2 changes: 2 additions & 0 deletions hugr-passes/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,5 @@ extension_inference = ["hugr-core/extension_inference"]

[dev-dependencies]
rstest = { workspace = true }
proptest = { workspace = true }
proptest-derive = { workspace = true }
50 changes: 30 additions & 20 deletions hugr-passes/src/const_fold2/datalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ ascent::ascent! {
node_in_value_row(c, n, utils::singleton_in_row(c, n, p, v.clone())) <-- in_wire_value(c, n, p, v);


// Per node-type rules
// TODO do all leaf ops with a rule
// define `fn propagate_leaf_op(Context, Node, ValueRow) -> ValueRow

// LoadConstant
relation load_constant_node(C, Node);
load_constant_node(c, n) <-- node(c, n), if c.hugr().get_optype(*n).is_load_constant();
Expand Down Expand Up @@ -104,29 +108,35 @@ ascent::ascent! {
io_node(c,tl,i, IO::Input), in_wire_value(c, tl, p, v);

// Output node of child region propagate to Input node of child region
out_wire_value(c, i, input_p, v) <-- tail_loop_node(c, tl),
io_node(c,tl,i, IO::Input),
io_node(c,tl,o, IO::Output),
in_wire_value(c, o, output_p, output_v),
if let Some(tailloop) = c.hugr().get_optype(*tl).as_tail_loop(),
out_wire_value(c, in_n, out_p, v) <-- tail_loop_node(c, tl_n),
io_node(c,tl_n,in_n, IO::Input),
io_node(c,tl_n,out_n, IO::Output),
node_in_value_row(c, out_n, out_in_row), // get the whole input row for the output node
if out_in_row[0].supports_tag(0), // if it is possible for tag to be 0
if let Some(tailloop) = c.hugr().get_optype(*tl_n).as_tail_loop(),
let variant_len = tailloop.just_inputs.len(),
for (input_p, v) in utils::tail_loop_worker(*output_p, 0, variant_len, output_v);
for (out_p, v) in out_in_row.iter(c, *out_n).flat_map(
|(input_p, v)| utils::outputs_for_variant(input_p, 0, variant_len, v)
);

// Output node of child region propagate to outputs of tail loop
out_wire_value(c, tl, p, v) <-- tail_loop_node(c, tl),
io_node(c,tl,o, IO::Output),
in_wire_value(c, o, output_p, output_v),
if let Some(tailloop) = c.hugr().get_optype(*tl).as_tail_loop(),
out_wire_value(c, tl_n, out_p, v) <-- tail_loop_node(c, tl_n),
io_node(c,tl_n,out_n, IO::Output),
node_in_value_row(c, out_n, out_in_row), // get the whole input row for the output node
if out_in_row[0].supports_tag(1), // if it is possible for the tag to be 1
if let Some(tailloop) = c.hugr().get_optype(*tl_n).as_tail_loop(),
let variant_len = tailloop.just_outputs.len(),
for (p, v) in utils::tail_loop_worker(*output_p, 1, variant_len, output_v);
for (out_p, v) in out_in_row.iter(c, *out_n).flat_map(
|(input_p, v)| utils::outputs_for_variant(input_p, 1, variant_len, v)
);

lattice tail_loop_termination(C,Node,OrdLattice<TailLoopTermination>);
tail_loop_termination(c,tl,TailLoopTermination::NeverTerminates.into()) <--
tail_loop_node(c,tl);
tail_loop_termination(c,tl,TailLoopTermination::from_control_value(v).into()) <--
tail_loop_node(c,tl),
io_node(c,tl,o, IO::Output),
in_wire_value(c, o, Into::<IncomingPort>::into(0usize), v);
lattice tail_loop_termination(C,Node,TailLoopTermination);
tail_loop_termination(c,tl_n,TailLoopTermination::bottom()) <--
tail_loop_node(c,tl_n);
tail_loop_termination(c,tl_n,TailLoopTermination::from_control_value(v)) <--
tail_loop_node(c,tl_n),
io_node(c,tl,out_n, IO::Output),
in_wire_value(c, out_n, IncomingPort::from(0), v);


// Conditional
Expand All @@ -145,7 +155,7 @@ ascent::ascent! {
in_wire_value(c, cond, cond_in_p, cond_in_v),
if let Some(conditional) = c.hugr().get_optype(*cond).as_conditional(),
let variant_len = conditional.sum_rows[*case_index].len(),
for (i_p, v) in utils::tail_loop_worker(*cond_in_p, *case_index, variant_len, cond_in_v);
for (i_p, v) in utils::outputs_for_variant(*cond_in_p, *case_index, variant_len, cond_in_v);

// outputs of case nodes propagate to outputs of conditional
out_wire_value(c, cond, OutgoingPort::from(o_p.index()), v) <--
Expand Down Expand Up @@ -219,7 +229,7 @@ impl<'a, H: HugrView> Machine<'a, H> {
self.program
.tail_loop_termination
.iter()
.find_map(|(c, n, v)| (c == context && n == &node).then_some(v.0.clone()))
.find_map(|(c, n, v)| (c == context && n == &node).then_some(*v))
.unwrap()
}

Expand Down
37 changes: 22 additions & 15 deletions hugr-passes/src/const_fold2/datalog/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ fn test_tail_loop_never_iterates() {
let o_r = machine.read_out_wire_value(&c, tl_o).unwrap();
assert_eq!(o_r, r_v);
assert_eq!(
TailLoopTermination::SingleIteration,
TailLoopTermination::ExactlyZeroContinues,
machine.tail_loop_terminates(&c, tail_loop.node())
)
}
Expand All @@ -96,22 +96,29 @@ fn test_tail_loop_always_iterates() {
let mut builder = DFGBuilder::new(FunctionType::new_endo(&[])).unwrap();
let r_w = builder
.add_load_value(Value::sum(0, [], SumType::new([type_row![], BOOL_T.into()])).unwrap());
let true_w = builder.add_load_value(Value::true_val());

let tlb = builder
.tail_loop_builder([], [], vec![BOOL_T].into())
.tail_loop_builder([], [(BOOL_T,true_w)], vec![BOOL_T].into())
.unwrap();
let tail_loop = tlb.finish_with_outputs(r_w, []).unwrap();
let [tl_o] = tail_loop.outputs_arr();

// r_w has tag 0, so we always continue;
// we put true in our "other_output", but we should not propagate this to
// output because r_w never supports 1.
let tail_loop = tlb.finish_with_outputs(r_w, [true_w]).unwrap();

let [tl_o1, tl_o2] = tail_loop.outputs_arr();
let hugr = builder.finish_hugr(&EMPTY_REG).unwrap();

let mut machine = Machine::new();
let c = machine.run_hugr(&hugr);
// dbg!(&machine.tail_loop_io_node);
// dbg!(&machine.out_wire_value);

let o_r = machine.read_out_wire_partial_value(&c, tl_o).unwrap();
assert_eq!(o_r, PartialValue::Bottom);
let o_r1 = machine.read_out_wire_partial_value(&c, tl_o1).unwrap();
assert_eq!(o_r1, PartialValue::bottom());
let o_r2 = machine.read_out_wire_partial_value(&c, tl_o2).unwrap();
assert_eq!(o_r2, PartialValue::bottom());
assert_eq!(
TailLoopTermination::NeverTerminates,
TailLoopTermination::bottom(),
machine.tail_loop_terminates(&c, tail_loop.node())
)
}
Expand Down Expand Up @@ -146,20 +153,20 @@ fn test_tail_loop_iterates_twice() {
let hugr = builder.finish_hugr(&EMPTY_REG).unwrap();
// TODO once we can do conditionals put these wires inside `just_outputs` and
// we should be able to propagate their values
// let [o_w1, o_w2, _] = tail_loop.outputs_arr();
let [o_w1, o_w2, _] = tail_loop.outputs_arr();

let mut machine = Machine::new();
let c = machine.run_hugr(&hugr);
// dbg!(&machine.tail_loop_io_node);
// dbg!(&machine.out_wire_value);

// TODO these hould be the propagated values
// let o_r1 = machine.read_out_wire_value(&c, o_w1).unwrap();
// assert_eq!(o_r1, Value::false_val());
// let o_r2 = machine.read_out_wire_value(&c, o_w2).unwrap();
// TODO these hould be the propagated values for now they will bt join(true,false)
let o_r1 = machine.read_out_wire_partial_value(&c, o_w1).unwrap();
// assert_eq!(o_r1, PartialValue::top());
let o_r2 = machine.read_out_wire_partial_value(&c, o_w2).unwrap();
// assert_eq!(o_r2, Value::true_val());
assert_eq!(
TailLoopTermination::Terminates,
TailLoopTermination::Top,
machine.tail_loop_terminates(&c, tail_loop.node())
)
}
Expand Down
Loading

0 comments on commit 067bf8d

Please sign in to comment.