From 8f561e25fc34fe2e4fe4bca266fa4b005b14bbef Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Thu, 11 Jan 2024 11:30:54 +0000 Subject: [PATCH] Add _if_true context manager --- guppy/compiler/expr_compiler.py | 60 ++++++++++++++++----------------- 1 file changed, 29 insertions(+), 31 deletions(-) diff --git a/guppy/compiler/expr_compiler.py b/guppy/compiler/expr_compiler.py index bcf3b4eb..cd083248 100644 --- a/guppy/compiler/expr_compiler.py +++ b/guppy/compiler/expr_compiler.py @@ -110,11 +110,25 @@ def _new_case( with self._new_dfcontainer(inputs, self.graph.add_case(cond_node)): yield self.graph.add_output([self.visit(name) for name in outputs]) - # Update the DFG with the outputs from the Conditional node, but only we haven't - # already added some - if cond_node.num_out_ports == 0: - for name in outputs: - self.dfg[name.id].port = cond_node.add_out_port(get_type(name)) + + @contextmanager + def _if_true(self, cond: ast.expr, inputs: list[ast.Name]) -> Iterator[None]: + """Context manager to build a graph inside the `true` case of a `Conditional` + + In the `false` case, the inputs are outputted as is. + """ + cond_node = self.graph.add_conditional( + self.visit(cond), [self.visit(inp) for inp in inputs] + ) + # If the condition is false, output the inputs as is + with self._new_case(inputs, inputs, cond_node): + pass + # If the condition is true, we enter the `with` block + with self._new_case(inputs, inputs, cond_node): + yield + # Update the DFG with the outputs from the Conditional node + for name in inputs: + self.dfg[name.id].port = cond_node.add_out_port(get_type(name)) def visit_Constant(self, node: ast.Constant) -> OutPortV: if value := python_value_to_hugr(node.value): @@ -230,38 +244,22 @@ def compile_generators(elt: ast.expr, gens: list[DesugaredGenerator]) -> None: compiler.compile_stmts([gen.iter_assign], self.dfg) inputs = [gen.iter, list_name] with self._new_loop(inputs, gen.hasnext): - # Compile the `hasnext` check and plug it into a conditional - compiler.compile_stmts([gen.hasnext_assign], self.dfg) - cond = self.graph.add_conditional( - self.visit(gen.hasnext), - [self.visit(inp) for inp in inputs], - ) - - # If the iterator is finished, output the iterator and list as is (this - # is achieved by passing `inputs, inputs` below) - with self._new_case(inputs, inputs, cond): - pass - # If there is a next element, compile it and continue with the next # generator - with self._new_case(inputs, inputs, cond): + compiler.compile_stmts([gen.hasnext_assign], self.dfg) + with self._if_true(gen.hasnext, inputs): def compile_ifs(ifs: list[ast.expr]) -> None: - if not ifs: + """Helper function to compile a series of if-guards into nested + Conditional nodes.""" + if ifs: + if_expr, *ifs = ifs + # If the condition is true, continue with the next one + with self._if_true(if_expr, inputs): + compile_ifs(ifs) + else: # If there are no guards left, compile the next generator compile_generators(elt, gens) - return - if_expr, *ifs = ifs - cond = self.graph.add_conditional( - self.visit(if_expr), - [self.visit(gen.iter), self.visit(list_name)], - ) - # If the condition is false, output the iterator and list as is - with self._new_case(inputs, inputs, cond): - pass - # If the condition is true, continue with the next one - with self._new_case(inputs, inputs, cond): - compile_ifs(ifs) compiler.compile_stmts([gen.next_assign], self.dfg) compile_ifs(gen.ifs)