From e4d1ccb8caac26d5fa2ef1be89a6930497bc9def Mon Sep 17 00:00:00 2001 From: ver217 Date: Mon, 18 Nov 2024 19:12:24 +0800 Subject: [PATCH] [checkpointio] fix size compute --- colossalai/zero/low_level/low_level_optim.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index e3c301640867..24ebae1c74d9 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -18,6 +18,7 @@ FP16MixedPrecisionMixin, MixedPrecisionMixin, ) +from colossalai.checkpoint_io.utils import calculate_tensor_size from colossalai.interface import OptimizerWrapper from colossalai.logging import get_dist_logger from colossalai.quantization.fp8 import all_gather_fp8, all_reduce_fp8, reduce_scatter_fp8 @@ -865,19 +866,17 @@ def state_dict_shard( for k, v in states.items(): if isinstance(v, torch.Tensor) and k != "step": - if pinned_state_dicts and k not in pinned_state_dicts[param_idx]: - pinned_state_dicts[param_idx][k] = torch.empty_like( - working_param, pin_memory=True, device="cpu" - ) state_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype) all_gather_into_flat_tensor_nd(state_tensor, v.to(device).flatten(), pg) state_tensor = state_tensor[: working_param.numel()].reshape_as(working_param) + if pinned_state_dicts and k not in pinned_state_dicts[param_idx]: + pinned_state_dicts[param_idx][k] = torch.empty_like(state_tensor, pin_memory=True, device="cpu") if pinned_state_dicts: pinned_state_dicts[param_idx][k].copy_(state_tensor) current_block[k] = pinned_state_dicts[param_idx][k] else: current_block[k] = state_tensor.cpu() - current_block_size += state_tensor.numel() + current_block_size += calculate_tensor_size(state_tensor) if ret_block_size + current_block_size > max_shard_size and len(ret_block) > 0: yield ret_block, ret_block_size