Skip to content

Commit

Permalink
Reproducible checkpoint for npu (#27208)
Browse files Browse the repository at this point in the history
* save NPU's RNG states when saving a checkpoint and set after all the
data skip phase when resuming training.

* re-trigger ci

* re-trigger ci
  • Loading branch information
statelesshz authored Nov 2, 2023
1 parent 7adaefe commit c52e429
Showing 1 changed file with 18 additions and 0 deletions.
18 changes: 18 additions & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@
is_sagemaker_mp_enabled,
is_torch_compile_available,
is_torch_neuroncore_available,
is_torch_npu_available,
is_torch_tpu_available,
logging,
strtobool,
Expand Down Expand Up @@ -2321,6 +2322,17 @@ def _load_rng_state(self, checkpoint):
)
if is_torch_tpu_available():
xm.set_rng_state(checkpoint_rng_state["xla"])
if is_torch_npu_available():
if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
torch.npu.random.set_rng_state_all(checkpoint_rng_state["npu"])
else:
try:
torch.npu.random.set_rng_state(checkpoint_rng_state["npu"])
except Exception as e:
logger.info(
f"Didn't manage to set back the RNG states of the NPU because of the following error:\n {e}"
"\nThis won't yield the same results as if the training had not been interrupted."
)

def _save_checkpoint(self, model, trial, metrics=None):
# In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
Expand Down Expand Up @@ -2423,6 +2435,12 @@ def _save_checkpoint(self, model, trial, metrics=None):
if is_torch_tpu_available():
rng_states["xla"] = xm.get_rng_state()

if is_torch_npu_available():
if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
rng_states["npu"] = torch.npu.random.get_rng_state_all()
else:
rng_states["npu"] = torch.npu.random.get_rng_state()

# A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may
# not yet exist.
os.makedirs(output_dir, exist_ok=True)
Expand Down

0 comments on commit c52e429

Please sign in to comment.