Skip to content

Commit

Permalink
move code to Trainer.evaluate to enable use of that function with mul…
Browse files Browse the repository at this point in the history
…tiple datasets (#27844)

* move code to Trainer.evaluate to enable use of that function with multiple datasets

* test

* update doc string

* and a tip

* forgot the type

---------

Co-authored-by: Prof. Peter Schneider-Kamp <[email protected]>
  • Loading branch information
peter-sk and peter-sk authored Dec 20, 2023
1 parent cd9f9d6 commit 769a954
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 15 deletions.
47 changes: 32 additions & 15 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2261,17 +2261,7 @@ def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for

metrics = None
if self.control.should_evaluate:
if isinstance(self.eval_dataset, dict):
metrics = {}
for eval_dataset_name, eval_dataset in self.eval_dataset.items():
dataset_metrics = self.evaluate(
eval_dataset=eval_dataset,
ignore_keys=ignore_keys_for_eval,
metric_key_prefix=f"eval_{eval_dataset_name}",
)
metrics.update(dataset_metrics)
else:
metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
self._report_to_hp_search(trial, self.state.global_step, metrics)

# Run delayed LR scheduler now that metrics are populated
Expand Down Expand Up @@ -2997,7 +2987,7 @@ def _rotate_checkpoints(self, use_mtime=False, output_dir=None) -> None:

def evaluate(
self,
eval_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
) -> Dict[str, float]:
Expand All @@ -3010,10 +3000,24 @@ def evaluate(
You can also subclass and override this method to inject custom behavior.
Args:
eval_dataset (`Dataset`, *optional*):
eval_dataset (Union[`Dataset`, Dict[str, `Dataset`]), *optional*):
Pass a dataset if you wish to override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns
not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__`
method.
not accepted by the `model.forward()` method are automatically removed. If it is a dictionary, it will
evaluate on each dataset, prepending the dictionary key to the metric name. Datasets must implement the
`__len__` method.
<Tip>
If you pass a dictionary with names of datasets as keys and datasets as values, evaluate will run
separate evaluations on each dataset. This can be useful to monitor how training affects other
datasets or simply to get a more fine-grained evaluation.
When used with `load_best_model_at_end`, make sure `metric_for_best_model` references exactly one
of the datasets. If you, for example, pass in `{"data1": data1, "data2": data2}` for two datasets
`data1` and `data2`, you could specify `metric_for_best_model="eval_data1_loss"` for using the
loss on `data1` and `metric_for_best_model="eval_data1_loss"` for the loss on `data2`.
</Tip>
ignore_keys (`List[str]`, *optional*):
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
gathering predictions.
Expand All @@ -3025,6 +3029,19 @@ def evaluate(
A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
dictionary also contains the epoch number which comes from the training state.
"""
# handle multipe eval datasets
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
if isinstance(eval_dataset, dict):
metrics = {}
for eval_dataset_name, _eval_dataset in eval_dataset.items():
dataset_metrics = self.evaluate(
eval_dataset=_eval_dataset,
ignore_keys=ignore_keys,
metric_key_prefix=f"{metric_key_prefix}_{eval_dataset_name}",
)
metrics.update(dataset_metrics)
return metrics

# memory metrics - must set up as early as possible
self._memory_tracker.start()

Expand Down
30 changes: 30 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@

import transformers.optimization
from transformers import (
AutoModelForCausalLM,
AutoModelForSequenceClassification,
EarlyStoppingCallback,
GlueDataset,
Expand Down Expand Up @@ -1845,6 +1846,35 @@ def test_trainer_eval_mrpc(self):
result = trainer.evaluate()
self.assertLess(result["eval_loss"], 0.2)

@slow
def test_trainer_eval_multiple(self):
MODEL_ID = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
dataset = LineByLineTextDataset(
tokenizer=tokenizer,
file_path=PATH_SAMPLE_TEXT,
block_size=tokenizer.max_len_single_sentence,
)
for example in dataset.examples:
example["labels"] = example["input_ids"]
training_args = TrainingArguments(
output_dir="./examples",
use_cpu=True,
per_device_eval_batch_size=1,
)
trainer = Trainer(
model=model,
args=training_args,
eval_dataset={
"data1": dataset,
"data2": dataset,
},
)
result = trainer.evaluate()
self.assertIn("eval_data1_loss", result)
self.assertIn("eval_data2_loss", result)

@slow
def test_trainer_eval_lm(self):
MODEL_ID = "distilroberta-base"
Expand Down

0 comments on commit 769a954

Please sign in to comment.