From 9069d776b60524e566340bd30dfcff4dcc25b31f Mon Sep 17 00:00:00 2001 From: Han Qi Date: Tue, 26 Nov 2024 14:04:15 -0800 Subject: [PATCH] also fix examples --- .../torch_xla2/examples/basic_training.py | 31 ++++++++++--------- experimental/torch_xla2/torch_xla2/tensor.py | 1 + 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/experimental/torch_xla2/examples/basic_training.py b/experimental/torch_xla2/examples/basic_training.py index a723f647ca8..fb814fcf978 100644 --- a/experimental/torch_xla2/examples/basic_training.py +++ b/experimental/torch_xla2/examples/basic_training.py @@ -51,7 +51,8 @@ def matplotlib_imshow(img, one_channel=False): plt.imshow(npimg, cmap="Greys") else: plt.imshow(np.transpose(npimg, (1, 2, 0))) - +#torch_xla2.env.config.debug_print_each_op = True +#torch_xla2.env.config.debug_mixed_tensor = True dataiter = iter(training_loader) images, labels = next(dataiter) @@ -80,15 +81,15 @@ def forward(self, x): return x -model = GarmentClassifier() +model = GarmentClassifier().to('jax') loss_fn = torch.nn.CrossEntropyLoss() # NB: Loss functions expect data in batches, so we're creating batches of 4 # Represents the model's confidence in each of the 10 classes for a given input -dummy_outputs = torch.rand(4, 10) +dummy_outputs = torch.rand(4, 10, device='jax') # Represents the correct class among the 10 being tested -dummy_labels = torch.tensor([1, 5, 3, 7]) +dummy_labels = torch.tensor([1, 5, 3, 7], device='jax') print(dummy_outputs) print(dummy_labels) @@ -110,6 +111,8 @@ def train_one_epoch(epoch_index, tb_writer=None): # Every data instance is an input + label pair # NEW: Move model to XLA device inputs, labels = data + inputs = inputs.to('jax') + labels = labels.to('jax') # Zero your gradients for every batch! optimizer.zero_grad() @@ -162,7 +165,9 @@ def train_one_epoch(epoch_index, tb_writer=None): # Disable gradient computation and reduce memory consumption. with torch.no_grad(): for i, vdata in enumerate(validation_loader): - # NOTE: move to XLA device + vinputs, vlabels = vdata + vinputs = vinputs.to('jax') + vlabels = vlabels.to('jax') voutputs = model(vinputs) # call model's forward vloss = loss_fn(voutputs, vlabels) running_vloss += vloss @@ -172,15 +177,11 @@ def train_one_epoch(epoch_index, tb_writer=None): # Log the running loss averaged per batch # for both training and validation - writer.add_scalars('Training vs. Validation Loss', - { 'Training' : avg_loss, 'Validation' : avg_vloss }, - epoch_number + 1) - writer.flush() - - # Track best performance, and save the model's state - if avg_vloss < best_vloss: - best_vloss = avg_vloss - model_path = 'model_{}_{}'.format(timestamp, epoch_number) - torch.save(model.state_dict(), model_path) + + # # Track best performance, and save the model's state + # if avg_vloss < best_vloss: + # best_vloss = avg_vloss + # model_path = 'model_{}_{}'.format(timestamp, epoch_number) + # torch.save(model.state_dict(), model_path) epoch_number += 1 diff --git a/experimental/torch_xla2/torch_xla2/tensor.py b/experimental/torch_xla2/torch_xla2/tensor.py index 6424349b24c..35d69eb7326 100644 --- a/experimental/torch_xla2/torch_xla2/tensor.py +++ b/experimental/torch_xla2/torch_xla2/tensor.py @@ -248,6 +248,7 @@ def _name_of_func(func): torch.rand, torch.randint, torch.full, + torch.as_tensor, }