From 6bbdcd28aee104db1ae83e2146512dd45dbbad6e Mon Sep 17 00:00:00 2001 From: Chenlei Hu Date: Tue, 27 Aug 2024 13:55:37 -0400 Subject: [PATCH] Support weight padding on diff weight patch (#4576) --- comfy/lora.py | 48 ++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 44 insertions(+), 4 deletions(-) diff --git a/comfy/lora.py b/comfy/lora.py index a3e7d9cc0c4..a3e33a27ec0 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -16,6 +16,7 @@ along with this program. If not, see . """ +from __future__ import annotations import comfy.utils import comfy.model_management import comfy.model_base @@ -347,6 +348,39 @@ def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediat weight[:] = weight_calc return weight +def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor: + """ + Pad a tensor to a new shape with zeros. + + Args: + tensor (torch.Tensor): The original tensor to be padded. + new_shape (List[int]): The desired shape of the padded tensor. + + Returns: + torch.Tensor: A new tensor padded with zeros to the specified shape. + + Note: + If the new shape is smaller than the original tensor in any dimension, + the original tensor will be truncated in that dimension. + """ + if any([new_shape[i] < tensor.shape[i] for i in range(len(new_shape))]): + raise ValueError("The new shape must be larger than the original tensor in all dimensions") + + if len(new_shape) != len(tensor.shape): + raise ValueError("The new shape must have the same number of dimensions as the original tensor") + + # Create a new tensor filled with zeros + padded_tensor = torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) + + # Create slicing tuples for both tensors + orig_slices = tuple(slice(0, dim) for dim in tensor.shape) + new_slices = tuple(slice(0, dim) for dim in tensor.shape) + + # Copy the original tensor into the new tensor + padded_tensor[new_slices] = tensor[orig_slices] + + return padded_tensor + def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32): for p in patches: strength = p[0] @@ -375,12 +409,18 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32): v = v[1] if patch_type == "diff": - w1 = v[0] + diff: torch.Tensor = v[0] + # An extra flag to pad the weight if the diff's shape is larger than the weight + do_pad_weight = len(v) > 1 and v[1]['pad_weight'] + if do_pad_weight and diff.shape != weight.shape: + logging.info("Pad weight {} from {} to shape: {}".format(key, weight.shape, diff.shape)) + weight = pad_tensor_to_shape(weight, diff.shape) + if strength != 0.0: - if w1.shape != weight.shape: - logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape)) + if diff.shape != weight.shape: + logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, diff.shape, weight.shape)) else: - weight += function(strength * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype)) + weight += function(strength * comfy.model_management.cast_to_device(diff, weight.device, weight.dtype)) elif patch_type == "lora": #lora/locon mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype) mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype)