diff --git a/optimum/onnxruntime/trainer.py b/optimum/onnxruntime/trainer.py index 66273cbcf96..4933ed73f4e 100644 --- a/optimum/onnxruntime/trainer.py +++ b/optimum/onnxruntime/trainer.py @@ -14,6 +14,7 @@ """ The ORTTrainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task with ONNX Runtime. """ + import functools import math import os @@ -131,11 +132,11 @@ def __init__(self, model, args, label_smoother): # Label smoothing self.label_smoother = label_smoother - def forward(self, inputs: Dict[str, Union[torch.Tensor, Any]], return_outputs): + def forward(self, inputs: Dict[str, Union[torch.Tensor, Any]], return_outputs, num_items_in_batch): # The compute_model_plus_loss_internal is assigned once the class is instantiated. # It should have same signature as Trainer.compute_loss(). # We do this to avoid potential un-synced states if we duplicated compute loss codes . - return self.compute_model_plus_loss_internal(self._original_model, inputs, return_outputs) + return self.compute_model_plus_loss_internal(self._original_model, inputs, return_outputs, num_items_in_batch) @property def module(self): @@ -291,14 +292,14 @@ def _set_signature_columns_if_needed(self): # Labels may be named label or label_ids, the default data collator handles that. self._signature_columns += list(set(["label", "label_ids"] + self.label_names)) - def compute_loss(self, model_with_loss, inputs, return_outputs=False): + def compute_loss(self, model_with_loss, inputs, return_outputs=False, num_items_in_batch=None): # Run model forward + loss compute. if isinstance(self.model, ModuleWithLoss): # ORTModule Does not support the BatchEncoding Type so we have to convert to a dict. dict_inputs = dict(inputs.items()) - return model_with_loss(dict_inputs, return_outputs) + return model_with_loss(dict_inputs, return_outputs, num_items_in_batch) else: - return super().compute_loss(model_with_loss, inputs, return_outputs) + return super().compute_loss(model_with_loss, inputs, return_outputs, num_items_in_batch) def train( self, @@ -803,7 +804,9 @@ def get_dataloader_sampler(dataloader): self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch self.control = self.callback_handler.on_step_end(args, self.state, self.control) - self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval) + self._maybe_log_save_evaluate( + tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time + ) else: self.control = self.callback_handler.on_substep_end(args, self.state, self.control) @@ -818,7 +821,7 @@ def get_dataloader_sampler(dataloader): self.control.should_training_stop = True self.control = self.callback_handler.on_epoch_end(args, self.state, self.control) - self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval) + self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time) if DebugOption.TPU_METRICS_DEBUG in self.args.debug: logger.warning(