From 55dea5a882f75bfef55982747492fe7498705649 Mon Sep 17 00:00:00 2001 From: jonb377 Date: Wed, 4 Oct 2023 15:17:28 -0700 Subject: [PATCH] Move distributed checkpointing to a subdirectory of experimental (#5656) --- test/spmd/test_xla_distributed_checkpoint.py | 2 +- torch_xla/experimental/distributed_checkpoint/__init__.py | 6 ++++++ .../_helpers.py} | 0 .../planners.py} | 7 +------ 4 files changed, 8 insertions(+), 7 deletions(-) create mode 100644 torch_xla/experimental/distributed_checkpoint/__init__.py rename torch_xla/experimental/{_distributed_checkpoint_helpers.py => distributed_checkpoint/_helpers.py} (100%) rename torch_xla/experimental/{distributed_checkpoint.py => distributed_checkpoint/planners.py} (99%) diff --git a/test/spmd/test_xla_distributed_checkpoint.py b/test/spmd/test_xla_distributed_checkpoint.py index 7b3d5eb86fbd..276571e59793 100644 --- a/test/spmd/test_xla_distributed_checkpoint.py +++ b/test/spmd/test_xla_distributed_checkpoint.py @@ -16,7 +16,7 @@ create_default_global_save_plan, ) from torch_xla.experimental.distributed_checkpoint import SPMDLoadPlanner, SPMDSavePlanner -from torch_xla.experimental._distributed_checkpoint_helpers import ( +from torch_xla.experimental.distributed_checkpoint._helpers import ( _sharded_cpu_state_dict, _CpuShards, _is_sharded_tensor) diff --git a/torch_xla/experimental/distributed_checkpoint/__init__.py b/torch_xla/experimental/distributed_checkpoint/__init__.py new file mode 100644 index 000000000000..7c91aba0126d --- /dev/null +++ b/torch_xla/experimental/distributed_checkpoint/__init__.py @@ -0,0 +1,6 @@ +from .planners import SPMDSavePlanner, SPMDLoadPlanner + +__all__ = [ + "SPMDSavePlanner", + "SPMDLoadPlanner", +] diff --git a/torch_xla/experimental/_distributed_checkpoint_helpers.py b/torch_xla/experimental/distributed_checkpoint/_helpers.py similarity index 100% rename from torch_xla/experimental/_distributed_checkpoint_helpers.py rename to torch_xla/experimental/distributed_checkpoint/_helpers.py diff --git a/torch_xla/experimental/distributed_checkpoint.py b/torch_xla/experimental/distributed_checkpoint/planners.py similarity index 99% rename from torch_xla/experimental/distributed_checkpoint.py rename to torch_xla/experimental/distributed_checkpoint/planners.py index 5b1ee97b7d64..fbf466ff28a9 100644 --- a/torch_xla/experimental/distributed_checkpoint.py +++ b/torch_xla/experimental/distributed_checkpoint/planners.py @@ -35,16 +35,11 @@ from torch.distributed.checkpoint.utils import find_state_dict_object from torch.utils._pytree import tree_map from torch_xla.experimental.xla_sharding import XLAShardedTensor, XLAShard -from torch_xla.experimental._distributed_checkpoint_helpers import ( +from torch_xla.experimental.distributed_checkpoint._helpers import ( FLATTEN_MAPPING, flatten_state_dict, dedup_tensors, _is_sharded_tensor, set_element, narrow_tensor_by_index, _unwrap_xla_sharded_tensor, _CpuShards) from typing import Any, Dict, List, Tuple, Union -__all__ = [ - "SPMDSavePlanner", - "SPMDLoadPlanner", -] - class SPMDSavePlanner(SavePlanner): """