Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix model copy after dispatch_model #1971

Merged
merged 3 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/accelerate/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,17 +155,17 @@ def add_hook_to_module(module: nn.Module, hook: ModelHook, append: bool = False)
module = hook.init_hook(module)
module._hf_hook = hook

@functools.wraps(old_forward)
def new_forward(*args, **kwargs):
def new_forward(module, *args, **kwargs):
args, kwargs = module._hf_hook.pre_forward(module, *args, **kwargs)
if module._hf_hook.no_grad:
with torch.no_grad():
output = old_forward(*args, **kwargs)
output = module._old_forward(*args, **kwargs)
else:
output = old_forward(*args, **kwargs)
output = module._old_forward(*args, **kwargs)
return module._hf_hook.post_forward(module, output)

module.forward = new_forward
module.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward)
SunMarc marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder about this line versus:

module.forward = MethodType(functools.update_wrapper(new_forward, old_forward), module)

@austinapatel Do you think that would be preferable? I thought it might affect pickling, but unfortunately, both variants don't allow it, but the error messages are different.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion and review! I just tried your code snippet, and I run into the same issue as this comment: #1971 (comment) since functools.update_wrapper needs to be called last (test_add_and_remove_hooks in test_hooks.py fails). If I try module.forward = functools.update_wrapper(MethodType(new_forward, module), old_forward) I get the following error: AttributeError: 'method' object has no attribute '__module__' when calling update_wrapper. My understanding is that update_wrapper is now being called on a method rather than on a function, which doesn't play as nicely with update_wrapper.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for explaining and testing, this makes sense. I'm a bit wary about binding the module via partial to basically imitate a bound method, but it seems that all other solutions don't work, so I guess it's the best we can do.


return module


Expand Down
33 changes: 32 additions & 1 deletion tests/test_big_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import os
import unittest
from tempfile import TemporaryDirectory
Expand Down Expand Up @@ -45,6 +45,18 @@ def forward(self, x):
return self.linear2(self.batchnorm(self.linear1(x)))


class ModelForTestCopy(nn.Module):
def __init__(self, id: int):
super().__init__()
self.id = id
self.linear1 = nn.Linear(3, 4)
self.batchnorm = nn.BatchNorm1d(4)
self.linear2 = nn.Linear(4, 5)

def forward(self, x):
return self.linear2(self.batchnorm(self.linear1(x))), self.id


class ModelForTestTiedWeights(nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -325,6 +337,25 @@ def test_dispatch_model_multi_gpu(self):
output = model(x)
self.assertTrue(torch.allclose(expected, output.cpu(), atol=1e-5))

@require_cuda
def test_dispatch_model_copy(self):
original_model = ModelForTestCopy(id=1)
device_map = {"linear1": 0, "batchnorm": "cpu", "linear2": 0}

x = torch.randn(2, 3)
expected, original_output_id = original_model(x)

dispatch_model(original_model, device_map)

copied_model = copy.deepcopy(original_model)
copied_model.id = 2
output, copied_output_id = copied_model(x)

self.assertEqual(original_model.id, original_output_id)
self.assertEqual(copied_model.id, copied_output_id)
SunMarc marked this conversation as resolved.
Show resolved Hide resolved
self.assertFalse(copied_model.linear1.forward is original_model.linear1.forward)
self.assertTrue(torch.allclose(expected, output.cpu(), atol=1e-5))

@require_cuda
def test_dispatch_model_move_offloaded_model(self):
model = ModelForTest()
Expand Down
Loading