From 825ba0da8bfd0267159d088c327486b7e9ea602a Mon Sep 17 00:00:00 2001 From: qihqi Date: Tue, 7 May 2024 17:49:17 -0700 Subject: [PATCH] Move op dispatching logic into an `Environment` class; and use Mode to capture dispatcher instead of tensor. (#7009) --- experimental/torch_xla2/docs/ops_registry.md | 40 + .../torch_xla2/examples/basic_training.py | 31 +- .../torch_xla2/examples/basic_training_jax.py | 12 +- .../torch_xla2/examples/eager_mode.py | 13 +- .../torch_xla2/test/gemma/test_gemma.py | 2 +- .../torch_xla2/test/llama/test_llama.py | 5 +- experimental/torch_xla2/test/test_context.py | 8 +- .../torch_xla2/test/test_core_aten_ops.py | 17 +- experimental/torch_xla2/test/test_extra.py | 64 - .../torch_xla2/test/test_functions.py | 6 +- .../torch_xla2/test/test_mutations.py | 61 +- experimental/torch_xla2/test/test_ops.py | 11 +- .../torch_xla2/torch_xla2/__init__.py | 29 +- experimental/torch_xla2/torch_xla2/_ops.py | 1781 --------------- .../torch_xla2/torch_xla2/decompositions.py | 19 +- .../torch_xla2/torch_xla2/environment.py | 24 - experimental/torch_xla2/torch_xla2/export.py | 147 +- experimental/torch_xla2/torch_xla2/extra.py | 62 - .../torch_xla2/torch_xla2/functions.py | 135 -- experimental/torch_xla2/torch_xla2/interop.py | 65 + .../torch_xla2/torch_xla2/ops/__init__.py | 9 + .../torch_xla2/torch_xla2/ops/jaten.py | 1951 ++++++++++++++++- .../torch_xla2/torch_xla2/ops/jtorch.py | 115 +- .../torch_xla2/torch_xla2/ops/op_base.py | 30 +- .../torch_xla2/torch_xla2/ops/ops_registry.py | 47 + .../torch_xla2/torch_xla2/ops_registry.py | 74 - experimental/torch_xla2/torch_xla2/tensor.py | 310 ++- experimental/torch_xla2/torch_xla2/types.py | 12 + 28 files changed, 2555 insertions(+), 2525 deletions(-) create mode 100644 experimental/torch_xla2/docs/ops_registry.md delete mode 100644 experimental/torch_xla2/test/test_extra.py delete mode 100644 experimental/torch_xla2/torch_xla2/_ops.py delete mode 100644 experimental/torch_xla2/torch_xla2/extra.py delete mode 100644 experimental/torch_xla2/torch_xla2/functions.py create mode 100644 experimental/torch_xla2/torch_xla2/interop.py create mode 100644 experimental/torch_xla2/torch_xla2/ops/ops_registry.py delete mode 100644 experimental/torch_xla2/torch_xla2/ops_registry.py create mode 100644 experimental/torch_xla2/torch_xla2/types.py diff --git a/experimental/torch_xla2/docs/ops_registry.md b/experimental/torch_xla2/docs/ops_registry.md new file mode 100644 index 00000000000..c0e68f42fc4 --- /dev/null +++ b/experimental/torch_xla2/docs/ops_registry.md @@ -0,0 +1,40 @@ +# Ops Registry + +## Background + +In the [How it works](how_it_works.md) doc, we mentioned 2 important pieces: + +1. A mechanism to route `ATen` ops to implementation written in + Jax or in PyTorch, and + +2. The ops themselves. + + +Ops Registry is there to help us to organize the ops themselves. + +An op implementation can written in terms of Jax, or in other PyTorch ops. +The latter is also known as "decompositions". For decompositions, +one need to be careful of not introducing circular dependencies. + +Here we simply store the operator implementations in a dictionary, +which key the torch / Aten callable that we wish to override, and +value an instance of `Operator` class. + +`Operator` class has this schema: + +```python +@dataclasses.dataclass +class Operator: + torch_op: TorchCallable + func: Union[TorchCallable, JaxCallable] + is_jax_function: bool + is_user_defined: bool + needs_env: bool +``` + +The `torch_op` is the corresponding torch callable, and `func` the implementation. `is_jax_function` is True if `func` is implemented using Jax, False if `func` is implemented using other torch ops. We can use this information to decide how to call it. + +If `needs_env` is true, `func` will recieve an extra kwarg with name `env`. +This will be the "Environment" in which this op operate on. In particular, +the environment will contain the Jax random number generator key, that might be useful for ops like `aten::rand`. + diff --git a/experimental/torch_xla2/examples/basic_training.py b/experimental/torch_xla2/examples/basic_training.py index 5d3f5a734c5..29e55700a32 100644 --- a/experimental/torch_xla2/examples/basic_training.py +++ b/experimental/torch_xla2/examples/basic_training.py @@ -10,7 +10,11 @@ from torch.utils import _pytree as pytree import torchvision import torchvision.transforms as transforms -import torch_xla2 +import torch_xla2.tensor + + +xla_env = torch_xla2.tensor.Environment(0) +mode = xla_env.mode() # PyTorch TensorBoard support from torch.utils.tensorboard import SummaryWriter @@ -80,6 +84,7 @@ def forward(self, x): model = GarmentClassifier() +model = xla_env.to_xla(model) loss_fn = torch.nn.CrossEntropyLoss() @@ -96,13 +101,6 @@ def forward(self, x): print('Total loss for this batch: {}'.format(loss.item())) # Optimizers specified in the torch.optim package - -# NEW: Move model to XLA device -state_dict = model.state_dict() -state_dict = pytree.tree_map_only(torch.Tensor, - torch_xla2.tensor.move_to_device, state_dict) -model.load_state_dict(state_dict, strict=False, assign=True) - optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9) def train_one_epoch(epoch_index, tb_writer): @@ -115,14 +113,14 @@ def train_one_epoch(epoch_index, tb_writer): for i, data in enumerate(training_loader): # Every data instance is an input + label pair # NEW: Move model to XLA device - data = pytree.tree_map_only(torch.Tensor, - torch_xla2.tensor.move_to_device, data) + data = xla_env.to_xla(data) inputs, labels = data # Zero your gradients for every batch! optimizer.zero_grad() # Make predictions for this batch + outputs = model(inputs) # Compute the loss and its gradients @@ -169,14 +167,11 @@ def train_one_epoch(epoch_index, tb_writer): # Disable gradient computation and reduce memory consumption. with torch.no_grad(): for i, vdata in enumerate(validation_loader): - # NOTE: move to XLA device - vinputs, vlabels = pytree.tree_map_only( - torch.Tensor, - torch_xla2.tensor.move_to_device, - vdata) - voutputs = model(vinputs) # call model's forward - vloss = loss_fn(voutputs, vlabels) - running_vloss += vloss + # NOTE: move to XLA device + vinputs, vlabels = xla_env.to_xla(vdata) + voutputs = model(vinputs) # call model's forward + vloss = loss_fn(voutputs, vlabels) + running_vloss += vloss avg_vloss = running_vloss / (i + 1) print('LOSS train {} valid {}'.format(avg_loss, avg_vloss)) diff --git a/experimental/torch_xla2/examples/basic_training_jax.py b/experimental/torch_xla2/examples/basic_training_jax.py index 3941fcdf8fe..ae6efdf4856 100644 --- a/experimental/torch_xla2/examples/basic_training_jax.py +++ b/experimental/torch_xla2/examples/basic_training_jax.py @@ -8,7 +8,7 @@ import torchvision import torchvision.transforms as transforms import torch_xla2 -import torch_xla2.extra +import torch_xla2.interop import jax import optax import numpy as np @@ -91,7 +91,7 @@ def forward(self, x): def jax_loss(weights, data, label): pred = jax_func(weights, data) - loss = torch_xla2.extra.call_torch(loss_fn, pred, label) + loss = torch_xla2.interop.call_torch(loss_fn, pred, label) return loss grad_fn = jax.jit(jax.value_and_grad(jax_loss)) @@ -155,12 +155,6 @@ def train_one_epoch(jax_weights, opt_state, epoch_index, tb_writer): # Make sure gradient tracking is on, and do a pass over the data model.train(True) - # NEW: Move model to XLA device - state_dict = model.state_dict() - state_dict = pytree.tree_map_only(torch.Tensor, - torch_xla2.tensor.move_to_device, state_dict) - model.load_state_dict(state_dict, strict=False, assign=True) - avg_loss, opt_state = train_one_epoch(jax_weights, opt_state, epoch_number, writer) running_vloss = 0.0 @@ -174,7 +168,7 @@ def train_one_epoch(jax_weights, opt_state, epoch_index, tb_writer): vinputs, vlabels = pytree.tree_map_only(torch.Tensor, torch_xla2.tensor.t2j, vdata) voutputs = jax_func(jax_weights, (vinputs, )) # call model's forward - vloss = torch_xla2.extra.call_torch(loss_fn, voutputs, vlabels) + vloss = torch_xla2.interop.call_torch(loss_fn, voutputs, vlabels) running_vloss += vloss avg_vloss = running_vloss / (i + 1) diff --git a/experimental/torch_xla2/examples/eager_mode.py b/experimental/torch_xla2/examples/eager_mode.py index 358ee6256c6..755f24b0d2b 100644 --- a/experimental/torch_xla2/examples/eager_mode.py +++ b/experimental/torch_xla2/examples/eager_mode.py @@ -1,10 +1,9 @@ - -from torch_xla2.tensor import move_to_device import torch_xla2 from torch import nn from torch.nn import functional as F import torch -from torch.utils import _pytree as pytree + +xla_env = torch_xla2.default_env() class MyModel(nn.Module): @@ -22,21 +21,21 @@ def forward(self, x): return x m = MyModel() +m = xla_env.to_xla(m) # Execute this model using torch inputs = (torch.randn(3, 3, 28, 28), ) +inputs = xla_env.to_xla(inputs) -inputs, state_dict = pytree.tree_map_only(torch.Tensor, move_to_device, (inputs, m.state_dict())) -m.load_state_dict(state_dict, strict=False, assign=True) print(m(*inputs)) print('---=====') -from torch_xla2.extra import jax_jit +from torch_xla2.interop import jax_jit @jax_jit def model_func(param, inputs): return torch.func.functional_call(m, param, inputs) -print(model_func(state_dict, inputs)) +print(model_func(m.state_dict(), inputs)) diff --git a/experimental/torch_xla2/test/gemma/test_gemma.py b/experimental/torch_xla2/test/gemma/test_gemma.py index bd0bb21dbb1..4d91bc6f9b0 100644 --- a/experimental/torch_xla2/test/gemma/test_gemma.py +++ b/experimental/torch_xla2/test/gemma/test_gemma.py @@ -74,7 +74,7 @@ def test_gemma(self): weights, jax_func = torch_xla2.extract_jax(model) inputs_jax = pytree.tree_map_only( - torch.Tensor, torch_xla2.tensor.move_to_device, inputs) + torch.Tensor, torch_xla2.tensor.t2j, inputs) import jax print(jax.jit(jax_func)(weights, inputs_jax)) diff --git a/experimental/torch_xla2/test/llama/test_llama.py b/experimental/torch_xla2/test/llama/test_llama.py index dae7bf0cc5c..083116ab89e 100644 --- a/experimental/torch_xla2/test/llama/test_llama.py +++ b/experimental/torch_xla2/test/llama/test_llama.py @@ -1,8 +1,5 @@ -import unittest -import jax import torch -from torch._functorch.make_functional import make_functional_with_buffers -from torch_xla2 import tensor, ops # pylint: disable=unused-import +from torch_xla2 import tensor # pylint: disable=unused-import import torch_xla2 from .. import test_base diff --git a/experimental/torch_xla2/test/test_context.py b/experimental/torch_xla2/test/test_context.py index 1a75a7d23d0..a6bcda5113a 100644 --- a/experimental/torch_xla2/test/test_context.py +++ b/experimental/torch_xla2/test/test_context.py @@ -1,20 +1,22 @@ import unittest import torch -import torch_xla2 from torch_xla2 import tensor +xla_env = tensor.Environment(0) + class TestContext(unittest.TestCase): + def test_mode_context_manager(self): - with torch_xla2.mode(): + with xla_env: x = torch.full((3, 3), -1) self.assertIsInstance(x, tensor.XLATensor2) y = x.abs() self.assertIsInstance(y, tensor.XLATensor2) @staticmethod - @torch_xla2.mode() + @xla_env def _test_mode_decorator(): x = torch.full((3, 3), -1) y = x.abs() diff --git a/experimental/torch_xla2/test/test_core_aten_ops.py b/experimental/torch_xla2/test/test_core_aten_ops.py index 357e41c9101..6a1cef306be 100644 --- a/experimental/torch_xla2/test/test_core_aten_ops.py +++ b/experimental/torch_xla2/test/test_core_aten_ops.py @@ -1,7 +1,6 @@ import unittest import torch -from torch_xla2 import ops_registry from torch_xla2 import tensor from . import test_base @@ -34,12 +33,13 @@ def run_export_and_compare(testcase, rtol=1e-5, equal_nan=True, ignore_indices=False): + with testcase.subTest("torch_eval"): res = func(*args, **kwargs) with testcase.subTest("torch_xla2_eval"): - args2, kwargs2 = pytree.tree_map_only(torch.Tensor, tensor.move_to_device, - (args, kwargs)) - res2 = func(*args2, **kwargs2) + args2, kwargs2 = testcase.env.to_xla((args, kwargs)) + with testcase.env: + res2 = func(*args2, **kwargs2) res2 = pytree.tree_map_only(tensor.XLATensor2, lambda t: t.torch(), res2) # import pdb; pdb.set_trace() with testcase.subTest("torch_xla2_diff:" + str(atol)): @@ -61,11 +61,11 @@ class TestCoreAtenOps(unittest.TestCase): @classmethod def setUpClass(cls): super().setUpClass() - ops_registry.print_missing_ops() def setUp(self): super().setUp() torch.manual_seed(0) + self.env = tensor.Environment(0) def test_aten_abs_0(self): args = (torch.randn((10, 10)).to(torch.float32),) @@ -2109,7 +2109,7 @@ def test_aten_logit_0(self): def test_aten_logit_1(self): args = (torch.randn((10, 10)).to(torch.float16),) kwargs = dict() - run_export_and_compare(self, torch.ops.aten.logit, args, kwargs) + run_export_and_compare(self, torch.ops.aten.logit, args, kwargs, atol=0.01,) def test_aten_logit_2(self): args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) @@ -3639,8 +3639,9 @@ def test_aten__softmax_1(self): def _compare_sorted_result(self, args): res = torch.ops.aten.sort(*args) with self.subTest("torch_xla2_eval"): - args2 = pytree.tree_map_only(torch.Tensor, tensor.move_to_device, args) - res2 = torch.ops.aten.sort(*args2) + args2 = self.env.to_xla(args) + with self.env: + res2 = torch.ops.aten.sort(*args2) # The second argument is the sorted index. These might not be # identical from torch vs. jax; but both can be correct diff --git a/experimental/torch_xla2/test/test_extra.py b/experimental/torch_xla2/test/test_extra.py deleted file mode 100644 index 768488d6a99..00000000000 --- a/experimental/torch_xla2/test/test_extra.py +++ /dev/null @@ -1,64 +0,0 @@ -import unittest -import torch -import torch.nn.functional as F -import jax -import jax.numpy as jnp -import torch_xla2 -from torch_xla2 import tensor, extra - - -class ExtraTest(unittest.TestCase): - - def setUp(self): - torch.manual_seed(0) - - def test_standard_callable(self): - def f(a, b): - return torch.add(a, b) - - a = jnp.ones((10, )) - b = jnp.ones((10, )) - - c = extra.jax_view(f)(a, b) - self.assertTrue(jnp.allclose(c, a + b)) - - def f2(a, b): - return jnp.add(a, b) - - a = tensor.move_to_device(torch.ones((10, ))) - b = tensor.move_to_device(torch.ones((10, ))) - c2 = extra.torch_view(f2)(a, b) - - self.assertTrue(jnp.allclose(c2._elem, c)) - - - - def test_fori_loop(self): - a = tensor.move_to_device(torch.ones((10, 10))) - - def body(i, c): - return c + a[i] - - init_val = tensor.move_to_device(torch.zeros(10)) - res = extra.fori_loop(0, 10, body, init_val) - expect = torch.ones(10) * 10 - self.assertTrue(torch.allclose(tensor.j2t(res._elem), expect)) - - def test_jax_jit(self): - - # functions that acts on torch tensor - def f(a, b): - return torch.sin(a) + torch.cos(b) - - fjitted = extra.jax_jit(f) - a = torch.rand((10, 10)) - b = torch.rand((10, 10)) - aj = tensor.move_to_device(a) - bj = tensor.move_to_device(b) - res = f(a, b) - res2 = fjitted(aj, bj) - self.assertTrue(torch.allclose(res, tensor.j2t(res2._elem))) - - -if __name__ == '__main__': - unittest.main() diff --git a/experimental/torch_xla2/test/test_functions.py b/experimental/torch_xla2/test/test_functions.py index 76e842d6fdd..2d624b25b5b 100644 --- a/experimental/torch_xla2/test/test_functions.py +++ b/experimental/torch_xla2/test/test_functions.py @@ -3,12 +3,14 @@ from absl.testing import parameterized import torch import torch_xla2 -import torch_xla2.functions import torch_xla2.tensor class TestTorchFunctions(parameterized.TestCase): + def setUp(self): + self.env = torch_xla2.tensor.Environment(0) + @parameterized.named_parameters( ('tensor_2d', lambda: torch.tensor([[0.1, 1.2], [2.2, 3.1], [4.9, 5.2]])), ('tensor_1d', lambda: torch.tensor([0, 1],)), @@ -32,7 +34,7 @@ class TestTorchFunctions(parameterized.TestCase): def test_tensor_constructor(self, func: Callable[[], torch.Tensor]): expected = func() - with torch_xla2.functions.XLAFunctionMode(): + with self.env: actual = func() self.assertIsInstance(actual, torch_xla2.tensor.XLATensor2) diff --git a/experimental/torch_xla2/test/test_mutations.py b/experimental/torch_xla2/test/test_mutations.py index 2f9ddca975b..50d78aa0fae 100644 --- a/experimental/torch_xla2/test/test_mutations.py +++ b/experimental/torch_xla2/test/test_mutations.py @@ -6,46 +6,43 @@ class TestMutations(TestCase): - def test_add(self): - x = torch.tensor([1, 2, 3], dtype=torch.int32) - y = torch.tensor([4, 5, 6], dtype=torch.int32) + def setUp(self): + self.env = torch_xla2.tensor.Environment(0) - x = torch_xla2.tensor.move_to_device(x) - y = torch_xla2.tensor.move_to_device(y) - x.add_(y) - xt = torch_xla2.tensor.j2t(x._elem) - self.assertEqual(xt, torch.tensor([5, 7, 9], dtype=torch.int32)) + def test_add(self): + with self.env: + x = torch.tensor([1, 2, 3], dtype=torch.int32) + y = torch.tensor([4, 5, 6], dtype=torch.int32) + x.add_(y) + xt = torch_xla2.tensor.j2t(x._elem) + self.assertEqual(xt, torch.tensor([5, 7, 9], dtype=torch.int32)) def test_sub(self): - x = torch.tensor([1, 2, 3], dtype=torch.int32) - y = torch.tensor([4, 5, 6], dtype=torch.int32) - - x = torch_xla2.tensor.move_to_device(x) - y = torch_xla2.tensor.move_to_device(y) - x.sub_(y) - xt = torch_xla2.tensor.j2t(x._elem) - self.assertEqual(xt, torch.tensor([-3, -3, -3], dtype=torch.int32)) + with self.env: + x = torch.tensor([1, 2, 3], dtype=torch.int32) + y = torch.tensor([4, 5, 6], dtype=torch.int32) + x.sub_(y) + xt = torch_xla2.tensor.j2t(x._elem) + self.assertEqual(xt, torch.tensor([-3, -3, -3], dtype=torch.int32)) def test_mul(self): - x = torch.tensor([1, 2, 3], dtype=torch.int32) - y = torch.tensor([4, 5, 6], dtype=torch.int32) + with self.env: + x = torch.tensor([1, 2, 3], dtype=torch.int32) + y = torch.tensor([4, 5, 6], dtype=torch.int32) - x = torch_xla2.tensor.move_to_device(x) - y = torch_xla2.tensor.move_to_device(y) - x.mul_(y) - xt = torch_xla2.tensor.j2t(x._elem) - self.assertEqual(xt, torch.tensor([4, 10, 18], dtype=torch.int32)) + x.mul_(y) + xt = torch_xla2.tensor.j2t(x._elem) + self.assertEqual(xt, torch.tensor([4, 10, 18], dtype=torch.int32)) def test_div(self): - x = torch.tensor([1, 2, 3], dtype=torch.int32) - y = torch.tensor([4, 5, 6], dtype=torch.int32) - - x = torch_xla2.tensor.move_to_device(x) - y = torch_xla2.tensor.move_to_device(y) - x.div_(y) - xt = torch_xla2.tensor.j2t(x._elem) - self.assertEqual(xt, - torch.tensor([1. / 4, 2. / 5, 3. / 6], dtype=torch.float)) + with self.env: + x = torch.tensor([1, 2, 3], dtype=torch.int32) + y = torch.tensor([4, 5, 6], dtype=torch.int32) + + x.div_(y) + xt = torch_xla2.tensor.j2t(x._elem) + self.assertEqual(xt, + torch.tensor([1. / 4, 2. / 5, 3. / 6], dtype=torch.float)) if __name__ == '__main__': diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 5f6fdbbeab2..20686f2fe6c 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -7,7 +7,6 @@ instantiate_device_type_tests, ops) from torch.utils import _pytree as pytree from torch_xla2 import tensor -import torch_xla2 skiplist = { @@ -626,10 +625,9 @@ def run_export_and_compare(testcase, with testcase.subTest("torch_eval"): res = func(sample_input.input, *sample_input.args, **sample_input.kwargs) with testcase.subTest("torch_xla2_eval"): - input2, args2, kwargs2 = pytree.tree_map_only( - torch.Tensor, tensor.move_to_device, - (sample_input.input, sample_input.args, sample_input.kwargs)) - with torch_xla2.mode(): + input2, args2, kwargs2 = testcase.env.to_xla(( + sample_input.input, sample_input.args, sample_input.kwargs)) + with testcase.env: res2 = func(input2, *args2, **kwargs2) res2 = pytree.tree_map_only(tensor.XLATensor2, lambda t: t.torch(), res2) with testcase.subTest("torch_xla2_diff:" + str(atol)): @@ -655,6 +653,9 @@ class TestOpInfo(TestCase): def setUpClass(cls): print('op_db size: ', len(op_db), 'testing: ', len(ops_to_test)) + def setUp(self): + self.env = tensor.Environment(0) + @ops(ops_to_test, allowed_dtypes=(torch.float32, torch.long)) def test_reference_eager(self, device, dtype, op): sample_inputs = op.sample_inputs(device, dtype) diff --git a/experimental/torch_xla2/torch_xla2/__init__.py b/experimental/torch_xla2/torch_xla2/__init__.py index b0bb20712d4..bd0e00fa6ca 100644 --- a/experimental/torch_xla2/torch_xla2/__init__.py +++ b/experimental/torch_xla2/torch_xla2/__init__.py @@ -1,31 +1,34 @@ -import contextlib import jax import torch from torch._functorch import make_functional from torch.utils import _pytree as pytree -from torch_xla2 import export, _ops, ops_registry, tensor, tf_integration, functions +from torch_xla2 import export, tensor, tf_integration jax.config.update('jax_enable_x64', True) +env = None +def default_env(): + global env + if env is None: + env = tensor.Environment(0) + return env -@contextlib.contextmanager -def mode(): - with tensor.XLADispatchMode(), functions.XLAFunctionMode(): - yield -def extract_jax(mod: torch.nn.Module): +def extract_jax(mod: torch.nn.Module, env=None): """Returns a pytree of jax.ndarray and a jax callable.""" + if env is None: + env = default_env() func, weights, buffer = make_functional.make_functional_with_buffers(mod) - states = (weights, buffer) + states = mod.state_dict() + states = pytree.tree_map_only(torch.Tensor, tensor.t2j, states) #@jax.jit def jax_func(states, inputs): - (states, inputs) = tensor.wrap((states, inputs)) - weights, buffer = states - with tensor.XLADispatchMode(): - res = func(weights, buffer, *inputs) - return tensor.unwrap(res) + (states, inputs) = env.j2t_iso((states, inputs)) + with env: + res = torch.func.functional_call(mod, states, inputs) + return env.t2j_iso(res) return states, jax_func diff --git a/experimental/torch_xla2/torch_xla2/_ops.py b/experimental/torch_xla2/torch_xla2/_ops.py deleted file mode 100644 index e3650234372..00000000000 --- a/experimental/torch_xla2/torch_xla2/_ops.py +++ /dev/null @@ -1,1781 +0,0 @@ -# pylint: disable -"""Torch ops implemented using jax.""" -import sys - -import jax -from jax import numpy as jnp -import numpy as np -import torch -from torch_xla2 import ops_registry -from torch_xla2 import tensor - - -class TorchFunctionLowering: - - def __init__(self, func, is_jax_func, should_jit=False): - if is_jax_func and should_jit: - func = jax.jit(func) - self.func = func - self.is_jax_func = is_jax_func - - def __call__(self, *args, **kwargs): - if self.is_jax_func: - (args, kwargs) = tensor.unwrap((args, kwargs)) - res = self.func(*args, **kwargs) - if self.is_jax_func: - res = tensor.wrap(res) - return res - - -def op(aten_op, is_jax_func=True): - """if is_jax_func is true, then the function it will register - - should takes jax array as input and returns jax array. - - Which means we need to wrap it - """ - - def inner(func): - ops_registry.lowerings.register(aten_op, - TorchFunctionLowering(func, is_jax_func)) - return func - - return inner - - -@op(torch.ops.aten.view_copy) -@op(torch.ops.aten.view) -@op(torch.ops.aten._unsafe_view) -@op(torch.ops.aten.reshape) -def _aten_unsafe_view(x, shape): - return jnp.reshape(x, shape) - - -@op(torch.ops.aten.add) -def _aten_add(x, y, *, alpha=1): - """if isinstance(x, jnp.ndarray) and isinstance(y, jnp.ndarray): - - assert x.dtype == y.dtype, (x.dtype, y.dtype) - """ - return x + y * alpha - - -@op(torch.ops.aten.copy_, is_jax_func=False) -def _aten_copy(x, y, memory_format=None): - if isinstance(x, tensor.XLATensor2): - x._elem = y._elem - elif isinstance(x, tensor.SliceView): - x.mutate(y) - return x - - -@op(torch.ops.aten.clone) -def _aten_clone(x, memory_format=None): - return jnp.copy(x) - - -@op(torch.ops.aten.full) -def _aten_full(size, value, **kwargs): - return jnp.full(size, value) - - -@op(torch.ops.aten.index_copy) -def _aten_index_copy(x, dim, indexes, source): - # return jax.lax.scatter(x, index, dim) - dims = [] - for i in range(len(x.shape)): - if i == dim: - dims.append(indexes) - else: - dims.append(slice(None, None, None)) - return x.at[dim].set(source) - - -@op(torch.ops.aten.select) -@op(torch.ops.aten.index_select) -@op(torch.ops.aten.select_copy) -def _aten_index_select(x, dim, indexes): - dims = [] - for i in range(len(x.shape)): - if i == dim: - dims.append(indexes) - else: - dims.append(slice(None, None, None)) - return x[tuple(dims)] - - -@op(torch.ops.aten.mean) -def _aten_mean(x, dim=None, keepdim=False): - return jnp.mean(x, dim, keepdims=keepdim) - - -def _torch_binary_scalar_type(scalar, tensor): - if "float" in str(tensor.dtype): - return tensor.dtype - - if isinstance(scalar, int): - if "int" in str(tensor.dtype): - return tensor.dtype - - return jnp.float32 - - -@op(torch.ops.aten.sub) -def _aten_sub(x, y): - if isinstance(x, float): - dtype = _torch_binary_scalar_type(x, y) - x = jnp.array(x, dtype=dtype) - if isinstance(y, float): - dtype = _torch_binary_scalar_type(y, x) - y = jnp.array(y, dtype=dtype) - return x - y - - -@op(torch.ops.aten.mm) -def _aten_mm(x, y): - res = x @ y - return res - - -@op(torch.ops.aten.mul) -def _aten_mul(x, y): - return x * y - - -@op(torch.ops.aten.silu) -def _aten_silu(x): - return jax.nn.silu(x) - - -@op(torch.ops.aten.t) -def _aten_t(x): - return jnp.transpose(x) - - -@op(torch.ops.aten.transpose) -@op(torch.ops.aten.transpose_copy) -def _aten_transpose(x, dim0, dim1): - shape = list(range(len(x.shape))) - shape[dim0], shape[dim1] = shape[dim1], shape[dim0] - return jnp.transpose(x, shape) - - -@op(torch.ops.aten.triu) -def _aten_triu(m, k): - return jnp.triu(m, k) - - -@op(torch.ops.aten.slice) -@op(torch.ops.aten.slice_copy) -def _aten_slice(self, dim=0, start=None, end=None, step=1): - if end == sys.maxsize: - end = self.shape[dim] - sl = slice(start, end, step) - dims = [] - for i in range(len(self.shape)): - if i == dim: - dims.append(sl) - else: - dims.append(slice(None, None, None)) - return self[tuple(dims)] - - -@op(torch.ops.aten.detach) -def _aten_detach(self): - return self - - -@op(torch.ops.aten.view_as_real) -def _aten_view_as_real(x): - real = jnp.real(x) - im = jnp.imag(x) - res = jnp.stack([real, im], -1) - return res - - -@op(torch.ops.aten.stack) -def _aten_stack(tensors, dim=0): - return jnp.stack(tensors, dim) - - -@op(torch.ops.aten._softmax) -def _aten_softmax(x, dim, halftofloat): - return jax.nn.softmax(x, dim) - - -@op(torch.ops.aten.pow) -def _aten_pow(x, y): - if isinstance(y, int): - y = float(y) - return jnp.power(x, y) - - -@op(torch.ops.aten.view_as_complex) -def _aten_view_as_complex(input): - if input.dtype == jnp.bfloat16: - input = input.astype(jnp.float32) - x, y = input[..., 0], input[..., 1] - return jax.lax.complex(x, y) - - -@op(torch.ops.aten.div) -def _aten_div(x, y, rounding_mode=""): - res = x / y - if rounding_mode == "trunc": - res = jnp.trunc(res) - return res - - -@op(torch.ops.aten.div_, is_jax_func=False) -def _aten_div_(x, y, rounding_mode=""): - x._elem = _aten_div(x._elem, y._elem, rounding_mode) - return x - - -@op(torch.ops.aten.true_divide) -def _aten_true_divide(x, y): - return x / y - - -@op(torch.ops.aten.bmm) -def _aten_bmm(x, y): - res = x @ y - return res - # return jnp.einsum('bnm,bmk->bnk', x, y) - - -@op(torch.ops.aten.embedding) -# embedding(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -def _aten_embedding(a, w, padding_idx=-1): - return jnp.take(a, w, axis=0) - - -@op(torch.ops.aten.rsqrt) -def _aten_rsqrt(x): - if isinstance(x, int): - x = float(x) - if x.dtype == jnp.int32: - x = x.astype(jnp.float32) - return jax.lax.rsqrt(x) - - -@op(torch.ops.aten.expand) -@op(torch.ops.aten.expand_copy) -def _aten_expand(x, dims): - - def fix_dims(d, xs): - if d == -1: - return xs - return d - - dims = [fix_dims(p, s) for p, s in zip(dims, x.shape)] - return jnp.broadcast_to(x, dims) - - -@op(torch.ops.aten.dot) -def _aten_dot(x, y): - return jnp.dot(x, y) - - -@op(torch.ops.aten._to_copy) -def _aten__to_copy(self, **kwargs): - dtype = tensor.t2j_dtype(kwargs["dtype"]) - if dtype != self.dtype: - return self.astype(dtype) - return jnp.copy(self) - - -@op(torch.ops.aten.empty) -def _aten_empty(sizes, **kwargs): - return jnp.zeros(sizes) - - -@op(torch.ops.aten.index_put_) -@op(torch.ops.aten.index_put) -def _aten_index_put(self, indexes, values, accumulate=False): - indexes = [slice(None, None, None) if i is None else i for i in indexes] - indexes = tuple(indexes) - if accumulate: - return self.at[indexes].add(values) - else: - return self.at[indexes].set(values) - - -@op(torch.ops.aten.index) -@op(torch.ops.aten._unsafe_index) -@op(torch.ops.aten.index.Tensor) -def _aten_index(self, indexes): - indexes = [slice(None, None, None) if i is None else i for i in indexes] - indexes = tuple(indexes) - return self[indexes] - - -@op(torch.ops.aten.split) -@op(torch.ops.aten.split_copy) -@op(torch.ops.aten.split_with_sizes) -def split_with_sizes(x, sizes, dim=0): - """Splits an array `x` into sub-arrays based on static sizes `sizes`. - - Args: - x: The input array to split. - sizes: A 1D array of integer sizes for each sub-array. - - Returns: - A list of sub-arrays. - """ - if isinstance(sizes, int): - # split equal size - new_sizes = [sizes] * (x.shape[dim] // sizes) - sizes = new_sizes - rank = x.ndim - splits = np.cumsum(sizes) # Cumulative sum for split points - - def make_range(rank, dim, start, end): - res = [slice(None, None, None)] * rank - res[dim] = slice(start, end) - return tuple(res) - - return [ - x[make_range(rank, dim, start, end)] - for start, end in zip([0] + list(splits[:-1]), splits) - ] - - -@op(torch.ops.aten.permute) -@op(torch.ops.aten.permute_copy) -def permute(t, dims): - return jnp.transpose(t, dims) - - -@op(torch.ops.aten.unsqueeze) -@op(torch.ops.aten.unsqueeze_copy) -@op(torch.ops.aten.unsqueeze.default) -def _aten_unsqueeze(self, dim): - if dim < 0: - dim += self.ndim + 1 - return jnp.expand_dims(self, dim) - - -@op(torch.ops.aten.ne) -def _aten_ne(x, y): - return jnp.not_equal(x, y) - - -@op(torch.ops.aten.cumsum) -def _aten_cumsum(x, y, dtype=None): - dtype = tensor.t2j_dtype(dtype) - res = jnp.cumsum(x, y, dtype) - return res - - -@op(torch.ops.aten.native_layer_norm) -def _aten_native_layer_norm(input, - normalized_shape, - weight=None, - bias=None, - eps=1e-5): - """Implements layer normalization in Jax as defined by `aten::native_layer_norm`. - - Args: - input: The input tensor. - normalized_shape: A list of integer dimensions to be normalized over. - weight: Optional weight tensor for the affine transformation. - bias: Optional bias tensor for the affine transformation. - eps: A small epsilon value for numerical stability. - - Returns: - output: The normalized tensor. - mean: The calculated mean tensor. - std: The calculated standard deviation tensor. - """ - if isinstance(normalized_shape, int): - normalized_shape = [normalized_shape] - axis = [i for i, d in enumerate(input.shape) if d in normalized_shape] - - # Calculate mean and standard deviation - mean = jnp.mean(input, axis=axis, keepdims=True) - var = jnp.var(input, axis=axis, keepdims=True) - rstd = jax.lax.rsqrt(var + eps) - - # Normalize the input - norm_x = (input - mean) * rstd - - # Apply affine transformation (if provided) - if weight is not None: - norm_x *= weight - if bias is not None: - norm_x += bias - return norm_x, mean, rstd - - -# - func: addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor -@op(torch.ops.aten.addmm) -@op(torch.ops.aten.addmv) -def _aten_addmm(self, mat1, mat2, *, beta=1.0, alpha=1.0): - alpha = jnp.array(alpha).astype(mat1.dtype) - beta = jnp.array(beta).astype(mat1.dtype) - self *= beta - self += alpha * jnp.matmul(mat1, mat2) - return self - -@op(torch.ops.aten.addbmm.default) -def _aten_addbmm(input, batch1, batch2, *, beta=1, alpha=1): - alpha = jnp.array(alpha).astype(batch1.dtype) - beta = jnp.array(beta).astype(batch1.dtype) - mm = jnp.einsum('bxy, byz -> xz', batch1, batch2) - return jax.lax.cond(beta == 0, - lambda: alpha * mm, - lambda: beta*input + alpha*mm) - - -@op(torch.ops.aten.gelu) -def _aten_gelu(self, *, approximate="none"): - approx = approximate == "tanh" - return jax.nn.gelu(self, approx) - - -@op(torch.ops.aten.squeeze) -@op(torch.ops.aten.squeeze_copy) -def _aten_squeeze_dim(self, dim): - """Squeezes a Jax tensor by removing a single dimension of size 1. - - Args: - self: The input tensor. - dim: The dimension to squeeze. - - Returns: - The squeezed tensor with the specified dimension removed if it is 1, - otherwise the original tensor is returned. - """ - - # Validate input arguments - if not isinstance(self, jnp.ndarray): - raise TypeError(f"Expected a Jax tensor, got {type(self)}.") - if isinstance(dim, int): - dim = [dim] - - # Check if the specified dimension has size 1 - if all([self.shape[d] != 1 for d in dim]): - return self - - # Use slicing to remove the dimension if it is 1 - new_shape = list(self.shape) - - def fix_dim(p): - if p < 0: - return p + len(self.shape) - return p - - dim = [fix_dim(d) for d in dim] - new_shape = [p for i, p in enumerate(self.shape) if i not in dim or p != 1] - return self.reshape(new_shape) - - -@op(torch.ops.aten.convolution) -def _aten_convolution( - input, - weight, - bias, - stride, - padding, - dilation, - transposed, - output_padding, - groups, -): - if transposed: - raise NotImplementedError("Transposed convolution is not implemented.") - - def make_padding(padding): - return ((p, p) for p in padding) - - def create_default_conv_dimension_numbers(num_spatial_dims): - # Ref: https://github.com/openxla/xla/blob/main/xla/client/xla_builder.cc#L4211 - # (batch dimension, feature dimension, spatial dimensions...) - lhs_spec = [0, 1] - # (out feature dimension, in feature dimension, spatial dimensions...) - rhs_spec = [0, 1] - # (batch dimension, feature dimension, spatial dimensions...) - out_spec = [0, 1] - for i in range(0, num_spatial_dims): - lhs_spec.append(i + 2) - rhs_spec.append(i + 2) - out_spec.append(i + 2) - return jax.lax.ConvDimensionNumbers( - *map(tuple, (lhs_spec, rhs_spec, out_spec))) - - res = jax.lax.conv_general_dilated( - input, - weight, - stride, - make_padding(padding), - lhs_dilation=(1,) * len(stride), - rhs_dilation=dilation, - dimension_numbers=create_default_conv_dimension_numbers(len(stride)), - feature_group_count=groups, - batch_group_count=1, - ) - - if bias is not None: - # TODO(qihqi): bias always on channel? - if len(bias.shape) == 1: - shape = [1] * len(res.shape) - shape[1] = bias.shape[0] - bias = bias.reshape(tuple(shape)) - res = res + bias - return res - - -# _native_batch_norm_legit(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, bool training, float momentum, float eps) -@op(torch.ops.aten._native_batch_norm_legit) -def _aten__native_batch_norm_legit(input, weight, bias, running_mean, - running_var, training, momentum, eps): - return _aten__native_batch_norm_legit_no_training(input, weight, bias, - running_mean, running_var, - momentum, eps) - - -@op(torch.ops.aten._native_batch_norm_legit_no_training) -def _aten__native_batch_norm_legit_no_training(input, weight, bias, - running_mean, running_var, - momentum, eps): - if weight is None: - weight = jnp.ones_like(running_mean) - if bias is None: - bias = jnp.zeros_like(running_mean) - - def broadcast(t): - return jax.lax.broadcast_in_dim(t, input.shape, broadcast_dimensions=(1,)) - - a = input - broadcast(running_mean) - b = broadcast(jnp.sqrt(running_var + eps)) - return ( - a / b * broadcast(weight) + broadcast(bias), - jnp.array([]), - jnp.array([]), - ) - - -@op(torch.ops.aten.relu) -def _aten_relu(self): - return jax.nn.relu(self) - - -@op(torch.ops.aten.cat) -def _aten_cat(tensors, dims=0): - return jnp.concatenate(tensors, dims) - - -@op(torch.ops.aten.max_pool2d_with_indices) -@op(torch.ops.aten.max_pool3d_with_indices) -def _aten_max_pool2d_with_indices(inputs, - kernel_size, - strides, - padding=0, - dilation=1, - ceil_mode=False): - num_batch_dims = len(inputs.shape) - len(kernel_size) - 1 - kernel_size = tuple(kernel_size) - strides = tuple(strides) - if isinstance(padding, int): - padding = tuple((padding, padding) for _ in range(len(kernel_size))) - elif isinstance(padding, list): - padding = tuple((p, p) for p in padding) - - window_shape = kernel_size - num_batch_dims = inputs.ndim - (len(window_shape) + 1) - strides = strides or (1,) * len(window_shape) - assert len(window_shape) == len( - strides), f'len({window_shape}) must equal len({strides})' - strides = (1,) * (1 + num_batch_dims) + strides - dims = (1,) * (1 + num_batch_dims) + window_shape - - is_single_input = False - if num_batch_dims == 0: - # add singleton batch dimension because lax.reduce_window always - # needs a batch dimension. - inputs = inputs[None] - strides = (1,) + strides - dims = (1,) + dims - is_single_input = True - - assert inputs.ndim == len(dims), f'len({inputs.shape}) != len({dims})' - if not isinstance(padding, str): - padding = tuple(map(tuple, padding)) - assert len(padding) == len(window_shape), ( - f'padding {padding} must specify pads for same number of dims as ' - f'window_shape {window_shape}') - assert all([len(x) == 2 for x in padding - ]), f'each entry in padding {padding} must be length 2' - padding = ((0, 0), (0, 0)) + padding - - indices = jnp.arange(np.prod(inputs.shape)).reshape(inputs.shape) - - def reduce_fn(a, b): - ai, av = a - bi, bv = b - which = av > bv - return jnp.where(which, ai, bi), jnp.where(which, av, bv) - - init_val = -jnp.inf - if inputs.dtype in (jnp.int32, jnp.int64): - init_val = -(1 << 31) - init_val = jnp.array(init_val).astype(inputs.dtype) - - indices, y = jax.lax.reduce_window((indices, inputs), (0, init_val), - reduce_fn, dims, strides, padding) - if is_single_input: - indices = jnp.squeeze(indices, axis=0) - y = jnp.squeeze(y, axis=0) - return y, indices - - batch_result = pool(inputs, -jnp.inf, jax.lax.max, kernel_size, strides, - padding) - indices = pool(inputs, 0, jnp.argmax, kernel_size, strides, padding) - return batch_result, indices - - -# TODO add more ops - - -@op(torch.ops.aten.min) -def _aten_min(x, axis=None): - return jnp.min(x, axis=axis), jnp.argmin(x, axis=axis).astype(jnp.int64) - - -@op(torch.ops.aten.amin) -def _aten_amin(x, dim=None, keepdim=False): - return _with_reduction_scalar(jnp.amin, x, dim, keepdim) - - -@op(torch.ops.aten.argmin) -def _aten_argmin(self, dim=None, keepdim=False): - return _with_reduction_scalar( - jnp.argmin, self, dim, keepdim) - - -@op(torch.ops.aten.sin) -def _aten_sin(x): - return jnp.sin(x) - - -@op(torch.ops.aten.sym_size) -def _aten_sym_size(x, dim): - return x.shape[dim] - - -@op(torch.ops.aten.var) -@op(torch.ops.prims.var) -def _aten_var(x, dim=None, *, correction=1, keepdim=False, out=None): - return jnp.var(x, axis=dim, ddof=correction, keepdims=keepdim) - - -@op(torch.ops.prims.broadcast_in_dim) -def _prims_broadcast_in_dim(t, shape, broadcast_dimensions): - return jax.lax.broadcast_in_dim( - t, shape, broadcast_dimensions=broadcast_dimensions) - - -# aten.native_group_norm -- should use decomp table -# func: native_group_norm(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps) -> (Tensor, Tensor, Tensor) - - -@op(torch.ops.aten.native_group_norm) -def _aten_native_group_norm(input, weight, bias, N, C, HxW, group, eps=1e-5): - """Group Normalization implementation in JAX. - - Args: - input: Input tensor. Expected shape (batch_size, channels, ... spatial dims - ...) - weight: Optional scaling (gamma) parameter. Shape (channels,) - bias: Optional shifting (beta) parameter. Shape (channels,) - N: Batch size. - C: Number of channels. - HxW: Product of spatial dimensions (number of elements per channel after - flattening). - group: Number of groups for Group Normalization. - eps: Small value added for numerical stability. - - Returns: - A tuple of (normalized_output, mean, rstd) - """ - - input_shape = input.shape - - # Reshape for group-wise normalization - reshaped_input = jnp.reshape(input, (1, N * group, -1)) - - # **Core Group Normalization** - def group_norm_body(x): # Function to apply within each group - mean = jnp.mean(x, axis=-1, keepdims=True) - var = jnp.var(x, axis=-1, keepdims=True) - rstd = jax.lax.rsqrt(var + eps) # Reciprocal of std with epsilon - normalized = (x - mean) * rstd - return normalized, mean, rstd - - normalized, group_mean, group_rstd = jax.lax.map(group_norm_body, - reshaped_input) - - # Reshape back to original input shape - output = jnp.reshape(normalized, input_shape) - - # **Affine transformation** - affine_shape = [-1 if i == 1 else 1 for i in range(input.ndim) - ] # Shape for broadcasting - if weight is not None and bias is not None: - output = bias.reshape(affine_shape) + output * weight.reshape(affine_shape) - elif weight is not None: - output = output * weight.reshape(affine_shape) - elif bias is not None: - output = output + bias.reshape(affine_shape) - - # Reshape mean and rstd - mean = jnp.reshape(group_mean, (N, group)) - rstd = jnp.reshape(group_rstd, (N, group)) - - return output, mean, rstd - - -@op(torch.ops.aten.linalg_vector_norm) -def _aten_linalg_vector_norm(self, ord=2, dim=None, keepdim=False, dtype=None): - """Calculates the vector norm along specified dimensions. - - Args: - self: The input tensor. - ord: The order of the norm. Can be a float or 'inf', '-inf', 'fro'. - Default is 2 (Euclidean norm). - dim: Dimensions along which to calculate the norm. If None, the norm is - calculated over all dimensions. - keepdim: Whether to keep the reduced dimensions. - dtype: Optional data type for the output. - - Returns: - The tensor containing the calculated vector norms. - """ - - if ord not in {2, float("inf"), float("-inf"), "fro"}: - raise ValueError( - f"Unsupported ord value: {ord}. Supported values are 2, inf, -inf, and" - " 'fro'.") - - # Special cases (for efficiency and clarity) - if ord == 2: # Euclidean norm - result = jnp.sqrt(jnp.sum(jnp.abs(self)**2, axis=dim, keepdims=keepdim)) - - elif ord == float("inf"): - result = jnp.max(jnp.abs(self), axis=dim, keepdims=keepdim) - - elif ord == float("-inf"): - result = jnp.min(jnp.abs(self), axis=dim, keepdims=keepdim) - - elif ord == "fro": # Frobenius norm - result = jnp.sqrt(jnp.sum(jnp.abs(self)**2, axis=dim, keepdims=keepdim)) - - else: # General case (e.g., ord = 1, ord = 3) - result = jnp.sum( - jnp.abs(self)**ord, axis=dim, keepdims=keepdim)**(1.0 / ord) - - # (Optional) dtype conversion - if dtype is not None: - result = result.astype(dtype) - - return result - - -# aten.reflection_pad1d -@op(torch.ops.aten.reflection_pad1d) -def _aten_reflection_pad1d(input, padding): - rank = len(input.shape) - pad_size = [(0, 0)] * rank - pad_size[-1] = padding - return jnp.pad(input, pad_size, mode="reflect") - - -# aten.alias -@op(torch.ops.aten.alias) -def _aten_alias(self, *args): - return self - - -# aten.sinh -@op(torch.ops.aten.sinh) -def _aten_sinh(self): - return jnp.sinh(self) - - -# aten.native_layer_norm_backward -@op(torch.ops.aten.native_layer_norm_backward) -def _aten_native_layer_norm_backward(grad_out, - input, - normalized_shape, - weight, - bias, - eps=1e-5): - """Implements the backward pass of layer normalization in Jax as defined by `aten::native_layer_norm_backward`. - - Args: - grad_out: The gradient of the output tensor. - input: The input tensor. - normalized_shape: A list of integer dimensions to be normalized over. - weight: Optional weight tensor for the affine transformation. - bias: Optional bias tensor for the affine transformation. - eps: A small epsilon value for numerical stability. - - Returns: - A tuple of (grad_input, grad_weight, grad_bias). - """ - return jax.lax.native_layer_norm_backward(grad_out, input, normalized_shape, - weight, bias, eps) - - -# aten.reflection_pad3d_backward -# aten.reflection_pad2d - - -# aten.atanh -@op(torch.ops.aten.atanh) -def _aten_atanh(self): - return jnp.arctanh(self) - - -# aten.bitwise_not -@op(torch.ops.aten.bitwise_not) -def _aten_bitwise_not(self): - return ~self - - -# aten.embedding_dense_backward - - -# aten.sum -@op(torch.ops.aten.sum) -def _aten_sum(self, dim=None, keepdim=False, dtype=None): - return jnp.sum(self, axis=dim, keepdims=keepdim, dtype=dtype) - - -# aten.sqrt -@op(torch.ops.aten.sqrt) -def _aten_sqrt(self): - return jnp.sqrt(self) - - -@op(torch.ops.aten.tan) -def _aten_tanh(self): - return jnp.tan(self) - - -# aten.tanh -@op(torch.ops.aten.tanh) -def _aten_tanh(self): - return jnp.tanh(self) - - -# aten.ceil -@op(torch.ops.aten.ceil) -def _aten_ceil(self): - return jnp.ceil(self) - - -# aten.asin -@op(torch.ops.aten.asin) -def _aten_asin(self): - return jnp.arcsin(self) - - -# aten.minimum -@op(torch.ops.aten.minimum) -def _aten_minimum(self, other): - return jnp.minimum(self, other) - - -# aten.max_pool2d_backward - - -def _scatter_index(dim, index): - """Returns a tuple of indexes; - - The first is to select in input (to modify), - the second is to select from the values. - """ - index_shape = list(index.shape) - input_indexes = [] - source_indexes = [] - for i in range(len(index_shape)): - source_indexes.append(slice(0, index_shape[i])) - if i == dim: - input_indexes.append(index) - else: - target_shape = [1] * len(index_shape) - target_shape[i] = index_shape[i] - input_indexes.append( - jnp.broadcast_to( - jnp.arange(index_shape[i]).reshape(target_shape), index_shape)) - return tuple(input_indexes), tuple(source_indexes) - - -# aten.scatter_add -@op(torch.ops.aten.scatter_add) -def _aten_scatter_add(input, dim, index, src): - """JAX implementation of scatter, mimicking torch.scatter behavior""" - - input_indexes, source_indexes = _scatter_index(dim, index) - return input.at[input_indexes].add(src[source_indexes]) - - -# aten.logical_not - - -# aten.sign -@op(torch.ops.aten.sign) -def _aten_sign(x): - return jnp.sign(x) - - -# aten.sigmoid -@op(torch.ops.aten.sigmoid) -def _aten_sigmoid(x): - if x.dtype in (jnp.int32, jnp.int64): - x = x.astype(jnp.float32) - return jax.nn.sigmoid(x) - - -# implement aten.asinh in jax -@op(torch.ops.aten.asinh) -def _aten_asinh(self): - return jnp.arcsinh(self) - - -# aten.atan -@op(torch.ops.aten.atan) -def _aten_atan(self): - return jnp.arctan(self) - - -# aten.scatter_reduce -@op(torch.ops.aten.scatter_reduce) -def _aten_scatter_reduce(input, dim, index, src, reduce, *, include_self=True): - input_indexes, source_indexes = _scatter_index(dim, index) - if reduce == "sum": - return input.at[input_indexes].add(src[source_indexes]) - elif reduce == "prod": - return input.at[input_indexes].multiply(src[source_indexes]) - elif reduce == "mean": - return input.at[input_indexes].add(src[source_indexes]) - elif reduce == "amax": - return input.at[input_indexes].max(src[source_indexes]) - elif reduce == "amin": - return input.at[input_indexes].min(src[source_indexes]) - else: - raise RuntimeError('Unknow reduction type: ', reduce) - - -# aten.acos -@op(torch.ops.aten.acos) -def _aten_acos(self): - return jnp.arccos(self) - - -# aten.sym_storage_offset -# aten.native_layer_norm_backward -# aten.max_pool3d_with_indices - - -# aten.gt -@op(torch.ops.aten.gt) -def _aten_gt(self, other): - return self > other - - -# aten.pixel_shuffle -@op(torch.ops.aten.pixel_shuffle) -def _aten_pixel_shuffle(x, upscale_factor): - """PixelShuffle implementation in JAX. - - Args: - x: Input tensor. Typically a feature map. - upscale_factor: Integer by which to upscale the spatial dimensions. - - Returns: - Tensor after PixelShuffle operation. - """ - - batch_size, channels, height, width = x.shape - - if channels % (upscale_factor**2) != 0: - raise ValueError( - 'Number of channels must be divisible by the square of the upscale factor.' - ) - - new_channels = channels // (upscale_factor**2) - new_height = height * upscale_factor - new_width = width * upscale_factor - - x = x.reshape(batch_size, new_channels, upscale_factor, upscale_factor, - height, width) - x = jnp.transpose(x, - (0, 1, 2, 4, 3, 5)) # Move channels to spatial dimensions - x = x.reshape(batch_size, new_channels, new_height, new_width) - - return x - - -# aten.sym_stride -# aten.lt -@op(torch.ops.aten.lt) -def _aten_lt(self, other): - return self < other - - -def pool(inputs, init, reduce_fn, window_shape, strides, padding): - """Helper function to define pooling functions. - - Pooling functions are implemented using the ReduceWindow XLA op. - NOTE: Be aware that pooling is not generally differentiable. - That means providing a reduce_fn that is differentiable does not imply that - pool is differentiable. - - Args: - inputs: input data with dimensions (batch, window dims..., features). - init: the initial value for the reduction - reduce_fn: a reduce function of the form ``(T, T) -> T``. - window_shape: a shape tuple defining the window to reduce over. - strides: a sequence of ``n`` integers, representing the inter-window - strides (default: ``(1, ..., 1)``). - padding: either the string ``'SAME'``, the string ``'VALID'``, or a sequence - of ``n`` ``(low, high)`` integer pairs that give the padding to apply before - and after each spatial dimension. - Returns: - The output of the reduction for each window slice. - """ - num_batch_dims = inputs.ndim - (len(window_shape) + 1) - strides = strides or (1,) * len(window_shape) - assert len(window_shape) == len( - strides), f'len({window_shape}) must equal len({strides})' - strides = (1,) * (1 + num_batch_dims) + strides - dims = (1,) * (1 + num_batch_dims) + window_shape - - is_single_input = False - if num_batch_dims == 0: - # add singleton batch dimension because lax.reduce_window always - # needs a batch dimension. - inputs = inputs[None] - strides = (1,) + strides - dims = (1,) + dims - is_single_input = True - - assert inputs.ndim == len(dims), f'len({inputs.shape}) != len({dims})' - if not isinstance(padding, str): - padding = tuple(map(tuple, padding)) - assert len(padding) == len(window_shape), ( - f'padding {padding} must specify pads for same number of dims as ' - f'window_shape {window_shape}') - assert all([len(x) == 2 for x in padding - ]), f'each entry in padding {padding} must be length 2' - padding = ((0, 0), (0, 0)) + padding - y = jax.lax.reduce_window(inputs, init, reduce_fn, dims, strides, padding) - if is_single_input: - y = jnp.squeeze(y, axis=0) - return y - - -@op(torch.ops.aten._adaptive_avg_pool3d) -def _aten_adaptive_avg_pool3d(x, output_shape): - return _aten_adaptive_avg_pool(x, output_shape, 3) - - -@op(torch.ops.aten._adaptive_avg_pool2d) -def _aten_adaptive_avg_pool3d(x, output_shape): - return _aten_adaptive_avg_pool(x, output_shape, 2) - - -def _aten_adaptive_avg_pool(x, output_shape, pool_dim): - - def adaptive_kernel_size(input_shape, output_shape): - sizes = [1, 1] - spatial_dim_off = len(input_shape) - pool_dim - for spatial_dim in range(pool_dim): - sizes.append(input_shape[spatial_dim_off + spatial_dim] // - output_shape[spatial_dim]) - return tuple(sizes) - - kernel_sizes = adaptive_kernel_size(x.shape, output_shape) - y = pool(x, 0.0, jax.lax.add, kernel_sizes, kernel_sizes, padding='VALID') - - div_shape = list(x.shape) - num_batch_dims = len(x.shape) - pool_dim - 1 - div_shape[num_batch_dims] = 1 - div_shape = tuple(div_shape) - if len(div_shape) - 2 == len(kernel_sizes): - div_shape = (1,) + div_shape[1:] - y = y / pool( - jnp.ones(div_shape), 0.0, jax.lax.add, kernel_sizes, kernel_sizes, - 'VALID') - return y - - -# aten.avg_pool2d -@op(torch.ops.aten.avg_pool2d) -@op(torch.ops.aten.avg_pool3d) -def _aten_avg_pool(inputs, - kernel_size, - strides=None, - padding=0, - ceil_mode=False, - count_include_pad=True, - divisor_override=None): - - num_batch_dims = len(inputs.shape) - len(kernel_size) - 1 - kernel_size = tuple(kernel_size) - strides = tuple(strides) - if isinstance(padding, int): - padding = tuple((padding, padding) for _ in range(len(kernel_size))) - elif isinstance(padding, list): - padding = tuple((p, p) for p in padding) - - y = pool(inputs, 0.0, jax.lax.add, kernel_size, strides, padding) - if count_include_pad: - y = y / np.prod(kernel_size) - else: - div_shape = list(inputs.shape) - div_shape[num_batch_dims] = 1 - div_shape = tuple(div_shape) - if len(div_shape) - 2 == len(kernel_size): - div_shape = (1,) + div_shape[1:] - y = y / pool( - jnp.ones(div_shape), 0.0, jax.lax.add, kernel_size, strides, padding) - return y - - -# aten.sym_numel -# aten.reciprocal -@op(torch.ops.aten.reciprocal) -def _aten_reciprocal(a): - return 1 / a - - -# aten.scatter -@op(torch.ops.aten.select_scatter) -def _aten_select_scatter(input, src, dim, index): - input_indexes = [] - for x in range(len(input.shape)): - if x == dim: - input_indexes.append(index) - else: - input_indexes.append(slice(None, None, None)) - return input.at[tuple(input_indexes)].set(src) - - -@op(torch.ops.aten.scatter.src) -def _aten_scatter_src(input, dim, index, src, reduce=None): - input_index, source_indexes = _scatter_index(dim, index) - return input.at[input_index].set(src[source_indexes]) - - -@op(torch.ops.aten.scatter.value) -def _aten_scatter(input, dim, index, src, reduce=None): - input_index, source_indexes = _scatter_index(dim, index) - return input.at[input_index].set(src) - - -# aten.acosh -@op(torch.ops.aten.acosh) -def _aten_acosh(self): - return jnp.arccosh(self) - - -# aten.avg_pool2d_backward -# aten.col2im -# aten.avg_pool3d -# aten.round -@op(torch.ops.aten.round) -def _aten_round(input, decimals=0): - return jnp.round(input, decimals) - - -# aten.max -@op(torch.ops.aten.max) -def _aten_max(self, dim=None, keepdim=False): - return jnp.max( - self, axis=dim, keepdims=keepdim), jnp.argmax( - self, axis=dim, keepdims=keepdim) - - -# aten.maximum -@op(torch.ops.aten.maximum) -def _aten_maximum(self, other): - return jnp.maximum(self, other) - - -# aten.abs -@op(torch.ops.aten.abs) -def _aten_abs(self): - return jnp.abs(self) - - -# generate aten.amax only -@op(torch.ops.aten.amax) -def _aten_amax(self, dim=None, keepdim=False): - return _with_reduction_scalar(jnp.amax, self, dim, keepdim) - - -def _with_reduction_scalar(jax_func, self, dim, keepdim): - expanded = False - if self.ndim == 0: - # for self of rank 0: - # torch.any(x, 0), torch.any(x, -1) works; - # torch.any(x, 1) throws out of bounds, so it's - # behavior is the same as a jnp array of rank 1 - expanded = True - self = jnp.expand_dims(self, 0) - res = jax_func(self, axis=dim, keepdims=keepdim) - if expanded: - res = res.squeeze() - return res - -# aten.any -@op(torch.ops.aten.any) -def _aten_any(self, dim=None, keepdim=False): - return _with_reduction_scalar(jnp.any, self, dim, keepdim) - - -# aten.arange -@op(torch.ops.aten.arange) -def _aten_arange(start, - end=None, - step=1, - *, - dtype=None, - layout=None, - requires_grad=False, - device=None, - pin_memory=False): - if end is None: - end = start - start = 0 - dtype = tensor.t2j_dtype(dtype) - return jnp.arange( - start, - end, - step, - dtype=dtype, - ) - - -# aten.argmax -@op(torch.ops.aten.argmax) -def _aten_argmax(self, dim=None, keepdim=False): - return _with_reduction_scalar( - jnp.argmax, self, dim, keepdim) - - -# aten.as_strided -@op(torch.ops.aten.as_strided) -@op(torch.ops.aten.as_strided_copy) -def _aten_as_strided(x, sizes, strides, storage_offset=None): - ind = jnp.zeros(sizes, dtype=jnp.int32) - - for i, (size, stride) in enumerate(zip(sizes, strides)): - result_shape = (1,) * i + (size,) + (1,) * (len(sizes) - i - 1) - indexes = (jnp.arange(size) * stride).reshape(result_shape) - ind += indexes - - return jnp.ravel(x)[ind] - - -# aten.atan2 -@op(torch.ops.aten.atan2) -def _aten_atan2(self, other): - return jnp.arctan2(self, other) - - -# aten.bitwise_and -@op(torch.ops.aten.bitwise_and) -def _aten_bitwise_and(self, other): - return self & other - - -# aten.bitwise_or -@op(torch.ops.aten.bitwise_or) -def _aten_bitwise_or(self, other): - return self | other - - -# aten.bitwise_xor -@op(torch.ops.aten.bitwise_xor) -def _aten_bitwise_xor(self, other): - return self ^ other - - -# aten.clamp -@op(torch.ops.aten.clamp) -def _aten_clamp(self, min=None, max=None): - return jnp.clip(self, min, max) - - -# aten.constant_pad_nd -@op(torch.ops.aten.constant_pad_nd) -def _aten_constant_pad_nd(input, padding, value=0): - # NOTE: Torch padding is flat and reversed: (1, 1, 2, 2) - # means last dim get padded 1 in front and 1 in back; - # and second last dim get padded 2 in front and 2 in back. - # Jax padding tuple of 2-tuple: the same padding is - # [(0, 0), ..., (2,2), (1,1)] - m = len(padding) - rev_padding = [(padding[i - 1], padding[i]) for i in range(m - 1, 0, -2)] - pad_dim = tuple(([(0, 0)] * (len(input.shape) - m // 2)) + rev_padding) - return jnp.pad(input, pad_dim, mode="constant", constant_values=value) - - -# aten.convolution_backward -@op(torch.ops.aten.copy) -@op(torch.ops.aten.lift_fresh_copy) -def _aten_copy(x): - return jnp.copy(x) - - -@op(torch.ops.aten._cdist_forward) -def _aten_cdist_forward(x1, x2, p, compute_mode=''): - # x1 is B x P x M - # x2 is B x Q x M - # res is B x P x Q - x1 = jnp.expand_dims(x1, len(x1.shape) - 1) - x2 = jnp.expand_dims(x2, len(x2.shape) - 2) - return jnp.linalg.norm(x1 - x2, ord=p, axis=-1) - - -@op(torch.ops.aten._pdist_forward) -def _aten__pdist_forward(x, p): - pairwise_dists = _aten_cdist_forward(x, x, p) - condensed_dists = pairwise_dists[jnp.triu_indices( - pairwise_dists.shape[0], k=1)] - return condensed_dists - - -# aten.cos -@op(torch.ops.aten.cos) -def _aten_cos(input): - return jnp.cos(input) - - -# aten.cosh -@op(torch.ops.aten.cosh) -def _aten_cosh(input): - return jnp.cosh(input) - - -# aten.diagonal -@op(torch.ops.aten.diagonal) -def _aten_diagonal(input, offset=0, dim1=0, dim2=1): - return jnp.diagonal(input, offset, dim1, dim2) - - -# aten.empty_strided -# aten.eq -@op(torch.ops.aten.eq) -def _aten_eq(input1, input2): - return input1 == input2 - - -# aten.erf -@op(torch.ops.aten.erf) -def _aten_erf(x): - if x.dtype in (jnp.int32, jnp.int64): - x = x.astype(jnp.float32) - return jax.lax.erf(x) - - -# aten.exp -@op(torch.ops.aten.exp) -def _aten_exp(input): - return jnp.exp(input) - - -# aten.expm1 -@op(torch.ops.aten.expm1) -def _aten_expm1(input): - return jnp.expm1(input) - - -# aten.fill -@op(torch.ops.aten.fill) -@op(torch.ops.aten.full_like) -def _aten_fill(x, value, dtype=None, pin_memory=None, memory_format=None): - if dtype is None: - dtype = x.dtype - else: - dtype = tensor.t2j_dtype(dtype) - return jnp.full(x.shape, value, dtype) - - -# aten.flip -@op(torch.ops.aten.flip) -def _aten_flip(input, dims): - if dims is not None: - return jnp.flip(input, tuple(dims)) - else: - return jnp.flip(input) - - -# aten.floor -@op(torch.ops.aten.floor) -def _aten_floor(input): - return jnp.floor(input) - - -# aten.fmod -@op(torch.ops.aten.fmod) -def _aten_fmod(input, other): - return input - other * _aten_div(input, other, 'trunc') - - -# aten.gather -@op(torch.ops.aten.gather) -def _aten_gather(input, dim, index): - input_indexes, source_indexes = _scatter_index(dim, index) - return input[input_indexes] - - -# aten.ge -@op(torch.ops.aten.ge) -def _aten_ge(self, other): - return self >= other - - -@op(torch.ops.aten.glu) -@op(torch.ops.aten.glu.default) -def _aten_glu(x, dim=-1): - return jax.nn.glu(x, dim) - - -# aten.hardtanh -@op(torch.ops.aten.hardtanh) -def _aten_hardtanh(input, min_val=-1., max_val=1., inplace=False): - return jnp.clip(input, min_val, max_val) - - -# aten.isinf -@op(torch.ops.aten.isinf) -def _aten_isinf(input): - return jnp.isinf(input) - - -# aten.isnan -@op(torch.ops.aten.isnan) -def _aten_isnan(input): - return jnp.isnan(input) - - -@op(torch.ops.aten.le) -def _aten_le(self, other): - return self <= other - - -# aten.leaky_relu -@op(torch.ops.aten.leaky_relu) -def _aten_leaky_relu(x, negative_slope): - return jax.nn.leaky_relu(x, negative_slope) - - -# aten.log -@op(torch.ops.aten.log) -def _aten_log(x): - return jnp.log(x) - - -# aten.log10 -@op(torch.ops.aten.log10) -def _aten_log10(x): - return jnp.log10(x) - - -# aten.log1p -@op(torch.ops.aten.log1p) -def _aten_log1p(x): - return jnp.log1p(x) - - -# aten.log2 -@op(torch.ops.aten.log2) -def _aten_log2(x): - return jnp.log2(x) - - -# aten.logical_and -@op(torch.ops.aten.logical_and) -def _aten_logical_and(self, other): - return jnp.logical_and(self, other) - - -# aten.logical_or -@op(torch.ops.aten.logical_or) -def _aten_logical_or(self, other): - return jnp.logical_or(self, other) - - -# aten.logical_not -@op(torch.ops.aten.logical_not) -def _aten_logical_not(self): - return jnp.logical_not(self) - - -# aten.log_softmax -@op(torch.ops.aten._log_softmax) -def _aten_log_softmax(self, axis=-1, half_to_float=False): - return jax.nn.log_softmax(self, axis) - - -# aten.max_pool3d_backward -# aten.logical_xor -@op(torch.ops.aten.logical_xor) -def _aten_logical_xor(self, other): - return jnp.logical_xor(self, other) - - -# aten.max_pool2d_with_indices_backward -# aten.native_dropout -# aten.native_group_norm_backward -# aten.neg -@op(torch.ops.aten.neg) -def _aten_neg(x): - return -1 * x - - -# aten.nonzero -@op(torch.ops.aten.nonzero) -def _aten_nonzero(x): - index_tuple = jnp.nonzero(x) - index_tuple = [jnp.expand_dims(p, -1) for p in index_tuple] - return jnp.concatenate(index_tuple, axis=-1) - - -# aten.prod - - -@op(torch.ops.aten.prod) -def _aten_prod(self, dim=None, keepdim=False): - return jnp.prod(self, axis=dim, keepdims=keepdim) - - -# aten.rand -# aten.randn -# aten.randperm -# aten.reflection_pad3d -# aten.remainder -@op(torch.ops.aten.remainder) -def _aten_remainder(inputs, other): - return inputs % other - - -# aten.repeat -@op(torch.ops.aten.repeat) -def _aten_repeat(x, reps): - return jnp.tile(x, reps) - - -# aten.replication_pad2d -# aten.replication_pad3d -# aten.roll -@op(torch.ops.aten.roll) -def _aten_roll(input, shifts, dims=None): - return jnp.roll(input, shifts, dims) - - -# aten.scalar_tensor -# aten.slice_scatter -@op(torch.ops.aten.slice_scatter) -def _aten_slice_scatter(input, src, dim=0, start=None, end=None, step=1): - input_index = [] - for x in range(len(input.shape)): - if x == dim: - input_index.append(slice(start, end, step)) - else: - input_index.append(slice(None, None, None)) - return input.at[tuple(input_index)].set(src) - - -# aten.sort -# torch.sort(input, dim=-1, descending=False, stable=False, *, out=None) -@op(torch.ops.aten.sort) -def _aten_sort(a, dim=-1, descending=False, stable=False): - return ( - jnp.sort(a, axis=dim, stable=stable, descending=descending), - jnp.argsort(a, axis=dim, stable=stable, descending=descending), - ) - - -# aten.sym_size - - -# aten.topk -@op(torch.ops.aten.topk) -def _aten_topk(input, k, dim=None, largest=True, sorted=True, *, out=None): - """JAX top-k implementation using jax.lax.top_k for improved efficiency. - - Args: - input: The input JAX array. - k: The number of top elements to return. - dim: The dimension along which to find the top-k. If None, operates on the - flattened array. - largest: If True, returns the largest k elements. Otherwise, smallest k. - sorted: If True, returns the elements in sorted order. - - Returns: - A tuple (values, indices) containing: - - values: The top k values. - - indices: The indices of the top k values in the original array. - """ - if dim is None: - input = input.flatten() - dim = 0 - - if not largest: - input = -input # Find top-k of negated input if we want the smallest - - transpose_shape = None - if dim != -1 and dim != len(input.shape) - 1: - transpose_shape = list(range(len(input.shape))) - transpose_shape[dim], transpose_shape[-1] = (transpose_shape[-1], - transpose_shape[dim]) - input = jnp.transpose(input, transpose_shape) - - values, indices = jax.lax.top_k(input, k) - - if sorted: - values = jnp.sort(values, descending=True) - indices = jnp.take_along_axis( - indices, jnp.argsort(values, axis=-1, descending=True), axis=-1) - - if not largest: - values = -values # Negate values back if we found smallest - - if transpose_shape is not None: - values = jnp.transpose(values, transpose_shape) - indices = jnp.transpose(indices, transpose_shape) - - return values, indices - - -# aten.trunc -@op(torch.ops.aten.trunc) -def _aten_trunc(a): - return jnp.trunc(a) - - -@op(torch.ops.aten.unbind) -@op(torch.ops.aten.unbind_copy) -def _aten_unbind(a, dim=0): - return tuple( - _aten_squeeze_dim(jax.lax.index_in_dim(a, i, axis=dim), dim) - for i in range(a.shape[dim])) - - -# NOTE: skip aten.upsample_nearest2d and aten.upsample_bilinear2d -# despite those being core aten ops, they also have decompositions. -# here we are using torch decompositions. - - -# aten.where -@op(torch.ops.aten.where) -def _aten_where(condition, x, y): - return jnp.where(condition, x, y) - - -# aten.to.dtype -#Tensor(a) self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None -@op(torch.ops.aten.to.dtype) -def _aten_to_dtype(a, - dtype, - non_blocking=False, - copy=False, - memory_format=None): - jaxdtype = tensor.t2j_dtype(dtype) - return a.astype(jaxdtype) - - -# aten.to.device - - -#Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False -@op(torch.ops.aten.var_mean.correction) -def _aten_var_mean_correction(self, dim=None, correction=None, keepdim=False): - return (jnp.var(self, axis=dim, ddof=correction, - keepdims=keepdim), jnp.mean(self, dim, keepdims=keepdim)) - - -@op(torch.ops.aten.scalar_tensor) -def _aten_scalar_tensor(s, - dtype=None, - layout=None, - device=None, - pin_memory=None): - if dtype is not None: - dtype = tensor.t2j_dtype(dtype) - return jnp.array(s, dtype=dtype) - return jnp.array(s) - - -@op(torch.ops.aten.to.device) -def _aten_to_device(x,device, dtype): - return x - - -@op(torch.ops.aten.max_pool2d_with_indices_backward) -def max_pool2d_with_indices_backward_custom(grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices): - - """ - Approximates the gradient calculation of PyTorch's max_pool2d_with_indices_backward. - - Args: - grad_output: The gradient tensor from the preceding layer. - self: The input tensor on which the original max pooling was performed. - kernel_size: The size of the pooling window. - stride: The stride of the pooling window. - padding: The padding applied during max pooling. - dilation: The dilation factor for the pooling operation. - ceil_mode: Whether to use ceil or floor when calculating output shapes. - indices: The indices of the maximum values, as produced by max_pool2d_with_indices. - - Returns: - The calculated gradient with respect to the input (grad_input). - """ - - kH, kW = kernel_size - dH, dW = stride - padH, padW = padding - dilH, dilW = dilation - - # Calculate output shape (may need adjustment based on ceil_mode) - out_shape = jnp.array(self.shape) - grad_input = jnp.zeros_like(self) - - # Iterate over the flattened input and output tensors - for i, idx in enumerate(indices.flatten()): - # Calculate input coordinates corresponding to the maximum value - out_y, out_x = i // grad_output.shape[3], i % grad_output.shape[3] - in_y = out_y * dH - padH + out_y * (dilH - 1) - in_x = out_x * dW - padW + out_x * (dilW - 1) - - # Scatter the gradient to the appropriate input locations (handling potential overlaps) - for y in range(in_y, in_y + kH): - for x in range(in_x, in_x + kW): - if 0 <= y < grad_input.shape[2] and 0 <= x < grad_input.shape[3]: - grad_input = grad_input.at[y, x].add(grad_output.flatten()[i]) - - return grad_input - - -@op(torch.ops.aten._local_scalar_dense) -def _aten_local_scalar_dense(x): - return x.item() - -@op(torch.ops.aten.tensor_split.sections) -def _aten_tensor_split(ary, indices_or_sections, axis=0): - return jnp.array_split(ary, indices_or_sections, axis) - -@op(torch.ops.aten.outer) -def _aten_outer(a, b): - return jnp.outer(a, b) - -@op(torch.ops.aten.allclose) -def _aten_allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False): - return jnp.allclose(input, other, rtol, atol, equal_nan) \ No newline at end of file diff --git a/experimental/torch_xla2/torch_xla2/decompositions.py b/experimental/torch_xla2/torch_xla2/decompositions.py index e85e49e13ee..81b48bb5da8 100644 --- a/experimental/torch_xla2/torch_xla2/decompositions.py +++ b/experimental/torch_xla2/torch_xla2/decompositions.py @@ -90,4 +90,21 @@ def _reflection_or_replication_pad( return result _try_register(aten.replication_pad1d, _replication_pad) -_try_register(aten.replication_pad3d, _replication_pad) \ No newline at end of file +_try_register(aten.replication_pad3d, _replication_pad) + +EXTRA_DECOMP = decomp.get_decompositions([ + torch.ops.aten.upsample_nearest2d, + torch.ops.aten._native_batch_norm_legit.no_stats, + torch.ops.aten._adaptive_avg_pool2d, + torch.ops.aten._adaptive_avg_pool3d, + torch.ops.aten.grid_sampler_2d, + torch.ops.aten.native_dropout, + torch.ops.aten.reflection_pad1d, + torch.ops.aten.reflection_pad2d, + torch.ops.aten.reflection_pad3d, + torch.ops.aten.replication_pad1d, + torch.ops.aten.replication_pad2d, + torch.ops.aten.replication_pad3d, +]) + +EXTRA_DECOMP[torch.ops.aten.uniform] = torch.ops.aten.rand \ No newline at end of file diff --git a/experimental/torch_xla2/torch_xla2/environment.py b/experimental/torch_xla2/torch_xla2/environment.py index 6a71c7d51c0..139597f9cb0 100644 --- a/experimental/torch_xla2/torch_xla2/environment.py +++ b/experimental/torch_xla2/torch_xla2/environment.py @@ -1,26 +1,2 @@ -import jax - - -class Environment: - """This class holds a set of configurations and "globals" needed - - for executing torch program using jax. - Things included so far: - - op registry - PRNGKey - Configs - - Also helper functions to manipulate those. - """ - - _prng_key: jax.random.PRNGKey - - - def __init__(self, random_seed): - self._prng_key = jax.random.PRNGKey(random_seed) - - def get_and_rotate_prng_key(self): - self._prng_key, key = jax.random.split(self._prng_key) diff --git a/experimental/torch_xla2/torch_xla2/export.py b/experimental/torch_xla2/torch_xla2/export.py index 64a3f9d175c..78430a6d537 100644 --- a/experimental/torch_xla2/torch_xla2/export.py +++ b/experimental/torch_xla2/torch_xla2/export.py @@ -2,146 +2,12 @@ """Utilities for exporting a torch program to jax/stablehlo.""" import copy from typing import Any, Dict, Tuple -import jax import torch -from torch.fx import _pytree as fx_pytree -from torch_xla2 import ops_registry, tensor +from torch_xla2.ops import ops_registry +from torch_xla2 import tensor from torch.utils import _pytree as pytree -class JaxProgram: - - def _wrap_inputs(self, xs, allow_torch_tensor=False): - - def convert(t): - if isinstance(t, tensor.XLATensor2): - return t - if isinstance(t, torch.Tensor): - if allow_torch_tensor: - return tensor.move_to_device(t) - else: - raise ValueError('Regular torch.Tensor is not allowed.') - if isinstance(t, jax.Array): - return tensor.XLATensor2(t) - return t - - return jax.tree_util.tree_map(convert, xs) - - def _unwrap_outputs(self, xs): - - def convert(t): - if isinstance(t, tensor.XLATensor2): - return t.jax() - if isinstance(t, torch.Tensor): - raise ValueError('Regular torch.Tensor is not allowed.') - return t - - return jax.tree_util.tree_map(convert, xs) - - def __init__( - self, - exported_program, - param_buffer_values, - ordered_tensor_constants, - ): - - self.param_buffer_values = self._wrap_inputs( - param_buffer_values, allow_torch_tensor=True) - self.ordered_tensor_constants = self._wrap_inputs( - ordered_tensor_constants, allow_torch_tensor=True) - self.exported_program = exported_program - - def __hash__(self): - return hash(self.exported_program) - - @property - def example_inputs(self): - args, kwargs = self.exported_program.example_inputs - args = pytree.tree_map(tensor.t2j, args) - kwargs = pytree.tree_map(tensor.t2j, kwargs) - return args, kwargs - - def flatten_inputs(self, args, kwargs): - if args is None: - args = tuple() - if kwargs is None: - kwargs = {} - - if (in_spec := self.exported_program.call_spec.in_spec) is not None: - if (in_spec.type == tuple and len(in_spec.children_specs) == 2 and - in_spec.children_specs[0].type == tuple and - in_spec.children_specs[1].type == dict): - # NOTE: this is the case where in_spec is for both args and kwargs - return fx_pytree.tree_flatten_spec((args, kwargs), in_spec) - return fx_pytree.tree_flatten_spec(args, in_spec) - return copy.deepcopy(args) - - def unflatten_outputs(self, res): - return pytree.tree_unflatten(res, self.exported_program.call_spec.out_spec) - - def __call__(self, *args, **kwargs): - - inputs = self.flatten_inputs(args, kwargs) - res = self.flatten_callable(*inputs) - res = self.unflatten_outputs(res) - - return res - - @property - def flatten_callable(self): - - def func(*inputs: jax.Array): - nonlocal self - inputs = self._wrap_inputs(inputs) - num_mutations = len( - self.exported_program.graph_signature.buffers_to_mutate) - res = torch.fx.Interpreter(self.exported_program.graph_module).run( - *self.param_buffer_values, - *inputs, - *self.ordered_tensor_constants, - enable_io_processing=False, - ) - res = res[num_mutations:] - res = self._unwrap_outputs(res) - return res - - return func - - def jit(self, *args, **kwargs): - """Returns `jax.jit(self, *args, **kwargs)`.""" - return jax.jit(self, *args, **kwargs) - - def jit_lower(self, *args, **kwargs): - """Returns `jax.jit(self, *args, **kwargs).lower(...)` with example_inputs used in export.""" - example_args, example_kwargs = self.example_inputs - return self.jit(*args, **kwargs).lower(*example_args, **example_kwargs) - - -def exported_program_to_jax_program(ep): - """exported_program_to_jax_program. - - Args: - ep: torch.export.ExportedProgram - - Returns: - JaxProgram - - """ - if torch.__version__ >= '2.2': - ep = ep.run_decompositions() - - param_buffer_keys = ep.graph_signature.parameters + ep.graph_signature.buffers - param_buffer_values = tuple(ep.state_dict[key] for key in param_buffer_keys) - - if hasattr(ep.graph_signature, 'lifted_tensor_constants'): - ordered_tensor_constants = tuple( - ep.tensor_constants[name] - for name in ep.graph_signature.lifted_tensor_constants) - else: - ordered_tensor_constants = tuple() - - return JaxProgram(ep, param_buffer_values, ordered_tensor_constants) - DEBUG = False @@ -149,6 +15,11 @@ def exported_program_to_jax_program(ep): class JaxInterpreter(torch.fx.Interpreter): """Experimental.""" + def __init__(self, graph_module): + super().__init__(graph_module) + import torch_xla2.ops.jaten + import torch_xla2.ops.jtorch + def call_function(self, target, args: Tuple, kwargs: Dict) -> Any: if not isinstance(target, (torch._ops.OpOverloadPacket, torch._ops.OpOverload)): @@ -157,7 +28,9 @@ def call_function(self, target, args: Tuple, kwargs: Dict) -> Any: if DEBUG: print('Running ', target.name(), '--------') - op = ops_registry.lowerings.lookup(target) + op = ops_registry.all_aten_ops.get(target) + if op is None: + op = ops_registry.all_aten_ops.get(target.overloadpacket) if op is None: print(target.name(), target.tags) raise RuntimeError('No lowering found for', target.name()) diff --git a/experimental/torch_xla2/torch_xla2/extra.py b/experimental/torch_xla2/torch_xla2/extra.py deleted file mode 100644 index ebfdb96b1db..00000000000 --- a/experimental/torch_xla2/torch_xla2/extra.py +++ /dev/null @@ -1,62 +0,0 @@ -import jax -import jax.numpy as jnp -import functools -import torch -from torch.utils import _pytree as pytree -from torch_xla2 import tensor - -def torch_view(t): - # t is an object from jax land - # view it as-if it's a torch land object - if isinstance(t, jax.Array): - return tensor.XLATensor2(t) - if isinstance(t, type(jnp.int32)): - return tensor.t2j_type(t) - if callable(t): - def new_t(*args, **kwargs): - # args, kwargs are torch-land - args, kwargs = pytree.tree_map(jax_view, (args, kwargs)) - # now they are objs in jax-land - res = t(*args, **kwargs) # t is jax callable - # res is jax-land obj - return pytree.tree_map(torch_view, res) - return new_t - # regular types are not changed - return t - - -def jax_view(t): - # t is an object from torch land - # view it as-if it's a jax land object - if isinstance(t, torch.Tensor): - assert isinstance(t, tensor.XLATensor2) - return t.jax() - if isinstance(t, type(torch.int32)): - return tensor.j2t_dtype(t) - if callable(t): - def new_t(*args, **kwargs): - # args, kwargs are jax-land - args, kwargs = pytree.tree_map(torch_view, (args, kwargs)) - # now they are objs in torch-land - res = t(*args, **kwargs) - # res is torch-land obj - return pytree.tree_map(jax_view, res) - return new_t - # regular types are not changed - return t - -def call_jax(jax_func, *args, **kwargs): - return torch_view(jax_func)(*args, **kwargs) - - -def call_torch(torch_func, *args, **kwargs): - return jax_view(torch_func)(*args, **kwargs) - - -fori_loop = torch_view(jax.lax.fori_loop) - -def jax_jit(torch_function, kwargs_for_jax_jit=None): - kwargs_for_jax_jit = kwargs_for_jax_jit or {} - jax_func = jax_view(torch_function) - jitted = jax.jit(jax_func, **kwargs_for_jax_jit) - return torch_view(jitted) diff --git a/experimental/torch_xla2/torch_xla2/functions.py b/experimental/torch_xla2/torch_xla2/functions.py deleted file mode 100644 index 94320fd7cb2..00000000000 --- a/experimental/torch_xla2/torch_xla2/functions.py +++ /dev/null @@ -1,135 +0,0 @@ -"""Tensor constructor overrides""" -import functools -import logging -from typing import Callable, Optional, ParamSpec, Sequence - -import jax -import torch -import jax.numpy as jnp -from torch_xla2 import tensor - -registry = {} - -P = ParamSpec('P') - - -def register_function(torch_func: Callable[P, torch.Tensor]): - """Registers a function as the JAX implementation of a torch function.""" - - def decorator(jax_impl: Callable[P, jax.Array]): - registry[torch_func] = jax_impl - return jax_impl - - return decorator - - -def convert_dtype(use_default_dtype: bool = True): - """Converts `dtype` kwarg of function from torch to JAX. - - Args: - use_default_dtype: Whether to use torch default dtype if none is provided. - - Returns: - A decorator that wraps a JAX implementation of a torch function. - """ - - def decorator(func: Callable[P, torch.Tensor]): - - @functools.wraps(func) - def wrapper(*args: P.args, - dtype: Optional[torch.dtype] = None, - **kwargs: P.kwargs): - if not dtype and use_default_dtype: - dtype = torch.get_default_dtype() - jax_dtype = tensor.t2j_dtype(dtype) - - return func(*args, dtype=jax_dtype, **kwargs) - - return wrapper - - return decorator - - -@register_function(torch.tensor) -@convert_dtype(use_default_dtype=False) # Attempt to infer type from elements -def _tensor(data, *, dtype=None, **kwargs): - python_types_to_torch_types = { - bool: jnp.bool, - int: jnp.int64, - float: jnp.float32, - complex: jnp.complex64, - } - if not dtype: - leaves = jax.tree_util.tree_leaves(data) - if len(leaves) > 0: - dtype = python_types_to_torch_types.get(type(leaves[0])) - - return jnp.array( - data, dtype=dtype or tensor.t2j_dtype(torch.get_default_dtype())) - - -@register_function(torch.ones) -@convert_dtype() -def _ones(*size: int, dtype=None, **kwargs): - return jnp.ones(size, dtype) - - -@register_function(torch.zeros) -@convert_dtype() -def _zeros(*size: int, dtype=None, **kwargs): - return jnp.zeros(size, dtype) - - -@register_function(torch.eye) -@convert_dtype() -def _eye(n: int, m: Optional[int] = None, *, dtype=None, **kwargs): - return jnp.eye(n, m, dtype=dtype) - - -@register_function(torch.full) -@convert_dtype() -def _full(size: Sequence[int], fill_value, *, dtype=None, **kwargs): - # TODO: handle torch.Size - return jnp.full(size, fill_value, dtype=dtype) - -@register_function(torch.allclose) -def _aten_allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False): - return jnp.allclose(input, other, rtol, atol, equal_nan) - -@register_function(torch.angle) -def _torch_angle(input): - return jnp.angle(input) - - -@register_function(torch.argsort) -def _torch_argsort(input, dim=-1, descending=False, stable=False): - expanded = False - if input == 0: - # for self of rank 0: - # torch.any(x, 0), torch.any(x, -1) works; - # torch.any(x, 1) throws out of bounds, so it's - # behavior is the same as a jnp array of rank 1 - expanded = True - input = jnp.expand_dims(input, 0) - res = jnp.argsort(input, axis=dim, descending=descending, - stable=stable) - if expanded: - res = res.squeeze() - return res - - - -class XLAFunctionMode(torch.overrides.TorchFunctionMode): - """Context manager that dispatches torch function calls to JAX.""" - - def __torch_function__(self, - func, - types, - args=(), - kwargs=None) -> torch.Tensor: - jax_func = registry.get(func) - if not jax_func: - return func(*args, **(kwargs or {})) - - # TODO: unwrap args here or in implementations? - return tensor.wrap(jax_func(*args, **(kwargs or {}))) diff --git a/experimental/torch_xla2/torch_xla2/interop.py b/experimental/torch_xla2/torch_xla2/interop.py new file mode 100644 index 00000000000..fbcd47922e1 --- /dev/null +++ b/experimental/torch_xla2/torch_xla2/interop.py @@ -0,0 +1,65 @@ +import functools +import torch +import jax +import jax.numpy as jnp +from jax import tree_util as pytree +from torch_xla2 import tensor +import torch_xla2 + +from torch_xla2.types import JaxValue, TorchValue, JaxCallable, TorchCallable + + + + +def torch_view(t: JaxValue) -> TorchValue: + # t is an object from jax land + # view it as-if it's a torch land object + if isinstance(t, jax.Array): + # TODO + return tensor.XLATensor2(t, torch_xla2.default_env()) + if isinstance(t, type(jnp.int32)): + return tensor.t2j_type(t) + if callable(t): # t is a JaxCallable + return functools.partial(call_jax, t) + # regular types are not changed + return t + + +def jax_view(t: TorchValue) -> JaxValue: + # t is an object from torch land + # view it as-if it's a jax land object + if isinstance(t, torch.Tensor): + assert isinstance(t, tensor.XLATensor2) + return t.jax() + if isinstance(t, type(torch.int32)): + return tensor.j2t_dtype(t) + + # torch.nn.Module needs special handling + if not isinstance(t, torch.nn.Module) and callable(t): # t is a TorchCallable + return functools.partial(call_torch, t) + # regular types are not changed + return t + + +def call_jax(jax_func: JaxCallable, + *args: TorchValue, + **kwargs: TorchValue) -> TorchValue: + args, kwargs = pytree.tree_map(jax_view, (args, kwargs)) + res: JaxValue = jax_func(*args, **kwargs) + return torch_view(res) + + +def call_torch(torch_func: TorchCallable, *args: JaxValue, **kwargs: JaxValue) -> JaxValue: + args, kwargs = pytree.tree_map(torch_view, (args, kwargs)) + with torch_xla2.default_env(): + res: TorchValue = torch_func(*args, **kwargs) + return jax_view(res) + + +fori_loop = torch_view(jax.lax.fori_loop) + +def jax_jit(torch_function, kwargs_for_jax_jit=None): + kwargs_for_jax_jit = kwargs_for_jax_jit or {} + jax_func = jax_view(torch_function) + jitted = jax.jit(jax_func, **kwargs_for_jax_jit) + return torch_view(jitted) diff --git a/experimental/torch_xla2/torch_xla2/ops/__init__.py b/experimental/torch_xla2/torch_xla2/ops/__init__.py index e69de29bb2d..abefc8344b1 100644 --- a/experimental/torch_xla2/torch_xla2/ops/__init__.py +++ b/experimental/torch_xla2/torch_xla2/ops/__init__.py @@ -0,0 +1,9 @@ +def all_aten_jax_ops(): + # to load the ops + import torch_xla2.jaten # type: ignore + import torch_xla2.ops_registry # type: ignore + return { + key: val.func + for key, val in torch_xla2.ops_registry.all_aten_ops + if val.is_jax_function + } \ No newline at end of file diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index a30fae82de8..f6adc702a14 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -1,5 +1,14 @@ -"""This module contains implementation of ATen ops.""" +"""Torch ops implemented using jax.""" + +import sys + +import jax +from jax import numpy as jnp +import numpy as np import torch +from torch_xla2.ops import ops_registry +from torch_xla2 import tensor +from torch_xla2.ops import op_base # Keys are OpOverload, value is a callable that takes # XLATensor2 @@ -9,29 +18,1933 @@ # and need to be implemented in jax mutation_ops_to_functional = { - torch.ops.aten.add_: torch.ops.aten.add, - torch.ops.aten.sub_: torch.ops.aten.sub, - torch.ops.aten.mul_: torch.ops.aten.mul, - torch.ops.aten.div_: torch.ops.aten.div, - torch.ops.aten.pow_: torch.ops.aten.pow, - torch.ops.aten.lt_: torch.ops.aten.lt, - torch.ops.aten.le_: torch.ops.aten.le, - torch.ops.aten.gt_: torch.ops.aten.gt, - torch.ops.aten.ge_: torch.ops.aten.ge, - torch.ops.aten.eq_: torch.ops.aten.eq, - torch.ops.aten.ne_: torch.ops.aten.ne, + torch.ops.aten.add_: torch.ops.aten.add, + torch.ops.aten.sub_: torch.ops.aten.sub, + torch.ops.aten.mul_: torch.ops.aten.mul, + torch.ops.aten.div_: torch.ops.aten.div, + torch.ops.aten.pow_: torch.ops.aten.pow, + torch.ops.aten.lt_: torch.ops.aten.lt, + torch.ops.aten.le_: torch.ops.aten.le, + torch.ops.aten.gt_: torch.ops.aten.gt, + torch.ops.aten.ge_: torch.ops.aten.ge, + torch.ops.aten.eq_: torch.ops.aten.eq, + torch.ops.aten.ne_: torch.ops.aten.ne, + torch.ops.aten.uniform_: torch.ops.aten.uniform, } def make_mutation(op): + return op_base.InplaceOp(mutation_ops_to_functional[op], position_to_mutate=0) - def f(*args, **kwargs): - res = mutation_ops_to_functional[op](*args, **kwargs) - args[0].copy_(res) - return args[0] - return f +for op in mutation_ops_to_functional.keys(): + ops_registry.register_torch_dispatch_op( + op, make_mutation(op), is_jax_function=False + ) -for op in mutation_ops_to_functional.keys(): - all_ops[op] = make_mutation(op) +def op(*aten, **kwargs): + def inner(func): + for a in aten: + ops_registry.register_torch_dispatch_op(a, func, **kwargs) + return func + + return inner + + +@op( + torch.ops.aten.view_copy, + torch.ops.aten.view, + torch.ops.aten._unsafe_view, + torch.ops.aten.reshape, +) +def _aten_unsafe_view(x, shape): + return jnp.reshape(x, shape) + + +@op(torch.ops.aten.add.Tensor) +@op(torch.ops.aten.add.Scalar) +def _aten_add(x, y, *, alpha=1): + """if isinstance(x, jnp.ndarray) and isinstance(y, jnp.ndarray): + + assert x.dtype == y.dtype, (x.dtype, y.dtype) + """ + return x + y * alpha + + +@op(torch.ops.aten.copy_, torch.ops.aten.copy_.default, is_jax_function=False) +def _aten_copy(x, y, memory_format=None): + if isinstance(x, tensor.XLATensor2): + x._elem = y._elem + elif isinstance(x, tensor.SliceView): + x.mutate(y) + return x + + +@op(torch.ops.aten.clone) +@op(torch.ops.aten.clone.default) +def _aten_clone(x, memory_format=None): + return jnp.copy(x) + + +@op(torch.ops.aten.full) +def _aten_full(size, value, **kwargs): + return jnp.full(size, value) + + +@op(torch.ops.aten.index_copy) +def _aten_index_copy(x, dim, indexes, source): + # return jax.lax.scatter(x, index, dim) + dims = [] + for i in range(len(x.shape)): + if i == dim: + dims.append(indexes) + else: + dims.append(slice(None, None, None)) + return x.at[dim].set(source) + + +@op(torch.ops.aten.select) +@op(torch.ops.aten.index_select) +@op(torch.ops.aten.select_copy) +def _aten_index_select(x, dim, indexes): + dims = [] + for i in range(len(x.shape)): + if i == dim: + dims.append(indexes) + else: + dims.append(slice(None, None, None)) + return x[tuple(dims)] + + +@op(torch.ops.aten.mean) +def _aten_mean(x, dim=None, keepdim=False): + return jnp.mean(x, dim, keepdims=keepdim) + + +def _torch_binary_scalar_type(scalar, tensor): + if "float" in str(tensor.dtype): + return tensor.dtype + + if isinstance(scalar, int): + if "int" in str(tensor.dtype): + return tensor.dtype + + return jnp.float32 + + +@op(torch.ops.aten.sub.Tensor) +@op(torch.ops.aten.sub.Scalar) +def _aten_sub(x, y): + if isinstance(x, float): + dtype = _torch_binary_scalar_type(x, y) + x = jnp.array(x, dtype=dtype) + if isinstance(y, float): + dtype = _torch_binary_scalar_type(y, x) + y = jnp.array(y, dtype=dtype) + return x - y + + +@op(torch.ops.aten.mm) +def _aten_mm(x, y): + res = x @ y + return res + + +@op(torch.ops.aten.mul.Tensor, torch.ops.aten.mul.Scalar) +def _aten_mul(x, y): + return x * y + + +@op(torch.ops.aten.silu) +def _aten_silu(x): + return jax.nn.silu(x) + + +@op(torch.ops.aten.t) +def _aten_t(x): + return jnp.transpose(x) + + +@op(torch.ops.aten.transpose) +@op(torch.ops.aten.transpose_copy) +def _aten_transpose(x, dim0, dim1): + shape = list(range(len(x.shape))) + shape[dim0], shape[dim1] = shape[dim1], shape[dim0] + return jnp.transpose(x, shape) + + +@op(torch.ops.aten.triu) +def _aten_triu(m, k): + return jnp.triu(m, k) + + +@op(torch.ops.aten.slice) +@op(torch.ops.aten.slice_copy) +def _aten_slice(self, dim=0, start=None, end=None, step=1): + if end == sys.maxsize: + end = self.shape[dim] + sl = slice(start, end, step) + dims = [] + for i in range(len(self.shape)): + if i == dim: + dims.append(sl) + else: + dims.append(slice(None, None, None)) + return self[tuple(dims)] + + +@op(torch.ops.aten.detach) +def _aten_detach(self): + return self + + +@op(torch.ops.aten.view_as_real) +def _aten_view_as_real(x): + real = jnp.real(x) + im = jnp.imag(x) + res = jnp.stack([real, im], -1) + return res + + +@op(torch.ops.aten.stack) +def _aten_stack(tensors, dim=0): + return jnp.stack(tensors, dim) + + +@op(torch.ops.aten._softmax) +def _aten_softmax(x, dim, halftofloat): + return jax.nn.softmax(x, dim) + + +@op(torch.ops.aten.pow) +def _aten_pow(x, y): + if isinstance(y, int): + y = float(y) + return jnp.power(x, y) + + +@op(torch.ops.aten.view_as_complex) +def _aten_view_as_complex(input): + if input.dtype == jnp.bfloat16: + input = input.astype(jnp.float32) + x, y = input[..., 0], input[..., 1] + return jax.lax.complex(x, y) + + +@op(torch.ops.aten.div) +def _aten_div(x, y, rounding_mode=""): + res = x / y + if rounding_mode == "trunc": + res = jnp.trunc(res) + return res + + +@op(torch.ops.aten.div_, is_jax_function=False) +def _aten_div_(x, y, rounding_mode=""): + x._elem = _aten_div(x._elem, y._elem, rounding_mode) + return x + + +@op(torch.ops.aten.true_divide) +def _aten_true_divide(x, y): + return x / y + + +@op(torch.ops.aten.bmm) +def _aten_bmm(x, y): + res = x @ y + return res + # return jnp.einsum('bnm,bmk->bnk', x, y) + + +@op(torch.ops.aten.embedding) +# embedding(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) +def _aten_embedding(a, w, padding_idx=-1): + return jnp.take(a, w, axis=0) + + +@op(torch.ops.aten.rsqrt) +def _aten_rsqrt(x): + if isinstance(x, int): + x = float(x) + if x.dtype == jnp.int32: + x = x.astype(jnp.float32) + return jax.lax.rsqrt(x) + + +@op(torch.ops.aten.expand) +@op(torch.ops.aten.expand_copy) +def _aten_expand(x, dims): + def fix_dims(d, xs): + if d == -1: + return xs + return d + + dims = [fix_dims(p, s) for p, s in zip(dims, x.shape)] + return jnp.broadcast_to(x, dims) + + +@op(torch.ops.aten.dot) +def _aten_dot(x, y): + return jnp.dot(x, y) + + +@op(torch.ops.aten._to_copy) +def _aten__to_copy(self, **kwargs): + dtype = tensor.t2j_dtype(kwargs["dtype"]) + if dtype != self.dtype: + return self.astype(dtype) + return jnp.copy(self) + + +@op(torch.ops.aten.empty) +def _aten_empty(sizes, **kwargs): + return jnp.zeros(sizes) + + +@op(torch.ops.aten.index_put_) +@op(torch.ops.aten.index_put) +def _aten_index_put(self, indexes, values, accumulate=False): + indexes = [slice(None, None, None) if i is None else i for i in indexes] + indexes = tuple(indexes) + if accumulate: + return self.at[indexes].add(values) + else: + return self.at[indexes].set(values) + + +@op(torch.ops.aten.index) +@op(torch.ops.aten._unsafe_index) +@op(torch.ops.aten.index.Tensor) +def _aten_index(self, indexes): + print(indexes) + indexes = [slice(None, None, None) if i is None else i for i in indexes] + indexes = tuple(indexes) + return self[indexes] + + +@op(torch.ops.aten.split) +@op(torch.ops.aten.split_copy) +@op(torch.ops.aten.split_with_sizes) +def split_with_sizes(x, sizes, dim=0): + """Splits an array `x` into sub-arrays based on static sizes `sizes`. + + Args: + x: The input array to split. + sizes: A 1D array of integer sizes for each sub-array. + + Returns: + A list of sub-arrays. + """ + if isinstance(sizes, int): + # split equal size + new_sizes = [sizes] * (x.shape[dim] // sizes) + sizes = new_sizes + rank = x.ndim + splits = np.cumsum(sizes) # Cumulative sum for split points + + def make_range(rank, dim, start, end): + res = [slice(None, None, None)] * rank + res[dim] = slice(start, end) + return tuple(res) + + return [ + x[make_range(rank, dim, start, end)] + for start, end in zip([0] + list(splits[:-1]), splits) + ] + + +@op(torch.ops.aten.permute) +@op(torch.ops.aten.permute_copy) +def permute(t, dims): + return jnp.transpose(t, dims) + + +@op(torch.ops.aten.unsqueeze) +@op(torch.ops.aten.unsqueeze_copy) +@op(torch.ops.aten.unsqueeze.default) +def _aten_unsqueeze(self, dim): + if dim < 0: + dim += self.ndim + 1 + return jnp.expand_dims(self, dim) + + +@op(torch.ops.aten.ne) +def _aten_ne(x, y): + return jnp.not_equal(x, y) + + +@op(torch.ops.aten.cumsum) +def _aten_cumsum(x, y, dtype=None): + if dtype: + dtype = tensor.t2j_dtype(dtype) + res = jnp.cumsum(x, y, dtype) + return res + + +@op(torch.ops.aten.native_layer_norm) +def _aten_native_layer_norm( + input, normalized_shape, weight=None, bias=None, eps=1e-5 +): + """Implements layer normalization in Jax as defined by `aten::native_layer_norm`. + + Args: + input: The input tensor. + normalized_shape: A list of integer dimensions to be normalized over. + weight: Optional weight tensor for the affine transformation. + bias: Optional bias tensor for the affine transformation. + eps: A small epsilon value for numerical stability. + + Returns: + output: The normalized tensor. + mean: The calculated mean tensor. + std: The calculated standard deviation tensor. + """ + if isinstance(normalized_shape, int): + normalized_shape = [normalized_shape] + axis = [i for i, d in enumerate(input.shape) if d in normalized_shape] + + # Calculate mean and standard deviation + mean = jnp.mean(input, axis=axis, keepdims=True) + var = jnp.var(input, axis=axis, keepdims=True) + rstd = jax.lax.rsqrt(var + eps) + + # Normalize the input + norm_x = (input - mean) * rstd + + # Apply affine transformation (if provided) + if weight is not None: + norm_x *= weight + if bias is not None: + norm_x += bias + return norm_x, mean, rstd + + +# - func: addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor +@op(torch.ops.aten.addmm) +@op(torch.ops.aten.addmv) +def _aten_addmm(self, mat1, mat2, *, beta=1.0, alpha=1.0): + alpha = jnp.array(alpha).astype(mat1.dtype) + beta = jnp.array(beta).astype(mat1.dtype) + self *= beta + self += alpha * jnp.matmul(mat1, mat2) + return self + + +@op(torch.ops.aten.addbmm.default) +def _aten_addbmm(input, batch1, batch2, *, beta=1, alpha=1): + alpha = jnp.array(alpha).astype(batch1.dtype) + beta = jnp.array(beta).astype(batch1.dtype) + mm = jnp.einsum("bxy, byz -> xz", batch1, batch2) + return jax.lax.cond( + beta == 0, lambda: alpha * mm, lambda: beta * input + alpha * mm + ) + + +@op(torch.ops.aten.gelu) +def _aten_gelu(self, *, approximate="none"): + approx = approximate == "tanh" + return jax.nn.gelu(self, approx) + + +@op(torch.ops.aten.squeeze) +@op(torch.ops.aten.squeeze_copy) +def _aten_squeeze_dim(self, dim): + """Squeezes a Jax tensor by removing a single dimension of size 1. + + Args: + self: The input tensor. + dim: The dimension to squeeze. + + Returns: + The squeezed tensor with the specified dimension removed if it is 1, + otherwise the original tensor is returned. + """ + + # Validate input arguments + if not isinstance(self, jnp.ndarray): + raise TypeError(f"Expected a Jax tensor, got {type(self)}.") + if isinstance(dim, int): + dim = [dim] + + # Check if the specified dimension has size 1 + if all([self.shape[d] != 1 for d in dim]): + return self + + # Use slicing to remove the dimension if it is 1 + new_shape = list(self.shape) + + def fix_dim(p): + if p < 0: + return p + len(self.shape) + return p + + dim = [fix_dim(d) for d in dim] + new_shape = [p for i, p in enumerate(self.shape) if i not in dim or p != 1] + return self.reshape(new_shape) + + +@op(torch.ops.aten.convolution) +def _aten_convolution( + input, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, +): + if transposed: + raise NotImplementedError("Transposed convolution is not implemented.") + + def make_padding(padding): + return ((p, p) for p in padding) + + def create_default_conv_dimension_numbers(num_spatial_dims): + # Ref: https://github.com/openxla/xla/blob/main/xla/client/xla_builder.cc#L4211 + # (batch dimension, feature dimension, spatial dimensions...) + lhs_spec = [0, 1] + # (out feature dimension, in feature dimension, spatial dimensions...) + rhs_spec = [0, 1] + # (batch dimension, feature dimension, spatial dimensions...) + out_spec = [0, 1] + for i in range(0, num_spatial_dims): + lhs_spec.append(i + 2) + rhs_spec.append(i + 2) + out_spec.append(i + 2) + return jax.lax.ConvDimensionNumbers( + *map(tuple, (lhs_spec, rhs_spec, out_spec)) + ) + + res = jax.lax.conv_general_dilated( + input, + weight, + stride, + make_padding(padding), + lhs_dilation=(1,) * len(stride), + rhs_dilation=dilation, + dimension_numbers=create_default_conv_dimension_numbers(len(stride)), + feature_group_count=groups, + batch_group_count=1, + ) + + if bias is not None: + # TODO(qihqi): bias always on channel? + if len(bias.shape) == 1: + shape = [1] * len(res.shape) + shape[1] = bias.shape[0] + bias = bias.reshape(tuple(shape)) + res = res + bias + return res + + +# _native_batch_norm_legit(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, bool training, float momentum, float eps) +@op(torch.ops.aten._native_batch_norm_legit) +def _aten__native_batch_norm_legit( + input, weight, bias, running_mean, running_var, training, momentum, eps +): + return _aten__native_batch_norm_legit_no_training( + input, weight, bias, running_mean, running_var, momentum, eps + ) + + +@op(torch.ops.aten._native_batch_norm_legit_no_training) +def _aten__native_batch_norm_legit_no_training( + input, weight, bias, running_mean, running_var, momentum, eps +): + if weight is None: + weight = jnp.ones_like(running_mean) + if bias is None: + bias = jnp.zeros_like(running_mean) + + def broadcast(t): + return jax.lax.broadcast_in_dim(t, input.shape, broadcast_dimensions=(1,)) + + if running_mean is not None: + a = input - broadcast(running_mean) + else: + a = input + if running_var is not None: + b = broadcast(jnp.sqrt(running_var + eps)) + else: + b = broadcast(jnp.sqrt(eps)) + return ( + a / b * broadcast(weight) + broadcast(bias), + jnp.array([]), + jnp.array([]), + ) + + +@op(torch.ops.aten.relu) +def _aten_relu(self): + return jax.nn.relu(self) + + +@op(torch.ops.aten.cat) +def _aten_cat(tensors, dims=0): + return jnp.concatenate(tensors, dims) + + +@op(torch.ops.aten.max_pool2d_with_indices) +@op(torch.ops.aten.max_pool3d_with_indices) +def _aten_max_pool2d_with_indices( + inputs, kernel_size, strides, padding=0, dilation=1, ceil_mode=False +): + num_batch_dims = len(inputs.shape) - len(kernel_size) - 1 + kernel_size = tuple(kernel_size) + strides = tuple(strides) + if isinstance(padding, int): + padding = tuple((padding, padding) for _ in range(len(kernel_size))) + elif isinstance(padding, list): + padding = tuple((p, p) for p in padding) + + window_shape = kernel_size + num_batch_dims = inputs.ndim - (len(window_shape) + 1) + strides = strides or (1,) * len(window_shape) + assert len(window_shape) == len( + strides + ), f"len({window_shape}) must equal len({strides})" + strides = (1,) * (1 + num_batch_dims) + strides + dims = (1,) * (1 + num_batch_dims) + window_shape + + is_single_input = False + if num_batch_dims == 0: + # add singleton batch dimension because lax.reduce_window always + # needs a batch dimension. + inputs = inputs[None] + strides = (1,) + strides + dims = (1,) + dims + is_single_input = True + + assert inputs.ndim == len(dims), f"len({inputs.shape}) != len({dims})" + if not isinstance(padding, str): + padding = tuple(map(tuple, padding)) + assert len(padding) == len(window_shape), ( + f"padding {padding} must specify pads for same number of dims as " + f"window_shape {window_shape}" + ) + assert all( + [len(x) == 2 for x in padding] + ), f"each entry in padding {padding} must be length 2" + padding = ((0, 0), (0, 0)) + padding + + indices = jnp.arange(np.prod(inputs.shape)).reshape(inputs.shape) + + def reduce_fn(a, b): + ai, av = a + bi, bv = b + which = av > bv + return jnp.where(which, ai, bi), jnp.where(which, av, bv) + + init_val = -jnp.inf + if inputs.dtype in (jnp.int32, jnp.int64): + init_val = -(1 << 31) + init_val = jnp.array(init_val).astype(inputs.dtype) + + indices, y = jax.lax.reduce_window( + (indices, inputs), (0, init_val), reduce_fn, dims, strides, padding + ) + if is_single_input: + indices = jnp.squeeze(indices, axis=0) + y = jnp.squeeze(y, axis=0) + return y, indices + + batch_result = pool( + inputs, -jnp.inf, jax.lax.max, kernel_size, strides, padding + ) + indices = pool(inputs, 0, jnp.argmax, kernel_size, strides, padding) + return batch_result, indices + + +# TODO add more ops + + +@op(torch.ops.aten.min) +def _aten_min(x, axis=None): + return jnp.min(x, axis=axis), jnp.argmin(x, axis=axis).astype(jnp.int64) + + +@op(torch.ops.aten.amin) +def _aten_amin(x, dim=None, keepdim=False): + return _with_reduction_scalar(jnp.amin, x, dim, keepdim) + + +@op(torch.ops.aten.argmin) +def _aten_argmin(self, dim=None, keepdim=False): + return _with_reduction_scalar(jnp.argmin, self, dim, keepdim) + + +@op(torch.ops.aten.sin) +def _aten_sin(x): + return jnp.sin(x) + + +@op(torch.ops.aten.sym_size) +def _aten_sym_size(x, dim): + return x.shape[dim] + + +@op(torch.ops.aten.var.correction) +@op(torch.ops.prims.var) +def _aten_var(x, dim=None, *, correction=1, keepdim=False, out=None): + return jnp.var(x, axis=dim, ddof=correction, keepdims=keepdim) + + +@op(torch.ops.prims.broadcast_in_dim) +def _prims_broadcast_in_dim(t, shape, broadcast_dimensions): + return jax.lax.broadcast_in_dim( + t, shape, broadcast_dimensions=broadcast_dimensions + ) + + +# aten.native_group_norm -- should use decomp table +# func: native_group_norm(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps) -> (Tensor, Tensor, Tensor) + + +@op(torch.ops.aten.native_group_norm) +def _aten_native_group_norm(input, weight, bias, N, C, HxW, group, eps=1e-5): + """Group Normalization implementation in JAX. + + Args: + input: Input tensor. Expected shape (batch_size, channels, ... spatial dims + ...) + weight: Optional scaling (gamma) parameter. Shape (channels,) + bias: Optional shifting (beta) parameter. Shape (channels,) + N: Batch size. + C: Number of channels. + HxW: Product of spatial dimensions (number of elements per channel after + flattening). + group: Number of groups for Group Normalization. + eps: Small value added for numerical stability. + + Returns: + A tuple of (normalized_output, mean, rstd) + """ + + input_shape = input.shape + + # Reshape for group-wise normalization + reshaped_input = jnp.reshape(input, (1, N * group, -1)) + + # **Core Group Normalization** + def group_norm_body(x): # Function to apply within each group + mean = jnp.mean(x, axis=-1, keepdims=True) + var = jnp.var(x, axis=-1, keepdims=True) + rstd = jax.lax.rsqrt(var + eps) # Reciprocal of std with epsilon + normalized = (x - mean) * rstd + return normalized, mean, rstd + + normalized, group_mean, group_rstd = jax.lax.map( + group_norm_body, reshaped_input + ) + + # Reshape back to original input shape + output = jnp.reshape(normalized, input_shape) + + # **Affine transformation** + affine_shape = [ + -1 if i == 1 else 1 for i in range(input.ndim) + ] # Shape for broadcasting + if weight is not None and bias is not None: + output = bias.reshape(affine_shape) + output * weight.reshape(affine_shape) + elif weight is not None: + output = output * weight.reshape(affine_shape) + elif bias is not None: + output = output + bias.reshape(affine_shape) + + # Reshape mean and rstd + mean = jnp.reshape(group_mean, (N, group)) + rstd = jnp.reshape(group_rstd, (N, group)) + + return output, mean, rstd + + +@op(torch.ops.aten.linalg_vector_norm) +def _aten_linalg_vector_norm(self, ord=2, dim=None, keepdim=False, dtype=None): + """Calculates the vector norm along specified dimensions. + + Args: + self: The input tensor. + ord: The order of the norm. Can be a float or 'inf', '-inf', 'fro'. + Default is 2 (Euclidean norm). + dim: Dimensions along which to calculate the norm. If None, the norm is + calculated over all dimensions. + keepdim: Whether to keep the reduced dimensions. + dtype: Optional data type for the output. + + Returns: + The tensor containing the calculated vector norms. + """ + + if ord not in {2, float("inf"), float("-inf"), "fro"}: + raise ValueError( + f"Unsupported ord value: {ord}. Supported values are 2, inf, -inf, and" + " 'fro'." + ) + + # Special cases (for efficiency and clarity) + if ord == 2: # Euclidean norm + result = jnp.sqrt(jnp.sum(jnp.abs(self) ** 2, axis=dim, keepdims=keepdim)) + + elif ord == float("inf"): + result = jnp.max(jnp.abs(self), axis=dim, keepdims=keepdim) + + elif ord == float("-inf"): + result = jnp.min(jnp.abs(self), axis=dim, keepdims=keepdim) + + elif ord == "fro": # Frobenius norm + result = jnp.sqrt(jnp.sum(jnp.abs(self) ** 2, axis=dim, keepdims=keepdim)) + + else: # General case (e.g., ord = 1, ord = 3) + result = jnp.sum(jnp.abs(self) ** ord, axis=dim, keepdims=keepdim) ** ( + 1.0 / ord + ) + + # (Optional) dtype conversion + if dtype is not None: + result = result.astype(dtype) + + return result + + +# aten.reflection_pad1d +@op(torch.ops.aten.reflection_pad1d) +def _aten_reflection_pad1d(input, padding): + rank = len(input.shape) + pad_size = [(0, 0)] * rank + pad_size[-1] = padding + return jnp.pad(input, pad_size, mode="reflect") + + +# aten.alias +@op(torch.ops.aten.alias) +def _aten_alias(self, *args): + return self + + +# aten.sinh +@op(torch.ops.aten.sinh) +def _aten_sinh(self): + return jnp.sinh(self) + + +# aten.native_layer_norm_backward +@op(torch.ops.aten.native_layer_norm_backward) +def _aten_native_layer_norm_backward( + grad_out, input, normalized_shape, weight, bias, eps=1e-5 +): + """Implements the backward pass of layer normalization in Jax as defined by `aten::native_layer_norm_backward`. + + Args: + grad_out: The gradient of the output tensor. + input: The input tensor. + normalized_shape: A list of integer dimensions to be normalized over. + weight: Optional weight tensor for the affine transformation. + bias: Optional bias tensor for the affine transformation. + eps: A small epsilon value for numerical stability. + + Returns: + A tuple of (grad_input, grad_weight, grad_bias). + """ + return jax.lax.native_layer_norm_backward( + grad_out, input, normalized_shape, weight, bias, eps + ) + + +# aten.reflection_pad3d_backward +# aten.reflection_pad2d + + +# aten.atanh +@op(torch.ops.aten.atanh) +def _aten_atanh(self): + return jnp.arctanh(self) + + +# aten.bitwise_not +@op(torch.ops.aten.bitwise_not) +def _aten_bitwise_not(self): + return ~self + + +# aten.embedding_dense_backward + + +# aten.sum +@op(torch.ops.aten.sum) +def _aten_sum(self, dim=None, keepdim=False, dtype=None): + if not dim: + dim = None + return jnp.sum(self, axis=dim, keepdims=keepdim, dtype=dtype) + + +# aten.sqrt +@op(torch.ops.aten.sqrt) +def _aten_sqrt(self): + return jnp.sqrt(self) + + +@op(torch.ops.aten.tan) +def _aten_tanh(self): + return jnp.tan(self) + + +# aten.tanh +@op(torch.ops.aten.tanh) +def _aten_tanh(self): + return jnp.tanh(self) + + +# aten.ceil +@op(torch.ops.aten.ceil) +def _aten_ceil(self): + return jnp.ceil(self) + + +# aten.asin +@op(torch.ops.aten.asin) +def _aten_asin(self): + return jnp.arcsin(self) + + +# aten.minimum +@op(torch.ops.aten.minimum) +def _aten_minimum(self, other): + return jnp.minimum(self, other) + + +# aten.max_pool2d_backward + + +def _scatter_index(dim, index): + """Returns a tuple of indexes; + + The first is to select in input (to modify), + the second is to select from the values. + """ + index_shape = list(index.shape) + input_indexes = [] + source_indexes = [] + for i in range(len(index_shape)): + source_indexes.append(slice(0, index_shape[i])) + if i == dim: + input_indexes.append(index) + else: + target_shape = [1] * len(index_shape) + target_shape[i] = index_shape[i] + input_indexes.append( + jnp.broadcast_to( + jnp.arange(index_shape[i]).reshape(target_shape), index_shape + ) + ) + return tuple(input_indexes), tuple(source_indexes) + + +# aten.scatter_add +@op(torch.ops.aten.scatter_add) +def _aten_scatter_add(input, dim, index, src): + """JAX implementation of scatter, mimicking torch.scatter behavior""" + + input_indexes, source_indexes = _scatter_index(dim, index) + return input.at[input_indexes].add(src[source_indexes]) + + +# aten.logical_not + + +# aten.sign +@op(torch.ops.aten.sign) +def _aten_sign(x): + return jnp.sign(x) + + +# aten.sigmoid +@op(torch.ops.aten.sigmoid) +def _aten_sigmoid(x): + if x.dtype in (jnp.int32, jnp.int64): + x = x.astype(jnp.float32) + return jax.nn.sigmoid(x) + + +# implement aten.asinh in jax +@op(torch.ops.aten.asinh) +def _aten_asinh(self): + return jnp.arcsinh(self) + + +# aten.atan +@op(torch.ops.aten.atan) +def _aten_atan(self): + return jnp.arctan(self) + + +# aten.scatter_reduce +@op(torch.ops.aten.scatter_reduce) +def _aten_scatter_reduce(input, dim, index, src, reduce, *, include_self=True): + input_indexes, source_indexes = _scatter_index(dim, index) + if reduce == "sum": + return input.at[input_indexes].add(src[source_indexes]) + elif reduce == "prod": + return input.at[input_indexes].multiply(src[source_indexes]) + elif reduce == "mean": + return input.at[input_indexes].add(src[source_indexes]) + elif reduce == "amax": + return input.at[input_indexes].max(src[source_indexes]) + elif reduce == "amin": + return input.at[input_indexes].min(src[source_indexes]) + else: + raise RuntimeError("Unknow reduction type: ", reduce) + + +# aten.acos +@op(torch.ops.aten.acos) +def _aten_acos(self): + return jnp.arccos(self) + + +# aten.sym_storage_offset +# aten.native_layer_norm_backward +# aten.max_pool3d_with_indices + + +# aten.gt +@op(torch.ops.aten.gt) +def _aten_gt(self, other): + return self > other + + +# aten.pixel_shuffle +@op(torch.ops.aten.pixel_shuffle) +def _aten_pixel_shuffle(x, upscale_factor): + """PixelShuffle implementation in JAX. + + Args: + x: Input tensor. Typically a feature map. + upscale_factor: Integer by which to upscale the spatial dimensions. + + Returns: + Tensor after PixelShuffle operation. + """ + + batch_size, channels, height, width = x.shape + + if channels % (upscale_factor**2) != 0: + raise ValueError( + "Number of channels must be divisible by the square of the upscale factor." + ) + + new_channels = channels // (upscale_factor**2) + new_height = height * upscale_factor + new_width = width * upscale_factor + + x = x.reshape( + batch_size, new_channels, upscale_factor, upscale_factor, height, width + ) + x = jnp.transpose( + x, (0, 1, 2, 4, 3, 5) + ) # Move channels to spatial dimensions + x = x.reshape(batch_size, new_channels, new_height, new_width) + + return x + + +# aten.sym_stride +# aten.lt +@op(torch.ops.aten.lt) +def _aten_lt(self, other): + return self < other + + +def pool(inputs, init, reduce_fn, window_shape, strides, padding): + """Helper function to define pooling functions. + + Pooling functions are implemented using the ReduceWindow XLA op. + NOTE: Be aware that pooling is not generally differentiable. + That means providing a reduce_fn that is differentiable does not imply that + pool is differentiable. + + Args: + inputs: input data with dimensions (batch, window dims..., features). + init: the initial value for the reduction + reduce_fn: a reduce function of the form ``(T, T) -> T``. + window_shape: a shape tuple defining the window to reduce over. + strides: a sequence of ``n`` integers, representing the inter-window + strides (default: ``(1, ..., 1)``). + padding: either the string ``'SAME'``, the string ``'VALID'``, or a sequence + of ``n`` ``(low, high)`` integer pairs that give the padding to apply before + and after each spatial dimension. + Returns: + The output of the reduction for each window slice. + """ + num_batch_dims = inputs.ndim - (len(window_shape) + 1) + strides = strides or (1,) * len(window_shape) + assert len(window_shape) == len( + strides + ), f"len({window_shape}) must equal len({strides})" + strides = (1,) * (1 + num_batch_dims) + strides + dims = (1,) * (1 + num_batch_dims) + window_shape + + is_single_input = False + if num_batch_dims == 0: + # add singleton batch dimension because lax.reduce_window always + # needs a batch dimension. + inputs = inputs[None] + strides = (1,) + strides + dims = (1,) + dims + is_single_input = True + + assert inputs.ndim == len(dims), f"len({inputs.shape}) != len({dims})" + if not isinstance(padding, str): + padding = tuple(map(tuple, padding)) + assert len(padding) == len(window_shape), ( + f"padding {padding} must specify pads for same number of dims as " + f"window_shape {window_shape}" + ) + assert all( + [len(x) == 2 for x in padding] + ), f"each entry in padding {padding} must be length 2" + padding = ((0, 0), (0, 0)) + padding + y = jax.lax.reduce_window(inputs, init, reduce_fn, dims, strides, padding) + if is_single_input: + y = jnp.squeeze(y, axis=0) + return y + + +@op(torch.ops.aten._adaptive_avg_pool3d) +def _aten_adaptive_avg_pool3d(x, output_shape): + return _aten_adaptive_avg_pool(x, output_shape, 3) + + +@op(torch.ops.aten._adaptive_avg_pool2d) +def _aten_adaptive_avg_pool3d(x, output_shape): + return _aten_adaptive_avg_pool(x, output_shape, 2) + + +def _aten_adaptive_avg_pool(x, output_shape, pool_dim): + def adaptive_kernel_size(input_shape, output_shape): + sizes = [1, 1] + spatial_dim_off = len(input_shape) - pool_dim + for spatial_dim in range(pool_dim): + sizes.append( + input_shape[spatial_dim_off + spatial_dim] // output_shape[spatial_dim] + ) + return tuple(sizes) + + kernel_sizes = adaptive_kernel_size(x.shape, output_shape) + y = pool(x, 0.0, jax.lax.add, kernel_sizes, kernel_sizes, padding="VALID") + + div_shape = list(x.shape) + num_batch_dims = len(x.shape) - pool_dim - 1 + div_shape[num_batch_dims] = 1 + div_shape = tuple(div_shape) + if len(div_shape) - 2 == len(kernel_sizes): + div_shape = (1,) + div_shape[1:] + y = y / pool( + jnp.ones(div_shape), 0.0, jax.lax.add, kernel_sizes, kernel_sizes, "VALID" + ) + return y + + +# aten.avg_pool2d +@op(torch.ops.aten.avg_pool2d) +@op(torch.ops.aten.avg_pool3d) +def _aten_avg_pool( + inputs, + kernel_size, + strides=None, + padding=0, + ceil_mode=False, + count_include_pad=True, + divisor_override=None, +): + num_batch_dims = len(inputs.shape) - len(kernel_size) - 1 + kernel_size = tuple(kernel_size) + strides = tuple(strides) + if isinstance(padding, int): + padding = tuple((padding, padding) for _ in range(len(kernel_size))) + elif isinstance(padding, list): + padding = tuple((p, p) for p in padding) + + y = pool(inputs, 0.0, jax.lax.add, kernel_size, strides, padding) + if count_include_pad: + y = y / np.prod(kernel_size) + else: + div_shape = list(inputs.shape) + div_shape[num_batch_dims] = 1 + div_shape = tuple(div_shape) + if len(div_shape) - 2 == len(kernel_size): + div_shape = (1,) + div_shape[1:] + y = y / pool( + jnp.ones(div_shape), 0.0, jax.lax.add, kernel_size, strides, padding + ) + return y + + +# aten.sym_numel +# aten.reciprocal +@op(torch.ops.aten.reciprocal) +def _aten_reciprocal(a): + return 1 / a + + +# aten.scatter +@op(torch.ops.aten.select_scatter) +def _aten_select_scatter(input, src, dim, index): + input_indexes = [] + for x in range(len(input.shape)): + if x == dim: + input_indexes.append(index) + else: + input_indexes.append(slice(None, None, None)) + return input.at[tuple(input_indexes)].set(src) + + +@op(torch.ops.aten.scatter.src) +def _aten_scatter_src(input, dim, index, src, reduce=None): + input_index, source_indexes = _scatter_index(dim, index) + return input.at[input_index].set(src[source_indexes]) + + +@op(torch.ops.aten.scatter.value) +def _aten_scatter(input, dim, index, src, reduce=None): + input_index, source_indexes = _scatter_index(dim, index) + return input.at[input_index].set(src) + + +# aten.acosh +@op(torch.ops.aten.acosh) +def _aten_acosh(self): + return jnp.arccosh(self) + + +# aten.avg_pool2d_backward +# aten.col2im +# aten.avg_pool3d +# aten.round +@op(torch.ops.aten.round) +def _aten_round(input, decimals=0): + return jnp.round(input, decimals) + + +# aten.max +@op(torch.ops.aten.max) +def _aten_max(self, dim=None, keepdim=False): + return jnp.max(self, axis=dim, keepdims=keepdim), jnp.argmax( + self, axis=dim, keepdims=keepdim + ) + + +# aten.maximum +@op(torch.ops.aten.maximum) +def _aten_maximum(self, other): + return jnp.maximum(self, other) + + +# aten.abs +@op(torch.ops.aten.abs) +def _aten_abs(self): + return jnp.abs(self) + + +# generate aten.amax only +@op(torch.ops.aten.amax) +def _aten_amax(self, dim=None, keepdim=False): + return _with_reduction_scalar(jnp.amax, self, dim, keepdim) + + +def _with_reduction_scalar(jax_func, self, dim, keepdim): + expanded = False + if self.ndim == 0: + # for self of rank 0: + # torch.any(x, 0), torch.any(x, -1) works; + # torch.any(x, 1) throws out of bounds, so it's + # behavior is the same as a jnp array of rank 1 + expanded = True + self = jnp.expand_dims(self, 0) + res = jax_func(self, axis=dim, keepdims=keepdim) + if expanded: + res = res.squeeze() + return res + + +# aten.any +@op(torch.ops.aten.any) +def _aten_any(self, dim=None, keepdim=False): + return _with_reduction_scalar(jnp.any, self, dim, keepdim) + + +# aten.arange +@op(torch.ops.aten.arange.start_step) +@op(torch.ops.aten.arange.start) +@op(torch.ops.aten.arange.default) +def _aten_arange( + start, + end=None, + step=1, + *, + dtype=None, + layout=None, + requires_grad=False, + device=None, + pin_memory=False, +): + if end is None: + end = start + start = 0 + if dtype: + dtype = tensor.t2j_dtype(dtype) + return jnp.arange( + start, + end, + step, + dtype=dtype, + ) + + +# aten.argmax +@op(torch.ops.aten.argmax) +def _aten_argmax(self, dim=None, keepdim=False): + return _with_reduction_scalar(jnp.argmax, self, dim, keepdim) + + +# aten.as_strided +@op(torch.ops.aten.as_strided) +@op(torch.ops.aten.as_strided_copy) +def _aten_as_strided(x, sizes, strides, storage_offset=None): + ind = jnp.zeros(sizes, dtype=jnp.int32) + + for i, (size, stride) in enumerate(zip(sizes, strides)): + result_shape = (1,) * i + (size,) + (1,) * (len(sizes) - i - 1) + indexes = (jnp.arange(size) * stride).reshape(result_shape) + ind += indexes + + return jnp.ravel(x)[ind] + + +# aten.atan2 +@op(torch.ops.aten.atan2) +def _aten_atan2(self, other): + return jnp.arctan2(self, other) + + +# aten.bitwise_and +@op(torch.ops.aten.bitwise_and) +def _aten_bitwise_and(self, other): + return self & other + + +# aten.bitwise_or +@op(torch.ops.aten.bitwise_or) +def _aten_bitwise_or(self, other): + return self | other + + +# aten.bitwise_xor +@op(torch.ops.aten.bitwise_xor) +def _aten_bitwise_xor(self, other): + return self ^ other + + +# aten.clamp +@op(torch.ops.aten.clamp.default) +@op(torch.ops.aten.clamp.Tensor) +def _aten_clamp(self, min=None, max=None): + return jnp.clip(self, min, max) + + +# aten.constant_pad_nd +@op(torch.ops.aten.constant_pad_nd) +def _aten_constant_pad_nd(input, padding, value=0): + # NOTE: Torch padding is flat and reversed: (1, 1, 2, 2) + # means last dim get padded 1 in front and 1 in back; + # and second last dim get padded 2 in front and 2 in back. + # Jax padding tuple of 2-tuple: the same padding is + # [(0, 0), ..., (2,2), (1,1)] + m = len(padding) + rev_padding = [(padding[i - 1], padding[i]) for i in range(m - 1, 0, -2)] + pad_dim = tuple(([(0, 0)] * (len(input.shape) - m // 2)) + rev_padding) + return jnp.pad(input, pad_dim, mode="constant", constant_values=value) + + +# aten.convolution_backward +@op(torch.ops.aten.copy) +@op(torch.ops.aten.lift_fresh_copy) +def _aten_copy(x): + return jnp.copy(x) + + +@op(torch.ops.aten._cdist_forward) +def _aten_cdist_forward(x1, x2, p, compute_mode=""): + # x1 is B x P x M + # x2 is B x Q x M + # res is B x P x Q + x1 = jnp.expand_dims(x1, len(x1.shape) - 1) + x2 = jnp.expand_dims(x2, len(x2.shape) - 2) + return jnp.linalg.norm(x1 - x2, ord=p, axis=-1) + + +@op(torch.ops.aten._pdist_forward) +def _aten__pdist_forward(x, p): + pairwise_dists = _aten_cdist_forward(x, x, p) + condensed_dists = pairwise_dists[ + jnp.triu_indices(pairwise_dists.shape[0], k=1) + ] + return condensed_dists + + +# aten.cos +@op(torch.ops.aten.cos) +def _aten_cos(input): + return jnp.cos(input) + + +# aten.cosh +@op(torch.ops.aten.cosh) +def _aten_cosh(input): + return jnp.cosh(input) + + +# aten.diagonal +@op(torch.ops.aten.diagonal) +def _aten_diagonal(input, offset=0, dim1=0, dim2=1): + return jnp.diagonal(input, offset, dim1, dim2) + + +# aten.empty_strided +# aten.eq +@op(torch.ops.aten.eq) +def _aten_eq(input1, input2): + return input1 == input2 + + +# aten.erf +@op(torch.ops.aten.erf) +def _aten_erf(x): + if x.dtype in (jnp.int32, jnp.int64): + x = x.astype(jnp.float32) + return jax.lax.erf(x) + + +# aten.exp +@op(torch.ops.aten.exp) +def _aten_exp(input): + return jnp.exp(input) + + +# aten.expm1 +@op(torch.ops.aten.expm1) +def _aten_expm1(input): + return jnp.expm1(input) + + +# aten.fill +@op(torch.ops.aten.fill) +@op(torch.ops.aten.full_like) +def _aten_fill(x, value, dtype=None, pin_memory=None, memory_format=None): + if dtype is None: + dtype = x.dtype + else: + dtype = tensor.t2j_dtype(dtype) + return jnp.full(x.shape, value, dtype) + + +# aten.flip +@op(torch.ops.aten.flip) +def _aten_flip(input, dims): + if dims is not None: + return jnp.flip(input, tuple(dims)) + else: + return jnp.flip(input) + + +# aten.floor +@op(torch.ops.aten.floor) +def _aten_floor(input): + return jnp.floor(input) + + +# aten.fmod +@op(torch.ops.aten.fmod) +def _aten_fmod(input, other): + return input - other * _aten_div(input, other, "trunc") + + +# aten.gather +@op(torch.ops.aten.gather) +def _aten_gather(input, dim, index): + input_indexes, source_indexes = _scatter_index(dim, index) + return input[input_indexes] + + +# aten.ge +@op(torch.ops.aten.ge) +def _aten_ge(self, other): + return self >= other + + +@op(torch.ops.aten.glu) +@op(torch.ops.aten.glu.default) +def _aten_glu(x, dim=-1): + return jax.nn.glu(x, dim) + + +# aten.hardtanh +@op(torch.ops.aten.hardtanh) +def _aten_hardtanh(input, min_val=-1.0, max_val=1.0, inplace=False): + return jnp.clip(input, min_val, max_val) + + +# aten.isinf +@op(torch.ops.aten.isinf) +def _aten_isinf(input): + return jnp.isinf(input) + + +# aten.isnan +@op(torch.ops.aten.isnan) +def _aten_isnan(input): + return jnp.isnan(input) + + +@op(torch.ops.aten.le) +def _aten_le(self, other): + return self <= other + + +# aten.leaky_relu +@op(torch.ops.aten.leaky_relu) +def _aten_leaky_relu(x, negative_slope): + return jax.nn.leaky_relu(x, negative_slope) + + +# aten.log +@op(torch.ops.aten.log) +def _aten_log(x): + return jnp.log(x) + + +# aten.log10 +@op(torch.ops.aten.log10) +def _aten_log10(x): + return jnp.log10(x) + + +# aten.log1p +@op(torch.ops.aten.log1p) +def _aten_log1p(x): + return jnp.log1p(x) + + +# aten.log2 +@op(torch.ops.aten.log2) +def _aten_log2(x): + return jnp.log2(x) + + +# aten.logical_and +@op(torch.ops.aten.logical_and) +def _aten_logical_and(self, other): + return jnp.logical_and(self, other) + + +# aten.logical_or +@op(torch.ops.aten.logical_or) +def _aten_logical_or(self, other): + return jnp.logical_or(self, other) + + +# aten.logical_not +@op(torch.ops.aten.logical_not) +def _aten_logical_not(self): + return jnp.logical_not(self) + + +# aten.log_softmax +@op(torch.ops.aten._log_softmax) +def _aten_log_softmax(self, axis=-1, half_to_float=False): + return jax.nn.log_softmax(self, axis) + + +# aten.max_pool3d_backward +# aten.logical_xor +@op(torch.ops.aten.logical_xor) +def _aten_logical_xor(self, other): + return jnp.logical_xor(self, other) + + +# aten.max_pool2d_with_indices_backward +# aten.native_dropout +# aten.native_group_norm_backward +# aten.neg +@op(torch.ops.aten.neg) +def _aten_neg(x): + return -1 * x + + +# aten.nonzero +@op(torch.ops.aten.nonzero) +def _aten_nonzero(x): + index_tuple = jnp.nonzero(x) + index_tuple = [jnp.expand_dims(p, -1) for p in index_tuple] + return jnp.concatenate(index_tuple, axis=-1) + + +# aten.prod + + +@op(torch.ops.aten.prod) +def _aten_prod(self, dim=None, keepdim=False): + return jnp.prod(self, axis=dim, keepdims=keepdim) + + +# aten.randperm + + +# aten.reflection_pad3d + + +# aten.remainder +@op(torch.ops.aten.remainder) +def _aten_remainder(inputs, other): + return inputs % other + + +# aten.repeat +@op(torch.ops.aten.repeat) +def _aten_repeat(x, reps): + return jnp.tile(x, reps) + + +# aten.replication_pad2d +# aten.replication_pad3d +# aten.roll +@op(torch.ops.aten.roll) +def _aten_roll(input, shifts, dims=None): + return jnp.roll(input, shifts, dims) + + +# aten.scalar_tensor +# aten.slice_scatter +@op(torch.ops.aten.slice_scatter) +def _aten_slice_scatter(input, src, dim=0, start=None, end=None, step=1): + input_index = [] + for x in range(len(input.shape)): + if x == dim: + input_index.append(slice(start, end, step)) + else: + input_index.append(slice(None, None, None)) + return input.at[tuple(input_index)].set(src) + + +# aten.sort +# torch.sort(input, dim=-1, descending=False, stable=False, *, out=None) +@op(torch.ops.aten.sort) +def _aten_sort(a, dim=-1, descending=False, stable=False): + return ( + jnp.sort(a, axis=dim, stable=stable, descending=descending), + jnp.argsort(a, axis=dim, stable=stable, descending=descending), + ) + + +# aten.sym_size + + +# aten.topk +@op(torch.ops.aten.topk) +def _aten_topk(input, k, dim=None, largest=True, sorted=True, *, out=None): + """JAX top-k implementation using jax.lax.top_k for improved efficiency. + + Args: + input: The input JAX array. + k: The number of top elements to return. + dim: The dimension along which to find the top-k. If None, operates on the + flattened array. + largest: If True, returns the largest k elements. Otherwise, smallest k. + sorted: If True, returns the elements in sorted order. + + Returns: + A tuple (values, indices) containing: + - values: The top k values. + - indices: The indices of the top k values in the original array. + """ + if dim is None: + input = input.flatten() + dim = 0 + + if not largest: + input = -input # Find top-k of negated input if we want the smallest + + transpose_shape = None + if dim != -1 and dim != len(input.shape) - 1: + transpose_shape = list(range(len(input.shape))) + transpose_shape[dim], transpose_shape[-1] = ( + transpose_shape[-1], + transpose_shape[dim], + ) + input = jnp.transpose(input, transpose_shape) + + values, indices = jax.lax.top_k(input, k) + + if sorted: + values = jnp.sort(values, descending=True) + indices = jnp.take_along_axis( + indices, jnp.argsort(values, axis=-1, descending=True), axis=-1 + ) + + if not largest: + values = -values # Negate values back if we found smallest + + if transpose_shape is not None: + values = jnp.transpose(values, transpose_shape) + indices = jnp.transpose(indices, transpose_shape) + + return values, indices + + +# aten.trunc +@op(torch.ops.aten.trunc) +def _aten_trunc(a): + return jnp.trunc(a) + + +@op(torch.ops.aten.unbind) +@op(torch.ops.aten.unbind_copy) +def _aten_unbind(a, dim=0): + return tuple( + _aten_squeeze_dim(jax.lax.index_in_dim(a, i, axis=dim), dim) + for i in range(a.shape[dim]) + ) + + +# NOTE: skip aten.upsample_nearest2d and aten.upsample_bilinear2d +# despite those being core aten ops, they also have decompositions. +# here we are using torch decompositions. + + +# aten.where +@op(torch.ops.aten.where.self) +@op(torch.ops.aten.where.ScalarSelf) +@op(torch.ops.aten.where.ScalarOther) +def _aten_where(condition, x, y): + return jnp.where(condition, x, y) + + +# aten.to.dtype +# Tensor(a) self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None +@op(torch.ops.aten.to.dtype) +def _aten_to_dtype( + a, dtype, non_blocking=False, copy=False, memory_format=None +): + if dtype: + jaxdtype = tensor.t2j_dtype(dtype) + return a.astype(jaxdtype) + + +# aten.to.device + + +# Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False +@op(torch.ops.aten.var_mean.correction) +def _aten_var_mean_correction(self, dim=None, correction=None, keepdim=False): + return ( + jnp.var(self, axis=dim, ddof=correction, keepdims=keepdim), + jnp.mean(self, dim, keepdims=keepdim), + ) + + +@op(torch.ops.aten.scalar_tensor) +def _aten_scalar_tensor( + s, dtype=None, layout=None, device=None, pin_memory=None +): + if dtype is not None: + dtype = tensor.t2j_dtype(dtype) + return jnp.array(s, dtype=dtype) + return jnp.array(s) + + +@op(torch.ops.aten.to.device) +def _aten_to_device(x, device, dtype): + return x + + +@op(torch.ops.aten.max_pool2d_with_indices_backward) +def max_pool2d_with_indices_backward_custom( + grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices +): + """ + Approximates the gradient calculation of PyTorch's max_pool2d_with_indices_backward. + + Args: + grad_output: The gradient tensor from the preceding layer. + self: The input tensor on which the original max pooling was performed. + kernel_size: The size of the pooling window. + stride: The stride of the pooling window. + padding: The padding applied during max pooling. + dilation: The dilation factor for the pooling operation. + ceil_mode: Whether to use ceil or floor when calculating output shapes. + indices: The indices of the maximum values, as produced by max_pool2d_with_indices. + + Returns: + The calculated gradient with respect to the input (grad_input). + """ + + kH, kW = kernel_size + dH, dW = stride + padH, padW = padding + dilH, dilW = dilation + + # Calculate output shape (may need adjustment based on ceil_mode) + out_shape = jnp.array(self.shape) + grad_input = jnp.zeros_like(self) + + # Iterate over the flattened input and output tensors + for i, idx in enumerate(indices.flatten()): + # Calculate input coordinates corresponding to the maximum value + out_y, out_x = i // grad_output.shape[3], i % grad_output.shape[3] + in_y = out_y * dH - padH + out_y * (dilH - 1) + in_x = out_x * dW - padW + out_x * (dilW - 1) + + # Scatter the gradient to the appropriate input locations (handling potential overlaps) + for y in range(in_y, in_y + kH): + for x in range(in_x, in_x + kW): + if 0 <= y < grad_input.shape[2] and 0 <= x < grad_input.shape[3]: + grad_input = grad_input.at[y, x].add(grad_output.flatten()[i]) + + return grad_input + + +@op(torch.ops.aten._local_scalar_dense) +def _aten_local_scalar_dense(x): + return x.item() + + +@op(torch.ops.aten.tensor_split.sections) +def _aten_tensor_split(ary, indices_or_sections, axis=0): + return jnp.array_split(ary, indices_or_sections, axis) + + +@op(torch.ops.aten.randn, needs_env=True) +def _randn( + *size, + generator=None, + out=None, + dtype=None, + layout=torch.strided, + device=None, + requires_grad=False, + pin_memory=False, + env=None, +): + shape = size + if len(shape) == 1 and isinstance(shape[0], (list, tuple)): + shape = shape[0] + key = env.get_and_rotate_prng_key() + res = jax.random.normal(key, shape) + if dtype is not None: + dtype = tensor.t2j_dtype(dtype) + res = res.astype(dtype) + return res + + +@op(torch.ops.aten.rand, needs_env=True) +def _rand( + *size, + generator=None, + out=None, + dtype=None, + layout=torch.strided, + device=None, + requires_grad=False, + pin_memory=False, + env=None, +): + shape = size + if len(shape) == 1 and isinstance(shape[0], (list, tuple)): + shape = shape[0] + key = env.get_and_rotate_prng_key() + res = jax.random.uniform(key, shape) + if dtype is not None: + dtype = tensor.t2j_dtype(dtype) + res = res.astype(dtype) + return res + + +@op(torch.ops.aten.scalar_tensor.default) +def _aten_scalar_tensor(val, **kwargs): + p = torch.ops.aten.scalar_tensor(val) + return tensor.t2j(p) + + +@op(torch.ops.aten.to.device) +def _aten_to_device(x, device, dtype): + return x + + +@op(torch.ops.aten.max_pool2d_with_indices_backward) +def max_pool2d_with_indices_backward_custom( + grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices +): + """ + Approximates the gradient calculation of PyTorch's max_pool2d_with_indices_backward. + + Args: + grad_output: The gradient tensor from the preceding layer. + self: The input tensor on which the original max pooling was performed. + kernel_size: The size of the pooling window. + stride: The stride of the pooling window. + padding: The padding applied during max pooling. + dilation: The dilation factor for the pooling operation. + ceil_mode: Whether to use ceil or floor when calculating output shapes. + indices: The indices of the maximum values, as produced by max_pool2d_with_indices. + + Returns: + The calculated gradient with respect to the input (grad_input). + """ + + kH, kW = kernel_size + dH, dW = stride + padH, padW = padding + dilH, dilW = dilation + + # Calculate output shape (may need adjustment based on ceil_mode) + out_shape = jnp.array(self.shape) + grad_input = jnp.zeros_like(self) + + # Iterate over the flattened input and output tensors + for i, idx in enumerate(indices.flatten()): + # Calculate input coordinates corresponding to the maximum value + out_y, out_x = i // grad_output.shape[3], i % grad_output.shape[3] + in_y = out_y * dH - padH + out_y * (dilH - 1) + in_x = out_x * dW - padW + out_x * (dilW - 1) + + # Scatter the gradient to the appropriate input locations (handling potential overlaps) + for y in range(in_y, in_y + kH): + for x in range(in_x, in_x + kW): + if 0 <= y < grad_input.shape[2] and 0 <= x < grad_input.shape[3]: + grad_input = grad_input.at[y, x].add(grad_output.flatten()[i]) + + return grad_input + + +@op(torch.ops.aten._local_scalar_dense) +def _aten_local_scalar_dense(x): + return x.item() + + +@op(torch.ops.aten.tensor_split.sections) +def _aten_tensor_split(ary, indices_or_sections, axis=0): + return jnp.array_split(ary, indices_or_sections, axis) + + +@op(torch.ops.aten.outer) +def _aten_outer(a, b): + return jnp.outer(a, b) + + +@op(torch.ops.aten.allclose) +def _aten_allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False): + return jnp.allclose(input, other, rtol, atol, equal_nan) + diff --git a/experimental/torch_xla2/torch_xla2/ops/jtorch.py b/experimental/torch_xla2/torch_xla2/ops/jtorch.py index 6628b7e9510..ddc04fa4b1b 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jtorch.py +++ b/experimental/torch_xla2/torch_xla2/ops/jtorch.py @@ -1,7 +1,116 @@ +"""Tensor constructor overrides""" +import functools +from typing import Callable, Optional, ParamSpec, Sequence + +import jax import torch +import jax.numpy as jnp +from torch_xla2 import tensor +from torch_xla2.ops.ops_registry import register_torch_function_op + +def register_function(torch_func, **kwargs): + return functools.partial(register_torch_function_op, torch_func, **kwargs) + + +P = ParamSpec('P') + + +def convert_dtype(use_default_dtype: bool = True): + """Converts `dtype` kwarg of function from torch to JAX. + + Args: + use_default_dtype: Whether to use torch default dtype if none is provided. + + Returns: + A decorator that wraps a JAX implementation of a torch function. + """ + + def decorator(func: Callable[P, torch.Tensor]): + + @functools.wraps(func) + def wrapper(*args: P.args, + dtype: Optional[torch.dtype] = None, + **kwargs: P.kwargs): + if not dtype and use_default_dtype: + dtype = torch.get_default_dtype() + jax_dtype = tensor.t2j_dtype(dtype) + + return func(*args, dtype=jax_dtype, **kwargs) + + return wrapper + + return decorator + + +@register_function(torch.tensor) +@convert_dtype(use_default_dtype=False) # Attempt to infer type from elements +def _tensor(data, *, dtype=None, **kwargs): + python_types_to_torch_types = { + bool: jnp.bool, + int: jnp.int64, + float: jnp.float32, + complex: jnp.complex64, + } + if not dtype: + leaves = jax.tree_util.tree_leaves(data) + if len(leaves) > 0: + dtype = python_types_to_torch_types.get(type(leaves[0])) + + return jnp.array( + data, dtype=dtype or tensor.t2j_dtype(torch.get_default_dtype())) + + +@register_function(torch.ones) +@convert_dtype() +def _ones(*size: int, dtype=None, **kwargs): + return jnp.ones(size, dtype) + + +@register_function(torch.zeros) +@convert_dtype() +def _zeros(*size: int, dtype=None, **kwargs): + return jnp.zeros(size, dtype) + + +@register_function(torch.eye) +@convert_dtype() +def _eye(n: int, m: Optional[int] = None, *, dtype=None, **kwargs): + return jnp.eye(n, m, dtype=dtype) + + +@register_function(torch.full) +@convert_dtype() +def _full(size: Sequence[int], fill_value, *, dtype=None, **kwargs): + # TODO: handle torch.Size + return jnp.full(size, fill_value, dtype=dtype) + + +@register_function(torch.allclose) +def _aten_allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False): + return jnp.allclose(input, other, rtol, atol, equal_nan) + +@register_function(torch.angle) +def _torch_angle(input): + return jnp.angle(input) +@register_function(torch.argsort) +def _torch_argsort(input, dim=-1, descending=False, stable=False): + expanded = False + if input == 0: + # for self of rank 0: + # torch.any(x, 0), torch.any(x, -1) works; + # torch.any(x, 1) throws out of bounds, so it's + # behavior is the same as a jnp array of rank 1 + expanded = True + input = jnp.expand_dims(input, 0) + res = jnp.argsort(input, axis=dim, descending=descending, + stable=stable) + if expanded: + res = res.squeeze() + return res -torch_ops_override = { - torch.allclose: torch.ops.aten.allclose -} \ No newline at end of file +@register_function(torch.einsum) +def _einsum(equation, *operands): + assert isinstance(equation, str), 'Only accept str equation' + return jnp.einsum(equation, *operands) \ No newline at end of file diff --git a/experimental/torch_xla2/torch_xla2/ops/op_base.py b/experimental/torch_xla2/torch_xla2/ops/op_base.py index 62df160edc9..983d20fb660 100644 --- a/experimental/torch_xla2/torch_xla2/ops/op_base.py +++ b/experimental/torch_xla2/torch_xla2/ops/op_base.py @@ -1,22 +1,11 @@ import torch -from torch_xla2 import extra - -class JaxOperator: - """This is a aten op backed by jax function.""" - - def __init__(self, jax_callable): - self.jax = jax_callable - - def __call__(self, *args, **kwargs): - # args are torch.Tensor - res = call_jax(self.jax, args, kwargs) - return res +from torch_xla2 import interop class BinaryOpWithPromotion: - def __init__(self, jax_callable): - self.jax = jax_callable + def __init__(self, inner): + self.inner = inner def _get_dtype(self, obj): if isinstance(obj, torch.Tensor): @@ -31,7 +20,7 @@ def _get_dtype(self, obj): def __call__(self, *args, **kwargs): # args are torch.Tensor - res = extra.torch_view(self.jax)(*args, **kwargs) + res = interop.torch_view(self.jax)(*args, **kwargs) dtype = torch.promote_types( self._get_dtype(args[0]), @@ -41,15 +30,6 @@ def __call__(self, *args, **kwargs): return res -class TorchLowering: - - def __init__(self, lowering): - self.lowering = lowering - - def __call__(self, *args, **kwargs): - return self.lowering(*args, **kwargs) - - class InplaceOp: def __init__(self, functional_op, position_to_mutate=0): @@ -58,7 +38,7 @@ def __init__(self, functional_op, position_to_mutate=0): def __call__(self, *args, **kwargs): to_mutate = args[0] - to_mutate._elem = self.functional(*args, **kwargs)._elem + to_mutate.copy_(self.functional(*args, **kwargs)) return to_mutate diff --git a/experimental/torch_xla2/torch_xla2/ops/ops_registry.py b/experimental/torch_xla2/torch_xla2/ops/ops_registry.py new file mode 100644 index 00000000000..e75d1549456 --- /dev/null +++ b/experimental/torch_xla2/torch_xla2/ops/ops_registry.py @@ -0,0 +1,47 @@ +import dataclasses +from torch_xla2.types import JaxCallable, TorchCallable + +from typing import Union, Dict + + +@dataclasses.dataclass +class Operator: + torch_op: TorchCallable + func: Union[TorchCallable, JaxCallable] + is_jax_function: bool + is_user_defined: bool + needs_env: bool + + +all_aten_ops: Dict[TorchCallable, Operator] = {} +all_torch_functions: Dict[TorchCallable, Operator] = {} + + +def register_torch_dispatch_op( + aten_op, impl_callable, + is_jax_function=True, + is_user_defined=False, + needs_env=False, +): + op = Operator( + aten_op, impl_callable, + is_jax_function=is_jax_function, + is_user_defined=is_user_defined, + needs_env=needs_env) + all_aten_ops[aten_op] = op + return impl_callable + + +def register_torch_function_op( + torch_func, impl_callable, + is_jax_function=True, + is_user_defined=False, + needs_env=False, +): + op = Operator( + torch_func, impl_callable, + is_jax_function=is_jax_function, + is_user_defined=is_user_defined, + needs_env=needs_env) + all_torch_functions[torch_func] = op + return impl_callable \ No newline at end of file diff --git a/experimental/torch_xla2/torch_xla2/ops_registry.py b/experimental/torch_xla2/torch_xla2/ops_registry.py deleted file mode 100644 index f1d115864d3..00000000000 --- a/experimental/torch_xla2/torch_xla2/ops_registry.py +++ /dev/null @@ -1,74 +0,0 @@ -import torch -import torch._decomp as decomp -import torch_xla2.decompositions - -class LoweringRegistry: - - def __init__(self): - self.registered_ops = {} - self.decomps = {} - - def lookup(self, op_or_name): - candidate = self._lookup(op_or_name) - if candidate is None: - if isinstance(op_or_name, torch._ops.OpOverloadPacket): - candidate = self._lookup(op_or_name.default) - if isinstance(op_or_name, torch._ops.OpOverload): - candidate = self._lookup(op_or_name.overloadpacket) - return candidate - - def _lookup(self, op): - candidate = self.registered_ops.get(op) - if candidate is None: - candidate = self.decomp.get(op) - return candidate - - def register(self, op, lowering): - if isinstance(op, torch._ops.OpOverloadPacket): - if hasattr(op, 'default'): - self.registered_ops[op.default] = lowering - self.registered_ops[op] = lowering - - -lowerings = LoweringRegistry() -EXTRA_DECOMP = decomp.get_decompositions([ - torch.ops.aten.upsample_nearest2d, - torch.ops.aten._native_batch_norm_legit.no_stats, - torch.ops.aten._adaptive_avg_pool2d, - torch.ops.aten._adaptive_avg_pool3d, - torch.ops.aten.grid_sampler_2d, - torch.ops.aten.native_dropout, - torch.ops.aten.reflection_pad1d, - torch.ops.aten.reflection_pad2d, - torch.ops.aten.reflection_pad3d, - torch.ops.aten.replication_pad1d, - torch.ops.aten.replication_pad2d, - torch.ops.aten.replication_pad3d, -]) -CORE_ATEN_DECOMP = decomp.core_aten_decompositions() -CORE_ATEN_DECOMP.update(EXTRA_DECOMP) -lowerings.decomp = CORE_ATEN_DECOMP - - -def _all_core_ops(): - """Yields all core ops.""" - import torch._ops - - for k, v in torch.ops.aten.__dict__.items(): - if k.startswith('__'): - continue - if k.startswith('_'): - continue - if isinstance(v, torch._ops.OpOverloadPacket): - for overload in v.overloads(): - op = getattr(v, overload) - if torch.Tag.core in op.tags: - yield v - break - - -def print_missing_ops(): - core_aten = set(_all_core_ops()) - existing = set(lowerings.registered_ops.keys()) - for v in core_aten - existing: - print(v) diff --git a/experimental/torch_xla2/torch_xla2/tensor.py b/experimental/torch_xla2/torch_xla2/tensor.py index 98953a8b04c..262bc95f566 100644 --- a/experimental/torch_xla2/torch_xla2/tensor.py +++ b/experimental/torch_xla2/torch_xla2/tensor.py @@ -1,53 +1,16 @@ -import functools +import contextlib import jax from jax import dlpack as jaxdl import jax.numpy as jnp import numpy import torch import torch.func -import torch._decomp.decompositions -from torch_xla2 import ops_registry import torch.utils._python_dispatch as torch_dispatch import torch.utils._pytree as torch_pytree import torch.utils.dlpack as torchdl -from torch_xla2.ops import jaten -from torch._subclasses.fake_tensor import FakeTensorMode -fake_mode = FakeTensorMode() - - -class XLADispatchMode(torch_dispatch.TorchDispatchMode): - - def __torch_dispatch__(self, fn, types, args=(), kwargs=None): - if fn in constructors: - args, kwargs = unwrap((args, kwargs)) - res = constructors[fn](*args, **kwargs) - return wrap(res) - - return fn(*args, **kwargs) - - -def _aten_arange(start, - end, - *, - dtype=None, - layout=None, - requires_grad=False, - device=None, - pin_memory=False): - return jnp.arange(start, end, 1) - - -def _aten_scalar_tensor(val, **kwargs): - p = torch.ops.aten.scalar_tensor(val) - return wrap(t2j(p)) - - -constructors = { - torch.ops.aten.scalar_tensor.default: _aten_scalar_tensor, - torch.ops.aten.arange.default: functools.partial(_aten_arange, 0), - torch.ops.aten.arange.start: _aten_arange, -} +class OperatorNotFound(Exception): + pass def wrap(jaxarray): @@ -61,7 +24,9 @@ def unwrap(torchtensors): def t2j(t): if isinstance(t, XLATensor2): return t._elem + is_bool = False if t.dtype == torch.bool: + is_bool = True t = t.to(torch.int8) if not t.is_contiguous(): @@ -82,7 +47,7 @@ def t2j(t): if t.dtype == torch.bfloat16: res = res.astype(jnp.bfloat16) - if t.dtype == torch.bool: + if is_bool: res = res.astype(jnp.bool_) return res @@ -97,48 +62,41 @@ def j2t(x): res = res.to(torch.bool) return res +TORCH_DTYPE_TO_JAX = { + torch.float16: jnp.dtype('float16'), + torch.bfloat16: jnp.dtype('bfloat16'), + torch.half: jnp.dtype('float16'), + torch.float32: jnp.dtype('float32'), + torch.double: jnp.dtype('double'), + torch.long: jnp.dtype('int64'), + torch.int32: jnp.dtype('int32'), + torch.int16: jnp.dtype('int16'), + torch.int8: jnp.dtype('int8'), + torch.uint8: jnp.dtype('uint8'), + torch.bool: jnp.dtype('bool_'), + torch.complex64: jnp.dtype('complex64'), + torch.complex128: jnp.dtype('complex128'), + None: None, +} + +JAX_DTYPE_TO_TORCH = {value: key for key, value in TORCH_DTYPE_TO_JAX.items()} def t2j_dtype(dtype): - return { - torch.float16: jnp.float16, - torch.bfloat16: jnp.bfloat16, - torch.half: jnp.float16, - torch.float32: jnp.float32, - torch.double: jnp.double, - torch.long: jnp.int64, - torch.int32: jnp.int32, - torch.int16: jnp.int16, - torch.int8: jnp.int8, - torch.uint8: jnp.uint8, - torch.bool: jnp.bool_, - torch.complex64: jnp.complex64, - torch.complex128: jnp.complex128, - }.get(dtype) + if dtype not in TORCH_DTYPE_TO_JAX: + raise RuntimeError(f'Attempting to convert unknown type: {dtype} to torch type,') + return TORCH_DTYPE_TO_JAX[dtype] def j2t_dtype(dtype): - return { - jnp.float16: torch.float16, - jnp.bfloat16: torch.bfloat16, - jnp.double: torch.double, - jnp.float32: torch.float32, - jnp.float16: torch.half, - jnp.int64: torch.long, - jnp.int32: torch.int32, - jnp.int16: torch.int16, - jnp.bool_: torch.bool, - jnp.complex64: torch.complex64, - }.get(dtype) - - -def move_to_device(t): - return XLATensor2(t2j(t)) + if dtype not in JAX_DTYPE_TO_TORCH: + raise RuntimeError(f'Attempting to convert unknown type: {dtype} to torch type,') + return JAX_DTYPE_TO_TORCH[dtype] class XLATensor2(torch.Tensor): @staticmethod - def __new__(cls, elem): + def __new__(cls, elem, env): dtype = j2t_dtype(elem.dtype) shape = list(elem.shape) for i, s in enumerate(shape): @@ -154,9 +112,10 @@ def __new__(cls, elem): requires_grad=False, ) - def __init__(self, elem: jax.Array): + def __init__(self, elem: jax.Array, env: 'Environment'): super().__init__() self._elem = elem + self._env = env def __str__(self): return "XLATensor2({} {})".format(str(type(self._elem)), str(self._elem)) @@ -178,7 +137,7 @@ def flatten(self, start_dim=0, end_dim=-1): new_shape = ( self._elem.shape[:start_dim] + (-1,) + self._elem.shape[end_dim:]) new_elem = jnp.reshape(self._elem, new_shape) - return XLATensor2(new_elem) + return XLATensor2(new_elem, self._env) # return torch.reshape(self, new_shape) def __setitem__(self, key, val): @@ -193,32 +152,17 @@ def type_as(self, other): @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): - kwargs = kwargs or {} - with jax.named_scope(func.name()): + env = None + for arg in torch_pytree.arg_tree_leaves(*args, **kwargs): + if isinstance(arg, XLATensor2): + env = arg._env + break - if isinstance(func, torch._ops.OpOverloadPacket): - return func(*args, **kwargs) - - if func.name() == 'aten::copy_': - x, y = args - x._elem = y._elem - return - - if func.overloadpacket in jaten.all_ops: - return jaten.all_ops[func.overloadpacket](*args, **kwargs) - - lowering = ops_registry.lowerings.lookup(func) - - if lowering is None: - raise RuntimeError("No lowering found for", func.name()) - - with XLADispatchMode(): - res = lowering(*args, **kwargs) - debug_accuracy(func, args, kwargs, res) - return res + with env: + return func(*args, **(kwargs or {})) def detach(self): - return XLATensor2(jax.lax.stop_gradient(self.jax())) + return XLATensor2(jax.lax.stop_gradient(self.jax()), self._env) def numpy(self) -> numpy.ndarray: import numpy as np @@ -231,6 +175,20 @@ def jax(self) -> jax.Array: def torch(self) -> torch.Tensor: return j2t(self.jax()) + def to(self, *args, **kwargs): + if len(args) == 1: + if isinstance(args[0], torch.dtype): + return XLATensor2(self._elem.astype(t2j_dtype(args[0])), self._env) + if 'dtype' in kwargs: + dtype = kwargs['dtype'] + return XLATensor2(self._elem.astype(t2j_dtype(dtype)), self._env) + return self + + @property + def dtype(self): + return j2t_dtype(self._elem.dtype) + + # TODO: slice of slice should also be another slice class SliceView(XLATensor2): @@ -281,3 +239,159 @@ def debug_accuracy(func, args, kwargs, current_output): pdb.set_trace() return True + + +class XLAFunctionMode(torch.overrides.TorchFunctionMode): + """Context manager that dispatches torch function calls to JAX.""" + + def __init__(self, env): + self.env = env + + def __torch_function__(self, + func, + types, + args=(), + kwargs=None) -> torch.Tensor: + try: + return self.env.dispatch(func, types, args, kwargs) + except OperatorNotFound: + return func(*args, **(kwargs or {})) + + +class XLADispatchMode(torch_dispatch.TorchDispatchMode): + + def __init__(self, env): + self.env = env + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + if isinstance(func, torch._ops.OpOverloadPacket): + with self: + return func(*args, **kwargs) + if func.namespace != 'aten': + return func(*args, **kwargs) + return self.env.dispatch(func, types, args, kwargs) + +def _name_of_func(func): + if hasattr(func, 'name'): + return func.name() + return func.__name__ + + +class Environment(contextlib.ContextDecorator): + """This class holds a set of configurations and "globals" needed + + for executing torch program using jax. + Things included so far: + + op registry + PRNGKey + Configs + + Also helper functions to manipulate those. + """ + + _prng_key: jax.random.PRNGKey + + + def __init__(self, random_seed): + self._prng_key = jax.random.PRNGKey(random_seed) + self._function_mode = XLAFunctionMode(self) + self._dispatch_mode = XLADispatchMode(self) + + # name is torch callable + self._ops = {} + self.load_ops() + + def load_ops(self): + from torch_xla2.ops import jaten, jtorch, ops_registry + self._ops.update(ops_registry.all_aten_ops) + self._ops.update(ops_registry.all_torch_functions) + + decomps = torch._decomp.core_aten_decompositions() + from torch_xla2.decompositions import EXTRA_DECOMP + decomps.update(EXTRA_DECOMP) + for k, v in decomps.items(): + if k not in self._ops: + self._ops[k] = ops_registry.Operator( + k, + v, + is_jax_function=False, + is_user_defined=False, + needs_env=False + ) + + def get_and_rotate_prng_key(self): + self._prng_key, key = jax.random.split(self._prng_key) + return key + + def dispatch(self, func, types, args, kwargs): + with jax.named_scope(_name_of_func(func)): + kwargs = kwargs or {} + op = self._ops.get(func) + + if op is None and isinstance(func, torch._ops.OpOverloadPacket): + op = self._ops.get(func.default) + + if op is None and isinstance(func, torch._ops.OpOverload): + op = self._ops.get(func.overloadpacket) + + if op is None: + raise OperatorNotFound( + f'Operator with name {_name_of_func(func)} has no lowering') + + if op.is_jax_function: + args, kwargs = self.t2j_iso((args, kwargs)) + + if op.needs_env: + kwargs['env'] = self + + with self: + res = op.func(*args, **kwargs) + + if op.is_jax_function: + res = self.j2t_iso(res) + + #if self.config.debug_accuracy_for_each_op: + # debug_accuracy(func, args, kwargs, res) + return res + + def __enter__(self): + self._dispatch_mode.__enter__() + self._function_mode.__enter__() + return self + + def __exit__(self, *exc): + self._function_mode.__exit__(*exc) + self._dispatch_mode.__exit__(*exc) + + def _move_one_value(self, val): + if isinstance(val, torch.nn.Module): + state_dict = self.to_xla(val.state_dict()) + val.load_state_dict(state_dict, assign=True) + return val + if isinstance(val, XLATensor2): + return val + if isinstance(val, torch.Tensor): + return XLATensor2(t2j(val), self) + return val + + def to_xla(self, torchvalues): + # tensors are torch.Tensors (not XLATensor) + res = torch_pytree.tree_map( + self._move_one_value, + torchvalues) + return res + + def t2j_iso(self, torchtensors): + return torch_pytree.tree_map_only( + XLATensor2, lambda x: x.jax(), torchtensors) + + def j2t_iso(self, jaxarray): + return torch_pytree.tree_map_only( + jnp.ndarray, lambda x: XLATensor2(x, self), jaxarray) + + def j2t_copy(self, args): + pass + + def j2t_copy(self, args): + pass diff --git a/experimental/torch_xla2/torch_xla2/types.py b/experimental/torch_xla2/torch_xla2/types.py new file mode 100644 index 00000000000..f39d530c18d --- /dev/null +++ b/experimental/torch_xla2/torch_xla2/types.py @@ -0,0 +1,12 @@ +from typing import TypeAlias, Callable, ParamSpec, Any, Union +import torch +import jax +import jax.numpy as jnp + + +P = ParamSpec('P') + +TorchValue: TypeAlias = Union[torch.Tensor, torch.dtype, 'TorchCallable', Any] +TorchCallable: TypeAlias = Callable[P, TorchValue] +JaxValue: TypeAlias = Union[jax.Array, jnp.dtype, 'JaxCallable', Any] +JaxCallable: TypeAlias = Callable[P, JaxValue] \ No newline at end of file