Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve block swap speed and apply to LoRA #1779

Merged
merged 12 commits into from
Nov 14, 2024
212 changes: 46 additions & 166 deletions README.md

Large diffs are not rendered by default.

214 changes: 33 additions & 181 deletions flux_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@
import os
from multiprocessing import Value
import time
from typing import List
from typing import List, Optional, Tuple, Union
import toml

from tqdm import tqdm

import torch
import torch.nn as nn
from library import utils
from library.device_utils import init_ipex, clean_memory_on_device

init_ipex()
Expand Down Expand Up @@ -76,6 +78,12 @@ def train(args):
)
args.gradient_checkpointing = True

assert (
args.blocks_to_swap is None or args.blocks_to_swap == 0
) or not args.cpu_offload_checkpointing, (
"blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません"
)

cache_latents = args.cache_latents
use_dreambooth_method = args.in_json is None

Expand Down Expand Up @@ -293,7 +301,7 @@ def train(args):
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
# This idea is based on 2kpr's great work. Thank you!
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
flux.enable_block_swap(args.blocks_to_swap)
flux.enable_block_swap(args.blocks_to_swap, accelerator.device)

if not cache_latents:
# load VAE here if not cached
Expand Down Expand Up @@ -336,15 +344,15 @@ def train(args):
# determine target layer and block index for each parameter
block_type = "other" # double, single or other
if np[0].startswith("double_blocks"):
block_idx = int(np[0].split(".")[1])
block_index = int(np[0].split(".")[1])
block_type = "double"
elif np[0].startswith("single_blocks"):
block_idx = int(np[0].split(".")[1])
block_index = int(np[0].split(".")[1])
block_type = "single"
else:
block_idx = -1
block_index = -1

param_group_key = (block_type, block_idx)
param_group_key = (block_type, block_index)
if param_group_key not in param_group:
param_group[param_group_key] = []
param_group[param_group_key].append(p)
Expand Down Expand Up @@ -464,132 +472,26 @@ def train(args):
# resumeする
train_util.resume_from_local_or_hf_if_specified(accelerator, args)

# memory efficient block swapping

def get_block_unit(dbl_blocks, sgl_blocks, index: int):
if index < len(dbl_blocks):
return (dbl_blocks[index],)
else:
index -= len(dbl_blocks)
index *= 2
return (sgl_blocks[index], sgl_blocks[index + 1])

def submit_move_blocks(futures, thread_pool, block_idx_to_cpu, block_idx_to_cuda, dbl_blocks, sgl_blocks, device):
def move_blocks(bidx_to_cpu, blocks_to_cpu, bidx_to_cuda, blocks_to_cuda, dvc):
# print(f"Backward: Move block {bidx_to_cpu} to CPU")
for block in blocks_to_cpu:
block = block.to("cpu", non_blocking=True)
torch.cuda.empty_cache()

# print(f"Backward: Move block {bidx_to_cuda} to CUDA")
for block in blocks_to_cuda:
block = block.to(dvc, non_blocking=True)

torch.cuda.synchronize()
# print(f"Backward: Moved blocks {bidx_to_cpu} and {bidx_to_cuda}")
return bidx_to_cpu, bidx_to_cuda

blocks_to_cpu = get_block_unit(dbl_blocks, sgl_blocks, block_idx_to_cpu)
blocks_to_cuda = get_block_unit(dbl_blocks, sgl_blocks, block_idx_to_cuda)

futures[block_idx_to_cuda] = thread_pool.submit(
move_blocks, block_idx_to_cpu, blocks_to_cpu, block_idx_to_cuda, blocks_to_cuda, device
)

def wait_blocks_move(block_idx, futures):
if block_idx not in futures:
return
# print(f"Backward: Wait for block {block_idx}")
# start_time = time.perf_counter()
future = futures.pop(block_idx)
future.result()
# print(f"Backward: Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s")
# torch.cuda.synchronize()
# print(f"Backward: Synchronized: {time.perf_counter()-start_time:.2f}s")

if args.fused_backward_pass:
# use fused optimizer for backward pass: other optimizers will be supported in the future
import library.adafactor_fused

library.adafactor_fused.patch_adafactor_fused(optimizer)

blocks_to_swap = args.blocks_to_swap
num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks)
num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks)
num_block_units = num_double_blocks + num_single_blocks // 2
handled_unit_indices = set()

n = 1 # only asynchronous purpose, no need to increase this number
# n = 2
# n = max(1, os.cpu_count() // 2)
thread_pool = ThreadPoolExecutor(max_workers=n)
futures = {}

for param_group, param_name_group in zip(optimizer.param_groups, param_names):
for parameter, param_name in zip(param_group["params"], param_name_group):
if parameter.requires_grad:
grad_hook = None

if blocks_to_swap:
is_double = param_name.startswith("double_blocks")
is_single = param_name.startswith("single_blocks")
if is_double or is_single:
block_idx = int(param_name.split(".")[1])
unit_idx = block_idx if is_double else num_double_blocks + block_idx // 2
if unit_idx not in handled_unit_indices:
# swap following (already backpropagated) block
handled_unit_indices.add(unit_idx)

# if n blocks were already backpropagated
num_blocks_propagated = num_block_units - unit_idx - 1
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= blocks_to_swap
waiting = unit_idx > 0 and unit_idx <= blocks_to_swap
if swapping or waiting:
block_idx_to_cpu = num_block_units - num_blocks_propagated
block_idx_to_cuda = blocks_to_swap - num_blocks_propagated
block_idx_to_wait = unit_idx - 1

# create swap hook
def create_swap_grad_hook(
bidx_to_cpu, bidx_to_cuda, bidx_to_wait, uidx: int, swpng: bool, wtng: bool
):
def __grad_hook(tensor: torch.Tensor):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
optimizer.step_param(tensor, param_group)
tensor.grad = None

# print(f"Backward: {uidx}, {swpng}, {wtng}")
if swpng:
submit_move_blocks(
futures,
thread_pool,
bidx_to_cpu,
bidx_to_cuda,
flux.double_blocks,
flux.single_blocks,
accelerator.device,
)
if wtng:
wait_blocks_move(bidx_to_wait, futures)

return __grad_hook

grad_hook = create_swap_grad_hook(
block_idx_to_cpu, block_idx_to_cuda, block_idx_to_wait, unit_idx, swapping, waiting
)

if grad_hook is None:

def __grad_hook(tensor: torch.Tensor, param_group=param_group):

def create_grad_hook(p_name, p_group):
def grad_hook(tensor: torch.Tensor):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
optimizer.step_param(tensor, param_group)
optimizer.step_param(tensor, p_group)
tensor.grad = None

grad_hook = __grad_hook
return grad_hook

parameter.register_post_accumulate_grad_hook(grad_hook)
parameter.register_post_accumulate_grad_hook(create_grad_hook(param_name, param_group))

elif args.blockwise_fused_optimizers:
# prepare for additional optimizers and lr schedulers
Expand All @@ -606,63 +508,22 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group):
num_parameters_per_group = [0] * len(optimizers)
parameter_optimizer_map = {}

blocks_to_swap = args.blocks_to_swap
num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks)
num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks)
num_block_units = num_double_blocks + num_single_blocks // 2

n = 1 # only asynchronous purpose, no need to increase this number
# n = max(1, os.cpu_count() // 2)
thread_pool = ThreadPoolExecutor(max_workers=n)
futures = {}

for opt_idx, optimizer in enumerate(optimizers):
for param_group in optimizer.param_groups:
for parameter in param_group["params"]:
if parameter.requires_grad:
block_type, block_idx = block_types_and_indices[opt_idx]

def create_optimizer_hook(btype, bidx):
def optimizer_hook(parameter: torch.Tensor):
# print(f"optimizer_hook: {btype}, {bidx}")
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(parameter, args.max_grad_norm)

i = parameter_optimizer_map[parameter]
optimizer_hooked_count[i] += 1
if optimizer_hooked_count[i] == num_parameters_per_group[i]:
optimizers[i].step()
optimizers[i].zero_grad(set_to_none=True)

# swap blocks if necessary
if blocks_to_swap and (btype == "double" or (btype == "single" and bidx % 2 == 0)):
unit_idx = bidx if btype == "double" else num_double_blocks + bidx // 2
num_blocks_propagated = num_block_units - unit_idx

swapping = num_blocks_propagated > 0 and num_blocks_propagated <= blocks_to_swap
waiting = unit_idx > 0 and unit_idx <= blocks_to_swap

if swapping:
block_idx_to_cpu = num_block_units - num_blocks_propagated
block_idx_to_cuda = blocks_to_swap - num_blocks_propagated
# print(f"Backward: Swap blocks {block_idx_to_cpu} and {block_idx_to_cuda}")
submit_move_blocks(
futures,
thread_pool,
block_idx_to_cpu,
block_idx_to_cuda,
flux.double_blocks,
flux.single_blocks,
accelerator.device,
)

if waiting:
block_idx_to_wait = unit_idx - 1
wait_blocks_move(block_idx_to_wait, futures)

return optimizer_hook

parameter.register_post_accumulate_grad_hook(create_optimizer_hook(block_type, block_idx))

def grad_hook(parameter: torch.Tensor):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(parameter, args.max_grad_norm)

i = parameter_optimizer_map[parameter]
optimizer_hooked_count[i] += 1
if optimizer_hooked_count[i] == num_parameters_per_group[i]:
optimizers[i].step()
optimizers[i].zero_grad(set_to_none=True)

parameter.register_post_accumulate_grad_hook(grad_hook)
parameter_optimizer_map[parameter] = opt_idx
num_parameters_per_group[opt_idx] += 1

Expand Down Expand Up @@ -934,6 +795,7 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
add_custom_train_arguments(parser) # TODO remove this from here
train_util.add_dit_training_arguments(parser)
flux_train_utils.add_flux_train_arguments(parser)

parser.add_argument(
Expand All @@ -958,16 +820,6 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください",
)
parser.add_argument(
"--blocks_to_swap",
type=int,
default=None,
help="[EXPERIMENTAL] "
"Sets the number of blocks (~640MB) to swap during the forward and backward passes."
"Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)."
" / 順伝播および逆伝播中にスワップするブロック(約640MB)の数を設定します。"
"この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度(s/it)も低下します。",
)
parser.add_argument(
"--double_blocks_to_swap",
type=int,
Expand Down
Loading