From c3e69c5e33628ac072f0fa093c1d17fb251931a2 Mon Sep 17 00:00:00 2001 From: Oh Joon Kwon Date: Thu, 12 Dec 2024 01:06:04 +0900 Subject: [PATCH] Adding `keep_torch_compile` argument. --- src/accelerate/accelerator.py | 9 ++++----- src/accelerate/utils/other.py | 13 +++++++++---- tests/test_accelerator.py | 12 ++++++++++++ tests/test_utils.py | 16 ++++++++++++++-- 4 files changed, 39 insertions(+), 11 deletions(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 943695b2b93..8771359f3a0 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -2601,7 +2601,7 @@ def pad_across_processes(self, tensor, dim=0, pad_index=0, pad_first=False): """ return pad_across_processes(tensor, dim=dim, pad_index=pad_index, pad_first=pad_first) - def unwrap_model(self, model, keep_fp32_wrapper: bool = True): + def unwrap_model(self, model, keep_fp32_wrapper: bool = True, keep_torch_compile: bool = False): """ Unwraps the `model` from the additional layer possible added by [`~Accelerator.prepare`]. Useful before saving the model. @@ -2611,7 +2611,8 @@ def unwrap_model(self, model, keep_fp32_wrapper: bool = True): The model to unwrap. keep_fp32_wrapper (`bool`, *optional*, defaults to `True`): Whether to not remove the mixed precision hook if it was added. - + keep_torch_compile (`bool`, *optional*, defaults to `False`): + Whether to not unwrap compiled model if compiled. Returns: `torch.nn.Module`: The unwrapped model. @@ -2632,9 +2633,7 @@ def unwrap_model(self, model, keep_fp32_wrapper: bool = True): MyModel ``` """ - while not isinstance(model, torch.nn.Module): - model = extract_model_from_parallel(model, keep_fp32_wrapper) - return model + return extract_model_from_parallel(model, keep_fp32_wrapper, keep_torch_compile) def wait_for_everyone(self): """ diff --git a/src/accelerate/utils/other.py b/src/accelerate/utils/other.py index f0a8f59424d..ce358a6acdc 100644 --- a/src/accelerate/utils/other.py +++ b/src/accelerate/utils/other.py @@ -59,7 +59,9 @@ def is_compiled_module(module): return isinstance(module, torch._dynamo.eval_frame.OptimizedModule) -def extract_model_from_parallel(model, keep_fp32_wrapper: bool = True, recursive: bool = False): +def extract_model_from_parallel( + model, keep_fp32_wrapper: bool = True, keep_torch_compile: bool = True, recursive: bool = False +): """ Extract a model from its distributed containers. @@ -68,6 +70,8 @@ def extract_model_from_parallel(model, keep_fp32_wrapper: bool = True, recursive The model to extract. keep_fp32_wrapper (`bool`, *optional*): Whether to remove mixed precision hooks from the model. + keep_torch_compile (`bool`, *optional*): + Whether to unwrap compiled model. recursive (`bool`, *optional*, defaults to `False`): Whether to recursively extract all cases of `module.module` from `model` as well as unwrap child sublayers recursively, not just the top-level distributed containers. @@ -124,9 +128,10 @@ def _recursive_unwrap(module): if getattr(model, "_converted_to_transformer_engine", False): convert_model(model, to_transformer_engine=False) - if is_compiled: - compiled_model._orig_mod = model - model = compiled_model + if keep_torch_compile: + if is_compiled: + compiled_model._orig_mod = model + model = compiled_model return model diff --git a/tests/test_accelerator.py b/tests/test_accelerator.py index 651dc17da5e..7a4cd219e05 100644 --- a/tests/test_accelerator.py +++ b/tests/test_accelerator.py @@ -611,6 +611,18 @@ def test_can_unwrap_model(self): model_loaded = pickle.loads(pickle.dumps(model)) model_loaded(inputs) + def test_can_unwrap_distributed_compiled_model(self): + model = create_components()[0] + accelerator = Accelerator() + + compiled_model = torch.compile(model) + + distributed_model = torch.nn.DataParallel(model) + distributed_compiled_model = torch.compile(distributed_model) + unwrapped_model = accelerator.unwrap_model(distributed_compiled_model) + + assert compiled_model._orig_mod == unwrapped_model + @parameterized.expand([True, False]) def test_can_pickle_dataloader(self, dispatch_batches): """ diff --git a/tests/test_utils.py b/tests/test_utils.py index 2f3c6703de8..796dfde2261 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -240,17 +240,29 @@ def nested_wrap(model): assert original_key == new_key, f"Keys did not align: {original_key} != {new_key}" @require_torch_min_version(version="2.0") - def test_dynamo_extract_model(self): + def test_dynamo_extract_model_keep_compilation(self): model = RegressionModel() compiled_model = torch.compile(model) # could also do a test with DistributedDataParallel, but difficult to run on CPU or single GPU distributed_model = torch.nn.parallel.DataParallel(model) distributed_compiled_model = torch.compile(distributed_model) - compiled_model_unwrapped = extract_model_from_parallel(distributed_compiled_model) + compiled_model_unwrapped = extract_model_from_parallel(distributed_compiled_model, keep_torch_compile=True) assert compiled_model._orig_mod == compiled_model_unwrapped._orig_mod + @require_torch_min_version(version="2.0") + def test_dynamo_extract_model_remove_compilation(self): + model = RegressionModel() + compiled_model = torch.compile(model) + + # could also do a test with DistributedDataParallel, but difficult to run on CPU or single GPU + distributed_model = torch.nn.parallel.DataParallel(model) + distributed_compiled_model = torch.compile(distributed_model) + compiled_model_unwrapped = extract_model_from_parallel(distributed_compiled_model, keep_torch_compile=False) + + assert compiled_model._orig_mod == compiled_model_unwrapped + def test_find_device(self): assert find_device([1, "a", torch.tensor([1, 2, 3])]) == torch.device("cpu") assert find_device({"a": 1, "b": torch.tensor([1, 2, 3])}) == torch.device("cpu")