-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[LoRA] feat: support unload_lora_weights()
for Flux Control.
#10206
base: main
Are you sure you want to change the base?
Conversation
|
||
current_param_weight = overwritten_params[f"{name}.weight"] | ||
in_features, out_features = current_param_weight.shape[1], current_param_weight.shape[0] | ||
with torch.device("meta"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we already pin torch
version this is safe enough.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also cc: @a-r-r-o-w. Something we should consider doing in:
https://github.com/huggingface/diffusers/blob/main/src/diffusers/loaders/lora_pipeline.py#L2351-L2354
unload_lora_weights()
for Flux Control.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks pretty good, well tested, no issues from my side.
@yiyixuxu @a-r-r-o-w could you give this a look? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for supporting this! Changes look good to me
@@ -2378,6 +2422,14 @@ def _maybe_expand_transformer_param_shape_or_error_( | |||
setattr(transformer.config, attribute_name, new_value) | |||
logger.info(f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}.") | |||
|
|||
# For `unload_lora_weights()`. | |||
overwritten_params[f"{current_module_name}.weight"] = module_weight |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this would have a small but significant memory overhead. For inference purposes only with loras, maybe this could be made opt-out if we know we never want call unload_lora_weights
. Not a blocker though and can be tackled in a different PR but lmk your thoughts
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah this could be tackled with discard_original_layers
. For now, I have added a note as a comment about it.
original_module = torch.nn.Linear( | ||
in_features, | ||
out_features, | ||
bias=bias, | ||
dtype=module_weight.dtype, | ||
) | ||
|
||
tmp_state_dict = {"weight": current_param_weight} | ||
if module_bias is not None: | ||
tmp_state_dict.update({"bias": overwritten_params[f"{name}.bias"]}) | ||
original_module.load_state_dict(tmp_state_dict, assign=True, strict=True) | ||
setattr(parent_module, current_module_name, original_module) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@a-r-r-o-w thanks for flagging the device
assignment while initializing original_module
. device
takes priority so original_module
was not getting initialized on "meta", rending the previous copy_()
ops ineffective.
LMK what you think about the current changes (have run the corresponding tests on a GPU and they pass).
@DN6 LMK your comments here too.
What does this PR do?
Fixes: #10202.
Will request for reviews from others later.