Skip to content

Commit

Permalink
ops refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi committed May 1, 2024
1 parent 2907ab3 commit 625e106
Show file tree
Hide file tree
Showing 18 changed files with 2,315 additions and 2,371 deletions.
5 changes: 1 addition & 4 deletions experimental/torch_xla2/test/llama/test_llama.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
import unittest
import jax
import torch
from torch._functorch.make_functional import make_functional_with_buffers
from torch_xla2 import tensor, ops # pylint: disable=unused-import
from torch_xla2 import tensor # pylint: disable=unused-import
import torch_xla2

from .. import test_base
Expand Down
8 changes: 5 additions & 3 deletions experimental/torch_xla2/test/test_context.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
import unittest

import torch
import torch_xla2
from torch_xla2 import tensor

xla_env = tensor.Environment(0)


class TestContext(unittest.TestCase):

def test_mode_context_manager(self):
with torch_xla2.mode():
with xla_env.mode():
x = torch.full((3, 3), -1)
self.assertIsInstance(x, tensor.XLATensor2)
y = x.abs()
self.assertIsInstance(y, tensor.XLATensor2)

@staticmethod
@torch_xla2.mode()
@xla_env.mode()
def _test_mode_decorator():
x = torch.full((3, 3), -1)
y = x.abs()
Expand Down
11 changes: 7 additions & 4 deletions experimental/torch_xla2/test/test_core_aten_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,14 @@ def run_export_and_compare(testcase,
rtol=1e-5,
equal_nan=True,
ignore_indices=False):

with testcase.subTest("torch_eval"):
res = func(*args, **kwargs)
with testcase.subTest("torch_xla2_eval"):
args2, kwargs2 = pytree.tree_map_only(torch.Tensor, tensor.move_to_device,
(args, kwargs))
res2 = func(*args2, **kwargs2)
with testcase.env.mode():
res2 = func(*args2, **kwargs2)
res2 = pytree.tree_map_only(tensor.XLATensor2, lambda t: t.torch(), res2)
# import pdb; pdb.set_trace()
with testcase.subTest("torch_xla2_diff:" + str(atol)):
Expand All @@ -61,11 +63,11 @@ class TestCoreAtenOps(unittest.TestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
ops_registry.print_missing_ops()

def setUp(self):
super().setUp()
torch.manual_seed(0)
self.env = tensor.Environment(0)

def test_aten_abs_0(self):
args = (torch.randn((10, 10)).to(torch.float32),)
Expand Down Expand Up @@ -2109,7 +2111,7 @@ def test_aten_logit_0(self):
def test_aten_logit_1(self):
args = (torch.randn((10, 10)).to(torch.float16),)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.logit, args, kwargs)
run_export_and_compare(self, torch.ops.aten.logit, args, kwargs, atol=0.01,)

def test_aten_logit_2(self):
args = (torch.randint(0, 10, (10, 10)).to(torch.int32),)
Expand Down Expand Up @@ -3640,7 +3642,8 @@ def _compare_sorted_result(self, args):
res = torch.ops.aten.sort(*args)
with self.subTest("torch_xla2_eval"):
args2 = pytree.tree_map_only(torch.Tensor, tensor.move_to_device, args)
res2 = torch.ops.aten.sort(*args2)
with self.env.mode():
res2 = torch.ops.aten.sort(*args2)

# The second argument is the sorted index. These might not be
# identical from torch vs. jax; but both can be correct
Expand Down
64 changes: 0 additions & 64 deletions experimental/torch_xla2/test/test_extra.py

This file was deleted.

5 changes: 4 additions & 1 deletion experimental/torch_xla2/test/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@

class TestTorchFunctions(parameterized.TestCase):

def setUp(self):
self.env = torch_xla2.tensor.Environment(0)

@parameterized.named_parameters(
('tensor_2d', lambda: torch.tensor([[0.1, 1.2], [2.2, 3.1], [4.9, 5.2]])),
('tensor_1d', lambda: torch.tensor([0, 1],)),
Expand All @@ -32,7 +35,7 @@ class TestTorchFunctions(parameterized.TestCase):
def test_tensor_constructor(self, func: Callable[[], torch.Tensor]):
expected = func()

with torch_xla2.functions.XLAFunctionMode():
with self.env.mode():
actual = func()
self.assertIsInstance(actual, torch_xla2.tensor.XLATensor2)

Expand Down
61 changes: 29 additions & 32 deletions experimental/torch_xla2/test/test_mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,46 +6,43 @@

class TestMutations(TestCase):

def test_add(self):
x = torch.tensor([1, 2, 3], dtype=torch.int32)
y = torch.tensor([4, 5, 6], dtype=torch.int32)
def setUp(self):
self.env = torch_xla2.tensor.Environment(0)

x = torch_xla2.tensor.move_to_device(x)
y = torch_xla2.tensor.move_to_device(y)
x.add_(y)
xt = torch_xla2.tensor.j2t(x._elem)
self.assertEqual(xt, torch.tensor([5, 7, 9], dtype=torch.int32))
def test_add(self):
with self.env.mode():
x = torch.tensor([1, 2, 3], dtype=torch.int32)
y = torch.tensor([4, 5, 6], dtype=torch.int32)
x.add_(y)
xt = torch_xla2.tensor.j2t(x._elem)
self.assertEqual(xt, torch.tensor([5, 7, 9], dtype=torch.int32))

def test_sub(self):
x = torch.tensor([1, 2, 3], dtype=torch.int32)
y = torch.tensor([4, 5, 6], dtype=torch.int32)

x = torch_xla2.tensor.move_to_device(x)
y = torch_xla2.tensor.move_to_device(y)
x.sub_(y)
xt = torch_xla2.tensor.j2t(x._elem)
self.assertEqual(xt, torch.tensor([-3, -3, -3], dtype=torch.int32))
with self.env.mode():
x = torch.tensor([1, 2, 3], dtype=torch.int32)
y = torch.tensor([4, 5, 6], dtype=torch.int32)
x.sub_(y)
xt = torch_xla2.tensor.j2t(x._elem)
self.assertEqual(xt, torch.tensor([-3, -3, -3], dtype=torch.int32))

def test_mul(self):
x = torch.tensor([1, 2, 3], dtype=torch.int32)
y = torch.tensor([4, 5, 6], dtype=torch.int32)
with self.env.mode():
x = torch.tensor([1, 2, 3], dtype=torch.int32)
y = torch.tensor([4, 5, 6], dtype=torch.int32)

x = torch_xla2.tensor.move_to_device(x)
y = torch_xla2.tensor.move_to_device(y)
x.mul_(y)
xt = torch_xla2.tensor.j2t(x._elem)
self.assertEqual(xt, torch.tensor([4, 10, 18], dtype=torch.int32))
x.mul_(y)
xt = torch_xla2.tensor.j2t(x._elem)
self.assertEqual(xt, torch.tensor([4, 10, 18], dtype=torch.int32))

def test_div(self):
x = torch.tensor([1, 2, 3], dtype=torch.int32)
y = torch.tensor([4, 5, 6], dtype=torch.int32)

x = torch_xla2.tensor.move_to_device(x)
y = torch_xla2.tensor.move_to_device(y)
x.div_(y)
xt = torch_xla2.tensor.j2t(x._elem)
self.assertEqual(xt,
torch.tensor([1. / 4, 2. / 5, 3. / 6], dtype=torch.float))
with self.env.mode():
x = torch.tensor([1, 2, 3], dtype=torch.int32)
y = torch.tensor([4, 5, 6], dtype=torch.int32)

x.div_(y)
xt = torch_xla2.tensor.j2t(x._elem)
self.assertEqual(xt,
torch.tensor([1. / 4, 2. / 5, 3. / 6], dtype=torch.float))


if __name__ == '__main__':
Expand Down
5 changes: 4 additions & 1 deletion experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,7 @@ def run_export_and_compare(testcase,
input2, args2, kwargs2 = pytree.tree_map_only(
torch.Tensor, tensor.move_to_device,
(sample_input.input, sample_input.args, sample_input.kwargs))
with torch_xla2.mode():
with testcase.env.mode():
res2 = func(input2, *args2, **kwargs2)
res2 = pytree.tree_map_only(tensor.XLATensor2, lambda t: t.torch(), res2)
with testcase.subTest("torch_xla2_diff:" + str(atol)):
Expand All @@ -655,6 +655,9 @@ class TestOpInfo(TestCase):
def setUpClass(cls):
print('op_db size: ', len(op_db), 'testing: ', len(ops_to_test))

def setUp(self):
self.env = tensor.Environment(0)

@ops(ops_to_test, allowed_dtypes=(torch.float32, torch.long))
def test_reference_eager(self, device, dtype, op):
sample_inputs = op.sample_inputs(device, dtype)
Expand Down
12 changes: 4 additions & 8 deletions experimental/torch_xla2/torch_xla2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,10 @@
jax.config.update('jax_enable_x64', True)


@contextlib.contextmanager
def mode():
with tensor.XLADispatchMode(), functions.XLAFunctionMode():
yield


def extract_jax(mod: torch.nn.Module):
def extract_jax(mod: torch.nn.Module, env=None):
"""Returns a pytree of jax.ndarray and a jax callable."""
if env is None:
env = tensor.Environment(0)
func, weights, buffer = make_functional.make_functional_with_buffers(mod)
states = (weights, buffer)
states = pytree.tree_map_only(torch.Tensor, tensor.t2j, states)
Expand All @@ -24,7 +20,7 @@ def extract_jax(mod: torch.nn.Module):
def jax_func(states, inputs):
(states, inputs) = tensor.wrap((states, inputs))
weights, buffer = states
with tensor.XLADispatchMode():
with env.mode():
res = func(weights, buffer, *inputs)
return tensor.unwrap(res)

Expand Down
Loading

0 comments on commit 625e106

Please sign in to comment.