From b502f584886fbf52f9a180981efe276ea8509de7 Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Tue, 29 Oct 2024 23:29:50 +0900 Subject: [PATCH] Fix emb_dim to work. --- networks/lora_sd3.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/networks/lora_sd3.py b/networks/lora_sd3.py index c1eb68b8a..efe202451 100644 --- a/networks/lora_sd3.py +++ b/networks/lora_sd3.py @@ -307,6 +307,7 @@ def create_modules( target_replace_modules: List[str], filter: Optional[str] = None, default_dim: Optional[int] = None, + include_conv2d_if_filter: bool = False, ) -> List[LoRAModule]: prefix = ( self.LORA_PREFIX_SD3 @@ -332,8 +333,11 @@ def create_modules( lora_name = prefix + "." + (name + "." if name else "") + child_name lora_name = lora_name.replace(".", "_") - if filter is not None and not filter in lora_name: - continue + force_incl_conv2d = False + if filter is not None: + if not filter in lora_name: + continue + force_incl_conv2d = include_conv2d_if_filter dim = None alpha = None @@ -373,6 +377,10 @@ def create_modules( elif self.conv_lora_dim is not None: dim = self.conv_lora_dim alpha = self.conv_alpha + elif force_incl_conv2d: + # x_embedder + dim = default_dim if default_dim is not None else self.lora_dim + alpha = self.alpha if dim is None or dim == 0: # skipした情報を出力 @@ -428,7 +436,7 @@ def create_modules( for filter, in_dim in zip( [ "context_embedder", - "_t_embedder", # don't use "t_embedder" because it's used in "context_embedder" + "_t_embedder", # don't use "t_embedder" because it's used in "context_embedder" "x_embedder", "y_embedder", "final_layer_adaLN_modulation", @@ -436,7 +444,12 @@ def create_modules( ], self.emb_dims, ): - loras, _ = create_modules(True, None, unet, None, filter=filter, default_dim=in_dim) + # x_embedder is conv2d, so we need to include it + loras, _ = create_modules( + True, None, unet, None, filter=filter, default_dim=in_dim, include_conv2d_if_filter=filter == "x_embedder" + ) + # if len(loras) > 0: + # logger.info(f"create LoRA for {filter}: {len(loras)} modules.") self.unet_loras.extend(loras) logger.info(f"create LoRA for SD3 MMDiT: {len(self.unet_loras)} modules.")