Skip to content

Commit

Permalink
copy when flipping function type
Browse files Browse the repository at this point in the history
  • Loading branch information
ss2165 committed Jun 18, 2024
1 parent 80e6832 commit 4c616c0
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
13 changes: 7 additions & 6 deletions hugr-py/src/hugr/_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,16 @@ def add_block(self, input_types: TypeRow) -> Block:
return new_block

def add_successor(self, pred: Wire) -> Block:
pred = pred.out_port()
block = self.hugr._get_typed_op(pred.node, ops.DataflowBlock)
inputs = block.nth_outputs(pred.offset)
b = self.add_block(inputs)
b = self.add_block(self._nth_outputs(pred))

self.branch(pred, b)
return b

def _nth_outputs(self, wire: Wire) -> TypeRow:
port = wire.out_port()
block = self.hugr._get_typed_op(port.node, ops.DataflowBlock)
return block.nth_outputs(port.offset)

def branch(self, src: Wire, dst: ToNode) -> None:
# TODO check for existing link/type compatibility
if dst.to_node() == self.exit:
Expand All @@ -116,8 +118,7 @@ def branch_exit(self, src: Wire) -> None:
src = src.out_port()
self.hugr.add_link(src, self.exit.inp(0))

src_block = self.hugr._get_typed_op(src.node, ops.DataflowBlock)
out_types = src_block.nth_outputs(src.offset)
out_types = self._nth_outputs(src)
if self._exit_op._cfg_outputs is not None:
if self._exit_op._cfg_outputs != out_types:
raise MismatchedExit(src.node.idx)

Check warning on line 124 in hugr-py/src/hugr/_cfg.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_cfg.py#L124

Added line #L124 was not covered by tests
Expand Down
2 changes: 1 addition & 1 deletion hugr-py/src/hugr/_tys.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def empty(cls) -> FunctionType:
return cls(input=[], output=[])

def flip(self) -> FunctionType:
return FunctionType(input=self.output, output=self.input)
return FunctionType(input=list(self.output), output=list(self.input))


@dataclass(frozen=True)
Expand Down

0 comments on commit 4c616c0

Please sign in to comment.