Skip to content

Commit

Permalink
Add _if_true context manager
Browse files Browse the repository at this point in the history
  • Loading branch information
mark-koch committed Jan 11, 2024
1 parent a14ae45 commit 8f561e2
Showing 1 changed file with 29 additions and 31 deletions.
60 changes: 29 additions & 31 deletions guppy/compiler/expr_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,25 @@ def _new_case(
with self._new_dfcontainer(inputs, self.graph.add_case(cond_node)):
yield
self.graph.add_output([self.visit(name) for name in outputs])
# Update the DFG with the outputs from the Conditional node, but only we haven't
# already added some
if cond_node.num_out_ports == 0:
for name in outputs:
self.dfg[name.id].port = cond_node.add_out_port(get_type(name))

@contextmanager
def _if_true(self, cond: ast.expr, inputs: list[ast.Name]) -> Iterator[None]:
"""Context manager to build a graph inside the `true` case of a `Conditional`
In the `false` case, the inputs are outputted as is.
"""
cond_node = self.graph.add_conditional(
self.visit(cond), [self.visit(inp) for inp in inputs]
)
# If the condition is false, output the inputs as is
with self._new_case(inputs, inputs, cond_node):
pass
# If the condition is true, we enter the `with` block
with self._new_case(inputs, inputs, cond_node):
yield
# Update the DFG with the outputs from the Conditional node
for name in inputs:
self.dfg[name.id].port = cond_node.add_out_port(get_type(name))

def visit_Constant(self, node: ast.Constant) -> OutPortV:
if value := python_value_to_hugr(node.value):
Expand Down Expand Up @@ -230,38 +244,22 @@ def compile_generators(elt: ast.expr, gens: list[DesugaredGenerator]) -> None:
compiler.compile_stmts([gen.iter_assign], self.dfg)
inputs = [gen.iter, list_name]
with self._new_loop(inputs, gen.hasnext):
# Compile the `hasnext` check and plug it into a conditional
compiler.compile_stmts([gen.hasnext_assign], self.dfg)
cond = self.graph.add_conditional(
self.visit(gen.hasnext),
[self.visit(inp) for inp in inputs],
)

# If the iterator is finished, output the iterator and list as is (this
# is achieved by passing `inputs, inputs` below)
with self._new_case(inputs, inputs, cond):
pass

# If there is a next element, compile it and continue with the next
# generator
with self._new_case(inputs, inputs, cond):
compiler.compile_stmts([gen.hasnext_assign], self.dfg)
with self._if_true(gen.hasnext, inputs):

def compile_ifs(ifs: list[ast.expr]) -> None:
if not ifs:
"""Helper function to compile a series of if-guards into nested
Conditional nodes."""
if ifs:
if_expr, *ifs = ifs
# If the condition is true, continue with the next one
with self._if_true(if_expr, inputs):
compile_ifs(ifs)
else:
# If there are no guards left, compile the next generator
compile_generators(elt, gens)
return
if_expr, *ifs = ifs
cond = self.graph.add_conditional(
self.visit(if_expr),
[self.visit(gen.iter), self.visit(list_name)],
)
# If the condition is false, output the iterator and list as is
with self._new_case(inputs, inputs, cond):
pass
# If the condition is true, continue with the next one
with self._new_case(inputs, inputs, cond):
compile_ifs(ifs)

compiler.compile_stmts([gen.next_assign], self.dfg)
compile_ifs(gen.ifs)
Expand Down

0 comments on commit 8f561e2

Please sign in to comment.