Skip to content

Commit

Permalink
working.
Browse files Browse the repository at this point in the history
  • Loading branch information
sayakpaul committed Dec 17, 2024
1 parent d3e177c commit 258a398
Showing 1 changed file with 42 additions and 41 deletions.
83 changes: 42 additions & 41 deletions src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2312,7 +2312,6 @@ def _maybe_expand_transformer_param_shape_or_error_(

# Expand transformer parameter shapes if they don't match lora
has_param_with_shape_update = False

for name, module in transformer.named_modules():
if isinstance(module, torch.nn.Linear):
module_weight = module.weight.data
Expand All @@ -2332,54 +2331,52 @@ def _maybe_expand_transformer_param_shape_or_error_(
continue

module_out_features, module_in_features = module_weight.shape
if out_features < module_out_features or in_features < module_in_features:
raise NotImplementedError(
f"Only LoRAs with input/output features higher than the current module's input/output features "
f"are currently supported. The provided LoRA contains {in_features=} and {out_features=}, which "
f"are lower than {module_in_features=} and {module_out_features=}. If you require support for "
f"this please open an issue at https://github.com/huggingface/diffusers/issues."
debug_message = ""
if in_features > module_in_features:
debug_message += (
f'Expanding the nn.Linear input/output features for module="{name}" because the provided LoRA '
f"checkpoint contains higher number of features than expected. The number of input_features will be "
f"expanded from {module_in_features} to {in_features}"
)

debug_message = (
f'Expanding the nn.Linear input/output features for module="{name}" because the provided LoRA '
f"checkpoint contains higher number of features than expected. The number of input_features will be "
f"expanded from {module_in_features} to {in_features}"
)
if module_out_features != out_features:
if out_features > module_out_features:
debug_message += (
", and the number of output features will be "
f"expanded from {module_out_features} to {out_features}."
)
else:
debug_message += "."
logger.debug(debug_message)
if debug_message:
logger.debug(debug_message)

has_param_with_shape_update = True
parent_module_name, _, current_module_name = name.rpartition(".")
parent_module = transformer.get_submodule(parent_module_name)
if out_features > module_out_features or in_features > module_in_features:
has_param_with_shape_update = True
parent_module_name, _, current_module_name = name.rpartition(".")
parent_module = transformer.get_submodule(parent_module_name)

# TODO: consider initializing this under meta device for optims.
expanded_module = torch.nn.Linear(
in_features, out_features, bias=bias, device=module_weight.device, dtype=module_weight.dtype
)
# Only weights are expanded and biases are not.
new_weight = torch.zeros_like(
expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype
)
slices = tuple(slice(0, dim) for dim in module_weight.shape)
new_weight[slices] = module_weight
expanded_module.weight.data.copy_(new_weight)
if module_bias is not None:
expanded_module.bias.data.copy_(module_bias)

setattr(parent_module, current_module_name, expanded_module)

if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX:
attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name]
new_value = int(expanded_module.weight.data.shape[1])
old_value = getattr(transformer.config, attribute_name)
setattr(transformer.config, attribute_name, new_value)
logger.info(f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}.")
# TODO: consider initializing this under meta device for optims.
expanded_module = torch.nn.Linear(
in_features, out_features, bias=bias, device=module_weight.device, dtype=module_weight.dtype
)
# Only weights are expanded and biases are not.
new_weight = torch.zeros_like(
expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype
)
slices = tuple(slice(0, dim) for dim in module_weight.shape)
new_weight[slices] = module_weight
expanded_module.weight.data.copy_(new_weight)
if module_bias is not None:
expanded_module.bias.data.copy_(module_bias)

setattr(parent_module, current_module_name, expanded_module)

if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX:
attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name]
new_value = int(expanded_module.weight.data.shape[1])
old_value = getattr(transformer.config, attribute_name)
setattr(transformer.config, attribute_name, new_value)
logger.info(
f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}."
)

return has_param_with_shape_update

Expand All @@ -2405,10 +2402,14 @@ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):
expanded_state_dict_weight[:, : lora_A_param.shape[1]].copy_(lora_A_param)
lora_state_dict[f"{k}.lora_A.weight"] = expanded_state_dict_weight
expanded_module_names.add(k)
elif base_weight_param.shape[1] < lora_A_param.shape[1]:
raise NotImplementedError(
"We currently don't support loading LoRAs for this use case. Please open an issue to file for a feature request - https://github.com/huggingface/diffusers/issues/new."
)

if expanded_module_names:
logger.info(
f"Found some LoRA modules for which the weights were expanded: {expanded_module_names}. Please open an issue if you think this was unexpected - https://github.com/huggingface/diffusers/issues/new."
f"Found some LoRA modules for which the weights were zero-padded: {expanded_module_names}. Please open an issue if you think this was unexpected - https://github.com/huggingface/diffusers/issues/new."
)
return lora_state_dict

Expand Down

0 comments on commit 258a398

Please sign in to comment.