From cde90b8903870b6b28dae274d07ed27978055e3c Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 12 Nov 2024 08:49:05 +0900 Subject: [PATCH] feat: implement block swapping for FLUX.1 LoRA (WIP) --- flux_train.py | 2 +- flux_train_network.py | 33 ++++++++++++++++++++++++ library/custom_offloading_utils.py | 40 +++++++++++++++++++++++++++++- library/flux_models.py | 8 ++++-- train_network.py | 9 ++++++- 5 files changed, 87 insertions(+), 5 deletions(-) diff --git a/flux_train.py b/flux_train.py index 02dede45e..346fe8fbd 100644 --- a/flux_train.py +++ b/flux_train.py @@ -519,7 +519,7 @@ def grad_hook(parameter: torch.Tensor): num_parameters_per_group[opt_idx] += 1 # add hooks for block swapping: this hook is called after fused_backward_pass hook or blockwise_fused_optimizers hook - if is_swapping_blocks: + if False: # is_swapping_blocks: import library.custom_offloading_utils as custom_offloading_utils num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks) diff --git a/flux_train_network.py b/flux_train_network.py index 2b71a8979..376cc1597 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -25,6 +25,7 @@ def __init__(self): super().__init__() self.sample_prompts_te_outputs = None self.is_schnell: Optional[bool] = None + self.is_swapping_blocks: bool = False def assert_extra_args(self, args, train_dataset_group): super().assert_extra_args(args, train_dataset_group) @@ -78,6 +79,12 @@ def load_target_model(self, args, weight_dtype, accelerator): if args.split_mode: model = self.prepare_split_model(model, weight_dtype, accelerator) + self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 + if self.is_swapping_blocks: + # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. + logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") + model.enable_block_swap(args.blocks_to_swap, accelerator.device) + clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) clip_l.eval() @@ -285,6 +292,8 @@ def sample_images(self, accelerator, args, epoch, global_step, device, ae, token text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders) if not args.split_mode: + if self.is_swapping_blocks: + accelerator.unwrap_model(flux).prepare_block_swap_before_forward() flux_train_utils.sample_images( accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs ) @@ -539,6 +548,19 @@ def forward(hidden_states): text_encoder.to(te_weight_dtype) # fp8 prepare_fp8(text_encoder, weight_dtype) + def prepare_unet_with_accelerator( + self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module + ) -> torch.nn.Module: + if not self.is_swapping_blocks: + return super().prepare_unet_with_accelerator(args, accelerator, unet) + + # if we doesn't swap blocks, we can move the model to device + flux: flux_models.Flux = unet + flux = accelerator.prepare(flux, device_placement=[not self.is_swapping_blocks]) + accelerator.unwrap_model(flux).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage + + return flux + def setup_parser() -> argparse.ArgumentParser: parser = train_network.setup_parser() @@ -550,6 +572,17 @@ def setup_parser() -> argparse.ArgumentParser: help="[EXPERIMENTAL] use split mode for Flux model, network arg `train_blocks=single` is required" + "/[実験的] Fluxモデルの分割モードを使用する。ネットワーク引数`train_blocks=single`が必要", ) + + parser.add_argument( + "--blocks_to_swap", + type=int, + default=None, + help="[EXPERIMENTAL] " + "Sets the number of blocks to swap during the forward and backward passes." + "Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)." + " / 順伝播および逆伝播中にスワップするブロックの数を設定します。" + "この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度(s/it)も低下します。", + ) return parser diff --git a/library/custom_offloading_utils.py b/library/custom_offloading_utils.py index 33a413004..70da93902 100644 --- a/library/custom_offloading_utils.py +++ b/library/custom_offloading_utils.py @@ -183,9 +183,47 @@ class ModelOffloader(Offloader): supports forward offloading """ - def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False): + def __init__(self, blocks: list[nn.Module], num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False): super().__init__(num_blocks, blocks_to_swap, device, debug) + # register backward hooks + self.remove_handles = [] + for i, block in enumerate(blocks): + hook = self.create_backward_hook(blocks, i) + if hook is not None: + handle = block.register_full_backward_hook(hook) + self.remove_handles.append(handle) + + def __del__(self): + for handle in self.remove_handles: + handle.remove() + + def create_backward_hook(self, blocks: list[nn.Module], block_index: int) -> Optional[callable]: + # -1 for 0-based index + num_blocks_propagated = self.num_blocks - block_index - 1 + swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap + waiting = block_index > 0 and block_index <= self.blocks_to_swap + + if not swapping and not waiting: + return None + + # create hook + block_idx_to_cpu = self.num_blocks - num_blocks_propagated + block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated + block_idx_to_wait = block_index - 1 + + def backward_hook(module, grad_input, grad_output): + if self.debug: + print(f"Backward hook for block {block_index}") + + if swapping: + self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda) + if waiting: + self._wait_blocks_move(block_idx_to_wait) + return None + + return backward_hook + def prepare_block_devices_before_forward(self, blocks: list[nn.Module]): if self.blocks_to_swap is None or self.blocks_to_swap == 0: return diff --git a/library/flux_models.py b/library/flux_models.py index e0bee160f..4fa272522 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -970,8 +970,12 @@ def enable_block_swap(self, num_blocks: int, device: torch.device): double_blocks_to_swap = num_blocks // 2 single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2 - self.offloader_double = custom_offloading_utils.ModelOffloader(self.num_double_blocks, double_blocks_to_swap, device) - self.offloader_single = custom_offloading_utils.ModelOffloader(self.num_single_blocks, single_blocks_to_swap, device) + self.offloader_double = custom_offloading_utils.ModelOffloader( + self.double_blocks, self.num_double_blocks, double_blocks_to_swap, device #, debug=True + ) + self.offloader_single = custom_offloading_utils.ModelOffloader( + self.single_blocks, self.num_single_blocks, single_blocks_to_swap, device #, debug=True + ) print( f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}." ) diff --git a/train_network.py b/train_network.py index b90aa420e..d70f14ad3 100644 --- a/train_network.py +++ b/train_network.py @@ -18,6 +18,7 @@ init_ipex() from accelerate.utils import set_seed +from accelerate import Accelerator from diffusers import DDPMScheduler from library import deepspeed_utils, model_util, strategy_base, strategy_sd @@ -272,6 +273,11 @@ def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder): def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype): text_encoder.text_model.embeddings.to(dtype=weight_dtype) + def prepare_unet_with_accelerator( + self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module + ) -> torch.nn.Module: + return accelerator.prepare(unet) + def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): pass @@ -627,7 +633,8 @@ def train(self, args): training_model = ds_model else: if train_unet: - unet = accelerator.prepare(unet) + # default implementation is: unet = accelerator.prepare(unet) + unet = self.prepare_unet_with_accelerator(args, accelerator, unet) # accelerator does some magic here else: unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator if train_text_encoder: