Skip to content

Commit

Permalink
Adding keep_torch_compile argument.
Browse files Browse the repository at this point in the history
  • Loading branch information
ggoggam committed Dec 11, 2024
1 parent ea03570 commit c3e69c5
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 11 deletions.
9 changes: 4 additions & 5 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2601,7 +2601,7 @@ def pad_across_processes(self, tensor, dim=0, pad_index=0, pad_first=False):
"""
return pad_across_processes(tensor, dim=dim, pad_index=pad_index, pad_first=pad_first)

def unwrap_model(self, model, keep_fp32_wrapper: bool = True):
def unwrap_model(self, model, keep_fp32_wrapper: bool = True, keep_torch_compile: bool = False):
"""
Unwraps the `model` from the additional layer possible added by [`~Accelerator.prepare`]. Useful before saving
the model.
Expand All @@ -2611,7 +2611,8 @@ def unwrap_model(self, model, keep_fp32_wrapper: bool = True):
The model to unwrap.
keep_fp32_wrapper (`bool`, *optional*, defaults to `True`):
Whether to not remove the mixed precision hook if it was added.
keep_torch_compile (`bool`, *optional*, defaults to `False`):
Whether to not unwrap compiled model if compiled.
Returns:
`torch.nn.Module`: The unwrapped model.
Expand All @@ -2632,9 +2633,7 @@ def unwrap_model(self, model, keep_fp32_wrapper: bool = True):
MyModel
```
"""
while not isinstance(model, torch.nn.Module):
model = extract_model_from_parallel(model, keep_fp32_wrapper)
return model
return extract_model_from_parallel(model, keep_fp32_wrapper, keep_torch_compile)

def wait_for_everyone(self):
"""
Expand Down
13 changes: 9 additions & 4 deletions src/accelerate/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ def is_compiled_module(module):
return isinstance(module, torch._dynamo.eval_frame.OptimizedModule)


def extract_model_from_parallel(model, keep_fp32_wrapper: bool = True, recursive: bool = False):
def extract_model_from_parallel(
model, keep_fp32_wrapper: bool = True, keep_torch_compile: bool = True, recursive: bool = False
):
"""
Extract a model from its distributed containers.
Expand All @@ -68,6 +70,8 @@ def extract_model_from_parallel(model, keep_fp32_wrapper: bool = True, recursive
The model to extract.
keep_fp32_wrapper (`bool`, *optional*):
Whether to remove mixed precision hooks from the model.
keep_torch_compile (`bool`, *optional*):
Whether to unwrap compiled model.
recursive (`bool`, *optional*, defaults to `False`):
Whether to recursively extract all cases of `module.module` from `model` as well as unwrap child sublayers
recursively, not just the top-level distributed containers.
Expand Down Expand Up @@ -124,9 +128,10 @@ def _recursive_unwrap(module):
if getattr(model, "_converted_to_transformer_engine", False):
convert_model(model, to_transformer_engine=False)

if is_compiled:
compiled_model._orig_mod = model
model = compiled_model
if keep_torch_compile:
if is_compiled:
compiled_model._orig_mod = model
model = compiled_model

return model

Expand Down
12 changes: 12 additions & 0 deletions tests/test_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,18 @@ def test_can_unwrap_model(self):
model_loaded = pickle.loads(pickle.dumps(model))
model_loaded(inputs)

def test_can_unwrap_distributed_compiled_model(self):
model = create_components()[0]
accelerator = Accelerator()

compiled_model = torch.compile(model)

distributed_model = torch.nn.DataParallel(model)
distributed_compiled_model = torch.compile(distributed_model)
unwrapped_model = accelerator.unwrap_model(distributed_compiled_model)

assert compiled_model._orig_mod == unwrapped_model

@parameterized.expand([True, False])
def test_can_pickle_dataloader(self, dispatch_batches):
"""
Expand Down
16 changes: 14 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,17 +240,29 @@ def nested_wrap(model):
assert original_key == new_key, f"Keys did not align: {original_key} != {new_key}"

@require_torch_min_version(version="2.0")
def test_dynamo_extract_model(self):
def test_dynamo_extract_model_keep_compilation(self):
model = RegressionModel()
compiled_model = torch.compile(model)

# could also do a test with DistributedDataParallel, but difficult to run on CPU or single GPU
distributed_model = torch.nn.parallel.DataParallel(model)
distributed_compiled_model = torch.compile(distributed_model)
compiled_model_unwrapped = extract_model_from_parallel(distributed_compiled_model)
compiled_model_unwrapped = extract_model_from_parallel(distributed_compiled_model, keep_torch_compile=True)

assert compiled_model._orig_mod == compiled_model_unwrapped._orig_mod

@require_torch_min_version(version="2.0")
def test_dynamo_extract_model_remove_compilation(self):
model = RegressionModel()
compiled_model = torch.compile(model)

# could also do a test with DistributedDataParallel, but difficult to run on CPU or single GPU
distributed_model = torch.nn.parallel.DataParallel(model)
distributed_compiled_model = torch.compile(distributed_model)
compiled_model_unwrapped = extract_model_from_parallel(distributed_compiled_model, keep_torch_compile=False)

assert compiled_model._orig_mod == compiled_model_unwrapped

def test_find_device(self):
assert find_device([1, "a", torch.tensor([1, 2, 3])]) == torch.device("cpu")
assert find_device({"a": 1, "b": torch.tensor([1, 2, 3])}) == torch.device("cpu")
Expand Down

0 comments on commit c3e69c5

Please sign in to comment.