From 16dc35afc753966bf341ac902e86e0ed22abf9e3 Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Fri, 2 Feb 2024 10:25:09 -0800 Subject: [PATCH] Unify TORCH_LIBRARY definitions (#6455) Summary: This pull request tries to unify all TORCH_LIBRARY definitions across torch_xla into one xla library. Test Plan: CI --- test/stablehlo/test_mark_pattern.py | 58 +++++++++----------- torch_xla/core/xla_model.py | 3 + torch_xla/csrc/aten_autograd_ops.cpp | 13 +++++ torch_xla/csrc/xla_sharding_util.cpp | 16 +----- torch_xla/experimental/mark_pattern_utils.py | 12 ++-- torch_xla/experimental/quantized.py | 12 ++-- torch_xla/experimental/xla_marker.py | 15 +++-- 7 files changed, 60 insertions(+), 69 deletions(-) diff --git a/test/stablehlo/test_mark_pattern.py b/test/stablehlo/test_mark_pattern.py index e0fcce201f74..0e635da1e4ac 100644 --- a/test/stablehlo/test_mark_pattern.py +++ b/test/stablehlo/test_mark_pattern.py @@ -39,9 +39,9 @@ def test_basic(self): def f(x): x = x + 1 - x = torch.ops.xla_pattern_marking.mark_tensor(x, "p", 0, "0", True) + x = torch.ops.xla.mark_tensor(x, "p", 0, "0", True) x = x + 2 - x = torch.ops.xla_pattern_marking.mark_tensor(x, "p", 0, "0", False) + x = torch.ops.xla.mark_tensor(x, "p", 0, "0", False) return x input_args = (torch.randn(5),) @@ -59,14 +59,11 @@ def __init__(self): def forward(self, x, y): q, k, v = x.split(128, dim=-2) - q = torch.ops.xla_pattern_marking.mark_tensor( - q, "sdpa", pos=0, id="0", is_input=True) - k = torch.ops.xla_pattern_marking.mark_tensor( - k, "sdpa", pos=1, id="0", is_input=True) - v = torch.ops.xla_pattern_marking.mark_tensor( - v, "sdpa", pos=2, id="0", is_input=True) + q = torch.ops.xla.mark_tensor(q, "sdpa", pos=0, id="0", is_input=True) + k = torch.ops.xla.mark_tensor(k, "sdpa", pos=1, id="0", is_input=True) + v = torch.ops.xla.mark_tensor(v, "sdpa", pos=2, id="0", is_input=True) attn_out = F.scaled_dot_product_attention(q, k, v, scale=0.25) - attn_out = torch.ops.xla_pattern_marking.mark_tensor( + attn_out = torch.ops.xla.mark_tensor( attn_out, "sdpa", pos=0, @@ -74,14 +71,11 @@ def forward(self, x, y): is_input=False, attr={"scale": 0.25}) q, k, v = y.split(128, dim=-2) - q = torch.ops.xla_pattern_marking.mark_tensor( - q, "sdpa", pos=0, id="1", is_input=True) - k = torch.ops.xla_pattern_marking.mark_tensor( - k, "sdpa", pos=1, id="1", is_input=True) - v = torch.ops.xla_pattern_marking.mark_tensor( - v, "sdpa", pos=2, id="1", is_input=True) + q = torch.ops.xla.mark_tensor(q, "sdpa", pos=0, id="1", is_input=True) + k = torch.ops.xla.mark_tensor(k, "sdpa", pos=1, id="1", is_input=True) + v = torch.ops.xla.mark_tensor(v, "sdpa", pos=2, id="1", is_input=True) attn_out2 = F.scaled_dot_product_attention(q, k, v, scale=4) - attn_out2 = torch.ops.xla_pattern_marking.mark_tensor( + attn_out2 = torch.ops.xla.mark_tensor( attn_out2, "sdpa", pos=0, id="1", is_input=False, attr={"scale": 2}) return attn_out, attn_out2 @@ -193,11 +187,11 @@ def forward(self, x, y): def test_multiple_input(self): def f(x, y): - x = torch.ops.xla_pattern_marking.mark_tensor(x, "p", 0, "0", True) - y = torch.ops.xla_pattern_marking.mark_tensor(y, "p", 1, "0", True) + x = torch.ops.xla.mark_tensor(x, "p", 0, "0", True) + y = torch.ops.xla.mark_tensor(y, "p", 1, "0", True) out = x + y out = out * x * y - out = torch.ops.xla_pattern_marking.mark_tensor(out, "p", 0, "0", False) + out = torch.ops.xla.mark_tensor(out, "p", 0, "0", False) return out input_args = (torch.ones(5), torch.ones(5)) @@ -209,12 +203,12 @@ def f(x, y): def test_multiple_output(self): def f(x, y): - x = torch.ops.xla_pattern_marking.mark_tensor(x, "p", 0, "0", True) - y = torch.ops.xla_pattern_marking.mark_tensor(y, "p", 1, "0", True) + x = torch.ops.xla.mark_tensor(x, "p", 0, "0", True) + y = torch.ops.xla.mark_tensor(y, "p", 1, "0", True) out1 = x + y out2 = x * y - out1 = torch.ops.xla_pattern_marking.mark_tensor(out1, "p", 0, "0", False) - out2 = torch.ops.xla_pattern_marking.mark_tensor(out2, "p", 1, "0", False) + out1 = torch.ops.xla.mark_tensor(out1, "p", 0, "0", False) + out2 = torch.ops.xla.mark_tensor(out2, "p", 1, "0", False) return out1, out2 input_args = (torch.ones(5), torch.ones(5)) @@ -224,14 +218,13 @@ def f(x, y): def test_nested_pattern(self): def f(x): - x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_outter", 0, "0", True) + x = torch.ops.xla.mark_tensor(x, "p_outter", 0, "0", True) x = x + 1 - x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_inner", 0, "0", True) + x = torch.ops.xla.mark_tensor(x, "p_inner", 0, "0", True) x = x + 1 - x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_inner", 0, "0", False) + x = torch.ops.xla.mark_tensor(x, "p_inner", 0, "0", False) x = x * 2 - x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_outter", 0, "0", - False) + x = torch.ops.xla.mark_tensor(x, "p_outter", 0, "0", False) input_args = (torch.ones(5),) stablehlo = self.run_func_get_stablehlo(f, input_args) @@ -240,14 +233,13 @@ def f(x): def test_tangent_output(self): # Special case of nested pattern, outputs don't have dependencies. def f(x): - x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_outter", 0, "0", True) + x = torch.ops.xla.mark_tensor(x, "p_outter", 0, "0", True) x = x + 1 - x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_inner", 0, "0", True) + x = torch.ops.xla.mark_tensor(x, "p_inner", 0, "0", True) x = x + 1 y = x - 1 - x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_inner", 0, "0", False) - y = torch.ops.xla_pattern_marking.mark_tensor(y, "p_outter", 0, "0", - False) + x = torch.ops.xla.mark_tensor(x, "p_inner", 0, "0", False) + y = torch.ops.xla.mark_tensor(y, "p_outter", 0, "0", False) input_args = (torch.ones(5),) stablehlo = self.run_func_get_stablehlo(f, input_args) diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index c2f59f19d41a..650f65f8ae9c 100755 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -9,6 +9,7 @@ from typing import List, Optional import torch import torch.distributed._functional_collectives +from torch.library import Library import torch.nn.functional as F import torch_xla from torch_xla import runtime @@ -38,6 +39,8 @@ # Default bucket size for all-reduce _ALLREDUCE_BUCKET_CAP_MB = 50 +XLA_LIB = Library("xla", "DEF") + def _init_world_size_ordinal(): global _WORLD_SIZE, _ORDINAL diff --git a/torch_xla/csrc/aten_autograd_ops.cpp b/torch_xla/csrc/aten_autograd_ops.cpp index 81cfdfb4f428..40b9790a1211 100644 --- a/torch_xla/csrc/aten_autograd_ops.cpp +++ b/torch_xla/csrc/aten_autograd_ops.cpp @@ -253,5 +253,18 @@ torch::Tensor max_pool2d_backward(torch::Tensor grad_output, torch::Tensor self, return grad; } +TORCH_LIBRARY_FRAGMENT(xla, m) { + m.def( + "max_pool2d_forward(Tensor self, int[2] kernel_size, int[2] stride=[], " + "int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor", + torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(max_pool2d_forward))); + + m.def( + "max_pool2d_backward(Tensor grad_output, Tensor self, int[2] " + "kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False) " + "-> Tensor", + torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(max_pool2d_backward))); +} + } // namespace aten_autograd_ops } // namespace torch_xla diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index c7469e18e846..c0481e4bca28 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -36,21 +36,7 @@ namespace torch_xla { // Macro for defining a function that will be run at static initialization time // to define a library of operators in the namespace. Used to define a new set // of custom operators that do not already exist in PyTorch. -TORCH_LIBRARY(xla, m) { - m.def( - "max_pool2d_forward(Tensor self, int[2] kernel_size, int[2] stride=[], " - "int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor", - torch::dispatch( - c10::DispatchKey::XLA, - TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_forward))); - - m.def( - "max_pool2d_backward(Tensor grad_output, Tensor self, int[2] " - "kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False) " - "-> Tensor", - torch::dispatch( - c10::DispatchKey::XLA, - TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_backward))); +TORCH_LIBRARY_FRAGMENT(xla, m) { m.def( "xla_mark_sharding_dynamo_custom_op(Tensor input, int[][] " "tile_assignment, int[][] group_assignment, int[][] replication_groups, " diff --git a/torch_xla/experimental/mark_pattern_utils.py b/torch_xla/experimental/mark_pattern_utils.py index 4665f9b14e07..2e0baa129f1b 100644 --- a/torch_xla/experimental/mark_pattern_utils.py +++ b/torch_xla/experimental/mark_pattern_utils.py @@ -15,7 +15,7 @@ class StableHLOCompositeBuilder: """ Helper for building a StableHLO Composite by marking input and output tensors. It should be used with the StableHLO converters from `torch_xla.stablehlo`. - + Args: name (str): The name of the built StableHLO Composite op. @@ -37,7 +37,7 @@ def _mark_tensor(self, *tensors: torch.Tensor, is_input: bool): if not isinstance(tensor, torch.Tensor): raise ValueError(f"input must be a torch tensor. Got {type(tensor)}.") marked_tensors.append( - torch.ops.xla_pattern_marking.mark_tensor( + torch.ops.xla.mark_tensor( tensor, name=self.name, pos=pos, @@ -52,9 +52,9 @@ def _mark_tensor(self, *tensors: torch.Tensor, is_input: bool): def mark_inputs(self, *tensors: torch.Tensor): """ - Mark the input tensors of the StableHLO Composite. This method must only be + Mark the input tensors of the StableHLO Composite. This method must only be called once per builder. - + Args: *tensors (torch.Tensor): Torch tensors to mark. @@ -68,9 +68,9 @@ def mark_inputs(self, *tensors: torch.Tensor): def mark_outputs(self, *tensors: torch.Tensor): """ - Mark the output tensors of the StableHLO Composite. This method must only be + Mark the output tensors of the StableHLO Composite. This method must only be called once per builder. - + Args: *tensors (torch.Tensor): Torch tensors to mark. diff --git a/torch_xla/experimental/quantized.py b/torch_xla/experimental/quantized.py index 2ab48ebf13ed..7fed638895ac 100644 --- a/torch_xla/experimental/quantized.py +++ b/torch_xla/experimental/quantized.py @@ -1,12 +1,10 @@ import numpy as np import torch +from torch.library import impl import torch_xla -from torch.library import Library, impl -quantized_decomposed_lib = Library("quantized_decomposed", "IMPL") - -@impl(quantized_decomposed_lib, "quantize_per_tensor", "XLA") +@impl("quantized_decomposed::quantize_per_tensor", "XLA") def xla_quantize_per_tensor(input: torch.Tensor, scale: float, zero_point: int, quant_min: int, quant_max: int, dtype: torch.dtype): return _xla_quantize(input, torch.tensor([scale]), @@ -14,7 +12,7 @@ def xla_quantize_per_tensor(input: torch.Tensor, scale: float, zero_point: int, quant_max, dtype) -@impl(quantized_decomposed_lib, "quantize_per_channel", "XLA") +@impl("quantized_decomposed::quantize_per_channel", "XLA") def xla_quantize_per_channel(input: torch.Tensor, scale: torch.Tensor, zero_point: torch.Tensor, axis: int, quant_min: int, quant_max: int, @@ -23,7 +21,7 @@ def xla_quantize_per_channel(input: torch.Tensor, scale: torch.Tensor, axis) -@impl(quantized_decomposed_lib, "dequantize_per_tensor", "XLA") +@impl("quantized_decomposed::dequantize_per_tensor", "XLA") def xla_dequantize_per_tensor(input: torch.Tensor, scale: float, zero_point: int, quant_min: int, quant_max: int, dtype: torch.dtype): @@ -32,7 +30,7 @@ def xla_dequantize_per_tensor(input: torch.Tensor, scale: float, quant_max, dtype) -@impl(quantized_decomposed_lib, "dequantize_per_channel", "XLA") +@impl("quantized_decomposed::dequantize_per_channel", "XLA") def xla_dequantize_per_tensor(input: torch.Tensor, scale: torch.Tensor, zero_point: torch.Tensor, axis: int, quant_min: int, quant_max: int, diff --git a/torch_xla/experimental/xla_marker.py b/torch_xla/experimental/xla_marker.py index 2c4a6d97fa26..647b3a30927d 100644 --- a/torch_xla/experimental/xla_marker.py +++ b/torch_xla/experimental/xla_marker.py @@ -4,12 +4,11 @@ from typing import Dict import torch +from torch.library import impl import torch_xla -from torch.library import Library, impl +from torch_xla.core.xla_model import XLA_LIB -xla_pattern_marking_lib = Library("xla_pattern_marking", "DEF") - -xla_pattern_marking_lib.define( +XLA_LIB.define( "mark_tensor(Tensor x, str name, int pos, str id, bool is_input, Any? attr=None) -> Tensor" ) @@ -45,7 +44,7 @@ def _assert_valid_composite_attr(attr): "Composite attr value must be either Python str, float, or int.") -@impl(xla_pattern_marking_lib, "mark_tensor", "XLA") +@impl(XLA_LIB, "mark_tensor", "XLA") def mark_tensor_xla(x: torch.Tensor, name: str, pos: int, @@ -53,7 +52,7 @@ def mark_tensor_xla(x: torch.Tensor, is_input: bool, attr: Dict = None): """Attach pattern boundary metadata to a XLA Tensor. - + Args: x: torch.Tensor (On XLA device) - the marked tensor. name: str - The name of the pattern, it will be the name of the stablehlo composite op. @@ -69,7 +68,7 @@ def mark_tensor_xla(x: torch.Tensor, x, json.dumps(pattern_info, cls=BoundaryMetadataSerializer)) -@impl(xla_pattern_marking_lib, "mark_tensor", "CompositeExplicitAutograd") +@impl(XLA_LIB, "mark_tensor", "CompositeExplicitAutograd") def mark_tensor(x: torch.Tensor, name: str, pos: int, @@ -80,7 +79,7 @@ def mark_tensor(x: torch.Tensor, return x -@impl(xla_pattern_marking_lib, "mark_tensor", "Meta") +@impl(XLA_LIB, "mark_tensor", "Meta") def mark_tensor_meta(x: torch.Tensor, name: str, pos: int,