Skip to content

Commit

Permalink
update split_param loading
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay committed Oct 24, 2024
1 parent 4ab0df1 commit 0d10c4c
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 12 deletions.
27 changes: 15 additions & 12 deletions paddlenlp/trainer/plugins/unified_checkpoint_sharding_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,12 +224,10 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
for shard_file in resolved_archive_file:
if expected_keys.isdisjoint(sharded_metadata["file_map"][os.path.split(shard_file)[-1]]):
continue

if model.config.tensor_parallel_degree > 1:
state_dict = load_state_dict(shard_file, tp_actions, expected_keys, device="expected")
state_dict = load_state_dict(shard_file, tp_actions, expected_keys, device="cpu")
else:
state_dict = load_state_dict(shard_file, None, expected_keys, device="expected")

state_dict = load_state_dict(shard_file, None, expected_keys, device="cpu")
returned_state_dict.update(state_dict)
del state_dict
gc.collect()
Expand All @@ -238,13 +236,6 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected

# get tp params
state_dict_optim = load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected_keys_optim)
if has_master_weights:
state_dict_master_weight = load_resolved_archive_file(
resolved_archive_file_mw,
sharded_metadata_mw,
expected_keys,
is_master_weights=True,
)

# need to split param for different sharding rank, maybe need to deal with oom issue.
for key in list(state_dict_optim.keys()):
Expand All @@ -266,15 +257,24 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
paddle.zeros([padding_end - padding_start], dtype=state_dict_optim[key].dtype),
)
)

if has_master_weights:
key_name = "_".join([static_name, FP32_MASTER, key_name[1]])
else:
key_name = "_".join([static_name, key_name[1]])

state_dict_optim[key] = state_dict_optim[key]._copy_to(paddle.framework._current_expected_place(), False)

returned_optim_state_dict[key_name] = state_dict_optim.pop(key)
returned_optim_state_dict[key_name].name = key_name

if has_master_weights:
state_dict_master_weight = load_resolved_archive_file(
resolved_archive_file_mw,
sharded_metadata_mw,
expected_keys,
is_master_weights=True,
)

for key in list(state_dict_master_weight.keys()):
static_name = struct2static_name_mappings.get(key, None)
if state_dict_master_weight[key].numel().item() > 1:
Expand All @@ -292,6 +292,9 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
paddle.zeros([padding_end - padding_start], dtype=state_dict_master_weight[key].dtype),
)
)
state_dict_master_weight[key] = state_dict_master_weight[key]._copy_to(
paddle.framework._current_expected_place(), False
)
returned_optim_state_dict["master_weights"][static_name] = state_dict_master_weight.pop(key)
returned_optim_state_dict["master_weights"][static_name].name = "_".join([static_name, FP32_MASTER])

Expand Down
5 changes: 5 additions & 0 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1402,6 +1402,11 @@ def is_segment_parallel_supported():
f"but got logging_steps={self.logging_steps}."
)

if "split_param" in sharding_parallel_config:
assert (
self.amp_master_grad
), "If `split_param` in sharding_parallel_config, `amp_master_grad` must be True."

fleet.init(is_collective=True, strategy=strategy)
logger.info(strategy)

Expand Down

0 comments on commit 0d10c4c

Please sign in to comment.