Skip to content

Commit

Permalink
Fix emb_dim to work.
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Oct 29, 2024
1 parent c9a1417 commit b502f58
Showing 1 changed file with 17 additions and 4 deletions.
21 changes: 17 additions & 4 deletions networks/lora_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ def create_modules(
target_replace_modules: List[str],
filter: Optional[str] = None,
default_dim: Optional[int] = None,
include_conv2d_if_filter: bool = False,
) -> List[LoRAModule]:
prefix = (
self.LORA_PREFIX_SD3
Expand All @@ -332,8 +333,11 @@ def create_modules(
lora_name = prefix + "." + (name + "." if name else "") + child_name
lora_name = lora_name.replace(".", "_")

if filter is not None and not filter in lora_name:
continue
force_incl_conv2d = False
if filter is not None:
if not filter in lora_name:
continue
force_incl_conv2d = include_conv2d_if_filter

dim = None
alpha = None
Expand Down Expand Up @@ -373,6 +377,10 @@ def create_modules(
elif self.conv_lora_dim is not None:
dim = self.conv_lora_dim
alpha = self.conv_alpha
elif force_incl_conv2d:
# x_embedder
dim = default_dim if default_dim is not None else self.lora_dim
alpha = self.alpha

if dim is None or dim == 0:
# skipした情報を出力
Expand Down Expand Up @@ -428,15 +436,20 @@ def create_modules(
for filter, in_dim in zip(
[
"context_embedder",
"_t_embedder", # don't use "t_embedder" because it's used in "context_embedder"
"_t_embedder", # don't use "t_embedder" because it's used in "context_embedder"
"x_embedder",
"y_embedder",
"final_layer_adaLN_modulation",
"final_layer_linear",
],
self.emb_dims,
):
loras, _ = create_modules(True, None, unet, None, filter=filter, default_dim=in_dim)
# x_embedder is conv2d, so we need to include it
loras, _ = create_modules(
True, None, unet, None, filter=filter, default_dim=in_dim, include_conv2d_if_filter=filter == "x_embedder"
)
# if len(loras) > 0:
# logger.info(f"create LoRA for {filter}: {len(loras)} modules.")
self.unet_loras.extend(loras)

logger.info(f"create LoRA for SD3 MMDiT: {len(self.unet_loras)} modules.")
Expand Down

0 comments on commit b502f58

Please sign in to comment.