Skip to content

Commit

Permalink
Add warmup for eval (huggingface#855)
Browse files Browse the repository at this point in the history
Co-authored-by: regisss <[email protected]>
  • Loading branch information
libinta and regisss authored Mar 31, 2024
1 parent 32625b1 commit 753da20
Showing 1 changed file with 73 additions and 0 deletions.
73 changes: 73 additions & 0 deletions optimum/habana/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1621,6 +1621,73 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None):
# Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

def evaluate(
self,
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
) -> Dict[str, float]:
"""
From https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/trainer.py#L3162 with the following modification
1. comment out TPU related
2. use throughput_warmup_steps in evaluation throughput calculation
"""
# handle multipe eval datasets
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
if isinstance(eval_dataset, dict):
metrics = {}
for eval_dataset_name, _eval_dataset in eval_dataset.items():
dataset_metrics = self.evaluate(
eval_dataset=_eval_dataset,
ignore_keys=ignore_keys,
metric_key_prefix=f"{metric_key_prefix}_{eval_dataset_name}",
)
metrics.update(dataset_metrics)
return metrics

# memory metrics - must set up as early as possible
self._memory_tracker.start()

eval_dataloader = self.get_eval_dataloader(eval_dataset)

start_time = time.time()
self.start_time_after_warmup = None

eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
output = eval_loop(
eval_dataloader,
description="Evaluation",
# No point gathering the predictions if there are no metrics, otherwise we defer to
# self.args.prediction_loss_only
prediction_loss_only=True if self.compute_metrics is None else None,
ignore_keys=ignore_keys,
metric_key_prefix=metric_key_prefix,
)

total_batch_size = self.args.eval_batch_size * self.args.world_size
if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
num_samples = output.num_samples - self.args.throughput_warmup_steps * total_batch_size
num_steps = math.ceil(output.num_samples / total_batch_size) - self.args.throughput_warmup_steps

output.metrics.update(
speed_metrics(
metric_key_prefix,
start_time,
num_samples=num_samples,
num_steps=num_steps,
start_time_after_warmup=self.start_time_after_warmup,
)
)

self.log(output.metrics)

self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)

self._memory_tracker.stop_and_update_metrics(output.metrics)

return output.metrics

def evaluation_loop(
self,
dataloader: DataLoader,
Expand Down Expand Up @@ -1716,6 +1783,12 @@ def evaluation_loop(
observed_num_examples = 0
# Main evaluation loop
for step, inputs in enumerate(dataloader):
if (
self.args.throughput_warmup_steps > 0
and not self.is_in_train
and step == self.args.throughput_warmup_steps
):
self.start_time_after_warmup = time.time()
# Update the observed num examples
observed_batch_size = find_batch_size(inputs)
if observed_batch_size is not None:
Expand Down

0 comments on commit 753da20

Please sign in to comment.