Skip to content

Commit

Permalink
Make use torch native for cpu tensor as True.
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi committed Nov 26, 2024
1 parent 72da142 commit 1557122
Show file tree
Hide file tree
Showing 12 changed files with 144 additions and 97 deletions.
172 changes: 86 additions & 86 deletions experimental/torch_xla2/test/llama/test_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,101 +12,101 @@
class LlamaTest(test_base.TestCase):

def test_can_run(self):
sample_args = (
torch.randint(0, 32000, (1, 2048)),
torch.arange(0, 2048),
)
sample_args = pytree.tree_map(tensor.t2j, sample_args)
with torch_xla2.default_env():
sample_args = (
torch.randint(0, 32000, (1, 2048), device='jax:0'),
torch.arange(0, 2048, device='jax:0'),
)

model_args = llama_model.ModelArgs(
block_size=2048,
vocab_size=32000,
n_layer=2,
n_head=4,
dim=256,
)
m = llama_model.Transformer(model_args)
m.to(torch.bfloat16)
m.setup_caches(1, 2048)
model_args = llama_model.ModelArgs(
block_size=2048,
vocab_size=32000,
n_layer=2,
n_head=4,
dim=256,
)
m = llama_model.Transformer(model_args)
m.to(torch.bfloat16)
m.setup_caches(1, 2048)
m = m.to('jax')

print(m(*sample_args))

# NOTE: this API does NOT use torch export
weights, jax_func = torch_xla2.extract_jax(m)
print(jax_func(weights, sample_args))

def test_can_run_exportable(self):
model_args = model_exportable.ModelArgs(
vocab_size=32000,
n_layers=2,
n_heads=4,
dim=256,
)
m = model_exportable.Transformer(model_args)
context_length = 2048
input_shape_prefill = (1, context_length)
input_shape_decode = (1, 1)
model_args = model_exportable.ModelArgs(
vocab_size=32000,
n_layers=2,
n_heads=4,
dim=256,
)
m = model_exportable.Transformer(model_args)
context_length = 2048
input_shape_prefill = (1, context_length)
input_shape_decode = (1, 1)

def make_cache(args, batch_size):
n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
n_local_heads = args.n_heads
n_local_kv_heads = n_kv_heads
n_rep = n_local_heads // n_local_kv_heads
head_dim = args.dim // args.n_heads
res = []
for i in range(args.n_layers):
if batch_size is None:
size = (
args.max_seq_len,
n_local_kv_heads,
head_dim,
)
else:
size = (
batch_size,
args.max_seq_len,
n_local_kv_heads,
head_dim,
)
res.append(
(torch.zeros(
size,
dtype=torch.bfloat16 if args.bf16_enable else torch.float),
torch.zeros(
size,
dtype=torch.bfloat16 if args.bf16_enable else torch.float)))
return res
def make_cache(args, batch_size):
n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
n_local_heads = args.n_heads
n_local_kv_heads = n_kv_heads
n_rep = n_local_heads // n_local_kv_heads
head_dim = args.dim // args.n_heads
res = []
for i in range(args.n_layers):
if batch_size is None:
size = (
args.max_seq_len,
n_local_kv_heads,
head_dim,
)
else:
size = (
batch_size,
args.max_seq_len,
n_local_kv_heads,
head_dim,
)
res.append(
(torch.zeros(
size,
dtype=torch.bfloat16 if args.bf16_enable else torch.float),
torch.zeros(
size,
dtype=torch.bfloat16 if args.bf16_enable else torch.float)))
return res

prefill_caches = make_cache(model_args, 1)
prefill_caches = make_cache(model_args, 1)

sample_input_prefill = (
torch.randint(0, 1000, input_shape_prefill,
dtype=torch.int32), # len seq length
torch.arange(0, context_length, dtype=torch.int32), # input indexes
torch.arange(0, context_length, dtype=torch.int32), # context indexes
prefill_caches,
True, # prefil
)
with torch.no_grad():
m_prefill = torch.export.export(m, sample_input_prefill)
sample_input_prefill = (
torch.randint(0, 1000, input_shape_prefill,
dtype=torch.int32), # len seq length
torch.arange(0, context_length, dtype=torch.int32), # input indexes
torch.arange(0, context_length, dtype=torch.int32), # context indexes
prefill_caches,
True, # prefil
)
with torch.no_grad():
m_prefill = torch.export.export(m, sample_input_prefill)

weights, mj_prefill = torch_xla2.export.exported_program_to_jax(m_prefill)
sample_inputs = pytree.tree_map_only(torch.Tensor, tensor.t2j,
sample_input_prefill)
print('Prefill', mj_prefill(weights, sample_inputs))
weights, mj_prefill = torch_xla2.export.exported_program_to_jax(m_prefill)
sample_inputs = pytree.tree_map_only(torch.Tensor, tensor.t2j,
sample_input_prefill)
print('Prefill', mj_prefill(weights, sample_inputs))

sample_input_decode = (
torch.randint(0, 1000, input_shape_decode,
dtype=torch.int32), # len = 1
torch.tensor([0], dtype=torch.int32),
torch.roll(torch.arange(context_length, dtype=torch.int32), 1, 0),
prefill_caches,
False # prefill
)
with torch.no_grad():
m_decode = torch.export.export(m, sample_input_decode)
weights, mj_decode = torch_xla2.export.exported_program_to_jax(m_decode)
sample_inputs = pytree.tree_map_only(torch.Tensor, tensor.t2j,
sample_input_decode)
print('Decode', mj_decode(weights, sample_inputs))
sample_input_decode = (
torch.randint(0, 1000, input_shape_decode,
dtype=torch.int32), # len = 1
torch.tensor([0], dtype=torch.int32),
torch.roll(torch.arange(context_length, dtype=torch.int32), 1, 0),
prefill_caches,
False # prefill
)
with torch.no_grad():
m_decode = torch.export.export(m, sample_input_decode)
weights, mj_decode = torch_xla2.export.exported_program_to_jax(m_decode)
sample_inputs = pytree.tree_map_only(torch.Tensor, tensor.t2j,
sample_input_decode)
print('Decode', mj_decode(weights, sample_inputs))


if __name__ == "__main__":
Expand Down
7 changes: 7 additions & 0 deletions experimental/torch_xla2/test/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@

class TestContext(unittest.TestCase):

def setUp(self):
self.old_var = xla_env.config.use_torch_native_for_cpu_tensor
xla_env.config.use_torch_native_for_cpu_tensor = False

def tearDown(self):
xla_env.config.use_torch_native_for_cpu_tensor = self.old_var

def test_mode_context_manager(self):
with xla_env:
x = torch.full((3, 3), -1)
Expand Down
5 changes: 5 additions & 0 deletions experimental/torch_xla2/test/test_core_aten_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ def setUp(self):
super().setUp()
torch.manual_seed(0)
self.env = tensor.Environment()
self.old_var = self.env.config.use_torch_native_for_cpu_tensor
self.env.config.use_torch_native_for_cpu_tensor = False

def tearDown(self):
self.env.config.use_torch_native_for_cpu_tensor = self.old_var

def test_aten_abs_0(self):
args = (torch.randn((10, 10)).to(torch.float32),)
Expand Down
1 change: 1 addition & 0 deletions experimental/torch_xla2/test/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class TestTorchFunctions(parameterized.TestCase):

def setUp(self):
self.env = torch_xla2.tensor.Environment()
self.env.config.use_torch_native_for_cpu_tensor = False
torch_xla2.enable_accuracy_mode()

@parameterized.named_parameters(
Expand Down
5 changes: 5 additions & 0 deletions experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,11 @@ def setUp(self):
torch_xla2.enable_accuracy_mode()
#self.env.config.debug_accuracy_for_each_op = True
torch.manual_seed(0)
self.old_var = self.env.config.use_torch_native_for_cpu_tensor
self.env.config.use_torch_native_for_cpu_tensor = False

def tearDown(self):
self.env.config.use_torch_native_for_cpu_tensor = self.old_var

# Replaces all values in the input torch_tensor that are less than the given threshold
# with the threshold value itself.
Expand Down
10 changes: 9 additions & 1 deletion experimental/torch_xla2/test/test_unbounded_dynamism.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
import sys
import unittest

import numpy as np
import torch
from torch.export import Dim, export
from torch_xla2.export import exported_program_to_stablehlo as exp2shlo
import torch_xla2

## This file is copied from `xla/test/stablehlo/test_unbounded_dynamism.py`
## To test that torch_xla2 has identical behavior.
Expand Down Expand Up @@ -44,6 +44,14 @@ def forward(self, *args):

class UnboundedDynamismExportTest(unittest.TestCase):

def setUp(self):
self.env = torch_xla2.default_env()
self.env.config.use_torch_native_for_cpu_tensor = False
torch_xla2.enable_accuracy_mode()

def tearDown(self):
self.env.config.use_torch_native_for_cpu_tensor = True

def test_add(self):
args = (torch.rand((10, 197, 768)), torch.rand((10, 197, 768)))
dynamic_shapes = (({0: Dim("dim")}, {0: Dim("dim")}),)
Expand Down
10 changes: 10 additions & 0 deletions experimental/torch_xla2/torch_xla2/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
from typing import List, Dict, Any, Optional
import dataclasses
import jax
Expand Down Expand Up @@ -73,6 +74,15 @@ def disable_globally():
global env
default_env().__exit__(None, None, None)

@contextlib.contextmanager
def disable_temporarily():
prev = default_env().enabled
if prev:
disable_globally()
yield()
if prev:
enable_globally()


torch.utils.rename_privateuse1_backend('jax')
unsupported_dtype = [torch.quint8]
Expand Down
2 changes: 1 addition & 1 deletion experimental/torch_xla2/torch_xla2/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@ class Configuration:

# device
treat_cuda_as_jax_device: bool = True
use_torch_native_for_cpu_tensor: bool = False
use_torch_native_for_cpu_tensor: bool = True
internal_respect_torch_return_dtypes: bool = False
5 changes: 5 additions & 0 deletions experimental/torch_xla2/torch_xla2/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ def call_function(self, target, args: Tuple, kwargs: Dict) -> Any:
print('Running ', target.name(), '--------')

op = ops_registry.all_aten_ops.get(target)
if op is None:
op = ops_registry.all_aten_ops.get(target.overloadpacket)
assert op is not None, target
assert op.is_jax_function, op
if op is None:
op = ops_registry.all_aten_ops.get(target.overloadpacket)
if op is None:
Expand Down Expand Up @@ -226,5 +230,6 @@ def exported_program_to_stablehlo(exported_program):
"""
weights, func = exported_program_to_jax(exported_program)
jax_avals = extract_avals(exported_program)
print('avals', jax_avals)
jax_export = jax.export.export(jax.jit(func))(weights, (jax_avals,))
return jax_export
3 changes: 3 additions & 0 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def _aten_copy(x, y, memory_format=None):


@op(torch.ops.aten.clone)
@op(torch.ops.aten.clone.default)
def _aten_clone(x, memory_format=None):
return x

Expand Down Expand Up @@ -675,6 +676,7 @@ def _aten_dot(x, y):


@op(torch.ops.aten._to_copy)
@op(torch.ops.aten._to_copy.default)
def _aten__to_copy(self, **kwargs):
dtype = mappings.t2j_dtype(kwargs["dtype"])
if dtype != self.dtype:
Expand Down Expand Up @@ -1472,6 +1474,7 @@ def _aten_reflection_pad1d(input, padding):

# aten.alias
@op(torch.ops.aten.alias)
@op(torch.ops.aten.alias.default)
def _aten_alias(self, *args):
return self

Expand Down
11 changes: 5 additions & 6 deletions experimental/torch_xla2/torch_xla2/ops/jtorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,12 @@ def _sdpa_reference(query, key, value, attn_mask=None, dropout_p=0.0,
is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
L, S = query.size(-2), key.size(-2)
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
attn_bias = torch.zeros(L, S, dtype=query.dtype)
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
if is_causal:
assert attn_mask is None
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype)

if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
Expand Down Expand Up @@ -249,14 +248,14 @@ def _aten_isclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False):
def _ones(*size: int, dtype=None, **kwargs):
if len(size) == 1 and isinstance(size[0], collections.abc.Iterable):
size = size[0]
return torch.ops.aten.ones(size, dtype=dtype)
return jaten._ones(size, dtype=dtype)


@register_function(torch.zeros, is_jax_function=False)
@register_function(torch.zeros, is_jax_function=True)
def _zeros(*size: int, dtype=None, **kwargs):
if len(size) == 1 and isinstance(size[0], collections.abc.Iterable):
size = size[0]
return torch.ops.aten.zeros(size, dtype=dtype)
return jaten._zeros(size, dtype=dtype)


@register_function(torch.eye)
Expand Down
Loading

0 comments on commit 1557122

Please sign in to comment.