From 15571223bbac6f18bc7b11e4a2838e37e2dc104f Mon Sep 17 00:00:00 2001 From: Han Qi Date: Mon, 25 Nov 2024 19:54:42 -0800 Subject: [PATCH] Make use torch native for cpu tensor as True. --- .../torch_xla2/test/llama/test_llama.py | 172 +++++++++--------- experimental/torch_xla2/test/test_context.py | 7 + .../torch_xla2/test/test_core_aten_ops.py | 5 + .../torch_xla2/test/test_functions.py | 1 + experimental/torch_xla2/test/test_ops.py | 5 + .../test/test_unbounded_dynamism.py | 10 +- .../torch_xla2/torch_xla2/__init__.py | 10 + experimental/torch_xla2/torch_xla2/config.py | 2 +- experimental/torch_xla2/torch_xla2/export.py | 5 + .../torch_xla2/torch_xla2/ops/jaten.py | 3 + .../torch_xla2/torch_xla2/ops/jtorch.py | 11 +- experimental/torch_xla2/torch_xla2/tensor.py | 10 +- 12 files changed, 144 insertions(+), 97 deletions(-) diff --git a/experimental/torch_xla2/test/llama/test_llama.py b/experimental/torch_xla2/test/llama/test_llama.py index 03dfef27d0e..a47e8572186 100644 --- a/experimental/torch_xla2/test/llama/test_llama.py +++ b/experimental/torch_xla2/test/llama/test_llama.py @@ -12,101 +12,101 @@ class LlamaTest(test_base.TestCase): def test_can_run(self): - sample_args = ( - torch.randint(0, 32000, (1, 2048)), - torch.arange(0, 2048), - ) - sample_args = pytree.tree_map(tensor.t2j, sample_args) + with torch_xla2.default_env(): + sample_args = ( + torch.randint(0, 32000, (1, 2048), device='jax:0'), + torch.arange(0, 2048, device='jax:0'), + ) - model_args = llama_model.ModelArgs( - block_size=2048, - vocab_size=32000, - n_layer=2, - n_head=4, - dim=256, - ) - m = llama_model.Transformer(model_args) - m.to(torch.bfloat16) - m.setup_caches(1, 2048) + model_args = llama_model.ModelArgs( + block_size=2048, + vocab_size=32000, + n_layer=2, + n_head=4, + dim=256, + ) + m = llama_model.Transformer(model_args) + m.to(torch.bfloat16) + m.setup_caches(1, 2048) + m = m.to('jax') + + print(m(*sample_args)) - # NOTE: this API does NOT use torch export - weights, jax_func = torch_xla2.extract_jax(m) - print(jax_func(weights, sample_args)) def test_can_run_exportable(self): - model_args = model_exportable.ModelArgs( - vocab_size=32000, - n_layers=2, - n_heads=4, - dim=256, - ) - m = model_exportable.Transformer(model_args) - context_length = 2048 - input_shape_prefill = (1, context_length) - input_shape_decode = (1, 1) + model_args = model_exportable.ModelArgs( + vocab_size=32000, + n_layers=2, + n_heads=4, + dim=256, + ) + m = model_exportable.Transformer(model_args) + context_length = 2048 + input_shape_prefill = (1, context_length) + input_shape_decode = (1, 1) - def make_cache(args, batch_size): - n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads - n_local_heads = args.n_heads - n_local_kv_heads = n_kv_heads - n_rep = n_local_heads // n_local_kv_heads - head_dim = args.dim // args.n_heads - res = [] - for i in range(args.n_layers): - if batch_size is None: - size = ( - args.max_seq_len, - n_local_kv_heads, - head_dim, - ) - else: - size = ( - batch_size, - args.max_seq_len, - n_local_kv_heads, - head_dim, - ) - res.append( - (torch.zeros( - size, - dtype=torch.bfloat16 if args.bf16_enable else torch.float), - torch.zeros( - size, - dtype=torch.bfloat16 if args.bf16_enable else torch.float))) - return res + def make_cache(args, batch_size): + n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads + n_local_heads = args.n_heads + n_local_kv_heads = n_kv_heads + n_rep = n_local_heads // n_local_kv_heads + head_dim = args.dim // args.n_heads + res = [] + for i in range(args.n_layers): + if batch_size is None: + size = ( + args.max_seq_len, + n_local_kv_heads, + head_dim, + ) + else: + size = ( + batch_size, + args.max_seq_len, + n_local_kv_heads, + head_dim, + ) + res.append( + (torch.zeros( + size, + dtype=torch.bfloat16 if args.bf16_enable else torch.float), + torch.zeros( + size, + dtype=torch.bfloat16 if args.bf16_enable else torch.float))) + return res - prefill_caches = make_cache(model_args, 1) + prefill_caches = make_cache(model_args, 1) - sample_input_prefill = ( - torch.randint(0, 1000, input_shape_prefill, - dtype=torch.int32), # len seq length - torch.arange(0, context_length, dtype=torch.int32), # input indexes - torch.arange(0, context_length, dtype=torch.int32), # context indexes - prefill_caches, - True, # prefil - ) - with torch.no_grad(): - m_prefill = torch.export.export(m, sample_input_prefill) + sample_input_prefill = ( + torch.randint(0, 1000, input_shape_prefill, + dtype=torch.int32), # len seq length + torch.arange(0, context_length, dtype=torch.int32), # input indexes + torch.arange(0, context_length, dtype=torch.int32), # context indexes + prefill_caches, + True, # prefil + ) + with torch.no_grad(): + m_prefill = torch.export.export(m, sample_input_prefill) - weights, mj_prefill = torch_xla2.export.exported_program_to_jax(m_prefill) - sample_inputs = pytree.tree_map_only(torch.Tensor, tensor.t2j, - sample_input_prefill) - print('Prefill', mj_prefill(weights, sample_inputs)) + weights, mj_prefill = torch_xla2.export.exported_program_to_jax(m_prefill) + sample_inputs = pytree.tree_map_only(torch.Tensor, tensor.t2j, + sample_input_prefill) + print('Prefill', mj_prefill(weights, sample_inputs)) - sample_input_decode = ( - torch.randint(0, 1000, input_shape_decode, - dtype=torch.int32), # len = 1 - torch.tensor([0], dtype=torch.int32), - torch.roll(torch.arange(context_length, dtype=torch.int32), 1, 0), - prefill_caches, - False # prefill - ) - with torch.no_grad(): - m_decode = torch.export.export(m, sample_input_decode) - weights, mj_decode = torch_xla2.export.exported_program_to_jax(m_decode) - sample_inputs = pytree.tree_map_only(torch.Tensor, tensor.t2j, - sample_input_decode) - print('Decode', mj_decode(weights, sample_inputs)) + sample_input_decode = ( + torch.randint(0, 1000, input_shape_decode, + dtype=torch.int32), # len = 1 + torch.tensor([0], dtype=torch.int32), + torch.roll(torch.arange(context_length, dtype=torch.int32), 1, 0), + prefill_caches, + False # prefill + ) + with torch.no_grad(): + m_decode = torch.export.export(m, sample_input_decode) + weights, mj_decode = torch_xla2.export.exported_program_to_jax(m_decode) + sample_inputs = pytree.tree_map_only(torch.Tensor, tensor.t2j, + sample_input_decode) + print('Decode', mj_decode(weights, sample_inputs)) if __name__ == "__main__": diff --git a/experimental/torch_xla2/test/test_context.py b/experimental/torch_xla2/test/test_context.py index 16bcedf7931..5255f415ee1 100644 --- a/experimental/torch_xla2/test/test_context.py +++ b/experimental/torch_xla2/test/test_context.py @@ -10,6 +10,13 @@ class TestContext(unittest.TestCase): + def setUp(self): + self.old_var = xla_env.config.use_torch_native_for_cpu_tensor + xla_env.config.use_torch_native_for_cpu_tensor = False + + def tearDown(self): + xla_env.config.use_torch_native_for_cpu_tensor = self.old_var + def test_mode_context_manager(self): with xla_env: x = torch.full((3, 3), -1) diff --git a/experimental/torch_xla2/test/test_core_aten_ops.py b/experimental/torch_xla2/test/test_core_aten_ops.py index d207bc22a82..e60086db087 100644 --- a/experimental/torch_xla2/test/test_core_aten_ops.py +++ b/experimental/torch_xla2/test/test_core_aten_ops.py @@ -66,6 +66,11 @@ def setUp(self): super().setUp() torch.manual_seed(0) self.env = tensor.Environment() + self.old_var = self.env.config.use_torch_native_for_cpu_tensor + self.env.config.use_torch_native_for_cpu_tensor = False + + def tearDown(self): + self.env.config.use_torch_native_for_cpu_tensor = self.old_var def test_aten_abs_0(self): args = (torch.randn((10, 10)).to(torch.float32),) diff --git a/experimental/torch_xla2/test/test_functions.py b/experimental/torch_xla2/test/test_functions.py index 9e291dc802a..aab34bd1472 100644 --- a/experimental/torch_xla2/test/test_functions.py +++ b/experimental/torch_xla2/test/test_functions.py @@ -10,6 +10,7 @@ class TestTorchFunctions(parameterized.TestCase): def setUp(self): self.env = torch_xla2.tensor.Environment() + self.env.config.use_torch_native_for_cpu_tensor = False torch_xla2.enable_accuracy_mode() @parameterized.named_parameters( diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 28d0f29f0c1..d79b35e533a 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -192,6 +192,11 @@ def setUp(self): torch_xla2.enable_accuracy_mode() #self.env.config.debug_accuracy_for_each_op = True torch.manual_seed(0) + self.old_var = self.env.config.use_torch_native_for_cpu_tensor + self.env.config.use_torch_native_for_cpu_tensor = False + + def tearDown(self): + self.env.config.use_torch_native_for_cpu_tensor = self.old_var # Replaces all values in the input torch_tensor that are less than the given threshold # with the threshold value itself. diff --git a/experimental/torch_xla2/test/test_unbounded_dynamism.py b/experimental/torch_xla2/test/test_unbounded_dynamism.py index 0cd800cb1a7..06d7b19b149 100644 --- a/experimental/torch_xla2/test/test_unbounded_dynamism.py +++ b/experimental/torch_xla2/test/test_unbounded_dynamism.py @@ -2,10 +2,10 @@ import sys import unittest -import numpy as np import torch from torch.export import Dim, export from torch_xla2.export import exported_program_to_stablehlo as exp2shlo +import torch_xla2 ## This file is copied from `xla/test/stablehlo/test_unbounded_dynamism.py` ## To test that torch_xla2 has identical behavior. @@ -44,6 +44,14 @@ def forward(self, *args): class UnboundedDynamismExportTest(unittest.TestCase): + def setUp(self): + self.env = torch_xla2.default_env() + self.env.config.use_torch_native_for_cpu_tensor = False + torch_xla2.enable_accuracy_mode() + + def tearDown(self): + self.env.config.use_torch_native_for_cpu_tensor = True + def test_add(self): args = (torch.rand((10, 197, 768)), torch.rand((10, 197, 768))) dynamic_shapes = (({0: Dim("dim")}, {0: Dim("dim")}),) diff --git a/experimental/torch_xla2/torch_xla2/__init__.py b/experimental/torch_xla2/torch_xla2/__init__.py index ef6cd058429..f36a0737c00 100644 --- a/experimental/torch_xla2/torch_xla2/__init__.py +++ b/experimental/torch_xla2/torch_xla2/__init__.py @@ -1,3 +1,4 @@ +import contextlib from typing import List, Dict, Any, Optional import dataclasses import jax @@ -73,6 +74,15 @@ def disable_globally(): global env default_env().__exit__(None, None, None) +@contextlib.contextmanager +def disable_temporarily(): + prev = default_env().enabled + if prev: + disable_globally() + yield() + if prev: + enable_globally() + torch.utils.rename_privateuse1_backend('jax') unsupported_dtype = [torch.quint8] diff --git a/experimental/torch_xla2/torch_xla2/config.py b/experimental/torch_xla2/torch_xla2/config.py index 8a0870996a2..351d137df57 100644 --- a/experimental/torch_xla2/torch_xla2/config.py +++ b/experimental/torch_xla2/torch_xla2/config.py @@ -14,5 +14,5 @@ class Configuration: # device treat_cuda_as_jax_device: bool = True - use_torch_native_for_cpu_tensor: bool = False + use_torch_native_for_cpu_tensor: bool = True internal_respect_torch_return_dtypes: bool = False diff --git a/experimental/torch_xla2/torch_xla2/export.py b/experimental/torch_xla2/torch_xla2/export.py index 2744d931de4..07bee48a43a 100644 --- a/experimental/torch_xla2/torch_xla2/export.py +++ b/experimental/torch_xla2/torch_xla2/export.py @@ -31,6 +31,10 @@ def call_function(self, target, args: Tuple, kwargs: Dict) -> Any: print('Running ', target.name(), '--------') op = ops_registry.all_aten_ops.get(target) + if op is None: + op = ops_registry.all_aten_ops.get(target.overloadpacket) + assert op is not None, target + assert op.is_jax_function, op if op is None: op = ops_registry.all_aten_ops.get(target.overloadpacket) if op is None: @@ -226,5 +230,6 @@ def exported_program_to_stablehlo(exported_program): """ weights, func = exported_program_to_jax(exported_program) jax_avals = extract_avals(exported_program) + print('avals', jax_avals) jax_export = jax.export.export(jax.jit(func))(weights, (jax_avals,)) return jax_export diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 1f9e9f4a045..57254d18028 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -133,6 +133,7 @@ def _aten_copy(x, y, memory_format=None): @op(torch.ops.aten.clone) +@op(torch.ops.aten.clone.default) def _aten_clone(x, memory_format=None): return x @@ -675,6 +676,7 @@ def _aten_dot(x, y): @op(torch.ops.aten._to_copy) +@op(torch.ops.aten._to_copy.default) def _aten__to_copy(self, **kwargs): dtype = mappings.t2j_dtype(kwargs["dtype"]) if dtype != self.dtype: @@ -1472,6 +1474,7 @@ def _aten_reflection_pad1d(input, padding): # aten.alias @op(torch.ops.aten.alias) +@op(torch.ops.aten.alias.default) def _aten_alias(self, *args): return self diff --git a/experimental/torch_xla2/torch_xla2/ops/jtorch.py b/experimental/torch_xla2/torch_xla2/ops/jtorch.py index f21c5b8f671..4d541cd04d1 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jtorch.py +++ b/experimental/torch_xla2/torch_xla2/ops/jtorch.py @@ -103,13 +103,12 @@ def _sdpa_reference(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor: L, S = query.size(-2), key.size(-2) scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale - attn_bias = torch.zeros(L, S, dtype=query.dtype) + attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) if is_causal: assert attn_mask is None - temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) + temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0) attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) attn_bias.to(query.dtype) - if attn_mask is not None: if attn_mask.dtype == torch.bool: attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) @@ -249,14 +248,14 @@ def _aten_isclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False): def _ones(*size: int, dtype=None, **kwargs): if len(size) == 1 and isinstance(size[0], collections.abc.Iterable): size = size[0] - return torch.ops.aten.ones(size, dtype=dtype) + return jaten._ones(size, dtype=dtype) -@register_function(torch.zeros, is_jax_function=False) +@register_function(torch.zeros, is_jax_function=True) def _zeros(*size: int, dtype=None, **kwargs): if len(size) == 1 and isinstance(size[0], collections.abc.Iterable): size = size[0] - return torch.ops.aten.zeros(size, dtype=dtype) + return jaten._zeros(size, dtype=dtype) @register_function(torch.eye) diff --git a/experimental/torch_xla2/torch_xla2/tensor.py b/experimental/torch_xla2/torch_xla2/tensor.py index 3f0b81e09ba..6424349b24c 100644 --- a/experimental/torch_xla2/torch_xla2/tensor.py +++ b/experimental/torch_xla2/torch_xla2/tensor.py @@ -247,6 +247,7 @@ def _name_of_func(func): torch.randn, torch.rand, torch.randint, + torch.full, } @@ -285,7 +286,8 @@ def get_as_jax_device(self, device: Any): if isinstance(device, torch.device): device = str(device) - if self.config.use_torch_native_for_cpu_tensor and device.startswith('cpu'): + if (self.config.use_torch_native_for_cpu_tensor and + not device.startswith('jax') and not device.startswith('cuda')): return None if not self.config.treat_cuda_as_jax_device and device.startswith('cuda'): @@ -358,7 +360,6 @@ def _handle_tensor_constructor(self, func, args, kwargs): # let torch handle it with mode_utils.no_dispatch(), torch._C.DisableTorchFunction(): return func(*args, **kwargs) - with jax.default_device(jax_device): op = self._ops.get(func) res = op.func(*args, **kwargs) @@ -396,7 +397,8 @@ def dispatch(self, func, types, args, kwargs): # If the func doesn't act on XLATensor2, and is not a tensor constructor, # We should skip and let torch handle it. - tensor_args = [t for t in args if isinstance(t, torch.Tensor)] + + tensor_args = [t for t in torch_pytree.tree_flatten(args)[0] if isinstance(t, torch.Tensor)] if tensor_args and all(not isinstance(t, XLATensor2) for t in tensor_args): return func(*args, **kwargs) @@ -444,11 +446,13 @@ def dispatch(self, func, types, args, kwargs): def __enter__(self): self._dispatch_mode.__enter__() self._function_mode.__enter__() + self.enabled = True return self def __exit__(self, *exc): self._function_mode.__exit__(*exc) self._dispatch_mode.__exit__(*exc) + self.enabled = False def _move_one_value(self, val): if isinstance(val, torch.nn.Module):