From 591c397922ad4506141796f78557f98f860cff87 Mon Sep 17 00:00:00 2001 From: JackCaoG <59073027+JackCaoG@users.noreply.github.com> Date: Mon, 2 Dec 2024 01:07:44 -0800 Subject: [PATCH] Reenable the distributed checkpointing test (#8424) --- test/tpu/run_tests.sh | 3 +-- torch_xla/csrc/tensor.cpp | 12 +++++++++++- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/test/tpu/run_tests.sh b/test/tpu/run_tests.sh index 8d5e74bde03..6ad06b07740 100755 --- a/test/tpu/run_tests.sh +++ b/test/tpu/run_tests.sh @@ -8,8 +8,7 @@ python3 test/pjrt/test_collective_ops_tpu.py python3 test/spmd/test_mp_input_sharding.py python3 test/spmd/test_xla_sharding.py python3 test/spmd/test_xla_virtual_device.py -# TODO(JackCaoG): to reenable -# python3 test/spmd/test_xla_distributed_checkpoint.py +python3 test/spmd/test_xla_distributed_checkpoint.py python3 test/spmd/test_train_spmd_linear_model.py python3 test/spmd/test_xla_spmd_python_api_interaction.py python3 test/spmd/test_xla_auto_sharding.py diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index 095a6ce4163..01306c53d38 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -562,7 +562,17 @@ void XLATensor::UpdateFromTensor(at::Tensor tensor, bool sync) { at::Tensor coyped_tensor = torch::lazy::CopyTensor(tensor, dtype()); SetTensorData(coyped_tensor); data()->handle = nullptr; - data()->sharding = nullptr; + // if shape is different, + if (data()->sharding) { + auto coyped_tensor_dims = XlaHelpers::I64List(coyped_tensor.sizes()); + auto sharding_dims = data()->sharding->shape.dimensions(); + if (coyped_tensor_dims != sharding_dims) { + // sharding shape from origional tensor is different from the new cpu + // tensor, we need to clear the sharding here. + ClearShardingSpec(); + } + } + // ClearShardingSpec(); AssignIrValue(torch::lazy::Value()); if (data()->view != nullptr) { torch::lazy::Value ir_value = GetIrValueForTensor(coyped_tensor, device);