From 0e87f23705f94775ad8670461e15cf6043463ccb Mon Sep 17 00:00:00 2001 From: Leonardo Emili Date: Wed, 7 Feb 2024 16:12:38 +0000 Subject: [PATCH 1/4] Add CausalLMBenchEvalCallback for measuring seq2seq performance --- README.md | 3 +- examples/llama-2/loftq.yml | 2 +- examples/llama-2/lora.yml | 2 +- examples/mamba/config.yml | 2 +- .../mistral/Mistral-7b-example/config.yml | 2 +- examples/mistral/config.yml | 2 +- examples/mistral/mixtral.yml | 2 +- examples/mistral/qlora.yml | 2 +- examples/qwen/lora.yml | 2 +- examples/qwen/qlora.yml | 2 +- examples/yi-34B-chat/qlora.yml | 2 +- requirements.txt | 2 +- src/axolotl/core/trainer_builder.py | 11 ++ src/axolotl/utils/callbacks.py | 171 +++++++++++++++++- src/axolotl/utils/config.py | 11 +- 15 files changed, 203 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 56ff5407f6..bdeb8352cb 100644 --- a/README.md +++ b/README.md @@ -768,7 +768,8 @@ save_total_limit: # Checkpoints saved at a time max_steps: eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0 -eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128 +eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128 +eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", chrf] loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training) loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3) diff --git a/examples/llama-2/loftq.yml b/examples/llama-2/loftq.yml index 2abbb78478..d0d78098d7 100644 --- a/examples/llama-2/loftq.yml +++ b/examples/llama-2/loftq.yml @@ -60,7 +60,7 @@ s2_attention: warmup_steps: 10 evals_per_epoch: 4 eval_table_size: -eval_table_max_new_tokens: 128 +eval_max_new_tokens: 128 saves_per_epoch: 1 debug: deepspeed: diff --git a/examples/llama-2/lora.yml b/examples/llama-2/lora.yml index 90a9cfd2c7..45df96c562 100644 --- a/examples/llama-2/lora.yml +++ b/examples/llama-2/lora.yml @@ -57,7 +57,7 @@ s2_attention: warmup_steps: 10 evals_per_epoch: 4 eval_table_size: -eval_table_max_new_tokens: 128 +eval_max_new_tokens: 128 saves_per_epoch: 1 debug: deepspeed: diff --git a/examples/mamba/config.yml b/examples/mamba/config.yml index 9b697892a9..0a5223bcac 100644 --- a/examples/mamba/config.yml +++ b/examples/mamba/config.yml @@ -49,7 +49,7 @@ flash_attention: warmup_steps: 10 evals_per_epoch: 4 eval_table_size: -eval_table_max_new_tokens: 128 +eval_max_new_tokens: 128 saves_per_epoch: 1 debug: deepspeed: diff --git a/examples/mistral/Mistral-7b-example/config.yml b/examples/mistral/Mistral-7b-example/config.yml index d28d8f6b75..45e69e5486 100644 --- a/examples/mistral/Mistral-7b-example/config.yml +++ b/examples/mistral/Mistral-7b-example/config.yml @@ -61,7 +61,7 @@ flash_attention: true warmup_steps: 10 evals_per_epoch: 4 eval_table_size: -eval_table_max_new_tokens: 128 +eval_max_new_tokens: 128 saves_per_epoch: 1 debug: #default deepspeed, can use more aggresive if needed like zero2, zero3 diff --git a/examples/mistral/config.yml b/examples/mistral/config.yml index df70478672..a5297fae81 100644 --- a/examples/mistral/config.yml +++ b/examples/mistral/config.yml @@ -49,7 +49,7 @@ flash_attention: true warmup_steps: 10 evals_per_epoch: 4 eval_table_size: -eval_table_max_new_tokens: 128 +eval_max_new_tokens: 128 saves_per_epoch: 1 debug: deepspeed: diff --git a/examples/mistral/mixtral.yml b/examples/mistral/mixtral.yml index 4489a272a8..7c18e7098c 100644 --- a/examples/mistral/mixtral.yml +++ b/examples/mistral/mixtral.yml @@ -81,7 +81,7 @@ loss_watchdog_patience: 3 warmup_steps: 10 evals_per_epoch: 4 eval_table_size: -eval_table_max_new_tokens: 128 +eval_max_new_tokens: 128 saves_per_epoch: 1 debug: deepspeed: deepspeed_configs/zero2.json diff --git a/examples/mistral/qlora.yml b/examples/mistral/qlora.yml index 44ab5691bb..70099b0e33 100644 --- a/examples/mistral/qlora.yml +++ b/examples/mistral/qlora.yml @@ -68,7 +68,7 @@ loss_watchdog_patience: 3 warmup_steps: 10 evals_per_epoch: 4 eval_table_size: -eval_table_max_new_tokens: 128 +eval_max_new_tokens: 128 saves_per_epoch: 1 debug: deepspeed: diff --git a/examples/qwen/lora.yml b/examples/qwen/lora.yml index c14e5f8d66..1a006ac4e1 100644 --- a/examples/qwen/lora.yml +++ b/examples/qwen/lora.yml @@ -58,7 +58,7 @@ flash_attention: warmup_steps: 10 evals_per_epoch: 4 eval_table_size: -eval_table_max_new_tokens: 128 +eval_max_new_tokens: 128 saves_per_epoch: 1 debug: deepspeed: diff --git a/examples/qwen/qlora.yml b/examples/qwen/qlora.yml index cb3666d256..462746a59f 100644 --- a/examples/qwen/qlora.yml +++ b/examples/qwen/qlora.yml @@ -58,7 +58,7 @@ flash_attention: warmup_steps: 10 evals_per_epoch: 4 eval_table_size: -eval_table_max_new_tokens: 128 +eval_max_new_tokens: 128 saves_per_epoch: 1 debug: deepspeed: diff --git a/examples/yi-34B-chat/qlora.yml b/examples/yi-34B-chat/qlora.yml index fedbc26b7e..5d55e143b5 100644 --- a/examples/yi-34B-chat/qlora.yml +++ b/examples/yi-34B-chat/qlora.yml @@ -29,7 +29,7 @@ num_epochs: 1 val_set_size: 0.1 evals_per_epoch: 5 eval_table_size: -eval_table_max_new_tokens: 128 +eval_max_new_tokens: 128 eval_sample_packing: false eval_batch_size: 1 diff --git a/requirements.txt b/requirements.txt index 2e978c16da..56a634655f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,7 +22,7 @@ numba numpy>=1.24.4 mlflow # qlora things -evaluate==0.4.0 +evaluate==0.4.1 scipy scikit-learn==1.2.2 pynvml diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 73eddd4260..02faaad994 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -37,6 +37,7 @@ SaveAxolotlConfigtoWandBCallback, SaveBetterTransformerModelCallback, bench_eval_callback_factory, + causal_lm_bench_eval_callback_factory, log_prediction_callback_factory, ) from axolotl.utils.collators import ( @@ -142,6 +143,9 @@ class AxolotlTrainingArguments(TrainingArguments): do_bench_eval: Optional[bool] = field( default=False, metadata={"help": "Whether to run the Benchmark evaluation."} ) + do_causal_lm_eval: Optional[bool] = field( + default=False, metadata={"help": "Whether to run the Causal LM evaluation."} + ) max_bench_samples: Optional[int] = field( default=None, metadata={ @@ -642,6 +646,11 @@ def get_post_trainer_create_callbacks(self, trainer): if self.cfg.do_bench_eval: callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer)) + if self.cfg.do_causal_lm_eval: + CausalLMBenchEvalCallback = causal_lm_bench_eval_callback_factory( + trainer, self.tokenizer + ) + callbacks.append(CausalLMBenchEvalCallback(self.cfg)) if self.cfg.early_stopping_patience: early_stop_cb = EarlyStoppingCallback( @@ -790,6 +799,8 @@ def build(self, total_num_steps): training_arguments_kwargs["do_bench_eval"] = self.cfg.do_bench_eval if self.cfg.bench_dataset: training_arguments_kwargs["bench_dataset"] = self.cfg.bench_dataset + if self.cfg.do_causal_lm_eval: + training_arguments_kwargs["do_causal_lm_eval"] = self.cfg.do_causal_lm_eval if self.cfg.metric_for_best_model: training_arguments_kwargs[ "metric_for_best_model" diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index 9de266be10..390290965c 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -15,7 +15,7 @@ import torch import torch.distributed as dist import wandb -from datasets import load_dataset +from datasets import load_dataset, Metric from optimum.bettertransformer import BetterTransformer from tqdm import tqdm from transformers import ( @@ -360,6 +360,173 @@ def on_evaluate( return BenchEvalCallback +def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer): + class CausalLMBenchEvalCallback(TrainerCallback): + """Callback to log prediction values during each evaluation""" + + def __init__(self, cfg): + self.cfg = cfg + self.logged = False + self.metrics = self.__maybe_load_metrics() + + def __maybe_load_metrics(self): + metrics = dict() + for metric in self.cfg.eval_causal_lm_metrics: + try: + metrics[metric] = evaluate.load(metric) + except Exception as e: + LOG.warning(e.args) + return metrics + + def on_evaluate( + self, + args: AxolotlTrainingArguments, # pylint: disable=unused-argument + state: TrainerState, + control: TrainerControl, + train_dataloader, # pylint: disable=unused-argument + eval_dataloader, + **kwargs, # pylint: disable=unused-argument + ): + trainer.model.eval() + device = torch.device(self.cfg.device) + + # pylint: disable=duplicate-code + generation_config = GenerationConfig( + max_new_tokens=self.cfg.eval_max_new_tokens, + bos_token_id=tokenizer.bos_token_id, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id, + do_sample=False, + use_cache=True, + return_dict_in_generate=True, + output_attentions=False, + output_hidden_states=False, + output_scores=False, + ) + + def find_ranges(lst): + ranges = [] + start = 0 + for i in range(1, len(lst)): + if lst[i] == 0: + ranges.append((start, i - 1)) + start = i + end = len(lst) - 1 + ranges.append((start, end)) + return ranges + + def compute(metric: Metric, **kwargs): + # safely compute a metric and return the score if the format is correct + metric_score = None + try: + metric_score = metric.compute(**kwargs) + return metric_score["score"] if "score" in metric_score else metric_score["mean_score"] + except: + pass + return metric_score + + def evaluate_preds(sources, predictions, references): + if len(self.metrics) == 0: + return + + scores = dict() + for metric_name, metric in self.metrics.items(): + score = compute(metric, references=references, predictions=predictions, sources=sources) + score = score or compute(metric, references=[[r] for r in references], predictions=predictions) + scores[metric_name] = score + return scores + + def predict_with_generate(): + eval_src, eval_pred, eval_ref = list(), list(), list() + + for batch in tqdm(eval_dataloader): + batch_labels = batch["labels"].to(device) + batch_input_ids = batch["input_ids"].to(device) + + if "position_ids" in batch: + batch_pos_ids = batch["position_ids"].tolist() + else: + batch_pos_ids = [None] * len(batch["input_ids"]) + + prompt_token_ids_list = [] + completion_token_ids_list = [] + + for input_ids_all, labels_all, pos_ids in zip( + batch_input_ids, + batch_labels, + batch_pos_ids, + ): + if pos_ids is None: + pos_ranges = [(0, len(input_ids_all) - 1)] + else: + pos_ranges = find_ranges(pos_ids) + + for pos_range in pos_ranges: + start, end = pos_range + if start == end: + continue + + input_ids = input_ids_all[start : end + 1] + labels = labels_all[start : end + 1] + + tokens_without_loss = labels == IGNORE_INDEX + tokens_with_loss = labels != IGNORE_INDEX + tokens_exclude_padding = input_ids != tokenizer.pad_token_id + prompt_token_includes = ( + tokens_without_loss & tokens_exclude_padding + ) + + prompt_token_ids = input_ids[prompt_token_includes] + prompt_token_ids_list.append(prompt_token_ids) + + completion_token_ids = input_ids[tokens_with_loss] + completion_token_ids_list.append(completion_token_ids) + + prompt_texts = tokenizer.batch_decode( + prompt_token_ids_list, skip_special_tokens=True + ) + completion_texts = tokenizer.batch_decode( + completion_token_ids_list, skip_special_tokens=True + ) + + with torch.no_grad(): + prompt_encoding = tokenizer( + prompt_texts, padding=True, return_tensors="pt" + ).to(self.cfg.device) + predictions = trainer.model.generate( + **prompt_encoding, generation_config=generation_config + ) + + prediction_all_tokens = predictions["sequences"].cpu().tolist() + prediction_without_prompt_tokens_list = [] + for prompt_token_ids, prediction_tokens in zip( + prompt_token_ids_list, prediction_all_tokens + ): + prediction_without_prompt_tokens = prediction_tokens[ + len(prompt_token_ids) : + ] + prediction_without_prompt_tokens_list.append( + prediction_without_prompt_tokens + ) + + predicted_texts = tokenizer.batch_decode( + prediction_without_prompt_tokens_list, skip_special_tokens=True + ) + + eval_src.extend(prompt_texts) + eval_pred.extend(predicted_texts) + eval_ref.extend(completion_texts) + + return eval_src, eval_pred, eval_ref + + if is_main_process(): + eval_preds = predict_with_generate() + trainer.log(evaluate_preds(*eval_preds)) + + return control + + return CausalLMBenchEvalCallback + def log_prediction_callback_factory(trainer: Trainer, tokenizer): class LogPredictionCallback(TrainerCallback): @@ -388,7 +555,7 @@ def on_evaluate( # pylint: disable=duplicate-code generation_config = GenerationConfig( - max_new_tokens=self.cfg.eval_table_max_new_tokens, + max_new_tokens=self.cfg.eval_max_new_tokens, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id, diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index ac77968675..b2a6409aee 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -56,7 +56,8 @@ def normalize_config(cfg): cfg.world_size = int(os.environ.get("WORLD_SIZE", 1)) cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0)) cfg.eval_table_size = cfg.eval_table_size or 0 - cfg.eval_table_max_new_tokens = cfg.eval_table_max_new_tokens or 128 + cfg.eval_max_new_tokens = cfg.eval_max_new_tokens or 128 + cfg.eval_causal_lm_metrics = cfg.eval_causal_lm_metrics or ["sacrebleu", "comet", "ter", "chrf"] choose_device(cfg) cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1 if cfg.ddp: @@ -550,6 +551,14 @@ def validate_config(cfg): if cfg.fsdp and "bnb" in cfg.optimizer: raise ValueError(f"FSDP not compatible with {cfg.optimizer}") + if cfg.eval_causal_lm_metrics: + supported_metrics = ["sacrebleu", "comet", "ter", "chrf"] + if not isinstance(cfg.eval_causal_lm_metrics, list): + raise ValueError("eval_causal_lm_metrics must be a list") + # only ["sacrebleu", "comet", "ter", "chrf"] supported + if set(cfg.eval_causal_lm_metrics) - set(supported_metrics): + raise ValueError(f"eval_causal_lm_metrics must be one of {supported_metrics}") + # TODO # MPT 7b # https://github.com/facebookresearch/bitsandbytes/issues/25 From 3392def7140dc6e4355f7061f8405edcdc38306c Mon Sep 17 00:00:00 2001 From: Leonardo Emili Date: Fri, 9 Feb 2024 16:49:26 +0000 Subject: [PATCH 2/4] Fix code for pre-commit --- src/axolotl/utils/callbacks.py | 50 ++++++++++++++++++++++------------ 1 file changed, 32 insertions(+), 18 deletions(-) diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index 390290965c..14f7bdfca0 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -15,7 +15,7 @@ import torch import torch.distributed as dist import wandb -from datasets import load_dataset, Metric +from datasets import Metric, load_dataset from optimum.bettertransformer import BetterTransformer from tqdm import tqdm from transformers import ( @@ -360,6 +360,7 @@ def on_evaluate( return BenchEvalCallback + def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer): class CausalLMBenchEvalCallback(TrainerCallback): """Callback to log prediction values during each evaluation""" @@ -368,14 +369,14 @@ def __init__(self, cfg): self.cfg = cfg self.logged = False self.metrics = self.__maybe_load_metrics() - + def __maybe_load_metrics(self): - metrics = dict() + metrics = {} for metric in self.cfg.eval_causal_lm_metrics: try: metrics[metric] = evaluate.load(metric) - except Exception as e: - LOG.warning(e.args) + except Exception as exc: # pylint: disable=broad-exception-caught + LOG.warning(exc.args) return metrics def on_evaluate( @@ -414,30 +415,43 @@ def find_ranges(lst): end = len(lst) - 1 ranges.append((start, end)) return ranges - + def compute(metric: Metric, **kwargs): # safely compute a metric and return the score if the format is correct metric_score = None try: metric_score = metric.compute(**kwargs) - return metric_score["score"] if "score" in metric_score else metric_score["mean_score"] - except: - pass + return ( + metric_score["score"] + if "score" in metric_score + else metric_score["mean_score"] + ) + except Exception: # pylint: disable=broad-exception-caught + LOG.warning( + f"Failed to compute metric {metric} with kwargs {kwargs.keys()}" + ) return metric_score - + def evaluate_preds(sources, predictions, references): - if len(self.metrics) == 0: - return - - scores = dict() + scores = {} + for metric_name, metric in self.metrics.items(): - score = compute(metric, references=references, predictions=predictions, sources=sources) - score = score or compute(metric, references=[[r] for r in references], predictions=predictions) + score = compute( + metric, + references=references, + predictions=predictions, + sources=sources, + ) + score = score or compute( + metric, + references=[[r] for r in references], + predictions=predictions, + ) scores[metric_name] = score return scores def predict_with_generate(): - eval_src, eval_pred, eval_ref = list(), list(), list() + eval_src, eval_pred, eval_ref = [], [], [] for batch in tqdm(eval_dataloader): batch_labels = batch["labels"].to(device) @@ -512,7 +526,7 @@ def predict_with_generate(): predicted_texts = tokenizer.batch_decode( prediction_without_prompt_tokens_list, skip_special_tokens=True ) - + eval_src.extend(prompt_texts) eval_pred.extend(predicted_texts) eval_ref.extend(completion_texts) From 61204d644f56d45cb56a024c3bd2782f266da283 Mon Sep 17 00:00:00 2001 From: Leonardo Emili Date: Fri, 9 Feb 2024 17:23:58 +0000 Subject: [PATCH 3/4] Fix typing and improve logging --- src/axolotl/utils/callbacks.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index 14f7bdfca0..10c0a33826 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -15,7 +15,7 @@ import torch import torch.distributed as dist import wandb -from datasets import Metric, load_dataset +from datasets import load_dataset from optimum.bettertransformer import BetterTransformer from tqdm import tqdm from transformers import ( @@ -376,7 +376,7 @@ def __maybe_load_metrics(self): try: metrics[metric] = evaluate.load(metric) except Exception as exc: # pylint: disable=broad-exception-caught - LOG.warning(exc.args) + LOG.warning(f"{metric}: {exc.args}") return metrics def on_evaluate( @@ -416,7 +416,7 @@ def find_ranges(lst): ranges.append((start, end)) return ranges - def compute(metric: Metric, **kwargs): + def compute(metric: evaluate.Metric, **kwargs): # safely compute a metric and return the score if the format is correct metric_score = None try: @@ -427,8 +427,8 @@ def compute(metric: Metric, **kwargs): else metric_score["mean_score"] ) except Exception: # pylint: disable=broad-exception-caught - LOG.warning( - f"Failed to compute metric {metric} with kwargs {kwargs.keys()}" + LOG.debug( + f"Failed to compute metric {metric.name} with kwargs {kwargs.keys()}" ) return metric_score From 2757a6c79c2e2691d91a4929cfe412aa38d687f7 Mon Sep 17 00:00:00 2001 From: Leonardo Emili Date: Fri, 9 Feb 2024 17:52:55 +0000 Subject: [PATCH 4/4] eval_sample_packing must be false with CausalLMBenchEvalCallback --- src/axolotl/utils/config.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index b2a6409aee..1fc470da9e 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -57,7 +57,12 @@ def normalize_config(cfg): cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0)) cfg.eval_table_size = cfg.eval_table_size or 0 cfg.eval_max_new_tokens = cfg.eval_max_new_tokens or 128 - cfg.eval_causal_lm_metrics = cfg.eval_causal_lm_metrics or ["sacrebleu", "comet", "ter", "chrf"] + cfg.eval_causal_lm_metrics = cfg.eval_causal_lm_metrics or [ + "sacrebleu", + "comet", + "ter", + "chrf", + ] choose_device(cfg) cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1 if cfg.ddp: @@ -551,13 +556,20 @@ def validate_config(cfg): if cfg.fsdp and "bnb" in cfg.optimizer: raise ValueError(f"FSDP not compatible with {cfg.optimizer}") + if cfg.do_causal_lm_eval and cfg.eval_sample_packing: + raise ValueError( + "do_causal_lm_eval is enabled, eval_sample_packing must be set to False" + ) + if cfg.eval_causal_lm_metrics: supported_metrics = ["sacrebleu", "comet", "ter", "chrf"] if not isinstance(cfg.eval_causal_lm_metrics, list): raise ValueError("eval_causal_lm_metrics must be a list") # only ["sacrebleu", "comet", "ter", "chrf"] supported if set(cfg.eval_causal_lm_metrics) - set(supported_metrics): - raise ValueError(f"eval_causal_lm_metrics must be one of {supported_metrics}") + raise ValueError( + f"eval_causal_lm_metrics must be one of {supported_metrics}" + ) # TODO # MPT 7b