Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
pengwa committed Oct 8, 2023
1 parent a7f4372 commit a81a59a
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions optimum/onnxruntime/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,11 @@ def _inner_training_loop(

# Wrap the model with `ORTModule`
logger.info("Wrap ORTModule for ONNX Runtime training.")

from onnxruntime import version as ort_version
ort_support_stage3 = version.parse(ort_version) >= version.parse("1.17.0")
os.environ["ORTMODULE_ENABLE_ZERO_STAGE3"] = str(int(ort_support_stage3 and self.is_deepspeed_enabled))

model = ORTModule(self.model)
self.model_wrapped = model
self.model = model
Expand All @@ -552,14 +557,10 @@ def _inner_training_loop(
self._created_lr_scheduler = False

if self.is_deepspeed_enabled:
if is_deepspeed_zero3_enabled():
from onnxruntime import version as ort_version
if version.parse(ort_version) < version.parse("1.17.0"):
raise NotImplementedError(
"`ORTTrainer` does not support ZeRO stage 3 for the moment. Please use DeepSpeed stage 1 or 2 instead."
)
else:
os.environ["ORTMODULE_ENABLE_ZERO_STAGE3"] = '1'
if is_deepspeed_zero3_enabled() and not ort_support_stage3:
raise NotImplementedError(
"`ORTTrainer` does not support ZeRO stage 3 for the moment. Please use DeepSpeed stage 1 or 2 instead."
)
if args.bf16:
warnings.warn(
"ONNX Runtime doesn't support BF16 when executing some operators. The execution will fail if there are any"
Expand Down

0 comments on commit a81a59a

Please sign in to comment.