From 02bd76e6c719ad85c108a177405846c5c958bd78 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 11 Nov 2024 21:15:36 +0900 Subject: [PATCH] Refactor block swapping to utilize custom offloading utilities --- flux_train.py | 228 ++++++++--------------------- library/custom_offloading_utils.py | 216 +++++++++++++++++++++++++++ library/flux_models.py | 113 ++------------ 3 files changed, 295 insertions(+), 262 deletions(-) create mode 100644 library/custom_offloading_utils.py diff --git a/flux_train.py b/flux_train.py index afddc897f..02dede45e 100644 --- a/flux_train.py +++ b/flux_train.py @@ -295,7 +295,7 @@ def train(args): # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. # This idea is based on 2kpr's great work. Thank you! logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") - flux.enable_block_swap(args.blocks_to_swap) + flux.enable_block_swap(args.blocks_to_swap, accelerator.device) if not cache_latents: # load VAE here if not cached @@ -338,15 +338,15 @@ def train(args): # determine target layer and block index for each parameter block_type = "other" # double, single or other if np[0].startswith("double_blocks"): - block_idx = int(np[0].split(".")[1]) + block_index = int(np[0].split(".")[1]) block_type = "double" elif np[0].startswith("single_blocks"): - block_idx = int(np[0].split(".")[1]) + block_index = int(np[0].split(".")[1]) block_type = "single" else: - block_idx = -1 + block_index = -1 - param_group_key = (block_type, block_idx) + param_group_key = (block_type, block_index) if param_group_key not in param_group: param_group[param_group_key] = [] param_group[param_group_key].append(p) @@ -466,123 +466,21 @@ def train(args): # resumeする train_util.resume_from_local_or_hf_if_specified(accelerator, args) - # memory efficient block swapping - - def submit_move_blocks(futures, thread_pool, block_idx_to_cpu, block_idx_to_cuda, blocks, block_id): - def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda): - # start_time = time.perf_counter() - # print(f"Backward: Move block {bidx_to_cpu} to CPU and block {bidx_to_cuda} to CUDA") - utils.swap_weight_devices(block_to_cpu, block_to_cuda) - # print(f"Backward: Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter()-start_time:.2f}s") - return bidx_to_cpu, bidx_to_cuda # , event - - block_to_cpu = blocks[block_idx_to_cpu] - block_to_cuda = blocks[block_idx_to_cuda] - - futures[block_id] = thread_pool.submit(move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda) - - def wait_blocks_move(block_id, futures): - if block_id not in futures: - return - # print(f"Backward: Wait for block {block_id}") - # start_time = time.perf_counter() - future = futures.pop(block_id) - _, bidx_to_cuda = future.result() - assert block_id[1] == bidx_to_cuda, f"Block index mismatch: {block_id[1]} != {bidx_to_cuda}" - # print(f"Backward: Waited for block {block_id}: {time.perf_counter()-start_time:.2f}s") - # print(f"Backward: Synchronized: {time.perf_counter()-start_time:.2f}s") - if args.fused_backward_pass: # use fused optimizer for backward pass: other optimizers will be supported in the future import library.adafactor_fused library.adafactor_fused.patch_adafactor_fused(optimizer) - double_blocks_to_swap = args.blocks_to_swap // 2 - single_blocks_to_swap = (args.blocks_to_swap - double_blocks_to_swap) * 2 - num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks) - num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks) - handled_block_ids = set() - - n = 1 # only asynchronous purpose, no need to increase this number - # n = 2 - # n = max(1, os.cpu_count() // 2) - thread_pool = ThreadPoolExecutor(max_workers=n) - futures = {} - for param_group, param_name_group in zip(optimizer.param_groups, param_names): for parameter, param_name in zip(param_group["params"], param_name_group): if parameter.requires_grad: - grad_hook = None - - if double_blocks_to_swap > 0 or single_blocks_to_swap > 0: - is_double = param_name.startswith("double_blocks") - is_single = param_name.startswith("single_blocks") - if is_double and double_blocks_to_swap > 0 or is_single and single_blocks_to_swap > 0: - block_idx = int(param_name.split(".")[1]) - block_id = (is_double, block_idx) # double or single, block index - if block_id not in handled_block_ids: - # swap following (already backpropagated) block - handled_block_ids.add(block_id) - - # if n blocks were already backpropagated - if is_double: - num_blocks = num_double_blocks - blocks_to_swap = double_blocks_to_swap - else: - num_blocks = num_single_blocks - blocks_to_swap = single_blocks_to_swap - - # -1 for 0-based index, -1 for current block is not fully backpropagated yet - num_blocks_propagated = num_blocks - block_idx - 2 - swapping = num_blocks_propagated > 0 and num_blocks_propagated <= blocks_to_swap - waiting = block_idx > 0 and block_idx <= blocks_to_swap - - if swapping or waiting: - block_idx_to_cpu = num_blocks - num_blocks_propagated - block_idx_to_cuda = blocks_to_swap - num_blocks_propagated - block_idx_to_wait = block_idx - 1 - - # create swap hook - def create_swap_grad_hook( - is_dbl, bidx_to_cpu, bidx_to_cuda, bidx_to_wait, swpng: bool, wtng: bool - ): - def __grad_hook(tensor: torch.Tensor): - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - accelerator.clip_grad_norm_(tensor, args.max_grad_norm) - optimizer.step_param(tensor, param_group) - tensor.grad = None - - # print( - # f"Backward: Block {is_dbl}, {bidx_to_cpu}, {bidx_to_cuda}, {bidx_to_wait}, {swpng}, {wtng}" - # ) - if swpng: - submit_move_blocks( - futures, - thread_pool, - bidx_to_cpu, - bidx_to_cuda, - flux.double_blocks if is_dbl else flux.single_blocks, - (is_dbl, bidx_to_cuda), # wait for this block - ) - if wtng: - wait_blocks_move((is_dbl, bidx_to_wait), futures) - - return __grad_hook - - grad_hook = create_swap_grad_hook( - is_double, block_idx_to_cpu, block_idx_to_cuda, block_idx_to_wait, swapping, waiting - ) - - if grad_hook is None: - - def __grad_hook(tensor: torch.Tensor, param_group=param_group): - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - accelerator.clip_grad_norm_(tensor, args.max_grad_norm) - optimizer.step_param(tensor, param_group) - tensor.grad = None - grad_hook = __grad_hook + def grad_hook(tensor: torch.Tensor, param_group=param_group): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, param_group) + tensor.grad = None parameter.register_post_accumulate_grad_hook(grad_hook) @@ -601,66 +499,66 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group): num_parameters_per_group = [0] * len(optimizers) parameter_optimizer_map = {} - blocks_to_swap = args.blocks_to_swap - num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks) - num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks) - num_block_units = num_double_blocks + num_single_blocks // 2 - - n = 1 # only asynchronous purpose, no need to increase this number - # n = max(1, os.cpu_count() // 2) - thread_pool = ThreadPoolExecutor(max_workers=n) - futures = {} - for opt_idx, optimizer in enumerate(optimizers): for param_group in optimizer.param_groups: for parameter in param_group["params"]: if parameter.requires_grad: - block_type, block_idx = block_types_and_indices[opt_idx] - - def create_optimizer_hook(btype, bidx): - def optimizer_hook(parameter: torch.Tensor): - # print(f"optimizer_hook: {btype}, {bidx}") - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - accelerator.clip_grad_norm_(parameter, args.max_grad_norm) - - i = parameter_optimizer_map[parameter] - optimizer_hooked_count[i] += 1 - if optimizer_hooked_count[i] == num_parameters_per_group[i]: - optimizers[i].step() - optimizers[i].zero_grad(set_to_none=True) - - # swap blocks if necessary - if blocks_to_swap and (btype == "double" or (btype == "single" and bidx % 2 == 0)): - unit_idx = bidx if btype == "double" else num_double_blocks + bidx // 2 - num_blocks_propagated = num_block_units - unit_idx - - swapping = num_blocks_propagated > 0 and num_blocks_propagated <= blocks_to_swap - waiting = unit_idx > 0 and unit_idx <= blocks_to_swap - - if swapping: - block_idx_to_cpu = num_block_units - num_blocks_propagated - block_idx_to_cuda = blocks_to_swap - num_blocks_propagated - # print(f"Backward: Swap blocks {block_idx_to_cpu} and {block_idx_to_cuda}") - submit_move_blocks( - futures, - thread_pool, - block_idx_to_cpu, - block_idx_to_cuda, - flux.double_blocks, - flux.single_blocks, - accelerator.device, - ) - - if waiting: - block_idx_to_wait = unit_idx - 1 - wait_blocks_move(block_idx_to_wait, futures) - - return optimizer_hook - - parameter.register_post_accumulate_grad_hook(create_optimizer_hook(block_type, block_idx)) + + def grad_hook(parameter: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(parameter, args.max_grad_norm) + + i = parameter_optimizer_map[parameter] + optimizer_hooked_count[i] += 1 + if optimizer_hooked_count[i] == num_parameters_per_group[i]: + optimizers[i].step() + optimizers[i].zero_grad(set_to_none=True) + + parameter.register_post_accumulate_grad_hook(grad_hook) parameter_optimizer_map[parameter] = opt_idx num_parameters_per_group[opt_idx] += 1 + # add hooks for block swapping: this hook is called after fused_backward_pass hook or blockwise_fused_optimizers hook + if is_swapping_blocks: + import library.custom_offloading_utils as custom_offloading_utils + + num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks) + num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks) + double_blocks_to_swap = args.blocks_to_swap // 2 + single_blocks_to_swap = (args.blocks_to_swap - double_blocks_to_swap) * 2 + + offloader_double = custom_offloading_utils.TrainOffloader(num_double_blocks, double_blocks_to_swap, accelerator.device) + offloader_single = custom_offloading_utils.TrainOffloader(num_single_blocks, single_blocks_to_swap, accelerator.device) + + param_name_pairs = [] + if not args.blockwise_fused_optimizers: + for param_group, param_name_group in zip(optimizer.param_groups, param_names): + param_name_pairs.extend(zip(param_group["params"], param_name_group)) + else: + # named_parameters is a list of (name, parameter) pairs + param_name_pairs.extend([(p, n) for n, p in flux.named_parameters()]) + + for parameter, param_name in param_name_pairs: + if not parameter.requires_grad: + continue + + is_double = param_name.startswith("double_blocks") + is_single = param_name.startswith("single_blocks") + if not is_double and not is_single: + continue + + block_index = int(param_name.split(".")[1]) + if is_double: + blocks = flux.double_blocks + offloader = offloader_double + else: + blocks = flux.single_blocks + offloader = offloader_single + + grad_hook = offloader.create_grad_hook(blocks, block_index) + if grad_hook is not None: + parameter.register_post_accumulate_grad_hook(grad_hook) + # epoch数を計算する num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) diff --git a/library/custom_offloading_utils.py b/library/custom_offloading_utils.py new file mode 100644 index 000000000..33a413004 --- /dev/null +++ b/library/custom_offloading_utils.py @@ -0,0 +1,216 @@ +from concurrent.futures import ThreadPoolExecutor +import time +from typing import Optional +import torch +import torch.nn as nn + +from library.device_utils import clean_memory_on_device + + +def synchronize_device(device: torch.device): + if device.type == "cuda": + torch.cuda.synchronize() + elif device.type == "xpu": + torch.xpu.synchronize() + elif device.type == "mps": + torch.mps.synchronize() + + +def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): + assert layer_to_cpu.__class__ == layer_to_cuda.__class__ + + weight_swap_jobs = [] + for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): + if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: + weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) + + torch.cuda.current_stream().synchronize() # this prevents the illegal loss value + + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + # cuda to cpu + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: + cuda_data_view.record_stream(stream) + module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True) + + stream.synchronize() + + # cpu to cuda + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: + cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) + module_to_cuda.weight.data = cuda_data_view + + stream.synchronize() + torch.cuda.current_stream().synchronize() # this prevents the illegal loss value + + +def swap_weight_devices_no_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): + """ + not tested + """ + assert layer_to_cpu.__class__ == layer_to_cuda.__class__ + + weight_swap_jobs = [] + for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): + if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: + weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) + + # device to cpu + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: + module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True) + + synchronize_device() + + # cpu to device + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: + cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) + module_to_cuda.weight.data = cuda_data_view + + synchronize_device() + + +def weighs_to_device(layer: nn.Module, device: torch.device): + for module in layer.modules(): + if hasattr(module, "weight") and module.weight is not None: + module.weight.data = module.weight.data.to(device, non_blocking=True) + + +class Offloader: + """ + common offloading class + """ + + def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False): + self.num_blocks = num_blocks + self.blocks_to_swap = blocks_to_swap + self.device = device + self.debug = debug + + self.thread_pool = ThreadPoolExecutor(max_workers=1) + self.futures = {} + self.cuda_available = device.type == "cuda" + + def swap_weight_devices(self, block_to_cpu: nn.Module, block_to_cuda: nn.Module): + if self.cuda_available: + swap_weight_devices(block_to_cpu, block_to_cuda) + else: + swap_weight_devices_no_cuda(self.device, block_to_cpu, block_to_cuda) + + def _submit_move_blocks(self, blocks, block_idx_to_cpu, block_idx_to_cuda): + def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda): + if self.debug: + start_time = time.perf_counter() + print(f"Move block {bidx_to_cpu} to CPU and block {bidx_to_cuda} to {'CUDA' if self.cuda_available else 'device'}") + + self.swap_weight_devices(block_to_cpu, block_to_cuda) + + if self.debug: + print(f"Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter()-start_time:.2f}s") + return bidx_to_cpu, bidx_to_cuda # , event + + block_to_cpu = blocks[block_idx_to_cpu] + block_to_cuda = blocks[block_idx_to_cuda] + + self.futures[block_idx_to_cuda] = self.thread_pool.submit( + move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda + ) + + def _wait_blocks_move(self, block_idx): + if block_idx not in self.futures: + return + + if self.debug: + print(f"Wait for block {block_idx}") + start_time = time.perf_counter() + + future = self.futures.pop(block_idx) + _, bidx_to_cuda = future.result() + + assert block_idx == bidx_to_cuda, f"Block index mismatch: {block_idx} != {bidx_to_cuda}" + + if self.debug: + print(f"Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s") + + +class TrainOffloader(Offloader): + """ + supports backward offloading + """ + + def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False): + super().__init__(num_blocks, blocks_to_swap, device, debug) + self.hook_added = set() + + def create_grad_hook(self, blocks: list[nn.Module], block_index: int) -> Optional[callable]: + if block_index in self.hook_added: + return None + self.hook_added.add(block_index) + + # -1 for 0-based index, -1 for current block is not fully backpropagated yet + num_blocks_propagated = self.num_blocks - block_index - 2 + swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap + waiting = block_index > 0 and block_index <= self.blocks_to_swap + + if not swapping and not waiting: + return None + + # create hook + block_idx_to_cpu = self.num_blocks - num_blocks_propagated + block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated + block_idx_to_wait = block_index - 1 + + if self.debug: + print( + f"Backward: Created grad hook for block {block_index} with {block_idx_to_cpu}, {block_idx_to_cuda}, {block_idx_to_wait}" + ) + if swapping: + + def grad_hook(tensor: torch.Tensor): + self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda) + + return grad_hook + + else: + + def grad_hook(tensor: torch.Tensor): + self._wait_blocks_move(block_idx_to_wait) + + return grad_hook + + +class ModelOffloader(Offloader): + """ + supports forward offloading + """ + + def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False): + super().__init__(num_blocks, blocks_to_swap, device, debug) + + def prepare_block_devices_before_forward(self, blocks: list[nn.Module]): + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + return + + for b in blocks[0 : self.num_blocks - self.blocks_to_swap]: + b.to(self.device) + weighs_to_device(b, self.device) # make sure weights are on device + + for b in blocks[self.num_blocks - self.blocks_to_swap :]: + b.to(self.device) # move block to device first + weighs_to_device(b, "cpu") # make sure weights are on cpu + + synchronize_device(self.device) + clean_memory_on_device(self.device) + + def wait_for_block(self, block_idx: int): + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + return + self._wait_blocks_move(block_idx) + + def submit_move_blocks(self, blocks: list[nn.Module], block_idx: int): + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + return + if block_idx >= self.blocks_to_swap: + return + block_idx_to_cpu = block_idx + block_idx_to_cuda = self.num_blocks - self.blocks_to_swap + block_idx + self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda) diff --git a/library/flux_models.py b/library/flux_models.py index 4721fa02e..e0bee160f 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -18,6 +18,7 @@ from einops import rearrange from torch import Tensor, nn from torch.utils.checkpoint import checkpoint +from library import custom_offloading_utils # USE_REENTRANT = True @@ -923,7 +924,8 @@ def __init__(self, params: FluxParams): self.cpu_offload_checkpointing = False self.blocks_to_swap = None - self.thread_pool: Optional[ThreadPoolExecutor] = None + self.offloader_double = None + self.offloader_single = None self.num_double_blocks = len(self.double_blocks) self.num_single_blocks = len(self.single_blocks) @@ -963,17 +965,17 @@ def disable_gradient_checkpointing(self): print("FLUX: Gradient checkpointing disabled.") - def enable_block_swap(self, num_blocks: int): + def enable_block_swap(self, num_blocks: int, device: torch.device): self.blocks_to_swap = num_blocks - self.double_blocks_to_swap = num_blocks // 2 - self.single_blocks_to_swap = (num_blocks - self.double_blocks_to_swap) * 2 + double_blocks_to_swap = num_blocks // 2 + single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2 + + self.offloader_double = custom_offloading_utils.ModelOffloader(self.num_double_blocks, double_blocks_to_swap, device) + self.offloader_single = custom_offloading_utils.ModelOffloader(self.num_single_blocks, single_blocks_to_swap, device) print( - f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {self.double_blocks_to_swap}, single blocks: {self.single_blocks_to_swap}." + f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}." ) - n = 1 # async block swap. 1 is enough - self.thread_pool = ThreadPoolExecutor(max_workers=n) - def move_to_device_except_swap_blocks(self, device: torch.device): # assume model is on cpu. do not move blocks to device to reduce temporary memory usage if self.blocks_to_swap: @@ -988,56 +990,11 @@ def move_to_device_except_swap_blocks(self, device: torch.device): self.double_blocks = save_double_blocks self.single_blocks = save_single_blocks - # def get_block_unit(self, index: int): - # if index < len(self.double_blocks): - # return (self.double_blocks[index],) - # else: - # index -= len(self.double_blocks) - # index *= 2 - # return self.single_blocks[index], self.single_blocks[index + 1] - - # def get_unit_index(self, is_double: bool, index: int): - # if is_double: - # return index - # else: - # return len(self.double_blocks) + index // 2 - def prepare_block_swap_before_forward(self): - # # make: first n blocks are on cuda, and last n blocks are on cpu - # if self.blocks_to_swap is None or self.blocks_to_swap == 0: - # # raise ValueError("Block swap is not enabled.") - # return - # for i in range(self.num_block_units - self.blocks_to_swap): - # for b in self.get_block_unit(i): - # b.to(self.device) - # for i in range(self.num_block_units - self.blocks_to_swap, self.num_block_units): - # for b in self.get_block_unit(i): - # b.to("cpu") - # clean_memory_on_device(self.device) - - # all blocks are on device, but some weights are on cpu - # make first n blocks weights on device, and last n blocks weights on cpu if self.blocks_to_swap is None or self.blocks_to_swap == 0: - # raise ValueError("Block swap is not enabled.") return - - for b in self.double_blocks[0 : self.num_double_blocks - self.double_blocks_to_swap]: - b.to(self.device) - utils.weighs_to_device(b, self.device) # make sure weights are on device - for b in self.double_blocks[self.num_double_blocks - self.double_blocks_to_swap :]: - b.to(self.device) # move block to device first - utils.weighs_to_device(b, "cpu") # make sure weights are on cpu - torch.cuda.synchronize() - clean_memory_on_device(self.device) - - for b in self.single_blocks[0 : self.num_single_blocks - self.single_blocks_to_swap]: - b.to(self.device) - utils.weighs_to_device(b, self.device) # make sure weights are on device - for b in self.single_blocks[self.num_single_blocks - self.single_blocks_to_swap :]: - b.to(self.device) # move block to device first - utils.weighs_to_device(b, "cpu") # make sure weights are on cpu - torch.cuda.synchronize() - clean_memory_on_device(self.device) + self.offloader_double.prepare_block_devices_before_forward(self.double_blocks) + self.offloader_single.prepare_block_devices_before_forward(self.single_blocks) def forward( self, @@ -1073,59 +1030,21 @@ def forward( for block in self.single_blocks: img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) else: - # device = self.device - - def submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda): - def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda): - # start_time = time.perf_counter() - # print(f"Moving {bidx_to_cpu} to cpu and {bidx_to_cuda} to cuda.") - utils.swap_weight_devices(block_to_cpu, block_to_cuda) - # print(f"Block move done. {bidx_to_cpu} to cpu, {bidx_to_cuda} to cuda.") - - # print(f"Move blocks took {time.perf_counter() - start_time:.2f} seconds") - return block_idx_to_cpu, block_idx_to_cuda # , event - - block_to_cpu = blocks[block_idx_to_cpu] - block_to_cuda = blocks[block_idx_to_cuda] - # print(f"Submit move blocks. {block_idx_to_cpu} to cpu, {block_idx_to_cuda} to cuda.") - return self.thread_pool.submit(move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda) - - def wait_for_blocks_move(block_idx, ftrs): - if block_idx not in ftrs: - return - # print(f"Waiting for move blocks: {block_idx}") - # start_time = time.perf_counter() - ftr = ftrs.pop(block_idx) - ftr.result() - # print(f"{block_idx} move blocks took {time.perf_counter() - start_time:.2f} seconds") - - double_futures = {} for block_idx, block in enumerate(self.double_blocks): - # print(f"Double block {block_idx}") - wait_for_blocks_move(block_idx, double_futures) + self.offloader_double.wait_for_block(block_idx) img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) - if block_idx < self.double_blocks_to_swap: - block_idx_to_cpu = block_idx - block_idx_to_cuda = self.num_double_blocks - self.double_blocks_to_swap + block_idx - future = submit_move_blocks(self.double_blocks, block_idx_to_cpu, block_idx_to_cuda) - double_futures[block_idx_to_cuda] = future + self.offloader_double.submit_move_blocks(self.double_blocks, block_idx) img = torch.cat((txt, img), 1) - single_futures = {} for block_idx, block in enumerate(self.single_blocks): - # print(f"Single block {block_idx}") - wait_for_blocks_move(block_idx, single_futures) + self.offloader_single.wait_for_block(block_idx) img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) - if block_idx < self.single_blocks_to_swap: - block_idx_to_cpu = block_idx - block_idx_to_cuda = self.num_single_blocks - self.single_blocks_to_swap + block_idx - future = submit_move_blocks(self.single_blocks, block_idx_to_cpu, block_idx_to_cuda) - single_futures[block_idx_to_cuda] = future + self.offloader_single.submit_move_blocks(self.single_blocks, block_idx) img = img[:, txt.shape[1] :, ...]