Skip to content

Commit

Permalink
Merge branch 'pr/1151'
Browse files Browse the repository at this point in the history
  • Loading branch information
gesen2egee committed Mar 13, 2024
2 parents fd468db + 2629117 commit f0fe563
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 35 deletions.
6 changes: 3 additions & 3 deletions library/sdxl_train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from transformers import CLIPTokenizer
from library import model_util, sdxl_model_util, train_util, sdxl_original_unet
from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
from library import token_merging
from library import token_downsampling
from .utils import setup_logging
setup_logging()
import logging
Expand Down Expand Up @@ -60,8 +60,8 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):

# apply token merging patch
if args.todo_factor:
token_merging.patch_attention(unet, args, is_sdxl=True)
logger.info(f"enable token downsampling optimization | {unet._tome_info['args']}")
token_downsampling.apply_patch(unet, args, is_sdxl=True)
logger.info(f"enable token downsampling optimization: downsample_factor={args.todo_factor}, max_depth={args.todo_max_depth}")

return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info

Expand Down
62 changes: 39 additions & 23 deletions library/token_merging.py → library/token_downsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,82 +18,98 @@ def up_or_downsample(item, cur_w, cur_h, new_w, new_h, method="nearest-exact"):
return item


def compute_merge(x: torch.Tensor, tome_info: dict):
original_h, original_w = tome_info["size"]
def compute_merge(x: torch.Tensor, todo_info: dict):
original_h, original_w = todo_info["size"]
original_tokens = original_h * original_w
downsample = int(math.ceil(math.sqrt(original_tokens // x.shape[1])))
cur_h = original_h // downsample
cur_w = original_w // downsample

args = tome_info["args"]
downsample_factor = args["downsample_factor"]
args = todo_info["args"]

merge_op = lambda x: x
if downsample <= args["max_downsample"]:
if downsample <= args["max_depth"]:
downsample_factor = args["downsample_factor"][downsample]
new_h = int(cur_h / downsample_factor)
new_w = int(cur_w / downsample_factor)
merge_op = lambda x: up_or_downsample(x, cur_w, cur_h, new_w, new_h)

return merge_op


def hook_tome_model(model: torch.nn.Module):
def hook_unet(model: torch.nn.Module):
""" Adds a forward pre hook to get the image size. This hook can be removed with remove_patch. """
def hook(module, args):
module._tome_info["size"] = (args[0].shape[2], args[0].shape[3])
module._todo_info["size"] = (args[0].shape[2], args[0].shape[3])
return None

model._tome_info["hooks"].append(model.register_forward_pre_hook(hook))
model._todo_info["hooks"].append(model.register_forward_pre_hook(hook))


def hook_attention(attn: torch.nn.Module):
""" Adds a forward pre hook to downsample attention keys and values. This hook can be removed with remove_patch. """
def hook(module, args, kwargs):
hidden_states = args[0]
m = compute_merge(hidden_states, module._tome_info)
m = compute_merge(hidden_states, module._todo_info)
kwargs["context"] = m(hidden_states)
return args, kwargs

attn._tome_info["hooks"].append(attn.register_forward_pre_hook(hook, with_kwargs=True))
attn._todo_info["hooks"].append(attn.register_forward_pre_hook(hook, with_kwargs=True))


def parse_todo_args(args, is_sdxl: bool = False) -> dict:
if args.todo_max_downsample is None:
args.todo_max_downsample = 2 if is_sdxl else 1
if is_sdxl and args.todo_max_downsample not in (2, 4):
raise ValueError(f"--todo_max_downsample for SDXL must be 2 or 4, received {args.todo_factor}")
# validate max_depth
if args.todo_max_depth is None:
args.todo_max_depth = min(len(args.todo_factor), 4)
if is_sdxl and args.todo_max_depth > 2:
raise ValueError(f"todo_max_depth for SDXL cannot be larger than 2, received {args.todo_max_depth}")

# validate factor
if len(args.todo_factor) > 1:
if len(args.todo_factor) != args.todo_max_depth:
raise ValueError(f"todo_factor number of values must be 1 or same as todo_max_depth, received {len(args.todo_factor)}")

# create dict of factors to support per-depth override
factors = args.todo_factor
if len(factors) == 1:
factors *= args.todo_max_depth
factors = {2**(i + int(is_sdxl)): factor for i, factor in enumerate(factors)}

# convert depth to powers of 2 to match layer dimensions: [1,2,3,4] -> [1,2,4,8]
# offset by 1 for sdxl which starts at 2
max_depth = 2**(args.todo_max_depth + int(is_sdxl) - 1)

todo_kwargs = {
"downsample_factor": args.todo_factor,
"max_downsample": args.todo_max_downsample,
"downsample_factor": factors,
"max_depth": max_depth,
}

return todo_kwargs


def patch_attention(unet: torch.nn.Module, args, is_sdxl=False):
def apply_patch(unet: torch.nn.Module, args, is_sdxl=False):
""" Patches the UNet's transformer blocks to apply token downsampling. """
todo_kwargs = parse_todo_args(args, is_sdxl)

unet._tome_info = {
unet._todo_info = {
"size": None,
"hooks": [],
"args": todo_kwargs,
}
hook_tome_model(unet)
hook_unet(unet)

for _, module in unet.named_modules():
if module.__class__.__name__ == "BasicTransformerBlock":
module.attn1._tome_info = unet._tome_info
module.attn1._todo_info = unet._todo_info
hook_attention(module.attn1)

return unet


def remove_patch(unet: torch.nn.Module):
if hasattr(unet, "_tome_info"):
for hook in unet._tome_info["hooks"]:
if hasattr(unet, "_todo_info"):
for hook in unet._todo_info["hooks"]:
hook.remove()
unet._tome_info["hooks"].clear()
unet._todo_info["hooks"].clear()

return unet
16 changes: 8 additions & 8 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
import library.model_util as model_util
import library.huggingface_util as huggingface_util
import library.sai_model_spec as sai_model_spec
from library import token_merging
from library import token_downsampling
from library.utils import setup_logging

setup_logging()
Expand Down Expand Up @@ -3226,15 +3226,15 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
parser.add_argument(
"--todo_factor",
type=float,
help="token downsampling (ToDo) factor > 1 (recommend around 2-4)",
nargs="+",
help="token downsampling (ToDo) factor > 1 (recommend around 2-4). Specify multiple to set factor for each depth",
)
parser.add_argument(
"--todo_max_downsample",
"--todo_max_depth",
type=int,
choices=[1, 2, 4, 8],
choices=[1, 2, 3, 4],
help=(
"apply ToDo to layers with at most this amount of downsampling."
" SDXL only accepts 2 and 4. Recommend 1 or 2. Default 1 (or 2 for SDXL)"
"apply ToDo to deeper layers (lower quality for slight speed increase). SDXL only accepts 2 and 3. Recommend 1 or 2. Default 1 (or 2 for SDXL)"
),
)

Expand Down Expand Up @@ -4405,8 +4405,8 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio

# apply token merging patch
if args.todo_factor:
token_merging.patch_attention(unet, args)
logger.info(f"enable token downsampling optimization | {unet._tome_info['args']}")
token_downsampling.apply_patch(unet, args)
logger.info(f"enable token downsampling optimization: downsample_factor={args.todo_factor}, max_depth={args.todo_max_depth}")

return text_encoder, vae, unet, load_stable_diffusion_format

Expand Down
2 changes: 1 addition & 1 deletion train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,7 +780,7 @@ def train(self, args):

if args.todo_factor:
metadata["ss_todo_factor"] = args.todo_factor
metadata["ss_todo_max_downsample"] = args.todo_max_downsample
metadata["ss_todo_max_depth"] = args.todo_max_depth

metadata = {k: str(v) for k, v in metadata.items()}

Expand Down

0 comments on commit f0fe563

Please sign in to comment.