diff --git a/optimum/habana/transformers/trainer.py b/optimum/habana/transformers/trainer.py index 4f22bdde3..c91d0724f 100644 --- a/optimum/habana/transformers/trainer.py +++ b/optimum/habana/transformers/trainer.py @@ -1621,6 +1621,73 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None): # Good practice: save your training arguments together with the trained model torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) + def evaluate( + self, + eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, + ignore_keys: Optional[List[str]] = None, + metric_key_prefix: str = "eval", + ) -> Dict[str, float]: + """ + From https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/trainer.py#L3162 with the following modification + 1. comment out TPU related + 2. use throughput_warmup_steps in evaluation throughput calculation + """ + # handle multipe eval datasets + eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset + if isinstance(eval_dataset, dict): + metrics = {} + for eval_dataset_name, _eval_dataset in eval_dataset.items(): + dataset_metrics = self.evaluate( + eval_dataset=_eval_dataset, + ignore_keys=ignore_keys, + metric_key_prefix=f"{metric_key_prefix}_{eval_dataset_name}", + ) + metrics.update(dataset_metrics) + return metrics + + # memory metrics - must set up as early as possible + self._memory_tracker.start() + + eval_dataloader = self.get_eval_dataloader(eval_dataset) + + start_time = time.time() + self.start_time_after_warmup = None + + eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop + output = eval_loop( + eval_dataloader, + description="Evaluation", + # No point gathering the predictions if there are no metrics, otherwise we defer to + # self.args.prediction_loss_only + prediction_loss_only=True if self.compute_metrics is None else None, + ignore_keys=ignore_keys, + metric_key_prefix=metric_key_prefix, + ) + + total_batch_size = self.args.eval_batch_size * self.args.world_size + if f"{metric_key_prefix}_jit_compilation_time" in output.metrics: + start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"] + num_samples = output.num_samples - self.args.throughput_warmup_steps * total_batch_size + num_steps = math.ceil(output.num_samples / total_batch_size) - self.args.throughput_warmup_steps + + output.metrics.update( + speed_metrics( + metric_key_prefix, + start_time, + num_samples=num_samples, + num_steps=num_steps, + start_time_after_warmup=self.start_time_after_warmup, + ) + ) + + self.log(output.metrics) + + self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics) + + self._memory_tracker.stop_and_update_metrics(output.metrics) + + return output.metrics + def evaluation_loop( self, dataloader: DataLoader, @@ -1716,6 +1783,12 @@ def evaluation_loop( observed_num_examples = 0 # Main evaluation loop for step, inputs in enumerate(dataloader): + if ( + self.args.throughput_warmup_steps > 0 + and not self.is_in_train + and step == self.args.throughput_warmup_steps + ): + self.start_time_after_warmup = time.time() # Update the observed num examples observed_batch_size = find_batch_size(inputs) if observed_batch_size is not None: