Skip to content

Commit

Permalink
Reverting changes. Fix unwrap_model.
Browse files Browse the repository at this point in the history
  • Loading branch information
ggoggam committed Dec 10, 2024
1 parent f601b8c commit ea03570
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 4 deletions.
4 changes: 3 additions & 1 deletion src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2632,7 +2632,9 @@ def unwrap_model(self, model, keep_fp32_wrapper: bool = True):
MyModel
```
"""
return extract_model_from_parallel(model, keep_fp32_wrapper)
while not isinstance(model, torch.nn.Module):
model = extract_model_from_parallel(model, keep_fp32_wrapper)
return model

def wait_for_everyone(self):
"""
Expand Down
5 changes: 3 additions & 2 deletions src/accelerate/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ def extract_model_from_parallel(model, keep_fp32_wrapper: bool = True, recursive
"""
options = (torch.nn.parallel.DistributedDataParallel, torch.nn.DataParallel)

if is_compiled_module(model):
is_compiled = is_compiled_module(model)
if is_compiled:
compiled_model = model
model = model._orig_mod

Expand Down Expand Up @@ -123,7 +124,7 @@ def _recursive_unwrap(module):
if getattr(model, "_converted_to_transformer_engine", False):
convert_model(model, to_transformer_engine=False)

if is_compiled_module(model):
if is_compiled:
compiled_model._orig_mod = model
model = compiled_model

Expand Down
2 changes: 1 addition & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def test_dynamo_extract_model(self):
distributed_compiled_model = torch.compile(distributed_model)
compiled_model_unwrapped = extract_model_from_parallel(distributed_compiled_model)

assert compiled_model._orig_mod == compiled_model_unwrapped
assert compiled_model._orig_mod == compiled_model_unwrapped._orig_mod

def test_find_device(self):
assert find_device([1, "a", torch.tensor([1, 2, 3])]) == torch.device("cpu")
Expand Down

0 comments on commit ea03570

Please sign in to comment.