Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add seq2seq eval benchmark callback #1274

Merged
merged 6 commits into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -782,7 +782,8 @@ save_total_limit: # Checkpoints saved at a time
max_steps:

eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", chrf]

loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)
Expand Down
2 changes: 1 addition & 1 deletion examples/llama-2/loftq.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ s2_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_table_max_new_tokens: 128
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
Expand Down
2 changes: 1 addition & 1 deletion examples/llama-2/lora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ s2_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_table_max_new_tokens: 128
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
Expand Down
2 changes: 1 addition & 1 deletion examples/mamba/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ flash_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_table_max_new_tokens: 128
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
Expand Down
2 changes: 1 addition & 1 deletion examples/mistral/Mistral-7b-example/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ flash_attention: true
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_table_max_new_tokens: 128
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
#default deepspeed, can use more aggresive if needed like zero2, zero3
Expand Down
2 changes: 1 addition & 1 deletion examples/mistral/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ flash_attention: true
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_table_max_new_tokens: 128
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
Expand Down
2 changes: 1 addition & 1 deletion examples/mistral/mixtral.yml
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_table_max_new_tokens: 128
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed: deepspeed_configs/zero2.json
Expand Down
2 changes: 1 addition & 1 deletion examples/mistral/qlora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_table_max_new_tokens: 128
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
Expand Down
2 changes: 1 addition & 1 deletion examples/qwen/lora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ flash_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_table_max_new_tokens: 128
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
Expand Down
2 changes: 1 addition & 1 deletion examples/qwen/qlora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ flash_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_table_max_new_tokens: 128
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
Expand Down
2 changes: 1 addition & 1 deletion examples/yi-34B-chat/qlora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ num_epochs: 1
val_set_size: 0.1
evals_per_epoch: 5
eval_table_size:
eval_table_max_new_tokens: 128
eval_max_new_tokens: 128
eval_sample_packing: false
eval_batch_size: 1

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ numba
numpy>=1.24.4
mlflow
# qlora things
evaluate==0.4.0
evaluate==0.4.1
scipy
scikit-learn==1.2.2
pynvml
Expand Down
11 changes: 11 additions & 0 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
SaveAxolotlConfigtoWandBCallback,
SaveBetterTransformerModelCallback,
bench_eval_callback_factory,
causal_lm_bench_eval_callback_factory,
log_prediction_callback_factory,
)
from axolotl.utils.collators import (
Expand Down Expand Up @@ -143,6 +144,9 @@ class AxolotlTrainingArguments(TrainingArguments):
do_bench_eval: Optional[bool] = field(
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
)
do_causal_lm_eval: Optional[bool] = field(
default=False, metadata={"help": "Whether to run the Causal LM evaluation."}
)
max_bench_samples: Optional[int] = field(
default=None,
metadata={
Expand Down Expand Up @@ -643,6 +647,11 @@ def get_post_trainer_create_callbacks(self, trainer):

if self.cfg.do_bench_eval:
callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer))
if self.cfg.do_causal_lm_eval:
CausalLMBenchEvalCallback = causal_lm_bench_eval_callback_factory(
trainer, self.tokenizer
)
callbacks.append(CausalLMBenchEvalCallback(self.cfg))

if self.cfg.early_stopping_patience:
early_stop_cb = EarlyStoppingCallback(
Expand Down Expand Up @@ -791,6 +800,8 @@ def build(self, total_num_steps):
training_arguments_kwargs["do_bench_eval"] = self.cfg.do_bench_eval
if self.cfg.bench_dataset:
training_arguments_kwargs["bench_dataset"] = self.cfg.bench_dataset
if self.cfg.do_causal_lm_eval:
training_arguments_kwargs["do_causal_lm_eval"] = self.cfg.do_causal_lm_eval
if self.cfg.metric_for_best_model:
training_arguments_kwargs[
"metric_for_best_model"
Expand Down
183 changes: 182 additions & 1 deletion src/axolotl/utils/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,187 @@ def on_evaluate(
return BenchEvalCallback


def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
class CausalLMBenchEvalCallback(TrainerCallback):
"""Callback to log prediction values during each evaluation"""

def __init__(self, cfg):
self.cfg = cfg
self.logged = False
self.metrics = self.__maybe_load_metrics()

def __maybe_load_metrics(self):
metrics = {}
for metric in self.cfg.eval_causal_lm_metrics:
try:
metrics[metric] = evaluate.load(metric)
except Exception as exc: # pylint: disable=broad-exception-caught
LOG.warning(f"{metric}: {exc.args}")
return metrics

def on_evaluate(
self,
args: AxolotlTrainingArguments, # pylint: disable=unused-argument
state: TrainerState,
control: TrainerControl,
train_dataloader, # pylint: disable=unused-argument
eval_dataloader,
**kwargs, # pylint: disable=unused-argument
):
trainer.model.eval()
device = torch.device(self.cfg.device)

# pylint: disable=duplicate-code
generation_config = GenerationConfig(
max_new_tokens=self.cfg.eval_max_new_tokens,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
do_sample=False,
use_cache=True,
return_dict_in_generate=True,
output_attentions=False,
output_hidden_states=False,
output_scores=False,
)

def find_ranges(lst):
ranges = []
start = 0
for i in range(1, len(lst)):
if lst[i] == 0:
ranges.append((start, i - 1))
start = i
end = len(lst) - 1
ranges.append((start, end))
return ranges

def compute(metric: evaluate.Metric, **kwargs):
# safely compute a metric and return the score if the format is correct
metric_score = None
try:
metric_score = metric.compute(**kwargs)
return (
metric_score["score"]
if "score" in metric_score
else metric_score["mean_score"]
)
except Exception: # pylint: disable=broad-exception-caught
LOG.debug(
f"Failed to compute metric {metric.name} with kwargs {kwargs.keys()}"
)
return metric_score

def evaluate_preds(sources, predictions, references):
scores = {}

for metric_name, metric in self.metrics.items():
score = compute(
metric,
references=references,
predictions=predictions,
sources=sources,
)
score = score or compute(
metric,
references=[[r] for r in references],
predictions=predictions,
)
scores[metric_name] = score
return scores

def predict_with_generate():
eval_src, eval_pred, eval_ref = [], [], []

for batch in tqdm(eval_dataloader):
batch_labels = batch["labels"].to(device)
batch_input_ids = batch["input_ids"].to(device)

if "position_ids" in batch:
batch_pos_ids = batch["position_ids"].tolist()
else:
batch_pos_ids = [None] * len(batch["input_ids"])

prompt_token_ids_list = []
completion_token_ids_list = []

for input_ids_all, labels_all, pos_ids in zip(
batch_input_ids,
batch_labels,
batch_pos_ids,
):
if pos_ids is None:
pos_ranges = [(0, len(input_ids_all) - 1)]
else:
pos_ranges = find_ranges(pos_ids)

for pos_range in pos_ranges:
start, end = pos_range
if start == end:
continue

input_ids = input_ids_all[start : end + 1]
labels = labels_all[start : end + 1]

tokens_without_loss = labels == IGNORE_INDEX
tokens_with_loss = labels != IGNORE_INDEX
tokens_exclude_padding = input_ids != tokenizer.pad_token_id
prompt_token_includes = (
tokens_without_loss & tokens_exclude_padding
)

prompt_token_ids = input_ids[prompt_token_includes]
prompt_token_ids_list.append(prompt_token_ids)

completion_token_ids = input_ids[tokens_with_loss]
completion_token_ids_list.append(completion_token_ids)

prompt_texts = tokenizer.batch_decode(
prompt_token_ids_list, skip_special_tokens=True
)
completion_texts = tokenizer.batch_decode(
completion_token_ids_list, skip_special_tokens=True
)

with torch.no_grad():
prompt_encoding = tokenizer(
prompt_texts, padding=True, return_tensors="pt"
).to(self.cfg.device)
predictions = trainer.model.generate(
**prompt_encoding, generation_config=generation_config
)

prediction_all_tokens = predictions["sequences"].cpu().tolist()
prediction_without_prompt_tokens_list = []
for prompt_token_ids, prediction_tokens in zip(
prompt_token_ids_list, prediction_all_tokens
):
prediction_without_prompt_tokens = prediction_tokens[
len(prompt_token_ids) :
]
prediction_without_prompt_tokens_list.append(
prediction_without_prompt_tokens
)

predicted_texts = tokenizer.batch_decode(
prediction_without_prompt_tokens_list, skip_special_tokens=True
)

eval_src.extend(prompt_texts)
eval_pred.extend(predicted_texts)
eval_ref.extend(completion_texts)

return eval_src, eval_pred, eval_ref

if is_main_process():
eval_preds = predict_with_generate()
trainer.log(evaluate_preds(*eval_preds))

return control

return CausalLMBenchEvalCallback


def log_prediction_callback_factory(trainer: Trainer, tokenizer):
class LogPredictionCallback(TrainerCallback):
"""Callback to log prediction values during each evaluation"""
Expand Down Expand Up @@ -388,7 +569,7 @@ def on_evaluate(

# pylint: disable=duplicate-code
generation_config = GenerationConfig(
max_new_tokens=self.cfg.eval_table_max_new_tokens,
max_new_tokens=self.cfg.eval_max_new_tokens,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
Expand Down
23 changes: 22 additions & 1 deletion src/axolotl/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,13 @@ def normalize_config(cfg):
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
cfg.eval_table_size = cfg.eval_table_size or 0
cfg.eval_table_max_new_tokens = cfg.eval_table_max_new_tokens or 128
cfg.eval_max_new_tokens = cfg.eval_max_new_tokens or 128
cfg.eval_causal_lm_metrics = cfg.eval_causal_lm_metrics or [
"sacrebleu",
"comet",
"ter",
"chrf",
]
choose_device(cfg)
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
if cfg.ddp:
Expand Down Expand Up @@ -550,6 +556,21 @@ def validate_config(cfg):
if cfg.fsdp and "bnb" in cfg.optimizer:
raise ValueError(f"FSDP not compatible with {cfg.optimizer}")

if cfg.do_causal_lm_eval and cfg.eval_sample_packing:
raise ValueError(
"do_causal_lm_eval is enabled, eval_sample_packing must be set to False"
)

if cfg.eval_causal_lm_metrics:
supported_metrics = ["sacrebleu", "comet", "ter", "chrf"]
if not isinstance(cfg.eval_causal_lm_metrics, list):
raise ValueError("eval_causal_lm_metrics must be a list")
# only ["sacrebleu", "comet", "ter", "chrf"] supported
if set(cfg.eval_causal_lm_metrics) - set(supported_metrics):
raise ValueError(
f"eval_causal_lm_metrics must be one of {supported_metrics}"
)

# TODO
# MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25
Expand Down
Loading