Skip to content

Commit

Permalink
style
Browse files Browse the repository at this point in the history
  • Loading branch information
SunMarc committed Nov 7, 2023
1 parent 09f3426 commit eda082f
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2583,7 +2583,7 @@ def save_model(
for param in model.parameters():
if param.device == torch.device("meta"):
raise RuntimeError("You can't save the model since some parameters are on the meta device.")

# get the state_dict of the model
state_dict = self.get_state_dict(model)

Expand Down
2 changes: 2 additions & 0 deletions tests/test_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def load_random_weights(model):
state = torch.nn.Linear(*tuple(model.weight.T.shape)).state_dict()
model.load_state_dict(state)


class ModelForTest(torch.nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -44,6 +45,7 @@ def __init__(self):
def forward(self, x):
return self.linear2(self.batchnorm(self.linear1(x)))


class AcceleratorTester(AccelerateTestCase):
@require_cuda
def test_accelerator_can_be_reinstantiated(self):
Expand Down

0 comments on commit eda082f

Please sign in to comment.