Skip to content

Commit

Permalink
resolve comments
Browse files Browse the repository at this point in the history
resolve comments

resolve comments

resolve comments

resolve comments
  • Loading branch information
flybird11111 committed Apr 12, 2024
1 parent f08e084 commit 14a4342
Show file tree
Hide file tree
Showing 7 changed files with 84 additions and 28 deletions.
7 changes: 2 additions & 5 deletions colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,8 @@ def get_param_info(model: nn.Module, optim: Optimizer):
# 1. A mapping from integer param_id to param32 shape.

param_info = {"id2shape": {}, "name2shape": {}}
for m_name, m_var in model.named_modules():
for p_name, p_var in m_var.named_parameters(recurse=False):
param_name = m_name + "." + p_name if m_name else p_name
original_shape = p_var.shape if isinstance(p_var, torch.Tensor) else None
param_info["name2shape"][param_name] = original_shape
for p_name, param in model.named_parameters(remove_duplicate=False):
param_info["name2shape"][p_name] = param.shape

if optim is None:
return param_info
Expand Down
28 changes: 22 additions & 6 deletions colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
save_param_groups,
save_state_dict,
save_state_dict_shards,
search_padding_dim,
search_tp_partition_dim,
sharded_optimizer_loading_epilogue,
)
Expand Down Expand Up @@ -937,14 +938,29 @@ def shard_from_complete_optimizer_state(
if isinstance(v, torch.Tensor) and k != "step":
# Shard state along tensor parallel group.
partition_dim = search_tp_partition_dim(current_shape, original_shape, self.tp_size)
global_shape = current_shape
if partition_dim is not None:
slice_size = current_shape[partition_dim]
# pad embedding params
if partition_dim == 0:
padding_size = current_shape[0] * self.tp_size - original_shape[0]
if padding_size > 0:
padding_data = torch.zeros_like(v[:padding_size, ...])
v = torch.cat((v, padding_data), dim=0).contiguous()
global_shape = (
*current_shape[:partition_dim],
current_shape[partition_dim] * self.tp_size,
*current_shape[partition_dim + 1 :],
)

padding_dim = search_padding_dim(global_shape, original_shape)
if padding_dim is not None:
padding_size = global_shape[padding_dim] - original_shape[padding_dim]
padding_data = torch.zeros(
*v.shape[:padding_dim],
padding_size,
*v.shape[padding_dim + 1 :],
device=v.device,
dtype=v.dtype,
)
v = torch.cat((v, padding_data), dim=padding_dim).contiguous()

if partition_dim is not None:
slice_size = current_shape[partition_dim]
v = v.split(slice_size, dim=partition_dim)[self.tp_rank]

# Shard state along data parallel group when using Zero.
Expand Down
9 changes: 9 additions & 0 deletions colossalai/checkpoint_io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,15 @@ def search_tp_partition_dim(current_shape: torch.Size, original_shape: torch.Siz
return partition_dim


def search_padding_dim(global_shape: torch.Size, original_shape: torch.Size) -> Optional[int]:
padding_dim = None
for dim, length in enumerate(global_shape):
if length > original_shape[dim]:
padding_dim = dim
break
return padding_dim


# ======================================
# Helper classes and functions for saving shard file
# ======================================
Expand Down
8 changes: 6 additions & 2 deletions colossalai/shardformer/layer/parallel_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,9 @@ def _load_from_state_dict(

if self.new_num_embeddings > self.old_num_embeddings:
num_padding_tokens = self.new_num_embeddings - self.old_num_embeddings
padding_embeddings = torch.zeros_like(input_param[:num_padding_tokens, ...])
padding_embeddings = torch.zeros(
num_padding_tokens, *input_param.shape[1:], device=input_param.device, dtype=input_param.dtype
)
input_param.data = torch.cat((input_param.data, padding_embeddings), dim=0).contiguous()

if is_distributed_tensor(param):
Expand Down Expand Up @@ -359,7 +361,9 @@ def _load_from_state_dict(
def resize_embedding_weight(self):
num_padding_tokens = self.new_num_embeddings - self.old_num_embeddings
valid_weight = self.weight.data
padding_weight = torch.zeros_like(self.weight[:num_padding_tokens, ...])
padding_weight = torch.zeros(
num_padding_tokens, *self.weight.shape[1:], device=self.weight.device, dtype=self.weight.dtype
)
# padding to embedding
self.weight.data = torch.cat((valid_weight, padding_weight), dim=0).contiguous()

Expand Down
35 changes: 28 additions & 7 deletions colossalai/zero/gemini/gemini_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torch.distributed.distributed_c10d import _get_default_group

from colossalai.accelerator import get_accelerator
from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param
from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param, search_padding_dim
from colossalai.interface import ModelWrapper
from colossalai.lazy import LazyTensor
from colossalai.logging import get_dist_logger
Expand Down Expand Up @@ -524,7 +524,13 @@ def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True):
else:
if self.params_info is not None:
origin_shape = self.params_info["name2shape"][name]
destination[prefix + name] = p_mapping[param][: origin_shape[0], ...]
padding_dim = search_padding_dim(p_mapping[param].shape, origin_shape)
if padding_dim is not None:
unpadding_slices = [slice(None)] * p_mapping[param].dim()
unpadding_slices[padding_dim] = slice(None, origin_shape[0])
destination[prefix + name] = p_mapping[param][tuple(unpadding_slices)]
else:
destination[prefix + name] = p_mapping[param]
else:
destination[prefix + name] = p_mapping[param]
del p_mapping
Expand Down Expand Up @@ -653,12 +659,23 @@ def load(
if state_key in state_dict:
input_param = state_dict[state_key]

global_shape = dest_tensor.shape
if source_device_mesh is not None and source_sharding_spec is not None:
global_shape = get_global_shape(dest_tensor)
padding_num = global_shape[0] - input_param.shape[0]
if padding_num > 0:
padding_data = torch.zeros_like(input_param[:padding_num, ...])
input_param = torch.cat((input_param, padding_data), dim=0)

padding_dim = search_padding_dim(global_shape, input_param.shape)
if padding_dim is not None:
padding_num = global_shape[padding_dim] - input_param.shape[padding_dim]
padding_data = torch.zeros(
*input_param.shape[:padding_dim],
padding_num,
*input_param.shape[padding_dim + 1 :],
device=input_param.device,
dtype=input_param.dtype,
)
input_param = torch.cat((input_param, padding_data), dim=padding_dim)

if source_device_mesh is not None and source_sharding_spec is not None:
input_param = distribute_tensor(input_param, source_device_mesh, source_sharding_spec)
elif shard_fn is not None and gather_fn is not None:
input_param = distribute_tensor_with_customization(
Expand Down Expand Up @@ -896,7 +913,11 @@ def state_dict_shard(

if self.params_info is not None:
origin_shape = self.params_info["name2shape"][name]
gathered_param = gathered_param[: origin_shape[0], ...]
padding_dim = search_padding_dim(gathered_param.shape, origin_shape)
if padding_dim is not None:
unpadding_slices = [slice(None)] * gathered_param.dim()
unpadding_slices[padding_dim] = slice(None, origin_shape[0])
gathered_param = gathered_param[tuple(unpadding_slices)]

block, block_size = sharder.append_param(prefix + name, gathered_param)
if block is not None:
Expand Down
24 changes: 17 additions & 7 deletions colossalai/zero/gemini/gemini_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from colossalai.accelerator import get_accelerator
from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin
from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param
from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param, search_padding_dim
from colossalai.interface import OptimizerWrapper
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import CPUAdam, FusedAdam, HybridAdam
Expand Down Expand Up @@ -705,7 +705,7 @@ def load_single_param_states(self, param_id: int, saved_states: dict):
Load saved optimizer states into parameter with given id.
"""

def cast(param, state_range, value, global_shape, key=None):
def cast(param, state_range, value, global_shape, origin_shape, key=None):
"""
Make a copy of the needed segment of value and cast it to device of param.
"""
Expand All @@ -722,11 +722,21 @@ def cast(param, state_range, value, global_shape, key=None):

if is_dtensor:
global_shape = get_global_shape(real_param)
padding_num = global_shape[0] - origin_shape[0]

padding_dim = search_padding_dim(global_shape, origin_shape)
if padding_dim is not None:
padding_num = global_shape[padding_dim] - origin_shape[padding_dim]
value = torch.reshape(value, origin_shape)
if padding_num > 0:
padding_data = torch.zeros_like(value[:padding_num, ...])
value = torch.cat((value, padding_data), dim=0).contiguous()
padding_data = torch.zeros(
*value.shape[:padding_dim],
padding_num,
*value.shape[padding_dim + 1 :],
device=value.device,
dtype=value.dtype,
)
value = torch.cat((value, padding_data), dim=padding_dim).contiguous()

if is_dtensor:
value = distribute_tensor(value, sharding_spec=shard_spec, device_mesh=device_mesh)
elif is_customized_distributed:
value = torch.reshape(value, global_shape)
Expand All @@ -753,7 +763,7 @@ def cast(param, state_range, value, global_shape, key=None):
origin_shape = global_shape

for k, v in saved_states.items():
updated_states[k] = cast(fake_param, state_range, v, global_shape, k)
updated_states[k] = cast(fake_param, state_range, v, global_shape, origin_shape, k)
del v # clean loaded states
self.optim.state[fake_param].update(updated_states)

Expand Down
1 change: 0 additions & 1 deletion tests/test_checkpoint_io/test_gemini_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha
for group in optimizer.param_groups:
group["lr"] = 0.1

optimizer.zero_grad()
with shared_tempdir() as tempdir:
model_ckpt_path = f"{tempdir}/model"
optimizer_ckpt_path = f"{tempdir}/optimizer"
Expand Down

0 comments on commit 14a4342

Please sign in to comment.