Skip to content

Commit

Permalink
Unify TORCH_LIBRARY definitions (pytorch#6455)
Browse files Browse the repository at this point in the history
Summary:
This pull request tries to unify all TORCH_LIBRARY definitions across torch_xla into one xla library.

Test Plan:
CI
  • Loading branch information
alanwaketan authored and amithrm committed Mar 1, 2024
1 parent a0c23ea commit 16dc35a
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 69 deletions.
58 changes: 25 additions & 33 deletions test/stablehlo/test_mark_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ def test_basic(self):

def f(x):
x = x + 1
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p", 0, "0", True)
x = torch.ops.xla.mark_tensor(x, "p", 0, "0", True)
x = x + 2
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p", 0, "0", False)
x = torch.ops.xla.mark_tensor(x, "p", 0, "0", False)
return x

input_args = (torch.randn(5),)
Expand All @@ -59,29 +59,23 @@ def __init__(self):

def forward(self, x, y):
q, k, v = x.split(128, dim=-2)
q = torch.ops.xla_pattern_marking.mark_tensor(
q, "sdpa", pos=0, id="0", is_input=True)
k = torch.ops.xla_pattern_marking.mark_tensor(
k, "sdpa", pos=1, id="0", is_input=True)
v = torch.ops.xla_pattern_marking.mark_tensor(
v, "sdpa", pos=2, id="0", is_input=True)
q = torch.ops.xla.mark_tensor(q, "sdpa", pos=0, id="0", is_input=True)
k = torch.ops.xla.mark_tensor(k, "sdpa", pos=1, id="0", is_input=True)
v = torch.ops.xla.mark_tensor(v, "sdpa", pos=2, id="0", is_input=True)
attn_out = F.scaled_dot_product_attention(q, k, v, scale=0.25)
attn_out = torch.ops.xla_pattern_marking.mark_tensor(
attn_out = torch.ops.xla.mark_tensor(
attn_out,
"sdpa",
pos=0,
id="0",
is_input=False,
attr={"scale": 0.25})
q, k, v = y.split(128, dim=-2)
q = torch.ops.xla_pattern_marking.mark_tensor(
q, "sdpa", pos=0, id="1", is_input=True)
k = torch.ops.xla_pattern_marking.mark_tensor(
k, "sdpa", pos=1, id="1", is_input=True)
v = torch.ops.xla_pattern_marking.mark_tensor(
v, "sdpa", pos=2, id="1", is_input=True)
q = torch.ops.xla.mark_tensor(q, "sdpa", pos=0, id="1", is_input=True)
k = torch.ops.xla.mark_tensor(k, "sdpa", pos=1, id="1", is_input=True)
v = torch.ops.xla.mark_tensor(v, "sdpa", pos=2, id="1", is_input=True)
attn_out2 = F.scaled_dot_product_attention(q, k, v, scale=4)
attn_out2 = torch.ops.xla_pattern_marking.mark_tensor(
attn_out2 = torch.ops.xla.mark_tensor(
attn_out2, "sdpa", pos=0, id="1", is_input=False, attr={"scale": 2})
return attn_out, attn_out2

Expand Down Expand Up @@ -193,11 +187,11 @@ def forward(self, x, y):
def test_multiple_input(self):

def f(x, y):
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p", 0, "0", True)
y = torch.ops.xla_pattern_marking.mark_tensor(y, "p", 1, "0", True)
x = torch.ops.xla.mark_tensor(x, "p", 0, "0", True)
y = torch.ops.xla.mark_tensor(y, "p", 1, "0", True)
out = x + y
out = out * x * y
out = torch.ops.xla_pattern_marking.mark_tensor(out, "p", 0, "0", False)
out = torch.ops.xla.mark_tensor(out, "p", 0, "0", False)
return out

input_args = (torch.ones(5), torch.ones(5))
Expand All @@ -209,12 +203,12 @@ def f(x, y):
def test_multiple_output(self):

def f(x, y):
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p", 0, "0", True)
y = torch.ops.xla_pattern_marking.mark_tensor(y, "p", 1, "0", True)
x = torch.ops.xla.mark_tensor(x, "p", 0, "0", True)
y = torch.ops.xla.mark_tensor(y, "p", 1, "0", True)
out1 = x + y
out2 = x * y
out1 = torch.ops.xla_pattern_marking.mark_tensor(out1, "p", 0, "0", False)
out2 = torch.ops.xla_pattern_marking.mark_tensor(out2, "p", 1, "0", False)
out1 = torch.ops.xla.mark_tensor(out1, "p", 0, "0", False)
out2 = torch.ops.xla.mark_tensor(out2, "p", 1, "0", False)
return out1, out2

input_args = (torch.ones(5), torch.ones(5))
Expand All @@ -224,14 +218,13 @@ def f(x, y):
def test_nested_pattern(self):

def f(x):
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_outter", 0, "0", True)
x = torch.ops.xla.mark_tensor(x, "p_outter", 0, "0", True)
x = x + 1
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_inner", 0, "0", True)
x = torch.ops.xla.mark_tensor(x, "p_inner", 0, "0", True)
x = x + 1
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_inner", 0, "0", False)
x = torch.ops.xla.mark_tensor(x, "p_inner", 0, "0", False)
x = x * 2
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_outter", 0, "0",
False)
x = torch.ops.xla.mark_tensor(x, "p_outter", 0, "0", False)

input_args = (torch.ones(5),)
stablehlo = self.run_func_get_stablehlo(f, input_args)
Expand All @@ -240,14 +233,13 @@ def f(x):
def test_tangent_output(self):
# Special case of nested pattern, outputs don't have dependencies.
def f(x):
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_outter", 0, "0", True)
x = torch.ops.xla.mark_tensor(x, "p_outter", 0, "0", True)
x = x + 1
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_inner", 0, "0", True)
x = torch.ops.xla.mark_tensor(x, "p_inner", 0, "0", True)
x = x + 1
y = x - 1
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_inner", 0, "0", False)
y = torch.ops.xla_pattern_marking.mark_tensor(y, "p_outter", 0, "0",
False)
x = torch.ops.xla.mark_tensor(x, "p_inner", 0, "0", False)
y = torch.ops.xla.mark_tensor(y, "p_outter", 0, "0", False)

input_args = (torch.ones(5),)
stablehlo = self.run_func_get_stablehlo(f, input_args)
Expand Down
3 changes: 3 additions & 0 deletions torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import List, Optional
import torch
import torch.distributed._functional_collectives
from torch.library import Library
import torch.nn.functional as F
import torch_xla
from torch_xla import runtime
Expand Down Expand Up @@ -38,6 +39,8 @@
# Default bucket size for all-reduce
_ALLREDUCE_BUCKET_CAP_MB = 50

XLA_LIB = Library("xla", "DEF")


def _init_world_size_ordinal():
global _WORLD_SIZE, _ORDINAL
Expand Down
13 changes: 13 additions & 0 deletions torch_xla/csrc/aten_autograd_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,5 +253,18 @@ torch::Tensor max_pool2d_backward(torch::Tensor grad_output, torch::Tensor self,
return grad;
}

TORCH_LIBRARY_FRAGMENT(xla, m) {
m.def(
"max_pool2d_forward(Tensor self, int[2] kernel_size, int[2] stride=[], "
"int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor",
torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(max_pool2d_forward)));

m.def(
"max_pool2d_backward(Tensor grad_output, Tensor self, int[2] "
"kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False) "
"-> Tensor",
torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(max_pool2d_backward)));
}

} // namespace aten_autograd_ops
} // namespace torch_xla
16 changes: 1 addition & 15 deletions torch_xla/csrc/xla_sharding_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,21 +36,7 @@ namespace torch_xla {
// Macro for defining a function that will be run at static initialization time
// to define a library of operators in the namespace. Used to define a new set
// of custom operators that do not already exist in PyTorch.
TORCH_LIBRARY(xla, m) {
m.def(
"max_pool2d_forward(Tensor self, int[2] kernel_size, int[2] stride=[], "
"int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor",
torch::dispatch(
c10::DispatchKey::XLA,
TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_forward)));

m.def(
"max_pool2d_backward(Tensor grad_output, Tensor self, int[2] "
"kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False) "
"-> Tensor",
torch::dispatch(
c10::DispatchKey::XLA,
TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_backward)));
TORCH_LIBRARY_FRAGMENT(xla, m) {
m.def(
"xla_mark_sharding_dynamo_custom_op(Tensor input, int[][] "
"tile_assignment, int[][] group_assignment, int[][] replication_groups, "
Expand Down
12 changes: 6 additions & 6 deletions torch_xla/experimental/mark_pattern_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class StableHLOCompositeBuilder:
"""
Helper for building a StableHLO Composite by marking input and output tensors. It
should be used with the StableHLO converters from `torch_xla.stablehlo`.
Args:
name (str):
The name of the built StableHLO Composite op.
Expand All @@ -37,7 +37,7 @@ def _mark_tensor(self, *tensors: torch.Tensor, is_input: bool):
if not isinstance(tensor, torch.Tensor):
raise ValueError(f"input must be a torch tensor. Got {type(tensor)}.")
marked_tensors.append(
torch.ops.xla_pattern_marking.mark_tensor(
torch.ops.xla.mark_tensor(
tensor,
name=self.name,
pos=pos,
Expand All @@ -52,9 +52,9 @@ def _mark_tensor(self, *tensors: torch.Tensor, is_input: bool):

def mark_inputs(self, *tensors: torch.Tensor):
"""
Mark the input tensors of the StableHLO Composite. This method must only be
Mark the input tensors of the StableHLO Composite. This method must only be
called once per builder.
Args:
*tensors (torch.Tensor):
Torch tensors to mark.
Expand All @@ -68,9 +68,9 @@ def mark_inputs(self, *tensors: torch.Tensor):

def mark_outputs(self, *tensors: torch.Tensor):
"""
Mark the output tensors of the StableHLO Composite. This method must only be
Mark the output tensors of the StableHLO Composite. This method must only be
called once per builder.
Args:
*tensors (torch.Tensor):
Torch tensors to mark.
Expand Down
12 changes: 5 additions & 7 deletions torch_xla/experimental/quantized.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,18 @@
import numpy as np
import torch
from torch.library import impl
import torch_xla
from torch.library import Library, impl

quantized_decomposed_lib = Library("quantized_decomposed", "IMPL")


@impl(quantized_decomposed_lib, "quantize_per_tensor", "XLA")
@impl("quantized_decomposed::quantize_per_tensor", "XLA")
def xla_quantize_per_tensor(input: torch.Tensor, scale: float, zero_point: int,
quant_min: int, quant_max: int, dtype: torch.dtype):
return _xla_quantize(input, torch.tensor([scale]),
torch.tensor([zero_point], dtype=dtype), quant_min,
quant_max, dtype)


@impl(quantized_decomposed_lib, "quantize_per_channel", "XLA")
@impl("quantized_decomposed::quantize_per_channel", "XLA")
def xla_quantize_per_channel(input: torch.Tensor, scale: torch.Tensor,
zero_point: torch.Tensor, axis: int,
quant_min: int, quant_max: int,
Expand All @@ -23,7 +21,7 @@ def xla_quantize_per_channel(input: torch.Tensor, scale: torch.Tensor,
axis)


@impl(quantized_decomposed_lib, "dequantize_per_tensor", "XLA")
@impl("quantized_decomposed::dequantize_per_tensor", "XLA")
def xla_dequantize_per_tensor(input: torch.Tensor, scale: float,
zero_point: int, quant_min: int, quant_max: int,
dtype: torch.dtype):
Expand All @@ -32,7 +30,7 @@ def xla_dequantize_per_tensor(input: torch.Tensor, scale: float,
quant_max, dtype)


@impl(quantized_decomposed_lib, "dequantize_per_channel", "XLA")
@impl("quantized_decomposed::dequantize_per_channel", "XLA")
def xla_dequantize_per_tensor(input: torch.Tensor, scale: torch.Tensor,
zero_point: torch.Tensor, axis: int,
quant_min: int, quant_max: int,
Expand Down
15 changes: 7 additions & 8 deletions torch_xla/experimental/xla_marker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@
from typing import Dict

import torch
from torch.library import impl
import torch_xla
from torch.library import Library, impl
from torch_xla.core.xla_model import XLA_LIB

xla_pattern_marking_lib = Library("xla_pattern_marking", "DEF")

xla_pattern_marking_lib.define(
XLA_LIB.define(
"mark_tensor(Tensor x, str name, int pos, str id, bool is_input, Any? attr=None) -> Tensor"
)

Expand Down Expand Up @@ -45,15 +44,15 @@ def _assert_valid_composite_attr(attr):
"Composite attr value must be either Python str, float, or int.")


@impl(xla_pattern_marking_lib, "mark_tensor", "XLA")
@impl(XLA_LIB, "mark_tensor", "XLA")
def mark_tensor_xla(x: torch.Tensor,
name: str,
pos: int,
id: str,
is_input: bool,
attr: Dict = None):
"""Attach pattern boundary metadata to a XLA Tensor.
Args:
x: torch.Tensor (On XLA device) - the marked tensor.
name: str - The name of the pattern, it will be the name of the stablehlo composite op.
Expand All @@ -69,7 +68,7 @@ def mark_tensor_xla(x: torch.Tensor,
x, json.dumps(pattern_info, cls=BoundaryMetadataSerializer))


@impl(xla_pattern_marking_lib, "mark_tensor", "CompositeExplicitAutograd")
@impl(XLA_LIB, "mark_tensor", "CompositeExplicitAutograd")
def mark_tensor(x: torch.Tensor,
name: str,
pos: int,
Expand All @@ -80,7 +79,7 @@ def mark_tensor(x: torch.Tensor,
return x


@impl(xla_pattern_marking_lib, "mark_tensor", "Meta")
@impl(XLA_LIB, "mark_tensor", "Meta")
def mark_tensor_meta(x: torch.Tensor,
name: str,
pos: int,
Expand Down

0 comments on commit 16dc35a

Please sign in to comment.