From 5a5d47458d9aaf7ead798d15291ba3d9bef785c5 Mon Sep 17 00:00:00 2001 From: Leonardo Emili Date: Tue, 13 Feb 2024 17:24:30 +0100 Subject: [PATCH] Add seq2seq eval benchmark callback (#1274) * Add CausalLMBenchEvalCallback for measuring seq2seq performance * Fix code for pre-commit * Fix typing and improve logging * eval_sample_packing must be false with CausalLMBenchEvalCallback --- README.md | 3 +- examples/llama-2/loftq.yml | 2 +- examples/llama-2/lora.yml | 2 +- examples/mamba/config.yml | 2 +- .../mistral/Mistral-7b-example/config.yml | 2 +- examples/mistral/config.yml | 2 +- examples/mistral/mixtral.yml | 2 +- examples/mistral/qlora.yml | 2 +- examples/qwen/lora.yml | 2 +- examples/qwen/qlora.yml | 2 +- examples/yi-34B-chat/qlora.yml | 2 +- requirements.txt | 2 +- src/axolotl/core/trainer_builder.py | 11 ++ src/axolotl/utils/callbacks.py | 183 +++++++++++++++++- src/axolotl/utils/config.py | 23 ++- 15 files changed, 228 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 5cb1df3246..999fd72706 100644 --- a/README.md +++ b/README.md @@ -784,7 +784,8 @@ save_total_limit: # Checkpoints saved at a time max_steps: eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0 -eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128 +eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128 +eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", chrf] loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training) loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3) diff --git a/examples/llama-2/loftq.yml b/examples/llama-2/loftq.yml index 2abbb78478..d0d78098d7 100644 --- a/examples/llama-2/loftq.yml +++ b/examples/llama-2/loftq.yml @@ -60,7 +60,7 @@ s2_attention: warmup_steps: 10 evals_per_epoch: 4 eval_table_size: -eval_table_max_new_tokens: 128 +eval_max_new_tokens: 128 saves_per_epoch: 1 debug: deepspeed: diff --git a/examples/llama-2/lora.yml b/examples/llama-2/lora.yml index 90a9cfd2c7..45df96c562 100644 --- a/examples/llama-2/lora.yml +++ b/examples/llama-2/lora.yml @@ -57,7 +57,7 @@ s2_attention: warmup_steps: 10 evals_per_epoch: 4 eval_table_size: -eval_table_max_new_tokens: 128 +eval_max_new_tokens: 128 saves_per_epoch: 1 debug: deepspeed: diff --git a/examples/mamba/config.yml b/examples/mamba/config.yml index 9b697892a9..0a5223bcac 100644 --- a/examples/mamba/config.yml +++ b/examples/mamba/config.yml @@ -49,7 +49,7 @@ flash_attention: warmup_steps: 10 evals_per_epoch: 4 eval_table_size: -eval_table_max_new_tokens: 128 +eval_max_new_tokens: 128 saves_per_epoch: 1 debug: deepspeed: diff --git a/examples/mistral/Mistral-7b-example/config.yml b/examples/mistral/Mistral-7b-example/config.yml index d28d8f6b75..45e69e5486 100644 --- a/examples/mistral/Mistral-7b-example/config.yml +++ b/examples/mistral/Mistral-7b-example/config.yml @@ -61,7 +61,7 @@ flash_attention: true warmup_steps: 10 evals_per_epoch: 4 eval_table_size: -eval_table_max_new_tokens: 128 +eval_max_new_tokens: 128 saves_per_epoch: 1 debug: #default deepspeed, can use more aggresive if needed like zero2, zero3 diff --git a/examples/mistral/config.yml b/examples/mistral/config.yml index df70478672..a5297fae81 100644 --- a/examples/mistral/config.yml +++ b/examples/mistral/config.yml @@ -49,7 +49,7 @@ flash_attention: true warmup_steps: 10 evals_per_epoch: 4 eval_table_size: -eval_table_max_new_tokens: 128 +eval_max_new_tokens: 128 saves_per_epoch: 1 debug: deepspeed: diff --git a/examples/mistral/mixtral.yml b/examples/mistral/mixtral.yml index 4489a272a8..7c18e7098c 100644 --- a/examples/mistral/mixtral.yml +++ b/examples/mistral/mixtral.yml @@ -81,7 +81,7 @@ loss_watchdog_patience: 3 warmup_steps: 10 evals_per_epoch: 4 eval_table_size: -eval_table_max_new_tokens: 128 +eval_max_new_tokens: 128 saves_per_epoch: 1 debug: deepspeed: deepspeed_configs/zero2.json diff --git a/examples/mistral/qlora.yml b/examples/mistral/qlora.yml index 44ab5691bb..70099b0e33 100644 --- a/examples/mistral/qlora.yml +++ b/examples/mistral/qlora.yml @@ -68,7 +68,7 @@ loss_watchdog_patience: 3 warmup_steps: 10 evals_per_epoch: 4 eval_table_size: -eval_table_max_new_tokens: 128 +eval_max_new_tokens: 128 saves_per_epoch: 1 debug: deepspeed: diff --git a/examples/qwen/lora.yml b/examples/qwen/lora.yml index c14e5f8d66..1a006ac4e1 100644 --- a/examples/qwen/lora.yml +++ b/examples/qwen/lora.yml @@ -58,7 +58,7 @@ flash_attention: warmup_steps: 10 evals_per_epoch: 4 eval_table_size: -eval_table_max_new_tokens: 128 +eval_max_new_tokens: 128 saves_per_epoch: 1 debug: deepspeed: diff --git a/examples/qwen/qlora.yml b/examples/qwen/qlora.yml index cb3666d256..462746a59f 100644 --- a/examples/qwen/qlora.yml +++ b/examples/qwen/qlora.yml @@ -58,7 +58,7 @@ flash_attention: warmup_steps: 10 evals_per_epoch: 4 eval_table_size: -eval_table_max_new_tokens: 128 +eval_max_new_tokens: 128 saves_per_epoch: 1 debug: deepspeed: diff --git a/examples/yi-34B-chat/qlora.yml b/examples/yi-34B-chat/qlora.yml index fedbc26b7e..5d55e143b5 100644 --- a/examples/yi-34B-chat/qlora.yml +++ b/examples/yi-34B-chat/qlora.yml @@ -29,7 +29,7 @@ num_epochs: 1 val_set_size: 0.1 evals_per_epoch: 5 eval_table_size: -eval_table_max_new_tokens: 128 +eval_max_new_tokens: 128 eval_sample_packing: false eval_batch_size: 1 diff --git a/requirements.txt b/requirements.txt index 4d1073500f..e20940f649 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,7 +23,7 @@ numba numpy>=1.24.4 mlflow # qlora things -evaluate==0.4.0 +evaluate==0.4.1 scipy scikit-learn==1.2.2 pynvml diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index e2667aea43..7d39aefeac 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -38,6 +38,7 @@ SaveAxolotlConfigtoWandBCallback, SaveBetterTransformerModelCallback, bench_eval_callback_factory, + causal_lm_bench_eval_callback_factory, log_prediction_callback_factory, ) from axolotl.utils.collators import ( @@ -148,6 +149,9 @@ class AxolotlTrainingArguments(TrainingArguments): do_bench_eval: Optional[bool] = field( default=False, metadata={"help": "Whether to run the Benchmark evaluation."} ) + do_causal_lm_eval: Optional[bool] = field( + default=False, metadata={"help": "Whether to run the Causal LM evaluation."} + ) max_bench_samples: Optional[int] = field( default=None, metadata={ @@ -664,6 +668,11 @@ def get_post_trainer_create_callbacks(self, trainer): if self.cfg.do_bench_eval: callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer)) + if self.cfg.do_causal_lm_eval: + CausalLMBenchEvalCallback = causal_lm_bench_eval_callback_factory( + trainer, self.tokenizer + ) + callbacks.append(CausalLMBenchEvalCallback(self.cfg)) if self.cfg.early_stopping_patience: early_stop_cb = EarlyStoppingCallback( @@ -812,6 +821,8 @@ def build(self, total_num_steps): training_arguments_kwargs["do_bench_eval"] = self.cfg.do_bench_eval if self.cfg.bench_dataset: training_arguments_kwargs["bench_dataset"] = self.cfg.bench_dataset + if self.cfg.do_causal_lm_eval: + training_arguments_kwargs["do_causal_lm_eval"] = self.cfg.do_causal_lm_eval if self.cfg.metric_for_best_model: training_arguments_kwargs[ "metric_for_best_model" diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index 9de266be10..10c0a33826 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -361,6 +361,187 @@ def on_evaluate( return BenchEvalCallback +def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer): + class CausalLMBenchEvalCallback(TrainerCallback): + """Callback to log prediction values during each evaluation""" + + def __init__(self, cfg): + self.cfg = cfg + self.logged = False + self.metrics = self.__maybe_load_metrics() + + def __maybe_load_metrics(self): + metrics = {} + for metric in self.cfg.eval_causal_lm_metrics: + try: + metrics[metric] = evaluate.load(metric) + except Exception as exc: # pylint: disable=broad-exception-caught + LOG.warning(f"{metric}: {exc.args}") + return metrics + + def on_evaluate( + self, + args: AxolotlTrainingArguments, # pylint: disable=unused-argument + state: TrainerState, + control: TrainerControl, + train_dataloader, # pylint: disable=unused-argument + eval_dataloader, + **kwargs, # pylint: disable=unused-argument + ): + trainer.model.eval() + device = torch.device(self.cfg.device) + + # pylint: disable=duplicate-code + generation_config = GenerationConfig( + max_new_tokens=self.cfg.eval_max_new_tokens, + bos_token_id=tokenizer.bos_token_id, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id, + do_sample=False, + use_cache=True, + return_dict_in_generate=True, + output_attentions=False, + output_hidden_states=False, + output_scores=False, + ) + + def find_ranges(lst): + ranges = [] + start = 0 + for i in range(1, len(lst)): + if lst[i] == 0: + ranges.append((start, i - 1)) + start = i + end = len(lst) - 1 + ranges.append((start, end)) + return ranges + + def compute(metric: evaluate.Metric, **kwargs): + # safely compute a metric and return the score if the format is correct + metric_score = None + try: + metric_score = metric.compute(**kwargs) + return ( + metric_score["score"] + if "score" in metric_score + else metric_score["mean_score"] + ) + except Exception: # pylint: disable=broad-exception-caught + LOG.debug( + f"Failed to compute metric {metric.name} with kwargs {kwargs.keys()}" + ) + return metric_score + + def evaluate_preds(sources, predictions, references): + scores = {} + + for metric_name, metric in self.metrics.items(): + score = compute( + metric, + references=references, + predictions=predictions, + sources=sources, + ) + score = score or compute( + metric, + references=[[r] for r in references], + predictions=predictions, + ) + scores[metric_name] = score + return scores + + def predict_with_generate(): + eval_src, eval_pred, eval_ref = [], [], [] + + for batch in tqdm(eval_dataloader): + batch_labels = batch["labels"].to(device) + batch_input_ids = batch["input_ids"].to(device) + + if "position_ids" in batch: + batch_pos_ids = batch["position_ids"].tolist() + else: + batch_pos_ids = [None] * len(batch["input_ids"]) + + prompt_token_ids_list = [] + completion_token_ids_list = [] + + for input_ids_all, labels_all, pos_ids in zip( + batch_input_ids, + batch_labels, + batch_pos_ids, + ): + if pos_ids is None: + pos_ranges = [(0, len(input_ids_all) - 1)] + else: + pos_ranges = find_ranges(pos_ids) + + for pos_range in pos_ranges: + start, end = pos_range + if start == end: + continue + + input_ids = input_ids_all[start : end + 1] + labels = labels_all[start : end + 1] + + tokens_without_loss = labels == IGNORE_INDEX + tokens_with_loss = labels != IGNORE_INDEX + tokens_exclude_padding = input_ids != tokenizer.pad_token_id + prompt_token_includes = ( + tokens_without_loss & tokens_exclude_padding + ) + + prompt_token_ids = input_ids[prompt_token_includes] + prompt_token_ids_list.append(prompt_token_ids) + + completion_token_ids = input_ids[tokens_with_loss] + completion_token_ids_list.append(completion_token_ids) + + prompt_texts = tokenizer.batch_decode( + prompt_token_ids_list, skip_special_tokens=True + ) + completion_texts = tokenizer.batch_decode( + completion_token_ids_list, skip_special_tokens=True + ) + + with torch.no_grad(): + prompt_encoding = tokenizer( + prompt_texts, padding=True, return_tensors="pt" + ).to(self.cfg.device) + predictions = trainer.model.generate( + **prompt_encoding, generation_config=generation_config + ) + + prediction_all_tokens = predictions["sequences"].cpu().tolist() + prediction_without_prompt_tokens_list = [] + for prompt_token_ids, prediction_tokens in zip( + prompt_token_ids_list, prediction_all_tokens + ): + prediction_without_prompt_tokens = prediction_tokens[ + len(prompt_token_ids) : + ] + prediction_without_prompt_tokens_list.append( + prediction_without_prompt_tokens + ) + + predicted_texts = tokenizer.batch_decode( + prediction_without_prompt_tokens_list, skip_special_tokens=True + ) + + eval_src.extend(prompt_texts) + eval_pred.extend(predicted_texts) + eval_ref.extend(completion_texts) + + return eval_src, eval_pred, eval_ref + + if is_main_process(): + eval_preds = predict_with_generate() + trainer.log(evaluate_preds(*eval_preds)) + + return control + + return CausalLMBenchEvalCallback + + def log_prediction_callback_factory(trainer: Trainer, tokenizer): class LogPredictionCallback(TrainerCallback): """Callback to log prediction values during each evaluation""" @@ -388,7 +569,7 @@ def on_evaluate( # pylint: disable=duplicate-code generation_config = GenerationConfig( - max_new_tokens=self.cfg.eval_table_max_new_tokens, + max_new_tokens=self.cfg.eval_max_new_tokens, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id, diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index ac77968675..1fc470da9e 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -56,7 +56,13 @@ def normalize_config(cfg): cfg.world_size = int(os.environ.get("WORLD_SIZE", 1)) cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0)) cfg.eval_table_size = cfg.eval_table_size or 0 - cfg.eval_table_max_new_tokens = cfg.eval_table_max_new_tokens or 128 + cfg.eval_max_new_tokens = cfg.eval_max_new_tokens or 128 + cfg.eval_causal_lm_metrics = cfg.eval_causal_lm_metrics or [ + "sacrebleu", + "comet", + "ter", + "chrf", + ] choose_device(cfg) cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1 if cfg.ddp: @@ -550,6 +556,21 @@ def validate_config(cfg): if cfg.fsdp and "bnb" in cfg.optimizer: raise ValueError(f"FSDP not compatible with {cfg.optimizer}") + if cfg.do_causal_lm_eval and cfg.eval_sample_packing: + raise ValueError( + "do_causal_lm_eval is enabled, eval_sample_packing must be set to False" + ) + + if cfg.eval_causal_lm_metrics: + supported_metrics = ["sacrebleu", "comet", "ter", "chrf"] + if not isinstance(cfg.eval_causal_lm_metrics, list): + raise ValueError("eval_causal_lm_metrics must be a list") + # only ["sacrebleu", "comet", "ter", "chrf"] supported + if set(cfg.eval_causal_lm_metrics) - set(supported_metrics): + raise ValueError( + f"eval_causal_lm_metrics must be one of {supported_metrics}" + ) + # TODO # MPT 7b # https://github.com/facebookresearch/bitsandbytes/issues/25