Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify TORCH_LIBRARY definitions #6455

Merged
merged 6 commits into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 25 additions & 25 deletions test/stablehlo/test_mark_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ def test_basic(self):

def f(x):
x = x + 1
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p", 0, "0", True)
x = torch.ops.xla.mark_tensor(x, "p", 0, "0", True)
x = x + 2
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p", 0, "0", False)
x = torch.ops.xla.mark_tensor(x, "p", 0, "0", False)
return x

input_args = (torch.randn(5),)
Expand All @@ -59,29 +59,29 @@ def __init__(self):

def forward(self, x, y):
q, k, v = x.split(128, dim=-2)
q = torch.ops.xla_pattern_marking.mark_tensor(
q = torch.ops.xla.mark_tensor(
q, "sdpa", pos=0, id="0", is_input=True)
k = torch.ops.xla_pattern_marking.mark_tensor(
k = torch.ops.xla.mark_tensor(
k, "sdpa", pos=1, id="0", is_input=True)
v = torch.ops.xla_pattern_marking.mark_tensor(
v = torch.ops.xla.mark_tensor(
v, "sdpa", pos=2, id="0", is_input=True)
attn_out = F.scaled_dot_product_attention(q, k, v, scale=0.25)
attn_out = torch.ops.xla_pattern_marking.mark_tensor(
attn_out = torch.ops.xla.mark_tensor(
attn_out,
"sdpa",
pos=0,
id="0",
is_input=False,
attr={"scale": 0.25})
q, k, v = y.split(128, dim=-2)
q = torch.ops.xla_pattern_marking.mark_tensor(
q = torch.ops.xla.mark_tensor(
q, "sdpa", pos=0, id="1", is_input=True)
k = torch.ops.xla_pattern_marking.mark_tensor(
k = torch.ops.xla.mark_tensor(
k, "sdpa", pos=1, id="1", is_input=True)
v = torch.ops.xla_pattern_marking.mark_tensor(
v = torch.ops.xla.mark_tensor(
v, "sdpa", pos=2, id="1", is_input=True)
attn_out2 = F.scaled_dot_product_attention(q, k, v, scale=4)
attn_out2 = torch.ops.xla_pattern_marking.mark_tensor(
attn_out2 = torch.ops.xla.mark_tensor(
attn_out2, "sdpa", pos=0, id="1", is_input=False, attr={"scale": 2})
return attn_out, attn_out2

Expand Down Expand Up @@ -193,11 +193,11 @@ def forward(self, x, y):
def test_multiple_input(self):

def f(x, y):
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p", 0, "0", True)
y = torch.ops.xla_pattern_marking.mark_tensor(y, "p", 1, "0", True)
x = torch.ops.xla.mark_tensor(x, "p", 0, "0", True)
y = torch.ops.xla.mark_tensor(y, "p", 1, "0", True)
out = x + y
out = out * x * y
out = torch.ops.xla_pattern_marking.mark_tensor(out, "p", 0, "0", False)
out = torch.ops.xla.mark_tensor(out, "p", 0, "0", False)
return out

input_args = (torch.ones(5), torch.ones(5))
Expand All @@ -209,12 +209,12 @@ def f(x, y):
def test_multiple_output(self):

def f(x, y):
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p", 0, "0", True)
y = torch.ops.xla_pattern_marking.mark_tensor(y, "p", 1, "0", True)
x = torch.ops.xla.mark_tensor(x, "p", 0, "0", True)
y = torch.ops.xla.mark_tensor(y, "p", 1, "0", True)
out1 = x + y
out2 = x * y
out1 = torch.ops.xla_pattern_marking.mark_tensor(out1, "p", 0, "0", False)
out2 = torch.ops.xla_pattern_marking.mark_tensor(out2, "p", 1, "0", False)
out1 = torch.ops.xla.mark_tensor(out1, "p", 0, "0", False)
out2 = torch.ops.xla.mark_tensor(out2, "p", 1, "0", False)
return out1, out2

input_args = (torch.ones(5), torch.ones(5))
Expand All @@ -224,13 +224,13 @@ def f(x, y):
def test_nested_pattern(self):

def f(x):
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_outter", 0, "0", True)
x = torch.ops.xla.mark_tensor(x, "p_outter", 0, "0", True)
x = x + 1
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_inner", 0, "0", True)
x = torch.ops.xla.mark_tensor(x, "p_inner", 0, "0", True)
x = x + 1
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_inner", 0, "0", False)
x = torch.ops.xla.mark_tensor(x, "p_inner", 0, "0", False)
x = x * 2
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_outter", 0, "0",
x = torch.ops.xla.mark_tensor(x, "p_outter", 0, "0",
False)

input_args = (torch.ones(5),)
Expand All @@ -240,13 +240,13 @@ def f(x):
def test_tangent_output(self):
# Special case of nested pattern, outputs don't have dependencies.
def f(x):
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_outter", 0, "0", True)
x = torch.ops.xla.mark_tensor(x, "p_outter", 0, "0", True)
x = x + 1
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_inner", 0, "0", True)
x = torch.ops.xla.mark_tensor(x, "p_inner", 0, "0", True)
x = x + 1
y = x - 1
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_inner", 0, "0", False)
y = torch.ops.xla_pattern_marking.mark_tensor(y, "p_outter", 0, "0",
x = torch.ops.xla.mark_tensor(x, "p_inner", 0, "0", False)
y = torch.ops.xla.mark_tensor(y, "p_outter", 0, "0",
False)

input_args = (torch.ones(5),)
Expand Down
2 changes: 2 additions & 0 deletions torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import List, Optional
import torch
import torch.distributed._functional_collectives
from torch.library import Library
import torch.nn.functional as F
import torch_xla
from torch_xla import runtime
Expand All @@ -35,6 +36,7 @@
_WORLD_SIZE = None
_ORDINAL = None

XLA_LIB = Library("xla", "DEF")

def _init_world_size_ordinal():
global _WORLD_SIZE, _ORDINAL
Expand Down
13 changes: 13 additions & 0 deletions torch_xla/csrc/aten_autograd_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,5 +253,18 @@ torch::Tensor max_pool2d_backward(torch::Tensor grad_output, torch::Tensor self,
return grad;
}

TORCH_LIBRARY_FRAGMENT(xla, m) {
m.def(
"max_pool2d_forward(Tensor self, int[2] kernel_size, int[2] stride=[], "
"int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor",
torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(max_pool2d_forward)));

m.def(
"max_pool2d_backward(Tensor grad_output, Tensor self, int[2] "
"kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False) "
"-> Tensor",
torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(max_pool2d_backward)));
}

} // namespace aten_autograd_ops
} // namespace torch_xla
16 changes: 1 addition & 15 deletions torch_xla/csrc/xla_sharding_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,21 +36,7 @@ namespace torch_xla {
// Macro for defining a function that will be run at static initialization time
// to define a library of operators in the namespace. Used to define a new set
// of custom operators that do not already exist in PyTorch.
TORCH_LIBRARY(xla, m) {
m.def(
"max_pool2d_forward(Tensor self, int[2] kernel_size, int[2] stride=[], "
"int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor",
torch::dispatch(
c10::DispatchKey::XLA,
TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_forward)));

m.def(
"max_pool2d_backward(Tensor grad_output, Tensor self, int[2] "
"kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False) "
"-> Tensor",
torch::dispatch(
c10::DispatchKey::XLA,
TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_backward)));
TORCH_LIBRARY_FRAGMENT(xla, m) {
m.def(
"xla_mark_sharding_dynamo_custom_op(Tensor input, int[][] "
"tile_assignment, int[][] group_assignment, int[][] replication_groups, "
Expand Down
12 changes: 6 additions & 6 deletions torch_xla/experimental/mark_pattern_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class StableHLOCompositeBuilder:
"""
Helper for building a StableHLO Composite by marking input and output tensors. It
should be used with the StableHLO converters from `torch_xla.stablehlo`.

Args:
name (str):
The name of the built StableHLO Composite op.
Expand All @@ -37,7 +37,7 @@ def _mark_tensor(self, *tensors: torch.Tensor, is_input: bool):
if not isinstance(tensor, torch.Tensor):
raise ValueError(f"input must be a torch tensor. Got {type(tensor)}.")
marked_tensors.append(
torch.ops.xla_pattern_marking.mark_tensor(
torch.ops.xla.mark_tensor(
tensor,
name=self.name,
pos=pos,
Expand All @@ -52,9 +52,9 @@ def _mark_tensor(self, *tensors: torch.Tensor, is_input: bool):

def mark_inputs(self, *tensors: torch.Tensor):
"""
Mark the input tensors of the StableHLO Composite. This method must only be
Mark the input tensors of the StableHLO Composite. This method must only be
called once per builder.

Args:
*tensors (torch.Tensor):
Torch tensors to mark.
Expand All @@ -68,9 +68,9 @@ def mark_inputs(self, *tensors: torch.Tensor):

def mark_outputs(self, *tensors: torch.Tensor):
"""
Mark the output tensors of the StableHLO Composite. This method must only be
Mark the output tensors of the StableHLO Composite. This method must only be
called once per builder.

Args:
*tensors (torch.Tensor):
Torch tensors to mark.
Expand Down
13 changes: 6 additions & 7 deletions torch_xla/experimental/quantized.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
import numpy as np
import torch
from torch.library import impl
import torch_xla
from torch.library import Library, impl
from torch_xla.core.xla_model import XLA_LIB

quantized_decomposed_lib = Library("quantized_decomposed", "IMPL")


@impl(quantized_decomposed_lib, "quantize_per_tensor", "XLA")
@impl(XLA_LIB, "quantize_per_tensor", "XLA")
def xla_quantize_per_tensor(input: torch.Tensor, scale: float, zero_point: int,
quant_min: int, quant_max: int, dtype: torch.dtype):
return _xla_quantize(input, torch.tensor([scale]),
torch.tensor([zero_point], dtype=dtype), quant_min,
quant_max, dtype)


@impl(quantized_decomposed_lib, "quantize_per_channel", "XLA")
@impl(XLA_LIB, "quantize_per_channel", "XLA")
def xla_quantize_per_channel(input: torch.Tensor, scale: torch.Tensor,
zero_point: torch.Tensor, axis: int,
quant_min: int, quant_max: int,
Expand All @@ -23,7 +22,7 @@ def xla_quantize_per_channel(input: torch.Tensor, scale: torch.Tensor,
axis)


@impl(quantized_decomposed_lib, "dequantize_per_tensor", "XLA")
@impl(XLA_LIB, "dequantize_per_tensor", "XLA")
def xla_dequantize_per_tensor(input: torch.Tensor, scale: float,
zero_point: int, quant_min: int, quant_max: int,
dtype: torch.dtype):
Expand All @@ -32,7 +31,7 @@ def xla_dequantize_per_tensor(input: torch.Tensor, scale: float,
quant_max, dtype)


@impl(quantized_decomposed_lib, "dequantize_per_channel", "XLA")
@impl(XLA_LIB, "dequantize_per_channel", "XLA")
def xla_dequantize_per_tensor(input: torch.Tensor, scale: torch.Tensor,
zero_point: torch.Tensor, axis: int,
quant_min: int, quant_max: int,
Expand Down
15 changes: 7 additions & 8 deletions torch_xla/experimental/xla_marker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@
from typing import Dict

import torch
from torch.library import impl
import torch_xla
from torch.library import Library, impl
from torch_xla.core.xla_model import XLA_LIB

xla_pattern_marking_lib = Library("xla_pattern_marking", "DEF")

xla_pattern_marking_lib.define(
XLA_LIB.define(
"mark_tensor(Tensor x, str name, int pos, str id, bool is_input, Any? attr=None) -> Tensor"
)

Expand Down Expand Up @@ -45,15 +44,15 @@ def _assert_valid_composite_attr(attr):
"Composite attr value must be either Python str, float, or int.")


@impl(xla_pattern_marking_lib, "mark_tensor", "XLA")
@impl(XLA_LIB, "mark_tensor", "XLA")
def mark_tensor_xla(x: torch.Tensor,
name: str,
pos: int,
id: str,
is_input: bool,
attr: Dict = None):
"""Attach pattern boundary metadata to a XLA Tensor.
Args:
x: torch.Tensor (On XLA device) - the marked tensor.
name: str - The name of the pattern, it will be the name of the stablehlo composite op.
Expand All @@ -69,7 +68,7 @@ def mark_tensor_xla(x: torch.Tensor,
x, json.dumps(pattern_info, cls=BoundaryMetadataSerializer))


@impl(xla_pattern_marking_lib, "mark_tensor", "CompositeExplicitAutograd")
@impl(XLA_LIB, "mark_tensor", "CompositeExplicitAutograd")
def mark_tensor(x: torch.Tensor,
name: str,
pos: int,
Expand All @@ -80,7 +79,7 @@ def mark_tensor(x: torch.Tensor,
return x


@impl(xla_pattern_marking_lib, "mark_tensor", "Meta")
@impl(XLA_LIB, "mark_tensor", "Meta")
def mark_tensor_meta(x: torch.Tensor,
name: str,
pos: int,
Expand Down
Loading