Skip to content

Commit

Permalink
Use regular torch.Tensor for CPU tensors (#8416)
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi authored Nov 27, 2024
1 parent 39e67b5 commit 20f5166
Show file tree
Hide file tree
Showing 14 changed files with 165 additions and 117 deletions.
31 changes: 16 additions & 15 deletions experimental/torch_xla2/examples/basic_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def matplotlib_imshow(img, one_channel=False):
plt.imshow(npimg, cmap="Greys")
else:
plt.imshow(np.transpose(npimg, (1, 2, 0)))

#torch_xla2.env.config.debug_print_each_op = True
#torch_xla2.env.config.debug_mixed_tensor = True
dataiter = iter(training_loader)
images, labels = next(dataiter)

Expand Down Expand Up @@ -80,15 +81,15 @@ def forward(self, x):
return x


model = GarmentClassifier()
model = GarmentClassifier().to('jax')

loss_fn = torch.nn.CrossEntropyLoss()

# NB: Loss functions expect data in batches, so we're creating batches of 4
# Represents the model's confidence in each of the 10 classes for a given input
dummy_outputs = torch.rand(4, 10)
dummy_outputs = torch.rand(4, 10, device='jax')
# Represents the correct class among the 10 being tested
dummy_labels = torch.tensor([1, 5, 3, 7])
dummy_labels = torch.tensor([1, 5, 3, 7], device='jax')

print(dummy_outputs)
print(dummy_labels)
Expand All @@ -110,6 +111,8 @@ def train_one_epoch(epoch_index, tb_writer=None):
# Every data instance is an input + label pair
# NEW: Move model to XLA device
inputs, labels = data
inputs = inputs.to('jax')
labels = labels.to('jax')

# Zero your gradients for every batch!
optimizer.zero_grad()
Expand Down Expand Up @@ -162,7 +165,9 @@ def train_one_epoch(epoch_index, tb_writer=None):
# Disable gradient computation and reduce memory consumption.
with torch.no_grad():
for i, vdata in enumerate(validation_loader):
# NOTE: move to XLA device
vinputs, vlabels = vdata
vinputs = vinputs.to('jax')
vlabels = vlabels.to('jax')
voutputs = model(vinputs) # call model's forward
vloss = loss_fn(voutputs, vlabels)
running_vloss += vloss
Expand All @@ -172,15 +177,11 @@ def train_one_epoch(epoch_index, tb_writer=None):

# Log the running loss averaged per batch
# for both training and validation
writer.add_scalars('Training vs. Validation Loss',
{ 'Training' : avg_loss, 'Validation' : avg_vloss },
epoch_number + 1)
writer.flush()

# Track best performance, and save the model's state
if avg_vloss < best_vloss:
best_vloss = avg_vloss
model_path = 'model_{}_{}'.format(timestamp, epoch_number)
torch.save(model.state_dict(), model_path)

# # Track best performance, and save the model's state
# if avg_vloss < best_vloss:
# best_vloss = avg_vloss
# model_path = 'model_{}_{}'.format(timestamp, epoch_number)
# torch.save(model.state_dict(), model_path)

epoch_number += 1
172 changes: 86 additions & 86 deletions experimental/torch_xla2/test/llama/test_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,101 +12,101 @@
class LlamaTest(test_base.TestCase):

def test_can_run(self):
sample_args = (
torch.randint(0, 32000, (1, 2048)),
torch.arange(0, 2048),
)
sample_args = pytree.tree_map(tensor.t2j, sample_args)
with torch_xla2.default_env():
sample_args = (
torch.randint(0, 32000, (1, 2048), device='jax:0'),
torch.arange(0, 2048, device='jax:0'),
)

model_args = llama_model.ModelArgs(
block_size=2048,
vocab_size=32000,
n_layer=2,
n_head=4,
dim=256,
)
m = llama_model.Transformer(model_args)
m.to(torch.bfloat16)
m.setup_caches(1, 2048)
model_args = llama_model.ModelArgs(
block_size=2048,
vocab_size=32000,
n_layer=2,
n_head=4,
dim=256,
)
m = llama_model.Transformer(model_args)
m.to(torch.bfloat16)
m.setup_caches(1, 2048)
m = m.to('jax')

print(m(*sample_args))

# NOTE: this API does NOT use torch export
weights, jax_func = torch_xla2.extract_jax(m)
print(jax_func(weights, sample_args))

def test_can_run_exportable(self):
model_args = model_exportable.ModelArgs(
vocab_size=32000,
n_layers=2,
n_heads=4,
dim=256,
)
m = model_exportable.Transformer(model_args)
context_length = 2048
input_shape_prefill = (1, context_length)
input_shape_decode = (1, 1)
model_args = model_exportable.ModelArgs(
vocab_size=32000,
n_layers=2,
n_heads=4,
dim=256,
)
m = model_exportable.Transformer(model_args)
context_length = 2048
input_shape_prefill = (1, context_length)
input_shape_decode = (1, 1)

def make_cache(args, batch_size):
n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
n_local_heads = args.n_heads
n_local_kv_heads = n_kv_heads
n_rep = n_local_heads // n_local_kv_heads
head_dim = args.dim // args.n_heads
res = []
for i in range(args.n_layers):
if batch_size is None:
size = (
args.max_seq_len,
n_local_kv_heads,
head_dim,
)
else:
size = (
batch_size,
args.max_seq_len,
n_local_kv_heads,
head_dim,
)
res.append(
(torch.zeros(
size,
dtype=torch.bfloat16 if args.bf16_enable else torch.float),
torch.zeros(
size,
dtype=torch.bfloat16 if args.bf16_enable else torch.float)))
return res
def make_cache(args, batch_size):
n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
n_local_heads = args.n_heads
n_local_kv_heads = n_kv_heads
n_rep = n_local_heads // n_local_kv_heads
head_dim = args.dim // args.n_heads
res = []
for i in range(args.n_layers):
if batch_size is None:
size = (
args.max_seq_len,
n_local_kv_heads,
head_dim,
)
else:
size = (
batch_size,
args.max_seq_len,
n_local_kv_heads,
head_dim,
)
res.append(
(torch.zeros(
size,
dtype=torch.bfloat16 if args.bf16_enable else torch.float),
torch.zeros(
size,
dtype=torch.bfloat16 if args.bf16_enable else torch.float)))
return res

prefill_caches = make_cache(model_args, 1)
prefill_caches = make_cache(model_args, 1)

sample_input_prefill = (
torch.randint(0, 1000, input_shape_prefill,
dtype=torch.int32), # len seq length
torch.arange(0, context_length, dtype=torch.int32), # input indexes
torch.arange(0, context_length, dtype=torch.int32), # context indexes
prefill_caches,
True, # prefil
)
with torch.no_grad():
m_prefill = torch.export.export(m, sample_input_prefill)
sample_input_prefill = (
torch.randint(0, 1000, input_shape_prefill,
dtype=torch.int32), # len seq length
torch.arange(0, context_length, dtype=torch.int32), # input indexes
torch.arange(0, context_length, dtype=torch.int32), # context indexes
prefill_caches,
True, # prefil
)
with torch.no_grad():
m_prefill = torch.export.export(m, sample_input_prefill)

weights, mj_prefill = torch_xla2.export.exported_program_to_jax(m_prefill)
sample_inputs = pytree.tree_map_only(torch.Tensor, tensor.t2j,
sample_input_prefill)
print('Prefill', mj_prefill(weights, sample_inputs))
weights, mj_prefill = torch_xla2.export.exported_program_to_jax(m_prefill)
sample_inputs = pytree.tree_map_only(torch.Tensor, tensor.t2j,
sample_input_prefill)
print('Prefill', mj_prefill(weights, sample_inputs))

sample_input_decode = (
torch.randint(0, 1000, input_shape_decode,
dtype=torch.int32), # len = 1
torch.tensor([0], dtype=torch.int32),
torch.roll(torch.arange(context_length, dtype=torch.int32), 1, 0),
prefill_caches,
False # prefill
)
with torch.no_grad():
m_decode = torch.export.export(m, sample_input_decode)
weights, mj_decode = torch_xla2.export.exported_program_to_jax(m_decode)
sample_inputs = pytree.tree_map_only(torch.Tensor, tensor.t2j,
sample_input_decode)
print('Decode', mj_decode(weights, sample_inputs))
sample_input_decode = (
torch.randint(0, 1000, input_shape_decode,
dtype=torch.int32), # len = 1
torch.tensor([0], dtype=torch.int32),
torch.roll(torch.arange(context_length, dtype=torch.int32), 1, 0),
prefill_caches,
False # prefill
)
with torch.no_grad():
m_decode = torch.export.export(m, sample_input_decode)
weights, mj_decode = torch_xla2.export.exported_program_to_jax(m_decode)
sample_inputs = pytree.tree_map_only(torch.Tensor, tensor.t2j,
sample_input_decode)
print('Decode', mj_decode(weights, sample_inputs))


if __name__ == "__main__":
Expand Down
7 changes: 7 additions & 0 deletions experimental/torch_xla2/test/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@

class TestContext(unittest.TestCase):

def setUp(self):
self.old_var = xla_env.config.use_torch_native_for_cpu_tensor
xla_env.config.use_torch_native_for_cpu_tensor = False

def tearDown(self):
xla_env.config.use_torch_native_for_cpu_tensor = self.old_var

def test_mode_context_manager(self):
with xla_env:
x = torch.full((3, 3), -1)
Expand Down
5 changes: 5 additions & 0 deletions experimental/torch_xla2/test/test_core_aten_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ def setUp(self):
super().setUp()
torch.manual_seed(0)
self.env = tensor.Environment()
self.old_var = self.env.config.use_torch_native_for_cpu_tensor
self.env.config.use_torch_native_for_cpu_tensor = False

def tearDown(self):
self.env.config.use_torch_native_for_cpu_tensor = self.old_var

def test_aten_abs_0(self):
args = (torch.randn((10, 10)).to(torch.float32),)
Expand Down
1 change: 1 addition & 0 deletions experimental/torch_xla2/test/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class TestTorchFunctions(parameterized.TestCase):

def setUp(self):
self.env = torch_xla2.tensor.Environment()
self.env.config.use_torch_native_for_cpu_tensor = False
torch_xla2.enable_accuracy_mode()

@parameterized.named_parameters(
Expand Down
9 changes: 6 additions & 3 deletions experimental/torch_xla2/test/test_libraries.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import unittest
import jax
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.library import Library, impl, impl_abstract
import torch_xla2
from torch_xla2 import tensor
import torch_xla2.export
from torch_xla2.ops import jaten
from torch_xla2.ops import jlibrary

Expand Down Expand Up @@ -56,6 +54,7 @@ class LibraryTest(unittest.TestCase):

def setUp(self):
torch.manual_seed(0)
torch_xla2.default_env().config.use_torch_native_for_cpu_tensor = False

def test_basic_sdpa_library(self):

Expand All @@ -78,3 +77,7 @@ def forward(self, q,k,v):
## stablehlo.composite ops.
self.assertIn("call @mylib.scaled_dot_product_attention", module_str)
self.assertIn("call @mylib.softmax", module_str)


if __name__ == '__main__':
unittest.main()
5 changes: 5 additions & 0 deletions experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,11 @@ def setUp(self):
torch_xla2.enable_accuracy_mode()
#self.env.config.debug_accuracy_for_each_op = True
torch.manual_seed(0)
self.old_var = self.env.config.use_torch_native_for_cpu_tensor
self.env.config.use_torch_native_for_cpu_tensor = False

def tearDown(self):
self.env.config.use_torch_native_for_cpu_tensor = self.old_var

# Replaces all values in the input torch_tensor that are less than the given threshold
# with the threshold value itself.
Expand Down
2 changes: 1 addition & 1 deletion experimental/torch_xla2/test/test_tf_integration.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import jax
import os
import tempfile
import numpy as np
import tensorflow as tf
import torch
import torch.nn.functional as F
Expand Down
10 changes: 9 additions & 1 deletion experimental/torch_xla2/test/test_unbounded_dynamism.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
import sys
import unittest

import numpy as np
import torch
from torch.export import Dim, export
from torch_xla2.export import exported_program_to_stablehlo as exp2shlo
import torch_xla2

## This file is copied from `xla/test/stablehlo/test_unbounded_dynamism.py`
## To test that torch_xla2 has identical behavior.
Expand Down Expand Up @@ -44,6 +44,14 @@ def forward(self, *args):

class UnboundedDynamismExportTest(unittest.TestCase):

def setUp(self):
self.env = torch_xla2.default_env()
self.env.config.use_torch_native_for_cpu_tensor = False
torch_xla2.enable_accuracy_mode()

def tearDown(self):
self.env.config.use_torch_native_for_cpu_tensor = True

def test_add(self):
args = (torch.rand((10, 197, 768)), torch.rand((10, 197, 768)))
dynamic_shapes = (({0: Dim("dim")}, {0: Dim("dim")}),)
Expand Down
10 changes: 10 additions & 0 deletions experimental/torch_xla2/torch_xla2/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
from typing import List, Dict, Any, Optional
import dataclasses
import jax
Expand Down Expand Up @@ -73,6 +74,15 @@ def disable_globally():
global env
default_env().__exit__(None, None, None)

@contextlib.contextmanager
def disable_temporarily():
prev = default_env().enabled
if prev:
disable_globally()
yield()
if prev:
enable_globally()


torch.utils.rename_privateuse1_backend('jax')
unsupported_dtype = [torch.quint8]
Expand Down
2 changes: 1 addition & 1 deletion experimental/torch_xla2/torch_xla2/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@ class Configuration:

# device
treat_cuda_as_jax_device: bool = True
use_torch_native_for_cpu_tensor: bool = False
use_torch_native_for_cpu_tensor: bool = True
internal_respect_torch_return_dtypes: bool = False
Loading

0 comments on commit 20f5166

Please sign in to comment.