Skip to content

Commit

Permalink
Merge branch 'release/3.0-beta2' of https://github.com/PaddlePaddle/P…
Browse files Browse the repository at this point in the history
…addleNLP into release/3.0-beta2
  • Loading branch information
DesmonDay committed Dec 12, 2024
2 parents a79783d + e473a81 commit 8db699a
Show file tree
Hide file tree
Showing 10 changed files with 112 additions and 41 deletions.
30 changes: 21 additions & 9 deletions paddlenlp/quantization/unified_checkpoint_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import paddle
from paddle.distributed import fleet

Expand All @@ -33,7 +34,7 @@
from paddlenlp.utils.log import logger


def dequant_unified_optimizer(state_dict, ckpt_quant_stage, scale_dict):
def dequant_unified_optimizer(state_dict, ckpt_quant_stage, scale_dict, use_pd=False):
"""
dequantize unified optimizer state dict.
Args:
Expand All @@ -44,6 +45,7 @@ def dequant_unified_optimizer(state_dict, ckpt_quant_stage, scale_dict):
scale_dict (`int`):
compression checkpoint scale dict.
"""
logger.info(f"Start unified checkpoint dequantization, stage {ckpt_quant_stage}.")
tp_rank, tp_degree = -1, 1
if paddle.distributed.get_world_size() > 1:
hcg = fleet.get_hybrid_communicate_group()
Expand All @@ -68,7 +70,7 @@ def dequant_unified_optimizer(state_dict, ckpt_quant_stage, scale_dict):
dequant=True,
tp_rank=tp_rank,
tp_degree=tp_degree,
use_pd=True,
use_pd=use_pd,
)
state_dict[quant_key] = weight
elif is_moment2:
Expand All @@ -85,10 +87,13 @@ def dequant_unified_optimizer(state_dict, ckpt_quant_stage, scale_dict):
dequant=True,
tp_rank=tp_rank,
tp_degree=tp_degree,
use_pd=True,
use_pd=use_pd,
)
# cal m2
weight = paddle.square(1.0 / weight - eps)
if use_pd:
weight = paddle.square(1.0 / weight - eps)
else:
weight = np.square(1.0 / weight - eps)
state_dict[quant_key] = weight
elif ckpt_quant_stage == "O2":
# set eps
Expand Down Expand Up @@ -117,7 +122,7 @@ def dequant_unified_optimizer(state_dict, ckpt_quant_stage, scale_dict):
quant=False,
tp_rank=tp_rank,
tp_degree=tp_degree,
use_pd=True,
use_pd=use_pd,
symmetry=True,
)
ratio_weight = group_wise_quant_dequant(
Expand All @@ -128,14 +133,19 @@ def dequant_unified_optimizer(state_dict, ckpt_quant_stage, scale_dict):
quant=False,
tp_rank=tp_rank,
tp_degree=tp_degree,
use_pd=True,
use_pd=use_pd,
)

ratio_weight = paddle.square(1.0 / ratio_weight - eps)
if use_pd:
ratio_weight = paddle.square(1.0 / ratio_weight - eps)
else:
ratio_weight = np.square(1.0 / ratio_weight - eps)
state_dict[quant_key] = ratio_weight
m1_state_dict[quant_key[: -len(MOMENT2_KEYNAME)] + MOMENT1_KEYNAME] = m1_weight
state_dict.update(m1_state_dict)

logger.info(f"Unified checkpoint dequantization done, stage {ckpt_quant_stage}.")

return state_dict


Expand All @@ -152,14 +162,15 @@ def quant_unified_optimizer(state_dict, state_dict_type, ckpt_quant_stage, async
async_save (`bool`):
whether use async_save.
"""
logger.info(f"Start unified checkpoint quantization, stage {ckpt_quant_stage}.")

quant = False
if ckpt_quant_stage != "O0":
quant = True
del_key = []
if quant and state_dict_type == "optimizer_weight":
scales_dict = {}
opt_keys = state_dict.keys()
for k in opt_keys:
for k in state_dict.keys():
momentum1 = k.endswith(MOMENT1_KEYNAME)
momentum2 = k.endswith(MOMENT2_KEYNAME)

Expand Down Expand Up @@ -205,5 +216,6 @@ def quant_unified_optimizer(state_dict, state_dict_type, ckpt_quant_stage, async
state_dict.pop(k, None)

state_dict.update(scales_dict)
logger.info(f"Unified checkpoint quantization done, stage {ckpt_quant_stage}.")

return state_dict
8 changes: 6 additions & 2 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1592,7 +1592,9 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa

additional_configs = {}
if is_iterable_dataset:
if self.args.dataset_world_size > 1 and eval_dataset is not None:
if (
self.args.dataset_world_size > 1 or self.args.pipeline_parallel_degree > 1
) and eval_dataset is not None:
eval_dataset = IterableDatasetShard(
eval_dataset,
batch_size=self.args.per_device_eval_batch_size,
Expand Down Expand Up @@ -3099,7 +3101,9 @@ def evaluation_loop(

# Metrics!
if self.compute_metrics is not None and all_preds is not None and all_labels is not None:
metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))
# all_labels maybe is a tuple when prediction_steps output label_mask
batch_labels = all_labels[0] if isinstance(all_labels, (list, tuple)) else all_labels
metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=batch_labels))
else:
metrics = {}

Expand Down
10 changes: 6 additions & 4 deletions paddlenlp/trainer/unified_checkpoint/async_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ def _file_save_async_or_sync(
if isinstance(state_dict[k], paddle.Tensor):
state_dict[k] = state_dict.pop(k).cpu().numpy()

state_dict = quant_unified_optimizer(state_dict, state_dict_type, ckpt_quant_stage)
if state_dict_type == "optimizer_weight" and ckpt_quant_stage != "O0":
state_dict = quant_unified_optimizer(state_dict, state_dict_type, ckpt_quant_stage)
safe_save_file(state_dict, path, metadata={"format": "np"})
else:
if len(state_dict.keys()) == 0:
Expand Down Expand Up @@ -206,9 +207,10 @@ def _save_file_async_in_process(
signal_path = shared_save_signal_path[:].decode("utf-8").rstrip("\x00")
logger.info(f"Start to async save {path}")
state_dict = _read_state_dict_from_shm(meta_dict, shm) # numpy array
state_dict = quant_unified_optimizer(
state_dict, state_dict_type, ckpt_quant_stage, async_save=True
) # ckpt quantization
if state_dict_type == "optimizer_weight" and ckpt_quant_stage != "O0":
state_dict = quant_unified_optimizer(
state_dict, state_dict_type, ckpt_quant_stage, async_save=True
) # ckpt quantization
safe_save_file(state_dict, path, {"format": "np"})
del state_dict
saved_signal_path = os.path.join(signal_path, f".{state_dict_type}.done.{global_rank}")
Expand Down
2 changes: 1 addition & 1 deletion paddlenlp/trainer/unified_checkpoint/check_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def check_dynamic_load(args, weight_map, existed_files, is_master_weights=False,
state_dict = get_expected_state_dict(model)

for key in state_dict.keys():
if model._keys_to_ignore_on_load_massing is not None and key in model._keys_to_ignore_on_load_missing:
if model._keys_to_ignore_on_load_missing is not None and key in model._keys_to_ignore_on_load_missing:
continue
if sharding_group.nranks > 1:
static_name = struct2static_name_mappings.get(key, None)
Expand Down
41 changes: 34 additions & 7 deletions paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@
unwrap_model,
)
from paddlenlp.transformers.utils import dtype_byte_size
from paddlenlp.utils import infohub
from paddlenlp.utils.env import (
LORA_WEIGHTS_NAME,
MAX_QUANTIZATION_TIMES,
PADDLE_MASTER_WEIGHTS_NAME,
PADDLE_OPTIMIZER_NAME,
PADDLE_WEIGHTS_NAME,
Expand Down Expand Up @@ -239,9 +241,16 @@ def save_non_merge_optimizer(self, model, optim_state_dict, master_weights, outp
optimizer_name = _add_variant(SAFE_OPTIMIZER_NAME, self.args.optimizer_name_suffix)
master_weights_name = _add_variant(SAFE_MASTER_WEIGHTS_NAME, self.args.optimizer_name_suffix)

sharded_optim_index = {}
# save opt index json if checkpoint quantization is on.
if self.args.ckpt_quant_stage != "O0":
sharded_optim_index = {"ckpt_quant_stage": self.args.ckpt_quant_stage}
if self.args.ckpt_quant_stage != "O0" and "quant_reach_limit" not in infohub:
sharded_optim_index["ckpt_quant_stage"] = self.args.ckpt_quant_stage

sharded_optim_index["quant_ckpt_resume_times"] = (
infohub["quant_ckpt_resume_times"] if "quant_ckpt_resume_times" in infohub else 0
)

if len(sharded_optim_index) > 0:
optimizer_index_name = SAFE_OPTIMIZER_INDEX_NAME
path = os.path.join(output_dir, optimizer_index_name)
if self.args.should_save:
Expand All @@ -257,7 +266,7 @@ def save_non_merge_optimizer(self, model, optim_state_dict, master_weights, outp
signal_path=signal_dir,
is_sync=is_sync_save,
state_dict_type="optimizer_weight",
ckpt_quant_stage=self.args.ckpt_quant_stage,
ckpt_quant_stage=self.args.ckpt_quant_stage if "quant_reach_limit" not in infohub else "O0",
)
if master_weights is not None:
self.async_handler._file_save_async_or_sync(
Expand All @@ -277,7 +286,7 @@ def load_non_merge_optimizer(self, model, optimizer, resume_from_checkpoint, ckp
optimizer_path = os.path.join(resume_from_checkpoint, optimizer_name)
master_weights_path = os.path.join(resume_from_checkpoint, master_weights_name)
# no quantization & no master weight represent O1 AMP strategy.
is_amp_o1 = True if not os.path.isfile(master_weights_path) and ckpt_quant_stage == "O0" else False
is_amp_o1 = self.args.fp16_opt_level == "O1"

model_state_dict = get_expected_state_dict(model)
struct2static_name_mappings = {k: v.name for k, v in model_state_dict.items()} # get optimizer param mappings
Expand Down Expand Up @@ -379,7 +388,7 @@ def save_unified_optimizer(self, model, optimizer, output_dir, signal_dir):
signal_path=signal_dir,
is_sync=is_sync_save,
state_dict_type="optimizer_weight",
ckpt_quant_stage=self.args.ckpt_quant_stage,
ckpt_quant_stage=self.args.ckpt_quant_stage if "quant_reach_limit" not in infohub else "O0",
)
if master_weight_state_dict is not None:
self.async_handler._file_save_async_or_sync(
Expand Down Expand Up @@ -429,10 +438,24 @@ def load_unified_optimizer(self, model, optimizer, resume_from_checkpoint):
with open(os.path.join(resume_from_checkpoint, SAFE_OPTIMIZER_INDEX_NAME), "r") as f:
index = json.loads(f.read())

# get quant ckpt info `ckpt_quant_stage` and `quant_ckpt_resume_times`
ckpt_quant_stage = "O0"
if "ckpt_quant_stage" in index:
ckpt_quant_stage = index["ckpt_quant_stage"]

quant_ckpt_resume_times = 0
if "quant_ckpt_resume_times" in index:
quant_ckpt_resume_times = index["quant_ckpt_resume_times"]
# increment and save resume times in infohub
if ckpt_quant_stage != "O0":
quant_ckpt_resume_times += 1
infohub["quant_ckpt_resume_times"] = quant_ckpt_resume_times

# Quantization times exceeds the limit. Turn off the quantization strategy.
if quant_ckpt_resume_times >= MAX_QUANTIZATION_TIMES:
infohub["quant_reach_limit"] = True
logger.info("Checkpoint quantization time reach limit and will be closed.")

# If not having merge optimizer, then load non-merge optimizer.
if "weight_map" not in index:
if self.args.data_parallel_rank == 0 or self.args.use_expert_parallel:
Expand Down Expand Up @@ -647,8 +670,12 @@ def unified_optimizer_into_shards(
)
sharded_optim_index = get_sharded_index(index_optimizer_filelist, total_optim_size_list)

if args.should_save and args.ckpt_quant_stage in ["O1", "O2"]:
sharded_optim_index["ckpt_quant_stage"] = args.ckpt_quant_stage
if args.should_save:
if args.ckpt_quant_stage in ["O1", "O2"] and "quant_reach_limit" not in infohub:
sharded_optim_index["ckpt_quant_stage"] = args.ckpt_quant_stage
sharded_optim_index["quant_ckpt_resume_times"] = (
infohub["quant_ckpt_resume_times"] if "quant_ckpt_resume_times" in infohub else 0
)

if master_weights is not None:
index_master_weight_filelist, total_master_weight_size_list = gather_sharded_object(
Expand Down
2 changes: 1 addition & 1 deletion paddlenlp/transformers/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ def load_state_dict(
if len(scale_dict) != 0:
if ckpt_quant_stage == "O0":
raise ValueError('optimizer weight has quantization scales but `ckpt_quant_stage` is set to "O0"')
state_dict = dequant_unified_optimizer(state_dict, ckpt_quant_stage, scale_dict)
state_dict = dequant_unified_optimizer(state_dict, ckpt_quant_stage, scale_dict, use_pd=True)

return state_dict

Expand Down
40 changes: 23 additions & 17 deletions paddlenlp/transformers/tensor_parallel_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,14 +244,16 @@ def forward(
grad_hidden_states = None

# initialize outputs
token_loss = paddle.empty((n_tokens,), dtype=hidden_states.dtype)
token_loss = paddle.empty((n_tokens,), dtype=paddle.float32)

# blockwise calculations
for i in range(0, n_tokens, loop_chunk_size):
token_start_idx = i
token_end_idx = min(i + loop_chunk_size, n_tokens)
hidden_states_chunk = hidden_states[token_start_idx:token_end_idx]
labels_chunk = labels[token_start_idx:token_end_idx]
cur_chunk_range = paddle.arange(token_start_idx, token_end_idx)
hidden_states_chunk = paddle.gather(hidden_states, cur_chunk_range, axis=0)
labels_chunk = paddle.gather(labels, cur_chunk_range, axis=0)
loss_mask_chunk = paddle.gather(loss_mask, cur_chunk_range, axis=0)

# logits calculations
logits_chunk_cast = paddle.matmul(
Expand Down Expand Up @@ -304,9 +306,9 @@ def forward(
group=model_parallel_group,
)
token_loss_chunk = (log_sum_exp_logits - label_logits).squeeze(1) / divisor
cond = loss_mask[token_start_idx:token_end_idx].astype("bool")
cond = loss_mask_chunk.astype("bool")
token_loss_chunk = paddle.where(cond, token_loss_chunk, paddle.zeros_like(token_loss_chunk))
token_loss[token_start_idx:token_end_idx] = token_loss_chunk * loss_mask[token_start_idx:token_end_idx]
paddle.scatter_(token_loss, cur_chunk_range, token_loss_chunk, overwrite=True)

# gradients calculations
if not return_token_loss:
Expand All @@ -324,10 +326,11 @@ def forward(
)

if grad_hidden_states is not None:
grad_hidden_states[token_start_idx:token_end_idx] = paddle.matmul(
grad_logits_chunk,
lm_head_weight_cast,
transpose_y=not transpose_y,
paddle.scatter_(
grad_hidden_states,
cur_chunk_range,
paddle.matmul(grad_logits_chunk, lm_head_weight_cast, transpose_y=not transpose_y),
overwrite=True,
)
if grad_lm_head_weight is not None:
if transpose_y:
Expand Down Expand Up @@ -487,8 +490,10 @@ def backward(ctx, grad_output):
for i in range(0, n_tokens, loop_chunk_size):
token_start_idx = i
token_end_idx = min(i + loop_chunk_size, n_tokens)
hidden_states_chunk = hidden_states[token_start_idx:token_end_idx]
labels_chunk = labels[token_start_idx:token_end_idx]
cur_chunk_range = paddle.arange(token_start_idx, token_end_idx)
hidden_states_chunk = paddle.gather(hidden_states, cur_chunk_range, axis=0)
labels_chunk = paddle.gather(labels, cur_chunk_range, axis=0)
loss_mask_chunk = paddle.gather(loss_mask, cur_chunk_range, axis=0)

# logits calculations
logits_chunk_cast = paddle.matmul(
Expand Down Expand Up @@ -528,20 +533,21 @@ def backward(ctx, grad_output):
labels_one_hot = labels_one_hot.split(model_parallel_group.nranks, axis=-1)[model_parallel_group.rank]
grad_logits_chunk = exp_logits / sum_exp_logits - labels_one_hot.astype("float32")
# NOTE(hehuang): scaling grad_logits_chunk by grad_token_loss
grad_logits_chunk *= grad_token_loss[token_start_idx:token_end_idx].unsqueeze(1)
grad_logits_chunk *= paddle.gather(grad_token_loss, cur_chunk_range, axis=0).unsqueeze(1)
grad_logits_chunk = grad_logits_chunk.astype(dtype)
cond = loss_mask[token_start_idx:token_end_idx].astype("bool")
cond = loss_mask_chunk.astype("bool")
grad_logits_chunk = paddle.where(
cond.unsqueeze(1),
grad_logits_chunk,
paddle.zeros_like(grad_logits_chunk),
)

if grad_hidden_states is not None:
grad_hidden_states[token_start_idx:token_end_idx] = paddle.matmul(
grad_logits_chunk,
lm_head_weight_cast,
transpose_y=not transpose_y,
paddle.scatter_(
grad_hidden_states,
cur_chunk_range,
paddle.matmul(grad_logits_chunk, lm_head_weight_cast, transpose_y=not transpose_y),
overwrite=True,
)
if grad_lm_head_weight is not None:
if transpose_y:
Expand Down
2 changes: 2 additions & 0 deletions paddlenlp/trl/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ def prediction_pipeline_step(
"""
prediction_step function for pipeline parallel mode.
"""
model._p2p_helper.clear_meta_cache()
concatenated_inputs = {}
# consider no drop last
per_device_train_batch_size = self.args.per_device_train_batch_size
Expand Down Expand Up @@ -349,6 +350,7 @@ def prediction_pipeline_step(
)
self.log_metric(**metric_inputs)
self.reset_dpo_infohub()
model._p2p_helper.clear_meta_cache()
return (loss, None, None)

def log_metric(
Expand Down
7 changes: 7 additions & 0 deletions paddlenlp/utils/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,10 @@ def _get_bool_env(env_key: str, default_value: str) -> bool:
SYMMETRY_QUANT_SCALE = "@scales"
ASYMMETRY_QUANT_SCALE_MIN = "@min_scales"
ASYMMETRY_QUANT_SCALE_MAX = "@max_scales"
MAX_QUANTIZATION_TIMES = 1

# LLM Inference related environment variables
# Note(@Wanglongzhi2001): MAX_BSZ, SPECULATE_MAX_BSZ, MAX_DRAFT_TOKENS must be the same as definition in get_output / save_output
MAX_BSZ = 512
SPECULATE_MAX_BSZ = 256
MAX_DRAFT_TOKENS = 6
Loading

0 comments on commit 8db699a

Please sign in to comment.