Skip to content

Commit

Permalink
[checkpointio] fix zero optimizer async save memory (#6151)
Browse files Browse the repository at this point in the history
* [checkpointio] fix zero optimizer async save memory

* [checkpointio] fit new tensornvme api

* [checkpointio] fit new tensornvme api
  • Loading branch information
ver217 authored Nov 25, 2024
1 parent 8ecff0c commit ab856fd
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 42 deletions.
16 changes: 8 additions & 8 deletions colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,22 +128,20 @@ def save_unsharded_optimizer(
# the `state_dict` in LowLevelZeroOptimizer has communication
# if only the master rank collect state_dict and save,
# the communication on each rank would not match
if use_async:
if use_async and self.coordinator.is_master():
if id(optimizer) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(optimizer)] = {}
pinned_state_dicts = self.pinned_state_dicts[id(optimizer)]
else:
pinned_state_dicts = None
state_dict = optimizer.state_dict(pinned_state_dicts)
state_dict = optimizer.state_dict(pinned_state_dicts, only_on_master=True)
if self.coordinator.is_master():
if use_async:
from tensornvme.async_file_io import AsyncFileWriter

from colossalai.utils.safetensors import save_nested

f_writer = AsyncFileWriter(
fp=open(checkpoint, "wb", buffering=0), n_entries=self.N_WRITE_ENTRIES, backend="pthread"
)
f_writer = AsyncFileWriter(checkpoint, n_entries=self.N_WRITE_ENTRIES, backend="pthread")
save_nested(f_writer, state_dict)
self.async_writers.append(f_writer)
else:
Expand Down Expand Up @@ -192,13 +190,15 @@ def save_sharded_optimizer(
# state_dict only provide only 'param_groups'
state_dict = optimizer.optim.state_dict()
# state shard would be handled by the low-level zero optimizer
if use_async:
if use_async and self.coordinator.is_master():
if id(optimizer) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(optimizer)] = {}
pinned_state_dicts = self.pinned_state_dicts[id(optimizer)]
else:
pinned_state_dicts = None
sharded_state = optimizer.state_dict_shard(max_shard_size=size_per_shard, pinned_state_dicts=pinned_state_dicts)
sharded_state = optimizer.state_dict_shard(
max_shard_size=size_per_shard, pinned_state_dicts=pinned_state_dicts, only_on_master=True
)

# Preparing file paths and index file.
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix, use_safetensors=use_async)
Expand Down Expand Up @@ -227,7 +227,7 @@ def save_sharded_optimizer(
from colossalai.utils.safetensors import save_nested

f_writer = AsyncFileWriter(
fp=open(checkpoint_file_path, "wb", buffering=0),
checkpoint_file_path,
n_entries=self.N_WRITE_ENTRIES,
backend="pthread",
)
Expand Down
1 change: 0 additions & 1 deletion colossalai/checkpoint_io/checkpoint_io_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ def __init__(self):
def _sync_io(self):
for writer in self.async_writers:
writer.synchronize()
writer.fp.close()
self.async_writers.clear()

def _sync_d2h(self):
Expand Down
2 changes: 1 addition & 1 deletion colossalai/checkpoint_io/general_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def save_unsharded_model(
if use_async:
from tensornvme.async_file_io import AsyncFileWriter

writer = AsyncFileWriter(open(checkpoint, "wb", buffering=0), self.N_WRITE_ENTRIES, backend="pthread")
writer = AsyncFileWriter(checkpoint, self.N_WRITE_ENTRIES, backend="pthread")
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
self.async_writers.append(writer)
Expand Down
4 changes: 1 addition & 3 deletions colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,9 +690,7 @@ def save_unsharded_model(

from colossalai.utils.safetensors import move_and_save

writer = AsyncFileWriter(
open(checkpoint, "wb", buffering=0), self.N_WRITE_ENTRIES, backend="pthread"
)
writer = AsyncFileWriter(checkpoint, self.N_WRITE_ENTRIES, backend="pthread")
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
self.async_writers.append(writer)
Expand Down
2 changes: 1 addition & 1 deletion colossalai/checkpoint_io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def async_save_state_dict_shards(
index_file.append_weight_map(key, shard_file)
checkpoint_file_path = os.path.join(checkpoint, shard_file)

writer = AsyncFileWriter(open(checkpoint_file_path, "wb", buffering=0), n_write_entries, backend="pthread")
writer = AsyncFileWriter(checkpoint_file_path, n_write_entries, backend="pthread")
writers.append(writer)

if pinned_state_dict is not None:
Expand Down
57 changes: 37 additions & 20 deletions colossalai/zero/low_level/low_level_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,7 +776,9 @@ def pack_group(group):

return {"state": packed_state, "param_groups": param_groups}

def state_dict(self, pinned_state_dicts: Optional[Dict[str, Dict[str, torch.Tensor]]] = None) -> Dict:
def state_dict(
self, pinned_state_dicts: Optional[Dict[str, Dict[str, torch.Tensor]]] = None, only_on_master: bool = False
) -> Dict:
"""Return a state_dict same with DDP
Returns:
Expand All @@ -785,23 +787,29 @@ def state_dict(self, pinned_state_dicts: Optional[Dict[str, Dict[str, torch.Tens
zero_state = dict()
device = get_accelerator().get_current_device()
for param, state in self.optim.state.items():
working_param = self.master_to_working_param[id(param)]
pg = self.param_to_pg[working_param]
if not only_on_master or get_nd_rank(pg) == 0:
zero_state[param] = copy.deepcopy(state)
else:
zero_state[param] = {}

if pinned_state_dicts is not None and param not in pinned_state_dicts:
pinned_state_dicts[param] = {}
zero_state[param] = copy.deepcopy(state)

for k, v in state.items():
if isinstance(v, torch.Tensor) and k != "step":
working_param = self.master_to_working_param[id(param)]
pg = self.param_to_pg[working_param]
gathered_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype)
all_gather_into_flat_tensor_nd(gathered_tensor, v.to(device).flatten(), pg)
param_state = gathered_tensor[: working_param.numel()].reshape_as(working_param)
if pinned_state_dicts is not None and k not in pinned_state_dicts[param]:
pinned_state_dicts[param][k] = torch.empty_like(param_state, pin_memory=True, device="cpu")
if pinned_state_dicts is not None:
pinned_state_dicts[param][k].copy_(param_state)
zero_state[param][k] = pinned_state_dicts[param][k]
else:
zero_state[param][k] = param_state.cpu()
if not only_on_master or get_nd_rank(pg) == 0:
if pinned_state_dicts is not None and k not in pinned_state_dicts[param]:
pinned_state_dicts[param][k] = torch.empty_like(param_state, pin_memory=True, device="cpu")
if pinned_state_dicts is not None:
pinned_state_dicts[param][k].copy_(param_state)
zero_state[param][k] = pinned_state_dicts[param][k]
else:
zero_state[param][k] = param_state.cpu()

states_dict = self._pack_state(zero_state)

Expand Down Expand Up @@ -837,7 +845,10 @@ def load_state_dict(self, state_dict: Dict):
self.optim.load_state_dict(zero_state_dict)

def state_dict_shard(
self, max_shard_size: int = 1024, pinned_state_dicts: Optional[Dict[str, Dict[str, torch.Tensor]]] = None
self,
max_shard_size: int = 1024,
pinned_state_dicts: Optional[Dict[str, Dict[str, torch.Tensor]]] = None,
only_on_master: bool = False,
) -> Iterator[Tuple[Dict, int]]:
"""Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``.
Only include the 'state' in state_dict.
Expand All @@ -862,25 +873,31 @@ def state_dict_shard(
cnt += 1
for param_idx, states in local_states.items():
current_block_size = 0
current_block = copy.deepcopy(states)
if pinned_state_dicts is not None and param_idx not in pinned_state_dicts:
pinned_state_dicts[param_idx] = {}
master_param = idx2master[param_idx]
working_param = self.master_to_working_param[id(master_param)]
pg = self.param_to_pg[working_param]
if not only_on_master or get_nd_rank(pg) == 0:
current_block = copy.deepcopy(states)
else:
current_block = {}

for k, v in states.items():
if isinstance(v, torch.Tensor) and k != "step":
state_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype)
all_gather_into_flat_tensor_nd(state_tensor, v.to(device).flatten(), pg)
state_tensor = state_tensor[: working_param.numel()].reshape_as(working_param)
if pinned_state_dicts is not None and k not in pinned_state_dicts[param_idx]:
pinned_state_dicts[param_idx][k] = torch.empty_like(state_tensor, pin_memory=True, device="cpu")
if pinned_state_dicts is not None:
pinned_state_dicts[param_idx][k].copy_(state_tensor)
current_block[k] = pinned_state_dicts[param_idx][k]
else:
current_block[k] = state_tensor.cpu()
if not only_on_master or get_nd_rank(pg) == 0:
if pinned_state_dicts is not None and k not in pinned_state_dicts[param_idx]:
pinned_state_dicts[param_idx][k] = torch.empty_like(
state_tensor, pin_memory=True, device="cpu"
)
if pinned_state_dicts is not None:
pinned_state_dicts[param_idx][k].copy_(state_tensor)
current_block[k] = pinned_state_dicts[param_idx][k]
else:
current_block[k] = state_tensor.cpu()
current_block_size += calculate_tensor_size(state_tensor)

if ret_block_size + current_block_size > max_shard_size and len(ret_block) > 0:
Expand Down
17 changes: 9 additions & 8 deletions tests/test_checkpoint_io/test_safetensors_async_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
except ModuleNotFoundError:
raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer")


from colossalai.testing import check_state_dict_equal
from colossalai.utils import get_current_device

Expand Down Expand Up @@ -110,20 +111,20 @@ def test_save_load():
}

optimizer_saved_path = f"{tempdir}/save_optimizer.safetensors"
f_writer = AsyncFileWriter(fp=open(optimizer_saved_path, "wb"), n_entries=191, backend="pthread")
f_writer = AsyncFileWriter(optimizer_saved_path, n_entries=191, backend="pthread")
save_nested(f_writer, optimizer_state_dict)
f_writer.sync_before_step()
f_writer.synchronize()
f_writer.fp.close()
del f_writer
load_state_dict = load_flat(optimizer_saved_path)
check_state_dict_equal(load_state_dict, optimizer_state_dict)

optimizer_shard_saved_path = f"{tempdir}/save_optimizer_shard.safetensors"
f_writer = AsyncFileWriter(fp=open(optimizer_shard_saved_path, "wb"), n_entries=191, backend="pthread")
f_writer = AsyncFileWriter(optimizer_shard_saved_path, n_entries=191, backend="pthread")
save_nested(f_writer, optimizer_state_dict["state"])
f_writer.sync_before_step()
f_writer.synchronize()
f_writer.fp.close()
del f_writer
load_state_dict_shard = load_flat(optimizer_shard_saved_path)
check_state_dict_equal(load_state_dict_shard, optimizer_state_dict["state"])

Expand All @@ -133,21 +134,21 @@ def test_save_load():
"module.weight2": torch.rand((1024, 1024)),
}
model_saved_path = f"{tempdir}/save_model.safetensors"
f_writer = AsyncFileWriter(fp=open(model_saved_path, "wb"), n_entries=191, backend="pthread")
f_writer = AsyncFileWriter(model_saved_path, n_entries=191, backend="pthread")
save(f_writer, model_state_dict)
f_writer.sync_before_step()
f_writer.synchronize()
f_writer.fp.close()
del f_writer
load_state_dict = load_file(model_saved_path)
check_state_dict_equal(model_state_dict, load_state_dict)

model_state_dict_cuda = {k: v.to(get_current_device()) for k, v in model_state_dict.items()}
model_state_pinned = {k: v.pin_memory() for k, v in model_state_dict.items()}
model_saved_path = f"{tempdir}/save_model_cuda.safetensors"
f_writer = AsyncFileWriter(fp=open(model_saved_path, "wb"), n_entries=191, backend="pthread")
f_writer = AsyncFileWriter(model_saved_path, n_entries=191, backend="pthread")
move_and_save(f_writer, model_state_dict_cuda, model_state_pinned)
f_writer.sync_before_step()
f_writer.synchronize()
f_writer.fp.close()
del f_writer
load_state_dict = load_file(model_saved_path)
check_state_dict_equal(model_state_dict, load_state_dict)

0 comments on commit ab856fd

Please sign in to comment.