From e3b9ce140c28c74467fb81d4f48f8e2f8d816507 Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Fri, 2 Feb 2024 01:26:22 +0000 Subject: [PATCH 1/6] initial commit --- torch_xla/core/xla_model.py | 2 ++ torch_xla/csrc/xla_sharding_util.cpp | 2 +- torch_xla/experimental/xla_marker.py | 15 +++++++-------- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index 10a860355e0..5774637a93c 100755 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -9,6 +9,7 @@ from typing import List, Optional import torch import torch.distributed._functional_collectives +from torch.library import Library import torch.nn.functional as F import torch_xla from torch_xla import runtime @@ -35,6 +36,7 @@ _WORLD_SIZE = None _ORDINAL = None +XLA_LIB = Library("xla", "DEF") def _init_world_size_ordinal(): global _WORLD_SIZE, _ORDINAL diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index c7469e18e84..e07cf45e224 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -36,7 +36,7 @@ namespace torch_xla { // Macro for defining a function that will be run at static initialization time // to define a library of operators in the namespace. Used to define a new set // of custom operators that do not already exist in PyTorch. -TORCH_LIBRARY(xla, m) { +TORCH_LIBRARY_FRAGMENT(xla, m) { m.def( "max_pool2d_forward(Tensor self, int[2] kernel_size, int[2] stride=[], " "int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor", diff --git a/torch_xla/experimental/xla_marker.py b/torch_xla/experimental/xla_marker.py index 2c4a6d97fa2..647b3a30927 100644 --- a/torch_xla/experimental/xla_marker.py +++ b/torch_xla/experimental/xla_marker.py @@ -4,12 +4,11 @@ from typing import Dict import torch +from torch.library import impl import torch_xla -from torch.library import Library, impl +from torch_xla.core.xla_model import XLA_LIB -xla_pattern_marking_lib = Library("xla_pattern_marking", "DEF") - -xla_pattern_marking_lib.define( +XLA_LIB.define( "mark_tensor(Tensor x, str name, int pos, str id, bool is_input, Any? attr=None) -> Tensor" ) @@ -45,7 +44,7 @@ def _assert_valid_composite_attr(attr): "Composite attr value must be either Python str, float, or int.") -@impl(xla_pattern_marking_lib, "mark_tensor", "XLA") +@impl(XLA_LIB, "mark_tensor", "XLA") def mark_tensor_xla(x: torch.Tensor, name: str, pos: int, @@ -53,7 +52,7 @@ def mark_tensor_xla(x: torch.Tensor, is_input: bool, attr: Dict = None): """Attach pattern boundary metadata to a XLA Tensor. - + Args: x: torch.Tensor (On XLA device) - the marked tensor. name: str - The name of the pattern, it will be the name of the stablehlo composite op. @@ -69,7 +68,7 @@ def mark_tensor_xla(x: torch.Tensor, x, json.dumps(pattern_info, cls=BoundaryMetadataSerializer)) -@impl(xla_pattern_marking_lib, "mark_tensor", "CompositeExplicitAutograd") +@impl(XLA_LIB, "mark_tensor", "CompositeExplicitAutograd") def mark_tensor(x: torch.Tensor, name: str, pos: int, @@ -80,7 +79,7 @@ def mark_tensor(x: torch.Tensor, return x -@impl(xla_pattern_marking_lib, "mark_tensor", "Meta") +@impl(XLA_LIB, "mark_tensor", "Meta") def mark_tensor_meta(x: torch.Tensor, name: str, pos: int, From bb075e36e9c425c17f0e7e2479bcec4a217be7bc Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Fri, 2 Feb 2024 01:29:51 +0000 Subject: [PATCH 2/6] Replace quantize --- torch_xla/experimental/quantized.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/torch_xla/experimental/quantized.py b/torch_xla/experimental/quantized.py index 2ab48ebf13e..e9128b1ccd9 100644 --- a/torch_xla/experimental/quantized.py +++ b/torch_xla/experimental/quantized.py @@ -1,12 +1,11 @@ import numpy as np import torch +from torch.library import impl import torch_xla -from torch.library import Library, impl +from torch_xla.core.xla_model import XLA_LIB -quantized_decomposed_lib = Library("quantized_decomposed", "IMPL") - -@impl(quantized_decomposed_lib, "quantize_per_tensor", "XLA") +@impl(XLA_LIB, "quantize_per_tensor", "XLA") def xla_quantize_per_tensor(input: torch.Tensor, scale: float, zero_point: int, quant_min: int, quant_max: int, dtype: torch.dtype): return _xla_quantize(input, torch.tensor([scale]), @@ -14,7 +13,7 @@ def xla_quantize_per_tensor(input: torch.Tensor, scale: float, zero_point: int, quant_max, dtype) -@impl(quantized_decomposed_lib, "quantize_per_channel", "XLA") +@impl(XLA_LIB, "quantize_per_channel", "XLA") def xla_quantize_per_channel(input: torch.Tensor, scale: torch.Tensor, zero_point: torch.Tensor, axis: int, quant_min: int, quant_max: int, @@ -23,7 +22,7 @@ def xla_quantize_per_channel(input: torch.Tensor, scale: torch.Tensor, axis) -@impl(quantized_decomposed_lib, "dequantize_per_tensor", "XLA") +@impl(XLA_LIB, "dequantize_per_tensor", "XLA") def xla_dequantize_per_tensor(input: torch.Tensor, scale: float, zero_point: int, quant_min: int, quant_max: int, dtype: torch.dtype): @@ -32,7 +31,7 @@ def xla_dequantize_per_tensor(input: torch.Tensor, scale: float, quant_max, dtype) -@impl(quantized_decomposed_lib, "dequantize_per_channel", "XLA") +@impl(XLA_LIB, "dequantize_per_channel", "XLA") def xla_dequantize_per_tensor(input: torch.Tensor, scale: torch.Tensor, zero_point: torch.Tensor, axis: int, quant_min: int, quant_max: int, From cef020f5f18d747fc5a3c759894afca9259612ac Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Fri, 2 Feb 2024 01:32:24 +0000 Subject: [PATCH 3/6] Move maxpool to the original place --- torch_xla/csrc/aten_autograd_ops.cpp | 13 +++++++++++++ torch_xla/csrc/xla_sharding_util.cpp | 14 -------------- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/torch_xla/csrc/aten_autograd_ops.cpp b/torch_xla/csrc/aten_autograd_ops.cpp index 81cfdfb4f42..40b9790a121 100644 --- a/torch_xla/csrc/aten_autograd_ops.cpp +++ b/torch_xla/csrc/aten_autograd_ops.cpp @@ -253,5 +253,18 @@ torch::Tensor max_pool2d_backward(torch::Tensor grad_output, torch::Tensor self, return grad; } +TORCH_LIBRARY_FRAGMENT(xla, m) { + m.def( + "max_pool2d_forward(Tensor self, int[2] kernel_size, int[2] stride=[], " + "int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor", + torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(max_pool2d_forward))); + + m.def( + "max_pool2d_backward(Tensor grad_output, Tensor self, int[2] " + "kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False) " + "-> Tensor", + torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(max_pool2d_backward))); +} + } // namespace aten_autograd_ops } // namespace torch_xla diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index e07cf45e224..c0481e4bca2 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -37,20 +37,6 @@ namespace torch_xla { // to define a library of operators in the namespace. Used to define a new set // of custom operators that do not already exist in PyTorch. TORCH_LIBRARY_FRAGMENT(xla, m) { - m.def( - "max_pool2d_forward(Tensor self, int[2] kernel_size, int[2] stride=[], " - "int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor", - torch::dispatch( - c10::DispatchKey::XLA, - TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_forward))); - - m.def( - "max_pool2d_backward(Tensor grad_output, Tensor self, int[2] " - "kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False) " - "-> Tensor", - torch::dispatch( - c10::DispatchKey::XLA, - TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_backward))); m.def( "xla_mark_sharding_dynamo_custom_op(Tensor input, int[][] " "tile_assignment, int[][] group_assignment, int[][] replication_groups, " From 5e7ce0f1e348a2aeaf2fd6519f6286199f817cd1 Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Fri, 2 Feb 2024 01:46:38 +0000 Subject: [PATCH 4/6] Fix tests --- test/stablehlo/test_mark_pattern.py | 50 ++++++++++---------- torch_xla/experimental/mark_pattern_utils.py | 12 ++--- 2 files changed, 31 insertions(+), 31 deletions(-) diff --git a/test/stablehlo/test_mark_pattern.py b/test/stablehlo/test_mark_pattern.py index e0fcce201f7..06194b6069e 100644 --- a/test/stablehlo/test_mark_pattern.py +++ b/test/stablehlo/test_mark_pattern.py @@ -39,9 +39,9 @@ def test_basic(self): def f(x): x = x + 1 - x = torch.ops.xla_pattern_marking.mark_tensor(x, "p", 0, "0", True) + x = torch.ops.xla.mark_tensor(x, "p", 0, "0", True) x = x + 2 - x = torch.ops.xla_pattern_marking.mark_tensor(x, "p", 0, "0", False) + x = torch.ops.xla.mark_tensor(x, "p", 0, "0", False) return x input_args = (torch.randn(5),) @@ -59,14 +59,14 @@ def __init__(self): def forward(self, x, y): q, k, v = x.split(128, dim=-2) - q = torch.ops.xla_pattern_marking.mark_tensor( + q = torch.ops.xla.mark_tensor( q, "sdpa", pos=0, id="0", is_input=True) - k = torch.ops.xla_pattern_marking.mark_tensor( + k = torch.ops.xla.mark_tensor( k, "sdpa", pos=1, id="0", is_input=True) - v = torch.ops.xla_pattern_marking.mark_tensor( + v = torch.ops.xla.mark_tensor( v, "sdpa", pos=2, id="0", is_input=True) attn_out = F.scaled_dot_product_attention(q, k, v, scale=0.25) - attn_out = torch.ops.xla_pattern_marking.mark_tensor( + attn_out = torch.ops.xla.mark_tensor( attn_out, "sdpa", pos=0, @@ -74,14 +74,14 @@ def forward(self, x, y): is_input=False, attr={"scale": 0.25}) q, k, v = y.split(128, dim=-2) - q = torch.ops.xla_pattern_marking.mark_tensor( + q = torch.ops.xla.mark_tensor( q, "sdpa", pos=0, id="1", is_input=True) - k = torch.ops.xla_pattern_marking.mark_tensor( + k = torch.ops.xla.mark_tensor( k, "sdpa", pos=1, id="1", is_input=True) - v = torch.ops.xla_pattern_marking.mark_tensor( + v = torch.ops.xla.mark_tensor( v, "sdpa", pos=2, id="1", is_input=True) attn_out2 = F.scaled_dot_product_attention(q, k, v, scale=4) - attn_out2 = torch.ops.xla_pattern_marking.mark_tensor( + attn_out2 = torch.ops.xla.mark_tensor( attn_out2, "sdpa", pos=0, id="1", is_input=False, attr={"scale": 2}) return attn_out, attn_out2 @@ -193,11 +193,11 @@ def forward(self, x, y): def test_multiple_input(self): def f(x, y): - x = torch.ops.xla_pattern_marking.mark_tensor(x, "p", 0, "0", True) - y = torch.ops.xla_pattern_marking.mark_tensor(y, "p", 1, "0", True) + x = torch.ops.xla.mark_tensor(x, "p", 0, "0", True) + y = torch.ops.xla.mark_tensor(y, "p", 1, "0", True) out = x + y out = out * x * y - out = torch.ops.xla_pattern_marking.mark_tensor(out, "p", 0, "0", False) + out = torch.ops.xla.mark_tensor(out, "p", 0, "0", False) return out input_args = (torch.ones(5), torch.ones(5)) @@ -209,12 +209,12 @@ def f(x, y): def test_multiple_output(self): def f(x, y): - x = torch.ops.xla_pattern_marking.mark_tensor(x, "p", 0, "0", True) - y = torch.ops.xla_pattern_marking.mark_tensor(y, "p", 1, "0", True) + x = torch.ops.xla.mark_tensor(x, "p", 0, "0", True) + y = torch.ops.xla.mark_tensor(y, "p", 1, "0", True) out1 = x + y out2 = x * y - out1 = torch.ops.xla_pattern_marking.mark_tensor(out1, "p", 0, "0", False) - out2 = torch.ops.xla_pattern_marking.mark_tensor(out2, "p", 1, "0", False) + out1 = torch.ops.xla.mark_tensor(out1, "p", 0, "0", False) + out2 = torch.ops.xla.mark_tensor(out2, "p", 1, "0", False) return out1, out2 input_args = (torch.ones(5), torch.ones(5)) @@ -224,13 +224,13 @@ def f(x, y): def test_nested_pattern(self): def f(x): - x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_outter", 0, "0", True) + x = torch.ops.xla.mark_tensor(x, "p_outter", 0, "0", True) x = x + 1 - x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_inner", 0, "0", True) + x = torch.ops.xla.mark_tensor(x, "p_inner", 0, "0", True) x = x + 1 - x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_inner", 0, "0", False) + x = torch.ops.xla.mark_tensor(x, "p_inner", 0, "0", False) x = x * 2 - x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_outter", 0, "0", + x = torch.ops.xla.mark_tensor(x, "p_outter", 0, "0", False) input_args = (torch.ones(5),) @@ -240,13 +240,13 @@ def f(x): def test_tangent_output(self): # Special case of nested pattern, outputs don't have dependencies. def f(x): - x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_outter", 0, "0", True) + x = torch.ops.xla.mark_tensor(x, "p_outter", 0, "0", True) x = x + 1 - x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_inner", 0, "0", True) + x = torch.ops.xla.mark_tensor(x, "p_inner", 0, "0", True) x = x + 1 y = x - 1 - x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_inner", 0, "0", False) - y = torch.ops.xla_pattern_marking.mark_tensor(y, "p_outter", 0, "0", + x = torch.ops.xla.mark_tensor(x, "p_inner", 0, "0", False) + y = torch.ops.xla.mark_tensor(y, "p_outter", 0, "0", False) input_args = (torch.ones(5),) diff --git a/torch_xla/experimental/mark_pattern_utils.py b/torch_xla/experimental/mark_pattern_utils.py index 4665f9b14e0..2e0baa129f1 100644 --- a/torch_xla/experimental/mark_pattern_utils.py +++ b/torch_xla/experimental/mark_pattern_utils.py @@ -15,7 +15,7 @@ class StableHLOCompositeBuilder: """ Helper for building a StableHLO Composite by marking input and output tensors. It should be used with the StableHLO converters from `torch_xla.stablehlo`. - + Args: name (str): The name of the built StableHLO Composite op. @@ -37,7 +37,7 @@ def _mark_tensor(self, *tensors: torch.Tensor, is_input: bool): if not isinstance(tensor, torch.Tensor): raise ValueError(f"input must be a torch tensor. Got {type(tensor)}.") marked_tensors.append( - torch.ops.xla_pattern_marking.mark_tensor( + torch.ops.xla.mark_tensor( tensor, name=self.name, pos=pos, @@ -52,9 +52,9 @@ def _mark_tensor(self, *tensors: torch.Tensor, is_input: bool): def mark_inputs(self, *tensors: torch.Tensor): """ - Mark the input tensors of the StableHLO Composite. This method must only be + Mark the input tensors of the StableHLO Composite. This method must only be called once per builder. - + Args: *tensors (torch.Tensor): Torch tensors to mark. @@ -68,9 +68,9 @@ def mark_inputs(self, *tensors: torch.Tensor): def mark_outputs(self, *tensors: torch.Tensor): """ - Mark the output tensors of the StableHLO Composite. This method must only be + Mark the output tensors of the StableHLO Composite. This method must only be called once per builder. - + Args: *tensors (torch.Tensor): Torch tensors to mark. From 9ea14ef8c21e5ea99fbf42006a1a2c9b5def8ecc Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Fri, 2 Feb 2024 01:56:58 +0000 Subject: [PATCH 5/6] Fix tests --- torch_xla/experimental/quantized.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/torch_xla/experimental/quantized.py b/torch_xla/experimental/quantized.py index e9128b1ccd9..7fed638895a 100644 --- a/torch_xla/experimental/quantized.py +++ b/torch_xla/experimental/quantized.py @@ -2,10 +2,9 @@ import torch from torch.library import impl import torch_xla -from torch_xla.core.xla_model import XLA_LIB -@impl(XLA_LIB, "quantize_per_tensor", "XLA") +@impl("quantized_decomposed::quantize_per_tensor", "XLA") def xla_quantize_per_tensor(input: torch.Tensor, scale: float, zero_point: int, quant_min: int, quant_max: int, dtype: torch.dtype): return _xla_quantize(input, torch.tensor([scale]), @@ -13,7 +12,7 @@ def xla_quantize_per_tensor(input: torch.Tensor, scale: float, zero_point: int, quant_max, dtype) -@impl(XLA_LIB, "quantize_per_channel", "XLA") +@impl("quantized_decomposed::quantize_per_channel", "XLA") def xla_quantize_per_channel(input: torch.Tensor, scale: torch.Tensor, zero_point: torch.Tensor, axis: int, quant_min: int, quant_max: int, @@ -22,7 +21,7 @@ def xla_quantize_per_channel(input: torch.Tensor, scale: torch.Tensor, axis) -@impl(XLA_LIB, "dequantize_per_tensor", "XLA") +@impl("quantized_decomposed::dequantize_per_tensor", "XLA") def xla_dequantize_per_tensor(input: torch.Tensor, scale: float, zero_point: int, quant_min: int, quant_max: int, dtype: torch.dtype): @@ -31,7 +30,7 @@ def xla_dequantize_per_tensor(input: torch.Tensor, scale: float, quant_max, dtype) -@impl(XLA_LIB, "dequantize_per_channel", "XLA") +@impl("quantized_decomposed::dequantize_per_channel", "XLA") def xla_dequantize_per_tensor(input: torch.Tensor, scale: torch.Tensor, zero_point: torch.Tensor, axis: int, quant_min: int, quant_max: int, From db141fe2bfa7d3a585ea648198262a44be580836 Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Fri, 2 Feb 2024 01:59:20 +0000 Subject: [PATCH 6/6] Fix linters --- test/stablehlo/test_mark_pattern.py | 24 ++++++++---------------- torch_xla/core/xla_model.py | 1 + 2 files changed, 9 insertions(+), 16 deletions(-) diff --git a/test/stablehlo/test_mark_pattern.py b/test/stablehlo/test_mark_pattern.py index 06194b6069e..0e635da1e4a 100644 --- a/test/stablehlo/test_mark_pattern.py +++ b/test/stablehlo/test_mark_pattern.py @@ -59,12 +59,9 @@ def __init__(self): def forward(self, x, y): q, k, v = x.split(128, dim=-2) - q = torch.ops.xla.mark_tensor( - q, "sdpa", pos=0, id="0", is_input=True) - k = torch.ops.xla.mark_tensor( - k, "sdpa", pos=1, id="0", is_input=True) - v = torch.ops.xla.mark_tensor( - v, "sdpa", pos=2, id="0", is_input=True) + q = torch.ops.xla.mark_tensor(q, "sdpa", pos=0, id="0", is_input=True) + k = torch.ops.xla.mark_tensor(k, "sdpa", pos=1, id="0", is_input=True) + v = torch.ops.xla.mark_tensor(v, "sdpa", pos=2, id="0", is_input=True) attn_out = F.scaled_dot_product_attention(q, k, v, scale=0.25) attn_out = torch.ops.xla.mark_tensor( attn_out, @@ -74,12 +71,9 @@ def forward(self, x, y): is_input=False, attr={"scale": 0.25}) q, k, v = y.split(128, dim=-2) - q = torch.ops.xla.mark_tensor( - q, "sdpa", pos=0, id="1", is_input=True) - k = torch.ops.xla.mark_tensor( - k, "sdpa", pos=1, id="1", is_input=True) - v = torch.ops.xla.mark_tensor( - v, "sdpa", pos=2, id="1", is_input=True) + q = torch.ops.xla.mark_tensor(q, "sdpa", pos=0, id="1", is_input=True) + k = torch.ops.xla.mark_tensor(k, "sdpa", pos=1, id="1", is_input=True) + v = torch.ops.xla.mark_tensor(v, "sdpa", pos=2, id="1", is_input=True) attn_out2 = F.scaled_dot_product_attention(q, k, v, scale=4) attn_out2 = torch.ops.xla.mark_tensor( attn_out2, "sdpa", pos=0, id="1", is_input=False, attr={"scale": 2}) @@ -230,8 +224,7 @@ def f(x): x = x + 1 x = torch.ops.xla.mark_tensor(x, "p_inner", 0, "0", False) x = x * 2 - x = torch.ops.xla.mark_tensor(x, "p_outter", 0, "0", - False) + x = torch.ops.xla.mark_tensor(x, "p_outter", 0, "0", False) input_args = (torch.ones(5),) stablehlo = self.run_func_get_stablehlo(f, input_args) @@ -246,8 +239,7 @@ def f(x): x = x + 1 y = x - 1 x = torch.ops.xla.mark_tensor(x, "p_inner", 0, "0", False) - y = torch.ops.xla.mark_tensor(y, "p_outter", 0, "0", - False) + y = torch.ops.xla.mark_tensor(y, "p_outter", 0, "0", False) input_args = (torch.ones(5),) stablehlo = self.run_func_get_stablehlo(f, input_args) diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index 5774637a93c..b3f7b4c9ad1 100755 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -38,6 +38,7 @@ XLA_LIB = Library("xla", "DEF") + def _init_world_size_ordinal(): global _WORLD_SIZE, _ORDINAL