diff --git a/experimental/torch_xla2/examples/basic_training.py b/experimental/torch_xla2/examples/basic_training.py index a723f647ca8..fb814fcf978 100644 --- a/experimental/torch_xla2/examples/basic_training.py +++ b/experimental/torch_xla2/examples/basic_training.py @@ -51,7 +51,8 @@ def matplotlib_imshow(img, one_channel=False): plt.imshow(npimg, cmap="Greys") else: plt.imshow(np.transpose(npimg, (1, 2, 0))) - +#torch_xla2.env.config.debug_print_each_op = True +#torch_xla2.env.config.debug_mixed_tensor = True dataiter = iter(training_loader) images, labels = next(dataiter) @@ -80,15 +81,15 @@ def forward(self, x): return x -model = GarmentClassifier() +model = GarmentClassifier().to('jax') loss_fn = torch.nn.CrossEntropyLoss() # NB: Loss functions expect data in batches, so we're creating batches of 4 # Represents the model's confidence in each of the 10 classes for a given input -dummy_outputs = torch.rand(4, 10) +dummy_outputs = torch.rand(4, 10, device='jax') # Represents the correct class among the 10 being tested -dummy_labels = torch.tensor([1, 5, 3, 7]) +dummy_labels = torch.tensor([1, 5, 3, 7], device='jax') print(dummy_outputs) print(dummy_labels) @@ -110,6 +111,8 @@ def train_one_epoch(epoch_index, tb_writer=None): # Every data instance is an input + label pair # NEW: Move model to XLA device inputs, labels = data + inputs = inputs.to('jax') + labels = labels.to('jax') # Zero your gradients for every batch! optimizer.zero_grad() @@ -162,7 +165,9 @@ def train_one_epoch(epoch_index, tb_writer=None): # Disable gradient computation and reduce memory consumption. with torch.no_grad(): for i, vdata in enumerate(validation_loader): - # NOTE: move to XLA device + vinputs, vlabels = vdata + vinputs = vinputs.to('jax') + vlabels = vlabels.to('jax') voutputs = model(vinputs) # call model's forward vloss = loss_fn(voutputs, vlabels) running_vloss += vloss @@ -172,15 +177,11 @@ def train_one_epoch(epoch_index, tb_writer=None): # Log the running loss averaged per batch # for both training and validation - writer.add_scalars('Training vs. Validation Loss', - { 'Training' : avg_loss, 'Validation' : avg_vloss }, - epoch_number + 1) - writer.flush() - - # Track best performance, and save the model's state - if avg_vloss < best_vloss: - best_vloss = avg_vloss - model_path = 'model_{}_{}'.format(timestamp, epoch_number) - torch.save(model.state_dict(), model_path) + + # # Track best performance, and save the model's state + # if avg_vloss < best_vloss: + # best_vloss = avg_vloss + # model_path = 'model_{}_{}'.format(timestamp, epoch_number) + # torch.save(model.state_dict(), model_path) epoch_number += 1 diff --git a/experimental/torch_xla2/test/llama/test_llama.py b/experimental/torch_xla2/test/llama/test_llama.py index 03dfef27d0e..a47e8572186 100644 --- a/experimental/torch_xla2/test/llama/test_llama.py +++ b/experimental/torch_xla2/test/llama/test_llama.py @@ -12,101 +12,101 @@ class LlamaTest(test_base.TestCase): def test_can_run(self): - sample_args = ( - torch.randint(0, 32000, (1, 2048)), - torch.arange(0, 2048), - ) - sample_args = pytree.tree_map(tensor.t2j, sample_args) + with torch_xla2.default_env(): + sample_args = ( + torch.randint(0, 32000, (1, 2048), device='jax:0'), + torch.arange(0, 2048, device='jax:0'), + ) - model_args = llama_model.ModelArgs( - block_size=2048, - vocab_size=32000, - n_layer=2, - n_head=4, - dim=256, - ) - m = llama_model.Transformer(model_args) - m.to(torch.bfloat16) - m.setup_caches(1, 2048) + model_args = llama_model.ModelArgs( + block_size=2048, + vocab_size=32000, + n_layer=2, + n_head=4, + dim=256, + ) + m = llama_model.Transformer(model_args) + m.to(torch.bfloat16) + m.setup_caches(1, 2048) + m = m.to('jax') + + print(m(*sample_args)) - # NOTE: this API does NOT use torch export - weights, jax_func = torch_xla2.extract_jax(m) - print(jax_func(weights, sample_args)) def test_can_run_exportable(self): - model_args = model_exportable.ModelArgs( - vocab_size=32000, - n_layers=2, - n_heads=4, - dim=256, - ) - m = model_exportable.Transformer(model_args) - context_length = 2048 - input_shape_prefill = (1, context_length) - input_shape_decode = (1, 1) + model_args = model_exportable.ModelArgs( + vocab_size=32000, + n_layers=2, + n_heads=4, + dim=256, + ) + m = model_exportable.Transformer(model_args) + context_length = 2048 + input_shape_prefill = (1, context_length) + input_shape_decode = (1, 1) - def make_cache(args, batch_size): - n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads - n_local_heads = args.n_heads - n_local_kv_heads = n_kv_heads - n_rep = n_local_heads // n_local_kv_heads - head_dim = args.dim // args.n_heads - res = [] - for i in range(args.n_layers): - if batch_size is None: - size = ( - args.max_seq_len, - n_local_kv_heads, - head_dim, - ) - else: - size = ( - batch_size, - args.max_seq_len, - n_local_kv_heads, - head_dim, - ) - res.append( - (torch.zeros( - size, - dtype=torch.bfloat16 if args.bf16_enable else torch.float), - torch.zeros( - size, - dtype=torch.bfloat16 if args.bf16_enable else torch.float))) - return res + def make_cache(args, batch_size): + n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads + n_local_heads = args.n_heads + n_local_kv_heads = n_kv_heads + n_rep = n_local_heads // n_local_kv_heads + head_dim = args.dim // args.n_heads + res = [] + for i in range(args.n_layers): + if batch_size is None: + size = ( + args.max_seq_len, + n_local_kv_heads, + head_dim, + ) + else: + size = ( + batch_size, + args.max_seq_len, + n_local_kv_heads, + head_dim, + ) + res.append( + (torch.zeros( + size, + dtype=torch.bfloat16 if args.bf16_enable else torch.float), + torch.zeros( + size, + dtype=torch.bfloat16 if args.bf16_enable else torch.float))) + return res - prefill_caches = make_cache(model_args, 1) + prefill_caches = make_cache(model_args, 1) - sample_input_prefill = ( - torch.randint(0, 1000, input_shape_prefill, - dtype=torch.int32), # len seq length - torch.arange(0, context_length, dtype=torch.int32), # input indexes - torch.arange(0, context_length, dtype=torch.int32), # context indexes - prefill_caches, - True, # prefil - ) - with torch.no_grad(): - m_prefill = torch.export.export(m, sample_input_prefill) + sample_input_prefill = ( + torch.randint(0, 1000, input_shape_prefill, + dtype=torch.int32), # len seq length + torch.arange(0, context_length, dtype=torch.int32), # input indexes + torch.arange(0, context_length, dtype=torch.int32), # context indexes + prefill_caches, + True, # prefil + ) + with torch.no_grad(): + m_prefill = torch.export.export(m, sample_input_prefill) - weights, mj_prefill = torch_xla2.export.exported_program_to_jax(m_prefill) - sample_inputs = pytree.tree_map_only(torch.Tensor, tensor.t2j, - sample_input_prefill) - print('Prefill', mj_prefill(weights, sample_inputs)) + weights, mj_prefill = torch_xla2.export.exported_program_to_jax(m_prefill) + sample_inputs = pytree.tree_map_only(torch.Tensor, tensor.t2j, + sample_input_prefill) + print('Prefill', mj_prefill(weights, sample_inputs)) - sample_input_decode = ( - torch.randint(0, 1000, input_shape_decode, - dtype=torch.int32), # len = 1 - torch.tensor([0], dtype=torch.int32), - torch.roll(torch.arange(context_length, dtype=torch.int32), 1, 0), - prefill_caches, - False # prefill - ) - with torch.no_grad(): - m_decode = torch.export.export(m, sample_input_decode) - weights, mj_decode = torch_xla2.export.exported_program_to_jax(m_decode) - sample_inputs = pytree.tree_map_only(torch.Tensor, tensor.t2j, - sample_input_decode) - print('Decode', mj_decode(weights, sample_inputs)) + sample_input_decode = ( + torch.randint(0, 1000, input_shape_decode, + dtype=torch.int32), # len = 1 + torch.tensor([0], dtype=torch.int32), + torch.roll(torch.arange(context_length, dtype=torch.int32), 1, 0), + prefill_caches, + False # prefill + ) + with torch.no_grad(): + m_decode = torch.export.export(m, sample_input_decode) + weights, mj_decode = torch_xla2.export.exported_program_to_jax(m_decode) + sample_inputs = pytree.tree_map_only(torch.Tensor, tensor.t2j, + sample_input_decode) + print('Decode', mj_decode(weights, sample_inputs)) if __name__ == "__main__": diff --git a/experimental/torch_xla2/test/test_context.py b/experimental/torch_xla2/test/test_context.py index 16bcedf7931..5255f415ee1 100644 --- a/experimental/torch_xla2/test/test_context.py +++ b/experimental/torch_xla2/test/test_context.py @@ -10,6 +10,13 @@ class TestContext(unittest.TestCase): + def setUp(self): + self.old_var = xla_env.config.use_torch_native_for_cpu_tensor + xla_env.config.use_torch_native_for_cpu_tensor = False + + def tearDown(self): + xla_env.config.use_torch_native_for_cpu_tensor = self.old_var + def test_mode_context_manager(self): with xla_env: x = torch.full((3, 3), -1) diff --git a/experimental/torch_xla2/test/test_core_aten_ops.py b/experimental/torch_xla2/test/test_core_aten_ops.py index d207bc22a82..e60086db087 100644 --- a/experimental/torch_xla2/test/test_core_aten_ops.py +++ b/experimental/torch_xla2/test/test_core_aten_ops.py @@ -66,6 +66,11 @@ def setUp(self): super().setUp() torch.manual_seed(0) self.env = tensor.Environment() + self.old_var = self.env.config.use_torch_native_for_cpu_tensor + self.env.config.use_torch_native_for_cpu_tensor = False + + def tearDown(self): + self.env.config.use_torch_native_for_cpu_tensor = self.old_var def test_aten_abs_0(self): args = (torch.randn((10, 10)).to(torch.float32),) diff --git a/experimental/torch_xla2/test/test_functions.py b/experimental/torch_xla2/test/test_functions.py index 9e291dc802a..aab34bd1472 100644 --- a/experimental/torch_xla2/test/test_functions.py +++ b/experimental/torch_xla2/test/test_functions.py @@ -10,6 +10,7 @@ class TestTorchFunctions(parameterized.TestCase): def setUp(self): self.env = torch_xla2.tensor.Environment() + self.env.config.use_torch_native_for_cpu_tensor = False torch_xla2.enable_accuracy_mode() @parameterized.named_parameters( diff --git a/experimental/torch_xla2/test/test_libraries.py b/experimental/torch_xla2/test/test_libraries.py index 019c967db56..492d15467d5 100644 --- a/experimental/torch_xla2/test/test_libraries.py +++ b/experimental/torch_xla2/test/test_libraries.py @@ -1,11 +1,9 @@ import unittest -import jax import torch -import torch.nn as nn import torch.nn.functional as F from torch.library import Library, impl, impl_abstract import torch_xla2 -from torch_xla2 import tensor +import torch_xla2.export from torch_xla2.ops import jaten from torch_xla2.ops import jlibrary @@ -56,6 +54,7 @@ class LibraryTest(unittest.TestCase): def setUp(self): torch.manual_seed(0) + torch_xla2.default_env().config.use_torch_native_for_cpu_tensor = False def test_basic_sdpa_library(self): @@ -78,3 +77,7 @@ def forward(self, q,k,v): ## stablehlo.composite ops. self.assertIn("call @mylib.scaled_dot_product_attention", module_str) self.assertIn("call @mylib.softmax", module_str) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 28d0f29f0c1..d79b35e533a 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -192,6 +192,11 @@ def setUp(self): torch_xla2.enable_accuracy_mode() #self.env.config.debug_accuracy_for_each_op = True torch.manual_seed(0) + self.old_var = self.env.config.use_torch_native_for_cpu_tensor + self.env.config.use_torch_native_for_cpu_tensor = False + + def tearDown(self): + self.env.config.use_torch_native_for_cpu_tensor = self.old_var # Replaces all values in the input torch_tensor that are less than the given threshold # with the threshold value itself. diff --git a/experimental/torch_xla2/test/test_tf_integration.py b/experimental/torch_xla2/test/test_tf_integration.py index ff9da220c57..4562ba8cb0c 100644 --- a/experimental/torch_xla2/test/test_tf_integration.py +++ b/experimental/torch_xla2/test/test_tf_integration.py @@ -1,6 +1,6 @@ -import jax import os import tempfile +import numpy as np import tensorflow as tf import torch import torch.nn.functional as F diff --git a/experimental/torch_xla2/test/test_unbounded_dynamism.py b/experimental/torch_xla2/test/test_unbounded_dynamism.py index 0cd800cb1a7..06d7b19b149 100644 --- a/experimental/torch_xla2/test/test_unbounded_dynamism.py +++ b/experimental/torch_xla2/test/test_unbounded_dynamism.py @@ -2,10 +2,10 @@ import sys import unittest -import numpy as np import torch from torch.export import Dim, export from torch_xla2.export import exported_program_to_stablehlo as exp2shlo +import torch_xla2 ## This file is copied from `xla/test/stablehlo/test_unbounded_dynamism.py` ## To test that torch_xla2 has identical behavior. @@ -44,6 +44,14 @@ def forward(self, *args): class UnboundedDynamismExportTest(unittest.TestCase): + def setUp(self): + self.env = torch_xla2.default_env() + self.env.config.use_torch_native_for_cpu_tensor = False + torch_xla2.enable_accuracy_mode() + + def tearDown(self): + self.env.config.use_torch_native_for_cpu_tensor = True + def test_add(self): args = (torch.rand((10, 197, 768)), torch.rand((10, 197, 768))) dynamic_shapes = (({0: Dim("dim")}, {0: Dim("dim")}),) diff --git a/experimental/torch_xla2/torch_xla2/__init__.py b/experimental/torch_xla2/torch_xla2/__init__.py index ef6cd058429..f36a0737c00 100644 --- a/experimental/torch_xla2/torch_xla2/__init__.py +++ b/experimental/torch_xla2/torch_xla2/__init__.py @@ -1,3 +1,4 @@ +import contextlib from typing import List, Dict, Any, Optional import dataclasses import jax @@ -73,6 +74,15 @@ def disable_globally(): global env default_env().__exit__(None, None, None) +@contextlib.contextmanager +def disable_temporarily(): + prev = default_env().enabled + if prev: + disable_globally() + yield() + if prev: + enable_globally() + torch.utils.rename_privateuse1_backend('jax') unsupported_dtype = [torch.quint8] diff --git a/experimental/torch_xla2/torch_xla2/config.py b/experimental/torch_xla2/torch_xla2/config.py index 8a0870996a2..351d137df57 100644 --- a/experimental/torch_xla2/torch_xla2/config.py +++ b/experimental/torch_xla2/torch_xla2/config.py @@ -14,5 +14,5 @@ class Configuration: # device treat_cuda_as_jax_device: bool = True - use_torch_native_for_cpu_tensor: bool = False + use_torch_native_for_cpu_tensor: bool = True internal_respect_torch_return_dtypes: bool = False diff --git a/experimental/torch_xla2/torch_xla2/export.py b/experimental/torch_xla2/torch_xla2/export.py index 2744d931de4..3fdbedc8474 100644 --- a/experimental/torch_xla2/torch_xla2/export.py +++ b/experimental/torch_xla2/torch_xla2/export.py @@ -31,6 +31,10 @@ def call_function(self, target, args: Tuple, kwargs: Dict) -> Any: print('Running ', target.name(), '--------') op = ops_registry.all_aten_ops.get(target) + if op is None: + op = ops_registry.all_aten_ops.get(target.overloadpacket) + assert op is not None, target + assert op.is_jax_function, op if op is None: op = ops_registry.all_aten_ops.get(target.overloadpacket) if op is None: diff --git a/experimental/torch_xla2/torch_xla2/ops/jtorch.py b/experimental/torch_xla2/torch_xla2/ops/jtorch.py index f21c5b8f671..4d541cd04d1 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jtorch.py +++ b/experimental/torch_xla2/torch_xla2/ops/jtorch.py @@ -103,13 +103,12 @@ def _sdpa_reference(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor: L, S = query.size(-2), key.size(-2) scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale - attn_bias = torch.zeros(L, S, dtype=query.dtype) + attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) if is_causal: assert attn_mask is None - temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) + temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0) attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) attn_bias.to(query.dtype) - if attn_mask is not None: if attn_mask.dtype == torch.bool: attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) @@ -249,14 +248,14 @@ def _aten_isclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False): def _ones(*size: int, dtype=None, **kwargs): if len(size) == 1 and isinstance(size[0], collections.abc.Iterable): size = size[0] - return torch.ops.aten.ones(size, dtype=dtype) + return jaten._ones(size, dtype=dtype) -@register_function(torch.zeros, is_jax_function=False) +@register_function(torch.zeros, is_jax_function=True) def _zeros(*size: int, dtype=None, **kwargs): if len(size) == 1 and isinstance(size[0], collections.abc.Iterable): size = size[0] - return torch.ops.aten.zeros(size, dtype=dtype) + return jaten._zeros(size, dtype=dtype) @register_function(torch.eye) diff --git a/experimental/torch_xla2/torch_xla2/tensor.py b/experimental/torch_xla2/torch_xla2/tensor.py index d14eb9a68e1..35d69eb7326 100644 --- a/experimental/torch_xla2/torch_xla2/tensor.py +++ b/experimental/torch_xla2/torch_xla2/tensor.py @@ -247,6 +247,8 @@ def _name_of_func(func): torch.randn, torch.rand, torch.randint, + torch.full, + torch.as_tensor, } @@ -285,7 +287,8 @@ def get_as_jax_device(self, device: Any): if isinstance(device, torch.device): device = str(device) - if self.config.use_torch_native_for_cpu_tensor and device.startswith('cpu'): + if (self.config.use_torch_native_for_cpu_tensor and + not device.startswith('jax') and not device.startswith('cuda')): return None if not self.config.treat_cuda_as_jax_device and device.startswith('cuda'): @@ -338,7 +341,7 @@ def _to_copy(self, the_tensor, new_dtype, new_device): arr = jax.device_put(arr, jax_device) else: with mode_utils.no_dispatch(), torch._C.DisableTorchFunction(): - return torch_tensor.to(new_device) + return the_tensor.to(new_device) return XLATensor2(arr, self) @@ -358,7 +361,6 @@ def _handle_tensor_constructor(self, func, args, kwargs): # let torch handle it with mode_utils.no_dispatch(), torch._C.DisableTorchFunction(): return func(*args, **kwargs) - with jax.default_device(jax_device): op = self._ops.get(func) res = op.func(*args, **kwargs) @@ -396,7 +398,8 @@ def dispatch(self, func, types, args, kwargs): # If the func doesn't act on XLATensor2, and is not a tensor constructor, # We should skip and let torch handle it. - tensor_args = [t for t in args if isinstance(t, torch.Tensor)] + + tensor_args = [t for t in torch_pytree.tree_flatten(args)[0] if isinstance(t, torch.Tensor)] if tensor_args and all(not isinstance(t, XLATensor2) for t in tensor_args): return func(*args, **kwargs) @@ -444,11 +447,13 @@ def dispatch(self, func, types, args, kwargs): def __enter__(self): self._dispatch_mode.__enter__() self._function_mode.__enter__() + self.enabled = True return self def __exit__(self, *exc): self._function_mode.__exit__(*exc) self._dispatch_mode.__exit__(*exc) + self.enabled = False def _move_one_value(self, val): if isinstance(val, torch.nn.Module):