Skip to content

Commit

Permalink
[IR] Create predecessors() and successors() on ir.Node (#2022)
Browse files Browse the repository at this point in the history
- Also updated `Usage` to a named tuple
- Implement `consumers()` on `Value`

---------

Co-authored-by: Copilot <[email protected]>
  • Loading branch information
justinchuby and Copilot authored Jan 21, 2025
1 parent 7582138 commit 969c078
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 9 deletions.
50 changes: 45 additions & 5 deletions onnxscript/ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
Hashable,
Iterable,
Iterator,
NamedTuple,
OrderedDict,
Sequence,
SupportsInt,
Expand Down Expand Up @@ -1055,6 +1056,18 @@ def _quoted(string: str) -> str:
return f'"{string}"'


class Usage(NamedTuple):
"""A usage of a value in a node.
Attributes:
node: The node that uses the value.
idx: The input index of the value in the node.
"""

node: Node
idx: int


class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
"""IR Node.
Expand Down Expand Up @@ -1293,6 +1306,25 @@ def inputs(self, _: Any) -> None:
"Directly mutating the input sequence is unsupported. Please use Node.replace_input_with() instead."
)

def predecessors(self) -> Sequence[Node]:
"""Return the predecessor nodes of the node, deduplicated, in a deterministic order."""
# Use the ordered nature of a dictionary to deduplicate the nodes
predecessors: dict[Node, None] = {}
for value in self.inputs:
if value is not None and (producer := value.producer()) is not None:
predecessors[producer] = None
return tuple(predecessors)

def successors(self) -> Sequence[Node]:
"""Return the successor nodes of the node, deduplicated, in a deterministic order."""
# Use the ordered nature of a dictionary to deduplicate the nodes
successors: dict[Node, None] = {}
for value in self.outputs:
assert value is not None, "Bug: Output values are not expected to be None"
for usage in value.uses():
successors[usage.node] = None
return tuple(successors)

def replace_input_with(self, index: int, value: Value | None) -> None:
"""Replace an input with a new value."""
if index < 0 or index >= len(self.inputs):
Expand Down Expand Up @@ -1564,7 +1596,7 @@ def __init__(
# Use a collection of (Node, int) to store uses. This is needed
# because a single use can use the same value multiple times.
# Use a dictionary to preserve insertion order so that the visiting order is deterministic
self._uses: dict[tuple[Node, int], None] = {}
self._uses: dict[Usage, None] = {}
self.doc_string = doc_string

def __repr__(self) -> str:
Expand Down Expand Up @@ -1595,31 +1627,39 @@ def producer(self) -> Node | None:
"""
return self._producer

def consumers(self) -> Sequence[Node]:
"""Return the nodes (deduplicated) that consume this value."""
return tuple({usage.node: None for usage in self._uses})

def index(self) -> int | None:
"""The index of the output of the defining node."""
return self._index

def uses(self) -> Collection[tuple[Node, int]]:
def uses(self) -> Collection[Usage]:
"""Return a set of uses of the value.
The set contains tuples of ``(Node, index)`` where the index is the index of the input
of the node. For example, if ``node.inputs[1] == value``, then the use is ``(node, 1)``.
"""
return self._uses.keys()
# Create a tuple for the collection so that iteration on will will not
# be affected when the usage changes during graph mutation.
# This adds a small overhead but is better a user experience than
# having users call tuple().
return tuple(self._uses)

def _add_usage(self, use: Node, index: int) -> None:
"""Add a usage of this value.
This is an internal method. It should only be called by the Node class.
"""
self._uses[(use, index)] = None
self._uses[Usage(use, index)] = None

def _remove_usage(self, use: Node, index: int) -> None:
"""Remove a node from the uses of this value.
This is an internal method. It should only be called by the Node class.
"""
self._uses.pop((use, index))
self._uses.pop(Usage(use, index))

@property
def name(self) -> str | None:
Expand Down
48 changes: 44 additions & 4 deletions onnxscript/ir/_core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,13 @@ def test_is_dynamic_on_empty_shape(self):


class ValueTest(unittest.TestCase):
def setUp(self) -> None:
self.v0 = _core.Value(name="v0")
self.v1 = _core.Value(name="v1")
self.node = _core.Node(
"test", "TestOp", inputs=(self.v0, self.v1, self.v1), num_outputs=2
)

def test_initialize(self):
_ = _core.Value()

Expand All @@ -732,14 +739,30 @@ def test_meta(self):
value.metadata_props["test"] = "any string"
self.assertEqual(value.metadata_props["test"], "any string")

def test_producer(self):
self.assertEqual(self.v0.producer(), None)
self.assertEqual(self.v1.producer(), None)
self.assertEqual(self.node.outputs[0].producer(), self.node)
self.assertEqual(self.node.outputs[1].producer(), self.node)

def test_consumers(self):
self.assertEqual(self.v0.consumers(), (self.node,))
self.assertEqual(self.v1.consumers(), (self.node,))
self.assertEqual(self.node.outputs[0].consumers(), ())
self.assertEqual(self.node.outputs[1].consumers(), ())

# TODO(justinchuby): Test all methods


class NodeTest(unittest.TestCase):
def setUp(self) -> None:
self.v0 = _core.Value()
self.v1 = _core.Value()
self.node = _core.Node("test", "TestOp", inputs=(self.v0, self.v1), num_outputs=3)
self.v0 = _core.Value(name="v0")
self.v1 = _core.Value(name="v1")
self.node = _core.Node(
"test", "TestOp", inputs=(self.v0, self.v1, self.v1), num_outputs=3
)
self.node_a = _core.Node("test", "TestOpA", inputs=[self.node.outputs[0]])
self.node_b = _core.Node("test", "TestOpB", inputs=self.node.outputs)

def test_it_is_hashable(self):
self.assertIsInstance(hash(self.node), int)
Expand All @@ -748,7 +771,7 @@ def test_it_is_hashable(self):
def test_init_with_values(self):
self.assertEqual(self.node.domain, "test")
self.assertEqual(self.node.op_type, "TestOp")
self.assertEqual(self.node.inputs, (self.v0, self.v1))
self.assertEqual(self.node.inputs, (self.v0, self.v1, self.v1))
self.assertEqual(len(self.node.outputs), 3)
self.assertEqual(self.node.attributes, {})

Expand Down Expand Up @@ -807,6 +830,23 @@ def test_it_is_added_to_a_graph_if_specified(self):
)
self.assertIn(self.node, graph)

def test_predecessors(self):
self.assertEqual(self.node.predecessors(), ())
self.assertEqual(self.node_a.predecessors(), (self.node,))
self.assertEqual(self.node_b.predecessors(), (self.node,))

def test_predecessors_are_unique(self):
# node_b has three inputs from node, but only one predecessor
self.assertEqual(self.node_b.predecessors(), (self.node,))

def test_successors(self):
self.assertEqual(self.node.successors(), (self.node_a, self.node_b))
self.assertEqual(self.node_a.successors(), ())
self.assertEqual(self.node_b.successors(), ())

def test_successors_are_unique(self):
self.assertEqual(self.node.successors(), (self.node_a, self.node_b))

# TODO(justinchuby): Test all methods


Expand Down

0 comments on commit 969c078

Please sign in to comment.