Skip to content

Commit

Permalink
also fix examples
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi committed Nov 26, 2024
1 parent 1557122 commit 9069d77
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 15 deletions.
31 changes: 16 additions & 15 deletions experimental/torch_xla2/examples/basic_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def matplotlib_imshow(img, one_channel=False):
plt.imshow(npimg, cmap="Greys")
else:
plt.imshow(np.transpose(npimg, (1, 2, 0)))

#torch_xla2.env.config.debug_print_each_op = True
#torch_xla2.env.config.debug_mixed_tensor = True
dataiter = iter(training_loader)
images, labels = next(dataiter)

Expand Down Expand Up @@ -80,15 +81,15 @@ def forward(self, x):
return x


model = GarmentClassifier()
model = GarmentClassifier().to('jax')

loss_fn = torch.nn.CrossEntropyLoss()

# NB: Loss functions expect data in batches, so we're creating batches of 4
# Represents the model's confidence in each of the 10 classes for a given input
dummy_outputs = torch.rand(4, 10)
dummy_outputs = torch.rand(4, 10, device='jax')
# Represents the correct class among the 10 being tested
dummy_labels = torch.tensor([1, 5, 3, 7])
dummy_labels = torch.tensor([1, 5, 3, 7], device='jax')

print(dummy_outputs)
print(dummy_labels)
Expand All @@ -110,6 +111,8 @@ def train_one_epoch(epoch_index, tb_writer=None):
# Every data instance is an input + label pair
# NEW: Move model to XLA device
inputs, labels = data
inputs = inputs.to('jax')
labels = labels.to('jax')

# Zero your gradients for every batch!
optimizer.zero_grad()
Expand Down Expand Up @@ -162,7 +165,9 @@ def train_one_epoch(epoch_index, tb_writer=None):
# Disable gradient computation and reduce memory consumption.
with torch.no_grad():
for i, vdata in enumerate(validation_loader):
# NOTE: move to XLA device
vinputs, vlabels = vdata
vinputs = vinputs.to('jax')
vlabels = vlabels.to('jax')
voutputs = model(vinputs) # call model's forward
vloss = loss_fn(voutputs, vlabels)
running_vloss += vloss
Expand All @@ -172,15 +177,11 @@ def train_one_epoch(epoch_index, tb_writer=None):

# Log the running loss averaged per batch
# for both training and validation
writer.add_scalars('Training vs. Validation Loss',
{ 'Training' : avg_loss, 'Validation' : avg_vloss },
epoch_number + 1)
writer.flush()

# Track best performance, and save the model's state
if avg_vloss < best_vloss:
best_vloss = avg_vloss
model_path = 'model_{}_{}'.format(timestamp, epoch_number)
torch.save(model.state_dict(), model_path)

# # Track best performance, and save the model's state
# if avg_vloss < best_vloss:
# best_vloss = avg_vloss
# model_path = 'model_{}_{}'.format(timestamp, epoch_number)
# torch.save(model.state_dict(), model_path)

epoch_number += 1
1 change: 1 addition & 0 deletions experimental/torch_xla2/torch_xla2/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ def _name_of_func(func):
torch.rand,
torch.randint,
torch.full,
torch.as_tensor,
}


Expand Down

0 comments on commit 9069d77

Please sign in to comment.