diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 146e5250a676..196dd2e77251 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -46,11 +46,13 @@ def get_param_info(model: nn.Module, optim: Optimizer): # 1. A mapping from integer param_id to param32 shape. param_info = {"id2shape": {}, "name2shape": {}} - for m_name, m_var in model.named_modules(): - for p_name, p_var in m_var.named_parameters(recurse=False): - param_name = m_name + "." + p_name if m_name else p_name - original_shape = p_var.shape if isinstance(p_var, torch.Tensor) else None - param_info["name2shape"][param_name] = original_shape + for p_name, param in model.named_parameters(remove_duplicate=False): + param_info["name2shape"][p_name] = param.shape + # for m_name, m_var in model.named_modules(): + # for p_name, p_var in m_var.named_parameters(recurse=False): + # param_name = m_name + "." + p_name if m_name else p_name + # original_shape = p_var.shape if isinstance(p_var, torch.Tensor) else None + # param_info["name2shape"][param_name] = original_shape if optim is None: return param_info diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 1e59ce8620b2..d0d809c02c5c 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -32,6 +32,7 @@ save_param_groups, save_state_dict, save_state_dict_shards, + search_padding_dim, search_tp_partition_dim, sharded_optimizer_loading_epilogue, ) @@ -937,14 +938,30 @@ def shard_from_complete_optimizer_state( if isinstance(v, torch.Tensor) and k != "step": # Shard state along tensor parallel group. partition_dim = search_tp_partition_dim(current_shape, original_shape, self.tp_size) + global_shape = current_shape if partition_dim is not None: - slice_size = current_shape[partition_dim] # pad embedding params - if partition_dim == 0: - padding_size = current_shape[0] * self.tp_size - original_shape[0] - if padding_size > 0: - padding_data = torch.zeros_like(v[:padding_size, ...]) - v = torch.cat((v, padding_data), dim=0).contiguous() + global_shape = ( + *current_shape[:partition_dim], + current_shape[partition_dim] * self.tp_size, + *current_shape[partition_dim + 1 :], + ) + + padding_dim = search_padding_dim(global_shape, original_shape) + if padding_dim is not None: + padding_size = global_shape[padding_dim] - original_shape[padding_dim] + if padding_size > 0: + padding_data = torch.zeros( + *v.shape[:padding_dim], + padding_size, + *v.shape[padding_dim + 1 :], + device=v.device, + dtype=v.dtype, + ) + v = torch.cat((v, padding_data), dim=padding_dim).contiguous() + + if partition_dim is not None: + slice_size = current_shape[partition_dim] v = v.split(slice_size, dim=partition_dim)[self.tp_rank] # Shard state along data parallel group when using Zero. diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 2a1d4de9b036..6197be9d1c8d 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -120,6 +120,15 @@ def search_tp_partition_dim(current_shape: torch.Size, original_shape: torch.Siz return partition_dim +def search_padding_dim(global_shape: torch.Size, original_shape: torch.Size) -> Optional[int]: + padding_dim = None + for dim, length in enumerate(global_shape): + if length > original_shape[dim]: + padding_dim = dim + break + return padding_dim + + # ====================================== # Helper classes and functions for saving shard file # ====================================== diff --git a/colossalai/shardformer/layer/parallel_module.py b/colossalai/shardformer/layer/parallel_module.py index e535416150b5..eae31215c58d 100644 --- a/colossalai/shardformer/layer/parallel_module.py +++ b/colossalai/shardformer/layer/parallel_module.py @@ -298,7 +298,9 @@ def _load_from_state_dict( if self.new_num_embeddings > self.old_num_embeddings: num_padding_tokens = self.new_num_embeddings - self.old_num_embeddings - padding_embeddings = torch.zeros_like(input_param[:num_padding_tokens, ...]) + padding_embeddings = torch.zeros( + num_padding_tokens, *input_param.shape[1:], device=input_param.device, dtype=input_param.dtype + ) input_param.data = torch.cat((input_param.data, padding_embeddings), dim=0).contiguous() if is_distributed_tensor(param): @@ -359,7 +361,9 @@ def _load_from_state_dict( def resize_embedding_weight(self): num_padding_tokens = self.new_num_embeddings - self.old_num_embeddings valid_weight = self.weight.data - padding_weight = torch.zeros_like(self.weight[:num_padding_tokens, ...]) + padding_weight = torch.zeros( + num_padding_tokens, *self.weight.shape[1:], device=self.weight.device, dtype=self.weight.dtype + ) # padding to embedding self.weight.data = torch.cat((valid_weight, padding_weight), dim=0).contiguous() diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index bc26bbe1a66e..5008986852f1 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -11,7 +11,7 @@ from torch.distributed.distributed_c10d import _get_default_group from colossalai.accelerator import get_accelerator -from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param +from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param, search_padding_dim from colossalai.interface import ModelWrapper from colossalai.lazy import LazyTensor from colossalai.logging import get_dist_logger @@ -524,7 +524,13 @@ def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True): else: if self.params_info is not None: origin_shape = self.params_info["name2shape"][name] - destination[prefix + name] = p_mapping[param][: origin_shape[0], ...] + padding_dim = search_padding_dim(p_mapping[param].shape, origin_shape) + if padding_dim is not None: + unpadding_slices = [slice(None)] * p_mapping[param].dim() + unpadding_slices[padding_dim] = slice(None, origin_shape[0]) + destination[prefix + name] = p_mapping[param][tuple(unpadding_slices)] + else: + destination[prefix + name] = p_mapping[param] else: destination[prefix + name] = p_mapping[param] del p_mapping @@ -653,12 +659,24 @@ def load( if state_key in state_dict: input_param = state_dict[state_key] + global_shape = dest_tensor.shape if source_device_mesh is not None and source_sharding_spec is not None: global_shape = get_global_shape(dest_tensor) - padding_num = global_shape[0] - input_param.shape[0] + + padding_dim = search_padding_dim(global_shape, input_param.shape) + if padding_dim is not None: + padding_num = global_shape[padding_dim] - input_param.shape[padding_dim] if padding_num > 0: - padding_data = torch.zeros_like(input_param[:padding_num, ...]) - input_param = torch.cat((input_param, padding_data), dim=0) + padding_data = torch.zeros( + *input_param.shape[:padding_dim], + padding_num, + *input_param.shape[padding_dim + 1 :], + device=input_param.device, + dtype=input_param.dtype, + ) + input_param = torch.cat((input_param, padding_data), dim=padding_dim) + + if source_device_mesh is not None and source_sharding_spec is not None: input_param = distribute_tensor(input_param, source_device_mesh, source_sharding_spec) elif shard_fn is not None and gather_fn is not None: input_param = distribute_tensor_with_customization( @@ -896,7 +914,11 @@ def state_dict_shard( if self.params_info is not None: origin_shape = self.params_info["name2shape"][name] - gathered_param = gathered_param[: origin_shape[0], ...] + padding_dim = search_padding_dim(gathered_param.shape, origin_shape) + if padding_dim is not None: + unpadding_slices = [slice(None)] * gathered_param.dim() + unpadding_slices[padding_dim] = slice(None, origin_shape[0]) + gathered_param = gathered_param[tuple(unpadding_slices)] block, block_size = sharder.append_param(prefix + name, gathered_param) if block is not None: diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index e670f8ccedba..98f1984d1061 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -13,7 +13,7 @@ from colossalai.accelerator import get_accelerator from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin -from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param +from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param, search_padding_dim from colossalai.interface import OptimizerWrapper from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import CPUAdam, FusedAdam, HybridAdam @@ -705,7 +705,7 @@ def load_single_param_states(self, param_id: int, saved_states: dict): Load saved optimizer states into parameter with given id. """ - def cast(param, state_range, value, global_shape, key=None): + def cast(param, state_range, value, global_shape, origin_shape, key=None): """ Make a copy of the needed segment of value and cast it to device of param. """ @@ -722,11 +722,22 @@ def cast(param, state_range, value, global_shape, key=None): if is_dtensor: global_shape = get_global_shape(real_param) - padding_num = global_shape[0] - origin_shape[0] + + padding_dim = search_padding_dim(global_shape, origin_shape) + if padding_dim is not None: + padding_num = global_shape[padding_dim] - origin_shape[padding_dim] value = torch.reshape(value, origin_shape) if padding_num > 0: - padding_data = torch.zeros_like(value[:padding_num, ...]) - value = torch.cat((value, padding_data), dim=0).contiguous() + padding_data = torch.zeros( + *value.shape[:padding_dim], + padding_num, + *value.shape[padding_dim + 1 :], + device=value.device, + dtype=value.dtype, + ) + value = torch.cat((value, padding_data), dim=padding_dim).contiguous() + + if is_dtensor: value = distribute_tensor(value, sharding_spec=shard_spec, device_mesh=device_mesh) elif is_customized_distributed: value = torch.reshape(value, global_shape) @@ -753,7 +764,7 @@ def cast(param, state_range, value, global_shape, key=None): origin_shape = global_shape for k, v in saved_states.items(): - updated_states[k] = cast(fake_param, state_range, v, global_shape, k) + updated_states[k] = cast(fake_param, state_range, v, global_shape, origin_shape, k) del v # clean loaded states self.optim.state[fake_param].update(updated_states) diff --git a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py index 89c44ec92c36..ac6f8caef816 100644 --- a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py @@ -120,7 +120,6 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha for group in optimizer.param_groups: group["lr"] = 0.1 - optimizer.zero_grad() with shared_tempdir() as tempdir: model_ckpt_path = f"{tempdir}/model" optimizer_ckpt_path = f"{tempdir}/optimizer"