Skip to content

Commit

Permalink
Merge pull request #831 from kohya-ss/dev
Browse files Browse the repository at this point in the history
update versions of accelerate and diffusers
  • Loading branch information
kohya-ss authored Sep 24, 2023
2 parents 20e929e + 28272de commit 9861516
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 12 deletions.
9 changes: 8 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,14 @@ __Stable Diffusion web UI now seems to support LoRA trained by ``sd-scripts``.__

The feature of SDXL training is now available in sdxl branch as an experimental feature.

Sep 3, 2023: The feature will be merged into the main branch soon. Following are the changes from the previous version.
Sep 24, 2023: The feature will be merged into the main branch very soon. Following are the changes from the previous version.

- `accelerate` is updated to 0.23.0, and `diffusers` is updated to 0.21.2. Dependency for `invisible-watermark` is removed. Please update them with the upgrade instructions below.
- Intel ARC support with IPEX is added. [#825](https://github.com/kohya-ss/sd-scripts/pull/825)
- Other changes and fixes.
- Thanks for contributions from Disty0, sdbds, jvkap, rockerBOO, Symbiomatrix and others!

Sep 3, 2023:

- ControlNet-LLLite is added. See [documentation](./docs/train_lllite_README.md) for details.
- JPEG XL is supported. [#786](https://github.com/kohya-ss/sd-scripts/pull/786)
Expand Down
19 changes: 12 additions & 7 deletions networks/lora_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def __init__(
super().__init__()
self.lora_name = lora_name

if org_module.__class__.__name__ == "Conv2d":
if org_module.__class__.__name__ == "Conv2d" or org_module.__class__.__name__ == "LoRACompatibleConv":
in_dim = org_module.in_channels
out_dim = org_module.out_channels
else:
Expand All @@ -126,7 +126,7 @@ def __init__(

self.lora_dim = lora_dim

if org_module.__class__.__name__ == "Conv2d":
if org_module.__class__.__name__ == "Conv2d" or org_module.__class__.__name__ == "LoRACompatibleConv":
kernel_size = org_module.kernel_size
stride = org_module.stride
padding = org_module.padding
Expand Down Expand Up @@ -166,7 +166,8 @@ def unapply_to(self):
self.org_module[0].forward = self.org_forward

# forward with lora
def forward(self, x):
# scale is used LoRACompatibleConv, but we ignore it because we have multiplier
def forward(self, x, scale=1.0):
if not self.enabled:
return self.org_forward(x)
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
Expand Down Expand Up @@ -318,8 +319,12 @@ def create_modules(
for name, module in root_module.named_modules():
if module.__class__.__name__ in target_replace_modules:
for child_name, child_module in module.named_modules():
is_linear = child_module.__class__.__name__ == "Linear"
is_conv2d = child_module.__class__.__name__ == "Conv2d"
is_linear = (
child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "LoRACompatibleLinear"
)
is_conv2d = (
child_module.__class__.__name__ == "Conv2d" or child_module.__class__.__name__ == "LoRACompatibleConv"
)

if is_linear or is_conv2d:
lora_name = prefix + "." + name + "." + child_name
Expand Down Expand Up @@ -359,7 +364,7 @@ def create_modules(
skipped_te += skipped
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
if len(skipped_te) > 0:
print(f"skipped {len(skipped_te)} modules because of missing weight.")
print(f"skipped {len(skipped_te)} modules because of missing weight for text encoder.")

# extend U-Net target modules to include Conv2d 3x3
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE + LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
Expand All @@ -368,7 +373,7 @@ def create_modules(
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
if len(skipped_un) > 0:
print(f"skipped {len(skipped_un)} modules because of missing weight.")
print(f"skipped {len(skipped_un)} modules because of missing weight for U-Net.")

# assertion
names = set()
Expand Down
6 changes: 2 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
accelerate==0.19.0
accelerate==0.23.0
transformers==4.30.2
diffusers[torch]==0.18.2
diffusers[torch]==0.21.2
ftfy==6.1.1
# albumentations==1.3.0
opencv-python==4.7.0.68
Expand All @@ -15,8 +15,6 @@ easygui==0.98.3
toml==0.10.2
voluptuous==0.13.1
huggingface-hub==0.15.1
# for loading Diffusers' SDXL
invisible-watermark==0.2.0
# for BLIP captioning
# requests==2.28.2
# timm==0.6.12
Expand Down

0 comments on commit 9861516

Please sign in to comment.