diff --git a/optimum/onnxruntime/trainer.py b/optimum/onnxruntime/trainer.py index afc90e405bb..4582b402c02 100644 --- a/optimum/onnxruntime/trainer.py +++ b/optimum/onnxruntime/trainer.py @@ -459,6 +459,14 @@ def _inner_training_loop( # Wrap the model with `ORTModule` logger.info("Wrap ORTModule for ONNX Runtime training.") + + from onnxruntime import version as ort_version + + ort_support_stage3 = version.parse(ort_version) >= version.parse("1.17.0") + os.environ["ORTMODULE_ENABLE_ZERO_STAGE3"] = str( + int(ort_support_stage3 and self.is_deepspeed_enabled and is_deepspeed_zero3_enabled()) + ) + model = ORTModule(self.model) self.model_wrapped = model self.model = model @@ -469,9 +477,9 @@ def _inner_training_loop( self._created_lr_scheduler = False if self.is_deepspeed_enabled: - if is_deepspeed_zero3_enabled(): + if is_deepspeed_zero3_enabled() and not ort_support_stage3: raise NotImplementedError( - "`ORTTrainer` does not support ZeRO stage 3 for the moment. Please use DeepSpeed stage 1 or 2 instead." + "`ORTTrainer` does not support ZeRO stage 3 for the moment. Please use DeepSpeed stage 1 or 2 instead or consider to update ONNX Runtime to 1.17 or later." ) if args.bf16: warnings.warn(