Skip to content

Commit

Permalink
Merge branch 'main' into justinchu/ir-save
Browse files Browse the repository at this point in the history
  • Loading branch information
justinchuby authored Jan 21, 2025
2 parents 273919d + 969c078 commit 67d8ab9
Show file tree
Hide file tree
Showing 10 changed files with 177 additions and 32 deletions.
23 changes: 7 additions & 16 deletions onnxscript/function_libs/torch_lib/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,26 +825,17 @@ def aten_leaky_relu_backward(
raise NotImplementedError()


# NOTE: Do not register - We rely on PyTorch decomposition to aten_addmm (Gemm)
def aten_linear(input: TFloat, weight: TFloat) -> TFloat:
@torch_op("aten::linear", trace_only=True)
def aten_linear(input: TFloat, weight: TFloat, bias: Optional[TFloat] = None) -> TFloat:
"""linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor"""

# NOTE: The symbolic function in torch.onnx also uses Gemm in certain cases
# Optimizers may consider this path and replace it with Gemm
# We do not use Gemm here because input can have batch dimensions, which Gemm does not support
weight_transposed = op.Transpose(weight, perm=[1, 0])
return op.MatMul(input, weight_transposed)


# NOTE: Do not register - We rely on PyTorch decomposition to aten_addmm (Gemm)
def aten_linear_bias(input: TFloat, weight: TFloat, bias: TFloat) -> TFloat:
"""linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor"""

# NOTE: The symbolic function in torch.onnx also uses Gemm in certain cases
# Optimizers may consider this path and replace it with Gemm
# We do not use Gemm here because input can have batch dimensions, which Gemm does not support
if len(input.shape) == 2:
# Use Gemm for the rank 2 input
return op.Gemm(input, weight, bias, transB=True)
weight_transposed = op.Transpose(weight, perm=[1, 0])
mul = op.MatMul(input, weight_transposed)
if bias is None:
return mul
return op.Add(mul, bias)


Expand Down
50 changes: 45 additions & 5 deletions onnxscript/ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
Hashable,
Iterable,
Iterator,
NamedTuple,
OrderedDict,
Sequence,
SupportsInt,
Expand Down Expand Up @@ -1055,6 +1056,18 @@ def _quoted(string: str) -> str:
return f'"{string}"'


class Usage(NamedTuple):
"""A usage of a value in a node.
Attributes:
node: The node that uses the value.
idx: The input index of the value in the node.
"""

node: Node
idx: int


class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
"""IR Node.
Expand Down Expand Up @@ -1293,6 +1306,25 @@ def inputs(self, _: Any) -> None:
"Directly mutating the input sequence is unsupported. Please use Node.replace_input_with() instead."
)

def predecessors(self) -> Sequence[Node]:
"""Return the predecessor nodes of the node, deduplicated, in a deterministic order."""
# Use the ordered nature of a dictionary to deduplicate the nodes
predecessors: dict[Node, None] = {}
for value in self.inputs:
if value is not None and (producer := value.producer()) is not None:
predecessors[producer] = None
return tuple(predecessors)

def successors(self) -> Sequence[Node]:
"""Return the successor nodes of the node, deduplicated, in a deterministic order."""
# Use the ordered nature of a dictionary to deduplicate the nodes
successors: dict[Node, None] = {}
for value in self.outputs:
assert value is not None, "Bug: Output values are not expected to be None"
for usage in value.uses():
successors[usage.node] = None
return tuple(successors)

def replace_input_with(self, index: int, value: Value | None) -> None:
"""Replace an input with a new value."""
if index < 0 or index >= len(self.inputs):
Expand Down Expand Up @@ -1564,7 +1596,7 @@ def __init__(
# Use a collection of (Node, int) to store uses. This is needed
# because a single use can use the same value multiple times.
# Use a dictionary to preserve insertion order so that the visiting order is deterministic
self._uses: dict[tuple[Node, int], None] = {}
self._uses: dict[Usage, None] = {}
self.doc_string = doc_string

def __repr__(self) -> str:
Expand Down Expand Up @@ -1595,31 +1627,39 @@ def producer(self) -> Node | None:
"""
return self._producer

def consumers(self) -> Sequence[Node]:
"""Return the nodes (deduplicated) that consume this value."""
return tuple({usage.node: None for usage in self._uses})

def index(self) -> int | None:
"""The index of the output of the defining node."""
return self._index

def uses(self) -> Collection[tuple[Node, int]]:
def uses(self) -> Collection[Usage]:
"""Return a set of uses of the value.
The set contains tuples of ``(Node, index)`` where the index is the index of the input
of the node. For example, if ``node.inputs[1] == value``, then the use is ``(node, 1)``.
"""
return self._uses.keys()
# Create a tuple for the collection so that iteration on will will not
# be affected when the usage changes during graph mutation.
# This adds a small overhead but is better a user experience than
# having users call tuple().
return tuple(self._uses)

def _add_usage(self, use: Node, index: int) -> None:
"""Add a usage of this value.
This is an internal method. It should only be called by the Node class.
"""
self._uses[(use, index)] = None
self._uses[Usage(use, index)] = None

def _remove_usage(self, use: Node, index: int) -> None:
"""Remove a node from the uses of this value.
This is an internal method. It should only be called by the Node class.
"""
self._uses.pop((use, index))
self._uses.pop(Usage(use, index))

@property
def name(self) -> str | None:
Expand Down
48 changes: 44 additions & 4 deletions onnxscript/ir/_core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,13 @@ def test_is_dynamic_on_empty_shape(self):


class ValueTest(unittest.TestCase):
def setUp(self) -> None:
self.v0 = _core.Value(name="v0")
self.v1 = _core.Value(name="v1")
self.node = _core.Node(
"test", "TestOp", inputs=(self.v0, self.v1, self.v1), num_outputs=2
)

def test_initialize(self):
_ = _core.Value()

Expand All @@ -732,14 +739,30 @@ def test_meta(self):
value.metadata_props["test"] = "any string"
self.assertEqual(value.metadata_props["test"], "any string")

def test_producer(self):
self.assertEqual(self.v0.producer(), None)
self.assertEqual(self.v1.producer(), None)
self.assertEqual(self.node.outputs[0].producer(), self.node)
self.assertEqual(self.node.outputs[1].producer(), self.node)

def test_consumers(self):
self.assertEqual(self.v0.consumers(), (self.node,))
self.assertEqual(self.v1.consumers(), (self.node,))
self.assertEqual(self.node.outputs[0].consumers(), ())
self.assertEqual(self.node.outputs[1].consumers(), ())

# TODO(justinchuby): Test all methods


class NodeTest(unittest.TestCase):
def setUp(self) -> None:
self.v0 = _core.Value()
self.v1 = _core.Value()
self.node = _core.Node("test", "TestOp", inputs=(self.v0, self.v1), num_outputs=3)
self.v0 = _core.Value(name="v0")
self.v1 = _core.Value(name="v1")
self.node = _core.Node(
"test", "TestOp", inputs=(self.v0, self.v1, self.v1), num_outputs=3
)
self.node_a = _core.Node("test", "TestOpA", inputs=[self.node.outputs[0]])
self.node_b = _core.Node("test", "TestOpB", inputs=self.node.outputs)

def test_it_is_hashable(self):
self.assertIsInstance(hash(self.node), int)
Expand All @@ -748,7 +771,7 @@ def test_it_is_hashable(self):
def test_init_with_values(self):
self.assertEqual(self.node.domain, "test")
self.assertEqual(self.node.op_type, "TestOp")
self.assertEqual(self.node.inputs, (self.v0, self.v1))
self.assertEqual(self.node.inputs, (self.v0, self.v1, self.v1))
self.assertEqual(len(self.node.outputs), 3)
self.assertEqual(self.node.attributes, {})

Expand Down Expand Up @@ -807,6 +830,23 @@ def test_it_is_added_to_a_graph_if_specified(self):
)
self.assertIn(self.node, graph)

def test_predecessors(self):
self.assertEqual(self.node.predecessors(), ())
self.assertEqual(self.node_a.predecessors(), (self.node,))
self.assertEqual(self.node_b.predecessors(), (self.node,))

def test_predecessors_are_unique(self):
# node_b has three inputs from node, but only one predecessor
self.assertEqual(self.node_b.predecessors(), (self.node,))

def test_successors(self):
self.assertEqual(self.node.successors(), (self.node_a, self.node_b))
self.assertEqual(self.node_a.successors(), ())
self.assertEqual(self.node_b.successors(), ())

def test_successors_are_unique(self):
self.assertEqual(self.node.successors(), (self.node_a, self.node_b))

# TODO(justinchuby): Test all methods


Expand Down
16 changes: 16 additions & 0 deletions onnxscript/ir/_tape.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class Tape(Iterable[ir.Node]):

def __init__(self) -> None:
self._nodes: list[ir.Node] = []
self._initializers: list[ir.Value] = []

def __iter__(self) -> Iterator[ir.Node]:
return iter(self._nodes)
Expand All @@ -26,6 +27,10 @@ def __iter__(self) -> Iterator[ir.Node]:
def nodes(self) -> Sequence[ir.Node]:
return tuple(self._nodes)

@property
def initializers(self) -> Sequence[ir.Value]:
return tuple(self._initializers)

def op(
self,
op_type: str,
Expand Down Expand Up @@ -60,6 +65,17 @@ def op_multi_output(

return node.outputs

def initializer(self, tensor: ir.TensorProtocol, name: str | None = None) -> ir.Value:
name = name or tensor.name
if name is None:
raise ValueError("Name must be provided for initializer.")
shape = ir.Shape((d if isinstance(d, int) else d.value) for d in tensor.shape.dims)
value = ir.Value(
name=name, shape=shape, type=ir.TensorType(tensor.dtype), const_value=tensor
)
self._initializers.append(value)
return value


# A type representing the domains/versions used in creating nodes in IR.
UsedOpsets = List[Tuple[str, Optional[int]]]
Expand Down
6 changes: 5 additions & 1 deletion onnxscript/optimizer/_constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,11 @@ def evaluate(self, domain: str, op: str, version: int, *args, **kwargs) -> Any:
evaluator = self.get_evaluator(domain, op, version)
if evaluator is None:
return None
return evaluator(*args, **kwargs)
try:
return evaluator(*args, **kwargs)
except Exception as e:
logger.warning("Evaluation failed: %s", e)
return None


_reference_evaluator = ReferenceEvaluator()
Expand Down
22 changes: 21 additions & 1 deletion onnxscript/rewriter/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,6 +900,7 @@ class ReplacementSubgraph:
match: MatchResult
new_outputs: Sequence[ir.Value]
new_nodes: Sequence[ir.Node]
new_initializers: Sequence[ir.Value]
used_opsets: _tape.UsedOpsets


Expand Down Expand Up @@ -928,7 +929,9 @@ def get_replacement(self, match: MatchResult) -> ReplacementSubgraph | None:
return None # Failed to create replacement subgraph
if not isinstance(new_outputs, Sequence):
new_outputs = [new_outputs]
return ReplacementSubgraph(match, new_outputs, context.nodes, context.used_opsets)
return ReplacementSubgraph(
match, new_outputs, context.nodes, context.initializers, context.used_opsets
)


def _update_opset_imports(
Expand Down Expand Up @@ -1566,6 +1569,23 @@ def _apply_to_graph_or_function(
if delta is None or tracer is not None:
continue
assert isinstance(delta, ReplacementSubgraph)
if delta.new_initializers:
if isinstance(graph_or_function, ir.Function):
# TODO(rama): Can't add initializers to functions. But currently this is not
# an issue, as we apply inlining before applying rewrite rules.
if verbose:
print(
f"Rewrites adding initializers not supported for functions: {rule}"
)
continue
initializers = graph_or_function.initializers
for initializer in delta.new_initializers:
if initializer.name in initializers:
if verbose:
print(f"Initializer {initializer.name} already exists.")
continue
for initializer in delta.new_initializers:
initializers[initializer.name] = initializer # type: ignore[index]
# TODO: This does not yet handle the problem of determining the correct insertion point
# for inserted nodes in the case of patterns with multiple output-nodes. The following
# is sufficient for patterns with a single output-node "node", which can serve as the
Expand Down
34 changes: 34 additions & 0 deletions onnxscript/rewriter/pattern_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
import unittest

import numpy as np
import onnx.checker
import onnx.parser

Expand Down Expand Up @@ -543,6 +544,39 @@ def test_model(x: FLOAT[1024]) -> FLOAT[1024]:
# Not a robust test. But test serves to ensure that debug mode is producing something.
self.assertIn("OpType mismatch: expected Abs, got Neg", captured_output)

def test_new_initializer(self):
def source_pattern(op, x, y):
return op.Gemm(x, op.Transpose(y))

def check(context, x, y):
return y.const_value is not None

def replacement(op, x, y):
tensor = y.const_value
name = y.name + "_transposed"
transposed = ir.tensor(tensor.numpy().T, name=name)
initializer = op.initializer(transposed)
return op.Gemm(x, initializer)

rule = pattern.RewriteRule(source_pattern, replacement, check)

y_value = np.random.rand(8, 4).astype(np.float32)

@script()
def test_model(x: FLOAT[16, 8]) -> FLOAT[16, 4]:
y = op.Constant(value=y_value)
return op.Gemm(x, op.Transpose(y))

model_proto = test_model.to_model_proto()
model = ir.serde.deserialize_model(model_proto)
rule.apply_to_model(model)
self.assertEqual(len(model.graph.initializers), 1)
last_node = model.graph[-1]
self.assertEqual(len(last_node.inputs), 2)
init_name = last_node.inputs[1].name
self.assertIn(init_name, model.graph.initializers)
self.assertIs(last_node.inputs[1], model.graph.initializers[init_name])


class PatternBuilderTest(unittest.TestCase):
def test_pattern_builder_context(self):
Expand Down
2 changes: 1 addition & 1 deletion requirements/ci/requirements-onnx-weekly.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
onnx-weekly==1.18.0.dev20250113
onnx-weekly==1.18.0.dev20250120
2 changes: 1 addition & 1 deletion requirements/lintrunner/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# This file is auto updated by dependabot
lintrunner-adapters>=0.8.0
# RUFF, RUFF-FIX
ruff==0.9.1
ruff==0.9.2
# MYPY
mypy==1.10.1
types-PyYAML==6.0.12.20241230
Expand Down
6 changes: 3 additions & 3 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1855,6 +1855,9 @@ def _where_input_wrangler(
tolerance={torch.float16: (8e-2, 1e-4)},
),
TorchLibOpInfo("nn.functional.glu", nn_ops.aten_glu),
TorchLibOpInfo(
"nn.functional.linear", nn_ops.aten_linear, tolerance={torch.float16: (1e-2, 1e-3)}
),
TorchLibOpInfo(
"nn.functional.unfold",
nn_ops.aten_im2col,
Expand Down Expand Up @@ -2176,9 +2179,6 @@ def _where_input_wrangler(
ops_test_common.duplicate_opinfo(OPS_DB, "mean", ("mean_dim",))
ops_test_common.duplicate_opinfo(OPS_DB, "min", ("min_dim",))
ops_test_common.duplicate_opinfo(OPS_DB, "minimum", ("minimum_bool",))
ops_test_common.duplicate_opinfo(
OPS_DB, "nn.functional.linear", ("nn.functional.linear_bias",)
)
ops_test_common.duplicate_opinfo(
OPS_DB,
"nn.functional.pad",
Expand Down

0 comments on commit 67d8ab9

Please sign in to comment.