Skip to content

Commit

Permalink
feat: add block swap for FLUX.1/SD3 LoRA training
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Nov 12, 2024
1 parent 17cf249 commit 2cb7a6d
Show file tree
Hide file tree
Showing 14 changed files with 291 additions and 632 deletions.
212 changes: 46 additions & 166 deletions README.md

Large diffs are not rendered by default.

56 changes: 5 additions & 51 deletions flux_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ def train(args):
)
args.gradient_checkpointing = True

assert (
args.blocks_to_swap is None or args.blocks_to_swap == 0
) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません"

cache_latents = args.cache_latents
use_dreambooth_method = args.in_json is None

Expand Down Expand Up @@ -518,47 +522,6 @@ def grad_hook(parameter: torch.Tensor):
parameter_optimizer_map[parameter] = opt_idx
num_parameters_per_group[opt_idx] += 1

# add hooks for block swapping: this hook is called after fused_backward_pass hook or blockwise_fused_optimizers hook
if False: # is_swapping_blocks:
import library.custom_offloading_utils as custom_offloading_utils

num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks)
num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks)
double_blocks_to_swap = args.blocks_to_swap // 2
single_blocks_to_swap = (args.blocks_to_swap - double_blocks_to_swap) * 2

offloader_double = custom_offloading_utils.TrainOffloader(num_double_blocks, double_blocks_to_swap, accelerator.device)
offloader_single = custom_offloading_utils.TrainOffloader(num_single_blocks, single_blocks_to_swap, accelerator.device)

param_name_pairs = []
if not args.blockwise_fused_optimizers:
for param_group, param_name_group in zip(optimizer.param_groups, param_names):
param_name_pairs.extend(zip(param_group["params"], param_name_group))
else:
# named_parameters is a list of (name, parameter) pairs
param_name_pairs.extend([(p, n) for n, p in flux.named_parameters()])

for parameter, param_name in param_name_pairs:
if not parameter.requires_grad:
continue

is_double = param_name.startswith("double_blocks")
is_single = param_name.startswith("single_blocks")
if not is_double and not is_single:
continue

block_index = int(param_name.split(".")[1])
if is_double:
blocks = flux.double_blocks
offloader = offloader_double
else:
blocks = flux.single_blocks
offloader = offloader_single

grad_hook = offloader.create_grad_hook(blocks, block_index)
if grad_hook is not None:
parameter.register_post_accumulate_grad_hook(grad_hook)

# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
Expand Down Expand Up @@ -827,6 +790,7 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
add_custom_train_arguments(parser) # TODO remove this from here
train_util.add_dit_training_arguments(parser)
flux_train_utils.add_flux_train_arguments(parser)

parser.add_argument(
Expand All @@ -851,16 +815,6 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください",
)
parser.add_argument(
"--blocks_to_swap",
type=int,
default=None,
help="[EXPERIMENTAL] "
"Sets the number of blocks (~640MB) to swap during the forward and backward passes."
"Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)."
" / 順伝播および逆伝播中にスワップするブロック(約640MB)の数を設定します。"
"この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度(s/it)も低下します。",
)
parser.add_argument(
"--double_blocks_to_swap",
type=int,
Expand Down
95 changes: 55 additions & 40 deletions flux_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,23 @@ def assert_extra_args(self, args, train_dataset_group):
if args.max_token_length is not None:
logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません")

assert not args.split_mode or not args.cpu_offload_checkpointing, (
"split_mode and cpu_offload_checkpointing cannot be used together"
" / split_modeとcpu_offload_checkpointingは同時に使用できません"
)
assert (
args.blocks_to_swap is None or args.blocks_to_swap == 0
) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません"

# deprecated split_mode option
if args.split_mode:
if args.blocks_to_swap is not None:
logger.warning(
"split_mode is deprecated. Because `--blocks_to_swap` is set, `--split_mode` is ignored."
" / split_modeは非推奨です。`--blocks_to_swap`が設定されているため、`--split_mode`は無視されます。"
)
else:
logger.warning(
"split_mode is deprecated. Please use `--blocks_to_swap` instead. `--blocks_to_swap 18` is automatically set."
" / split_modeは非推奨です。代わりに`--blocks_to_swap`を使用してください。`--blocks_to_swap 18`が自動的に設定されました。"
)
args.blocks_to_swap = 18 # 18 is safe for most cases

train_dataset_group.verify_bucket_reso_steps(32) # TODO check this

Expand All @@ -75,9 +88,15 @@ def load_target_model(self, args, weight_dtype, accelerator):
raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}")
elif model.dtype == torch.float8_e4m3fn:
logger.info("Loaded fp8 FLUX model")
else:
logger.info(
"Cast FLUX model to fp8. This may take a while. You can reduce the time by using fp8 checkpoint."
" / FLUXモデルをfp8に変換しています。これには時間がかかる場合があります。fp8チェックポイントを使用することで時間を短縮できます。"
)
model.to(torch.float8_e4m3fn)

if args.split_mode:
model = self.prepare_split_model(model, weight_dtype, accelerator)
# if args.split_mode:
# model = self.prepare_split_model(model, weight_dtype, accelerator)

self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
if self.is_swapping_blocks:
Expand Down Expand Up @@ -108,6 +127,7 @@ def load_target_model(self, args, weight_dtype, accelerator):

return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model

"""
def prepare_split_model(self, model, weight_dtype, accelerator):
from accelerate import init_empty_weights
Expand Down Expand Up @@ -144,6 +164,7 @@ def prepare_split_model(self, model, weight_dtype, accelerator):
logger.info("split model prepared")
return flux_lower
"""

def get_tokenize_strategy(self, args):
_, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)
Expand Down Expand Up @@ -291,14 +312,12 @@ def sample_images(self, accelerator, args, epoch, global_step, device, ae, token
text_encoders = text_encoder # for compatibility
text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders)

if not args.split_mode:
if self.is_swapping_blocks:
accelerator.unwrap_model(flux).prepare_block_swap_before_forward()
flux_train_utils.sample_images(
accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs
)
return
flux_train_utils.sample_images(
accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs
)
# return

"""
class FluxUpperLowerWrapper(torch.nn.Module):
def __init__(self, flux_upper: flux_models.FluxUpper, flux_lower: flux_models.FluxLower, device: torch.device):
super().__init__()
Expand All @@ -325,6 +344,7 @@ def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None, txt_a
accelerator, args, epoch, global_step, wrapper, ae, text_encoders, self.sample_prompts_te_outputs
)
clean_memory_on_device(accelerator.device)
"""

def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
Expand Down Expand Up @@ -383,20 +403,21 @@ def get_noise_pred_and_target(
t5_attn_mask = None

def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
if not args.split_mode:
# normal forward
with accelerator.autocast():
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
model_pred = unet(
img=img,
img_ids=img_ids,
txt=t5_out,
txt_ids=txt_ids,
y=l_pooled,
timesteps=timesteps / 1000,
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
)
# if not args.split_mode:
# normal forward
with accelerator.autocast():
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
model_pred = unet(
img=img,
img_ids=img_ids,
txt=t5_out,
txt_ids=txt_ids,
y=l_pooled,
timesteps=timesteps / 1000,
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
)
"""
else:
# split forward to reduce memory usage
assert network.train_blocks == "single", "train_blocks must be single for split mode"
Expand Down Expand Up @@ -430,6 +451,7 @@ def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t
vec.requires_grad_(True)
pe.requires_grad_(True)
model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask)
"""

return model_pred

Expand Down Expand Up @@ -558,30 +580,23 @@ def prepare_unet_with_accelerator(
flux: flux_models.Flux = unet
flux = accelerator.prepare(flux, device_placement=[not self.is_swapping_blocks])
accelerator.unwrap_model(flux).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
accelerator.unwrap_model(flux).prepare_block_swap_before_forward()

return flux


def setup_parser() -> argparse.ArgumentParser:
parser = train_network.setup_parser()
train_util.add_dit_training_arguments(parser)
flux_train_utils.add_flux_train_arguments(parser)

parser.add_argument(
"--split_mode",
action="store_true",
help="[EXPERIMENTAL] use split mode for Flux model, network arg `train_blocks=single` is required"
+ "/[実験的] Fluxモデルの分割モードを使用する。ネットワーク引数`train_blocks=single`が必要",
)

parser.add_argument(
"--blocks_to_swap",
type=int,
default=None,
help="[EXPERIMENTAL] "
"Sets the number of blocks to swap during the forward and backward passes."
"Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)."
" / 順伝播および逆伝播中にスワップするブロックの数を設定します。"
"この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度(s/it)も低下します。",
# help="[EXPERIMENTAL] use split mode for Flux model, network arg `train_blocks=single` is required"
# + "/[実験的] Fluxモデルの分割モードを使用する。ネットワーク引数`train_blocks=single`が必要",
help="[Deprecated] This option is deprecated. Please use `--blocks_to_swap` instead."
" / このオプションは非推奨です。代わりに`--blocks_to_swap`を使用してください。",
)
return parser

Expand Down
75 changes: 24 additions & 51 deletions library/custom_offloading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,29 @@ def synchronize_device(device: torch.device):
torch.mps.synchronize()


def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
assert layer_to_cpu.__class__ == layer_to_cuda.__class__

weight_swap_jobs = []
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))

# This is not working for all cases (e.g. SD3), so we need to find the corresponding modules
# for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
# print(module_to_cpu.__class__, module_to_cuda.__class__)
# if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
# weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))

modules_to_cpu = {k: v for k, v in layer_to_cpu.named_modules()}
for module_to_cuda_name, module_to_cuda in layer_to_cuda.named_modules():
if hasattr(module_to_cuda, "weight") and module_to_cuda.weight is not None:
module_to_cpu = modules_to_cpu.get(module_to_cuda_name, None)
if module_to_cpu is not None and module_to_cpu.weight.shape == module_to_cuda.weight.shape:
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
else:
if module_to_cuda.weight.data.device.type != device.type:
# print(
# f"Module {module_to_cuda_name} not found in CPU model or shape mismatch, so not swapping and moving to device"
# )
module_to_cuda.weight.data = module_to_cuda.weight.data.to(device)

torch.cuda.current_stream().synchronize() # this prevents the illegal loss value

Expand Down Expand Up @@ -92,7 +108,7 @@ def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, d

def swap_weight_devices(self, block_to_cpu: nn.Module, block_to_cuda: nn.Module):
if self.cuda_available:
swap_weight_devices(block_to_cpu, block_to_cuda)
swap_weight_devices_cuda(self.device, block_to_cpu, block_to_cuda)
else:
swap_weight_devices_no_cuda(self.device, block_to_cpu, block_to_cuda)

Expand Down Expand Up @@ -132,52 +148,6 @@ def _wait_blocks_move(self, block_idx):
print(f"Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s")


class TrainOffloader(Offloader):
"""
supports backward offloading
"""

def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
super().__init__(num_blocks, blocks_to_swap, device, debug)
self.hook_added = set()

def create_grad_hook(self, blocks: list[nn.Module], block_index: int) -> Optional[callable]:
if block_index in self.hook_added:
return None
self.hook_added.add(block_index)

# -1 for 0-based index, -1 for current block is not fully backpropagated yet
num_blocks_propagated = self.num_blocks - block_index - 2
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap
waiting = block_index > 0 and block_index <= self.blocks_to_swap

if not swapping and not waiting:
return None

# create hook
block_idx_to_cpu = self.num_blocks - num_blocks_propagated
block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated
block_idx_to_wait = block_index - 1

if self.debug:
print(
f"Backward: Created grad hook for block {block_index} with {block_idx_to_cpu}, {block_idx_to_cuda}, {block_idx_to_wait}"
)
if swapping:

def grad_hook(tensor: torch.Tensor):
self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda)

return grad_hook

else:

def grad_hook(tensor: torch.Tensor):
self._wait_blocks_move(block_idx_to_wait)

return grad_hook


class ModelOffloader(Offloader):
"""
supports forward offloading
Expand Down Expand Up @@ -228,6 +198,9 @@ def prepare_block_devices_before_forward(self, blocks: list[nn.Module]):
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
return

if self.debug:
print("Prepare block devices before forward")

for b in blocks[0 : self.num_blocks - self.blocks_to_swap]:
b.to(self.device)
weighs_to_device(b, self.device) # make sure weights are on device
Expand Down
Loading

0 comments on commit 2cb7a6d

Please sign in to comment.