Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: Automatic checkpoint path inference issue #1989

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2876,7 +2876,7 @@ def _inner(folder):
return list(map(int, re.findall(r"[\/]?([0-9]+)(?=[^\/]*$)", folder)))[0]

folders.sort(key=_inner)
input_dir = os.path.join(input_dir, folders[-1])
input_dir = folders[-1]
else:
raise ValueError("No input_dir provided and automatic checkpoint naming is disabled.")
logger.info(f"Loading states from {input_dir}")
Expand Down
67 changes: 67 additions & 0 deletions tests/test_state_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import shutil
import tempfile
import unittest
import uuid
from contextlib import contextmanager

import pytest
import torch
Expand Down Expand Up @@ -201,6 +203,71 @@ def test_can_resume_training(self):
self.assertEqual(opt_state1, opt_state3)
self.assertEqual(ground_truth_rands, test_rands)

def test_can_resume_training_checkpoints_relative_path(self):
# See #1983
# This test is like test_can_resume_training but uses a relative path for the checkpoint and automatically
# infers the checkpoint path when loading.
@contextmanager
def temporary_relative_directory():
# This is equivalent to tempfile.TemporaryDirectory() except that it returns a relative path
rand_dir = f"test_path_{uuid.uuid4()}"
os.mkdir(rand_dir)
try:
yield rand_dir
finally:
shutil.rmtree(rand_dir)

with temporary_relative_directory() as tmpdir:
set_seed(42)
model = DummyModel()
optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3)
train_dataloader, valid_dataloader = dummy_dataloaders()
project_config = ProjectConfiguration(automatic_checkpoint_naming=True)

# Train baseline
accelerator = Accelerator(project_dir=tmpdir, project_config=project_config)
model, optimizer, train_dataloader, valid_dataloader = accelerator.prepare(
model, optimizer, train_dataloader, valid_dataloader
)
# Save initial
accelerator.save_state()
(a, b) = model.a.item(), model.b.item()
opt_state = optimizer.state_dict()
ground_truth_rands = train(3, model, train_dataloader, optimizer, accelerator)
(a1, b1) = model.a.item(), model.b.item()
opt_state1 = optimizer.state_dict()

# Train partially
set_seed(42)
model = DummyModel()
optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3)
train_dataloader, valid_dataloader = dummy_dataloaders()
project_config = ProjectConfiguration(iteration=1, automatic_checkpoint_naming=True)
accelerator = Accelerator(project_dir=tmpdir, project_config=project_config)
model, optimizer, train_dataloader, valid_dataloader = accelerator.prepare(
model, optimizer, train_dataloader, valid_dataloader
)
accelerator.load_state() # <= infer the directory automatically
(a2, b2) = model.a.item(), model.b.item()
opt_state2 = optimizer.state_dict()
self.assertEqual(a, a2)
self.assertEqual(b, b2)
self.assertEqual(opt_state, opt_state2)

test_rands = train(2, model, train_dataloader, optimizer, accelerator)
# Save everything
accelerator.save_state()

# Load everything back in and make sure all states work
accelerator.load_state(os.path.join(tmpdir, "checkpoints", "checkpoint_1"))
test_rands += train(1, model, train_dataloader, optimizer, accelerator)
(a3, b3) = model.a.item(), model.b.item()
opt_state3 = optimizer.state_dict()
self.assertEqual(a1, a3)
self.assertEqual(b1, b3)
self.assertEqual(opt_state1, opt_state3)
self.assertEqual(ground_truth_rands, test_rands)

def test_invalid_registration(self):
t = torch.tensor([1, 2, 3])
t1 = torch.tensor([2, 3, 4])
Expand Down
Loading