Skip to content

Commit

Permalink
fix retie
Browse files Browse the repository at this point in the history
  • Loading branch information
SunMarc committed Nov 9, 2023
1 parent 0b0d921 commit 960b08b
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 5 deletions.
15 changes: 11 additions & 4 deletions src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,15 +535,22 @@ def retie_parameters(model, tied_params):
"""
for tied_group in tied_params:
param_to_tie = None
# First iteration of the loop will set param_to_tie, next ones will tie it to the others
# two loops : the first one to set param_to_tie , the second one to change the values of tied_group
for param_name in tied_group:
module = model
splits = param_name.split(".")
for split in splits[:-1]:
module = getattr(module, split)
if param_to_tie is None:
param_to_tie = getattr(module, splits[-1])
else:
param = getattr(module, splits[-1])
if param_to_tie is None and param.device != torch.device("meta"):
param_to_tie = param
break
if param_to_tie is not None:
for param_name in tied_group:
module = model
splits = param_name.split(".")
for split in splits[:-1]:
module = getattr(module, split)
setattr(module, splits[-1], param_to_tie)


Expand Down
2 changes: 1 addition & 1 deletion tests/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ def test_int8_serialization_offload(self):
model_8bit_from_saved = load_and_quantize_model(
model_8bit_from_saved,
bnb_quantization_config,
weights_location=tmpdirname + "/pytorch_model.bin",
weights_location=tmpdirname,
device_map=device_map,
no_split_module_classes=["BloomBlock"],
offload_folder=tmpdirname + "/tmp",
Expand Down

0 comments on commit 960b08b

Please sign in to comment.