diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 0718b2a60889..7946d9b9c197 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -14,7 +14,12 @@ from colossalai.cluster import DistCoordinator from colossalai.interface import ModelWrapper, OptimizerWrapper -from colossalai.tensor.p_tensor import init_as_ptensor, is_padded_tensor, to_padded_tensor, to_unpadded_tensor +from colossalai.tensor.padded_tensor import ( + init_as_padded_tensor, + is_padded_tensor, + to_padded_tensor, + to_unpadded_tensor, +) from colossalai.utils import get_current_device from .general_checkpoint_io import GeneralCheckpointIO @@ -873,7 +878,7 @@ def gather_from_sharded_optimizer_state( padding_dim = search_padding_dim(v.shape, original_shape) if padding_dim is not None: - v = init_as_ptensor(v, v.shape[padding_dim], original_shape[padding_dim], padding_dim) + v = init_as_padded_tensor(v, v.shape[padding_dim], original_shape[padding_dim], padding_dim) v = to_unpadded_tensor(v) state_[k] = v.detach().clone().to(device) diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index a3234663c26d..6197be9d1c8d 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -19,7 +19,6 @@ to_global, to_global_for_customized_distributed_tensor, ) -from colossalai.tensor.p_tensor.api import init_as_ptensor, is_padded_tensor SAFE_WEIGHTS_NAME = "model.safetensors" WEIGHTS_NAME = "pytorch_model.bin" @@ -208,13 +207,11 @@ def gather_distributed_param(param: torch.Tensor, keep_vars: bool = False) -> to """ param_ = param if keep_vars else param.detach() if is_distributed_tensor(param_): - param_ = to_global(param_) + return to_global(param_) elif is_customized_distributed_tensor(param_): - param_ = to_global_for_customized_distributed_tensor(param_) - - if is_padded_tensor(param): - param_ = init_as_ptensor(param_, param.current_length, param.origin_length, param.padding_dim) - return param_ + return to_global_for_customized_distributed_tensor(param_) + else: + return param_ def save_state_dict_shards( diff --git a/colossalai/shardformer/layer/parallel_module.py b/colossalai/shardformer/layer/parallel_module.py index 55114281d1ac..11ef73538c36 100644 --- a/colossalai/shardformer/layer/parallel_module.py +++ b/colossalai/shardformer/layer/parallel_module.py @@ -20,7 +20,7 @@ is_distributed_tensor, sharded_tensor_to_param, ) -from colossalai.tensor.p_tensor import is_padded_tensor, to_padded_tensor, to_unpadded_tensor +from colossalai.tensor.padded_tensor import is_padded_tensor, to_padded_tensor, to_unpadded_tensor __all__ = ["ParallelModule"] @@ -297,8 +297,7 @@ def _load_from_state_dict( continue if is_padded_tensor(param): - print("is_padded_tensor(param)", is_padded_tensor(param)) - input_param = to_padded_tensor(input_param, param.current_length, param.padding_dim) + input_param = to_padded_tensor(input_param, param._current_length, param._padding_dim) if is_distributed_tensor(param): # shard the input param diff --git a/colossalai/tensor/d_tensor/layout_converter.py b/colossalai/tensor/d_tensor/layout_converter.py index 667a7b78e4f5..7bfd7c5357b5 100644 --- a/colossalai/tensor/d_tensor/layout_converter.py +++ b/colossalai/tensor/d_tensor/layout_converter.py @@ -10,6 +10,7 @@ from colossalai.tensor.d_tensor.comm_spec import * from colossalai.tensor.d_tensor.layout import Layout from colossalai.tensor.d_tensor.misc import LayoutException +from colossalai.tensor.padded_tensor.api import init_as_padded_tensor, is_padded_tensor from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator from .sharding_spec import ShardingSpec @@ -607,8 +608,16 @@ def apply(self, tensor: torch.Tensor, source_layout: Layout, target_layout: Layo [3.], [3.]]) """ + _, comm_action_sequence = self.layout_converting(source_layout, target_layout) for comm_spec in comm_action_sequence: - tensor = comm_spec.covert_spec_to_action(tensor) - tensor.dist_layout = target_layout - return tensor + target_tensor = comm_spec.covert_spec_to_action(tensor) + target_tensor.dist_layout = target_layout + + # restore the padding information + if is_padded_tensor(tensor) and not is_padded_tensor(target_tensor): + target_tensor = init_as_padded_tensor( + target_tensor, tensor._current_length, tensor._origin_length, tensor._padding_dim + ) + + return target_tensor diff --git a/colossalai/tensor/p_tensor/__init__.py b/colossalai/tensor/p_tensor/__init__.py deleted file mode 100644 index 84490fc2a538..000000000000 --- a/colossalai/tensor/p_tensor/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .api import init_as_ptensor, is_padded_tensor, to_padded_tensor, to_unpadded_tensor - -__all__ = ["is_padded_tensor", "to_padded_tensor", "to_unpadded_tensor", "init_as_ptensor"] diff --git a/colossalai/tensor/padded_tensor/__init__.py b/colossalai/tensor/padded_tensor/__init__.py new file mode 100644 index 000000000000..353ff35f84ca --- /dev/null +++ b/colossalai/tensor/padded_tensor/__init__.py @@ -0,0 +1,3 @@ +from .api import init_as_padded_tensor, is_padded_tensor, to_padded_tensor, to_unpadded_tensor + +__all__ = ["is_padded_tensor", "to_padded_tensor", "to_unpadded_tensor", "init_as_padded_tensor"] diff --git a/colossalai/tensor/p_tensor/api.py b/colossalai/tensor/padded_tensor/api.py similarity index 74% rename from colossalai/tensor/p_tensor/api.py rename to colossalai/tensor/padded_tensor/api.py index 7f95b7fe9457..5b66c016b399 100644 --- a/colossalai/tensor/p_tensor/api.py +++ b/colossalai/tensor/padded_tensor/api.py @@ -16,16 +16,16 @@ def _hijack_detach_and_clone(ptensor: torch.Tensor) -> torch.Tensor: def new_detach(self): t_ = self._unpad_detach() - t_.padding_dim = self.padding_dim - t_.origin_length = self.origin_length - t_.current_length = self.current_length + t_._padding_dim = self._padding_dim + t_._origin_length = self._origin_length + t_._current_length = self._current_length return t_ def new_clone(self, *args, **kwargs): t_ = self._unpad_clone(*args, **kwargs) - t_.padding_dim = self.padding_dim - t_.origin_length = self.origin_length - t_.current_length = self.current_length + t_._padding_dim = self._padding_dim + t_._origin_length = self._origin_length + t_._current_length = self._current_length return t_ # bind the new methods to the tensor @@ -63,7 +63,7 @@ def is_padded_tensor(tensor: torch.Tensor) -> bool: Returns: bool: Whether the given tensor is a padding tensor. """ - return hasattr(tensor, "padding_dim") + return hasattr(tensor, "_padding_dim") def to_padded_tensor( @@ -89,9 +89,9 @@ def to_padded_tensor( ) tensor.data = torch.cat((tensor.data, padding_data), dim=padding_dim).contiguous() - setattr(tensor, "padding_dim", padding_dim) - setattr(tensor, "origin_length", origin_length) - setattr(tensor, "current_length", current_length) + tensor._padding_dim = padding_dim + tensor._origin_length = origin_length + tensor._current_length = current_length _hijack_detach_and_clone(tensor) @@ -103,25 +103,25 @@ def to_unpadded_tensor(ptensor: torch.Tensor): return ptensor unpad_slices = [slice(None)] * ptensor.dim() - unpad_slices[ptensor.padding_dim] = slice(None, ptensor.origin_length) + unpad_slices[ptensor._padding_dim] = slice(None, ptensor._origin_length) ptensor.data = ptensor.data[tuple(unpad_slices)] - delattr(ptensor, "padding_dim") - delattr(ptensor, "origin_length") - delattr(ptensor, "current_length") + delattr(ptensor, "_padding_dim") + delattr(ptensor, "_origin_length") + delattr(ptensor, "_current_length") _hijack_back_detach_and_clone(ptensor) return ptensor -def init_as_ptensor(tensor: torch.Tensor, current_length: int, origin_length: int, padding_dim: int): +def init_as_padded_tensor(tensor: torch.Tensor, current_length: int, origin_length: int, padding_dim: int): if is_padded_tensor(tensor): return tensor - setattr(tensor, "padding_dim", padding_dim) - setattr(tensor, "origin_length", origin_length) - setattr(tensor, "current_length", current_length) + tensor._padding_dim = padding_dim + tensor._origin_length = origin_length + tensor._current_length = current_length _hijack_detach_and_clone(tensor) diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 22351d26e9a6..c79422171f1b 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -27,7 +27,12 @@ is_customized_distributed_tensor, is_distributed_tensor, ) -from colossalai.tensor.p_tensor import init_as_ptensor, is_padded_tensor, to_padded_tensor, to_unpadded_tensor +from colossalai.tensor.padded_tensor import ( + init_as_padded_tensor, + is_padded_tensor, + to_padded_tensor, + to_unpadded_tensor, +) from colossalai.tensor.param_op_hook import ColoParamOpHookManager from colossalai.utils import _cast_float, free_storage, is_ddp_ignored @@ -462,8 +467,8 @@ def _get_chunk_to_save_data(self, chunk: Chunk, only_rank_0: bool) -> Dict: ) record_tensor = gather_distributed_param(record_tensor, keep_vars=False).cpu() if is_padded_tensor(tensor): - record_tensor = init_as_ptensor( - record_tensor, tensor.current_length, tensor.origin_length, tensor.padding_dim + record_tensor = init_as_padded_tensor( + record_tensor, tensor._current_length, tensor._origin_length, tensor._padding_dim ) record_tensor = to_unpadded_tensor(record_tensor) @@ -661,7 +666,7 @@ def load( global_shape = get_global_shape(dest_tensor) if is_padded_tensor(dest_tensor): - padding_dim = dest_tensor.padding_dim + padding_dim = dest_tensor._padding_dim input_param = to_padded_tensor(input_param, global_shape[padding_dim], padding_dim) if source_device_mesh is not None and source_sharding_spec is not None: diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index 135927e4f295..ae02fe297d88 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -28,7 +28,12 @@ is_customized_distributed_tensor, is_distributed_tensor, ) -from colossalai.tensor.p_tensor import init_as_ptensor, is_padded_tensor, to_padded_tensor, to_unpadded_tensor +from colossalai.tensor.padded_tensor import ( + init_as_padded_tensor, + is_padded_tensor, + to_padded_tensor, + to_unpadded_tensor, +) from colossalai.utils import disposable, is_ddp_ignored from .chunk import Chunk, ChunkManager @@ -495,8 +500,8 @@ def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict: state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu() state_tensor = state_tensor.reshape(global_shape) if is_padded_tensor(param): - state_tensor = init_as_ptensor( - state_tensor, param.current_length, param.origin_length, param.padding_dim + state_tensor = init_as_padded_tensor( + state_tensor, param._current_length, param._origin_length, param._padding_dim ) state_tensor = to_unpadded_tensor(state_tensor) collected_states[state_name] = state_tensor @@ -555,8 +560,8 @@ def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict: ) state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu() if is_padded_tensor(param): - state_tensor = init_as_ptensor( - state_tensor, param.current_length, param.origin_length, param.padding_dim + state_tensor = init_as_padded_tensor( + state_tensor, param._current_length, param._origin_length, param._padding_dim ) state_tensor = to_unpadded_tensor(state_tensor) @@ -732,7 +737,7 @@ def cast(param, state_range, value, global_shape, origin_shape, key=None): if is_padded_tensor(real_param): value = torch.reshape(value, origin_shape) - padding_dim = real_param.padding_dim + padding_dim = real_param._padding_dim value = to_padded_tensor(value, global_shape[padding_dim], padding_dim) if is_dtensor: diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 0488b8ba43c2..a77ba39a122c 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -21,7 +21,7 @@ from colossalai.shardformer._utils import getattr_ from colossalai.shardformer.policies.auto_policy import Policy from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor -from colossalai.tensor.p_tensor.api import is_padded_tensor, to_unpadded_tensor +from colossalai.tensor.padded_tensor.api import is_padded_tensor, to_unpadded_tensor def build_model( diff --git a/tests/test_tensor/test_padded_tensor/test_padded_tensor.py b/tests/test_tensor/test_padded_tensor/test_padded_tensor.py new file mode 100644 index 000000000000..e34613ad1104 --- /dev/null +++ b/tests/test_tensor/test_padded_tensor/test_padded_tensor.py @@ -0,0 +1,48 @@ +import torch + +from colossalai.device.device_mesh import DeviceMesh +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.tensor.d_tensor import ShardingSpec, distribute_tensor, is_distributed_tensor, to_global +from colossalai.tensor.padded_tensor import is_padded_tensor, to_padded_tensor, to_unpadded_tensor +from colossalai.testing import rerun_if_address_is_in_use, spawn + + +def check_padded_tensor(rank, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + original_tensor = torch.rand(32, 64).to("cuda") + + device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True) + target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0]}) + d_tensor = distribute_tensor(original_tensor, device_mesh, target_sharding_spec) + + padded_tensor = to_padded_tensor(d_tensor, current_length=64, padding_dim=0) + assert padded_tensor.shape == (64, 64) + + tensor_copy = padded_tensor.clone() + assert is_padded_tensor(tensor_copy) + assert is_distributed_tensor(tensor_copy) + + tensor_detached = padded_tensor.detach() + assert is_padded_tensor(tensor_detached) + assert is_distributed_tensor(tensor_detached) + assert tensor_detached.requires_grad == False + assert tensor_detached.grad == None + + unpadded_tensor = to_unpadded_tensor(padded_tensor) + assert unpadded_tensor.shape == d_tensor.shape + assert is_distributed_tensor(unpadded_tensor) + + global_tensor = to_global(unpadded_tensor) + assert global_tensor.shape == original_tensor.shape + + +@rerun_if_address_is_in_use() +def test_padded_tensor(): + world_size = 4 + spawn(check_padded_tensor, world_size) + + +if __name__ == "__main__": + test_padded_tensor()