diff --git a/optimum/onnxruntime/trainer.py b/optimum/onnxruntime/trainer.py index 12474d98f12..591d67eac28 100644 --- a/optimum/onnxruntime/trainer.py +++ b/optimum/onnxruntime/trainer.py @@ -542,6 +542,11 @@ def _inner_training_loop( # Wrap the model with `ORTModule` logger.info("Wrap ORTModule for ONNX Runtime training.") + + from onnxruntime import version as ort_version + ort_support_stage3 = version.parse(ort_version) >= version.parse("1.17.0") + os.environ["ORTMODULE_ENABLE_ZERO_STAGE3"] = str(int(ort_support_stage3 and self.is_deepspeed_enabled)) + model = ORTModule(self.model) self.model_wrapped = model self.model = model @@ -552,14 +557,10 @@ def _inner_training_loop( self._created_lr_scheduler = False if self.is_deepspeed_enabled: - if is_deepspeed_zero3_enabled(): - from onnxruntime import version as ort_version - if version.parse(ort_version) < version.parse("1.17.0"): - raise NotImplementedError( - "`ORTTrainer` does not support ZeRO stage 3 for the moment. Please use DeepSpeed stage 1 or 2 instead." - ) - else: - os.environ["ORTMODULE_ENABLE_ZERO_STAGE3"] = '1' + if is_deepspeed_zero3_enabled() and not ort_support_stage3: + raise NotImplementedError( + "`ORTTrainer` does not support ZeRO stage 3 for the moment. Please use DeepSpeed stage 1 or 2 instead." + ) if args.bf16: warnings.warn( "ONNX Runtime doesn't support BF16 when executing some operators. The execution will fail if there are any"