diff --git a/library/flux_models.py b/library/flux_models.py index 48dea4fc9..4721fa02e 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -1077,7 +1077,7 @@ def forward( def submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda): def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda): - start_time = time.perf_counter() + # start_time = time.perf_counter() # print(f"Moving {bidx_to_cpu} to cpu and {bidx_to_cuda} to cuda.") utils.swap_weight_devices(block_to_cpu, block_to_cuda) # print(f"Block move done. {bidx_to_cpu} to cpu, {bidx_to_cuda} to cuda.") @@ -1123,7 +1123,7 @@ def wait_for_blocks_move(block_idx, ftrs): if block_idx < self.single_blocks_to_swap: block_idx_to_cpu = block_idx - block_idx_to_cuda = self.num_single_blocks - self.blocks_to_swap + block_idx + block_idx_to_cuda = self.num_single_blocks - self.single_blocks_to_swap + block_idx future = submit_move_blocks(self.single_blocks, block_idx_to_cpu, block_idx_to_cuda) single_futures[block_idx_to_cuda] = future