From cebff674375b73623b9911e4cf607915e0f7b62a Mon Sep 17 00:00:00 2001 From: able2 Date: Wed, 4 Dec 2024 12:57:07 +0800 Subject: [PATCH] Change model loading sequence to reduce VRAM cost. Use mmap to reduce RAM usage during main model loading. Quantize model to bf16 before loading to GPU to save VRAM. Changes to be committed: modified: configs/UniAnimate_infer.yaml modified: tools/inferences/inference_unianimate_entrance.py --- configs/UniAnimate_infer.yaml | 2 +- .../inference_unianimate_entrance.py | 108 +++++++++--------- 2 files changed, 56 insertions(+), 54 deletions(-) diff --git a/configs/UniAnimate_infer.yaml b/configs/UniAnimate_infer.yaml index 4c8b6a2..203363e 100644 --- a/configs/UniAnimate_infer.yaml +++ b/configs/UniAnimate_infer.yaml @@ -31,7 +31,7 @@ use_fp16: True batch_size: 1 latent_random_ref: True chunk_size: 2 -decoder_bs: 2 +decoder_bs: 1 scale: 8 use_fps_condition: False test_model: checkpoints/unianimate_16f_32f_non_ema_223000.pth diff --git a/tools/inferences/inference_unianimate_entrance.py b/tools/inferences/inference_unianimate_entrance.py index 882f78e..24a081c 100644 --- a/tools/inferences/inference_unianimate_entrance.py +++ b/tools/inferences/inference_unianimate_entrance.py @@ -272,6 +272,28 @@ def is_libuv_supported(): clip_encoder.model.to(gpu) with torch.no_grad(): _, _, zero_y = clip_encoder(text="") + + # initialize reference_image, pose_sequence, frame_interval, max_frames, resolution_x, + vit_frame, video_data, misc_data, dwpose_data, random_ref_frame_data, random_ref_dwpose_data = load_video_frames(reference_image, ref_pose, pose_sequence, train_trans, vit_transforms, train_trans_pose, max_frames, frame_interval, resolution) + misc_data = misc_data.unsqueeze(0).to(gpu) + vit_frame = vit_frame.unsqueeze(0).to(gpu) + dwpose_data = dwpose_data.unsqueeze(0).to(gpu) + random_ref_frame_data = random_ref_frame_data.unsqueeze(0).to(gpu) + random_ref_dwpose_data = random_ref_dwpose_data.unsqueeze(0).to(gpu) + + with torch.no_grad(): + y_visual = [] + if 'image' in cfg.video_compositions: + with torch.no_grad(): + vit_frame = vit_frame.squeeze(1) + y_visual = clip_encoder.encode_image(vit_frame).unsqueeze(1) # [60, 1024] + y_visual0 = y_visual.clone() + + if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE: + clip_encoder.cpu() # add this line + del clip_encoder # Delete this object to free memory + import gc + gc.collect() # [Model] auotoencoder @@ -281,39 +303,6 @@ def is_libuv_supported(): param.requires_grad = False autoencoder.cuda() - # [Model] UNet - if "config" in cfg.UNet: - cfg.UNet["config"] = cfg - cfg.UNet["zero_y"] = zero_y - model = MODEL.build(cfg.UNet) - # Here comes the UniAnimate model - # inferences folder - current_directory = os.path.dirname(os.path.abspath(__file__)) - # tools folder - parent_directory = os.path.dirname(current_directory) - # uniAnimate folder - root_directory = os.path.dirname(parent_directory) - unifiedModel = os.path.join(root_directory, 'checkpoints/unianimate_16f_32f_non_ema_223000.pth') - state_dict = torch.load(unifiedModel, map_location='cpu') - if 'state_dict' in state_dict: - state_dict = state_dict['state_dict'] - if 'step' in state_dict: - resume_step = state_dict['step'] - else: - resume_step = 0 - status = model.load_state_dict(state_dict, strict=True) - # logging.info('Load model from {} with status {}'.format(unifiedModel, status)) - print(f'Load model from ({unifiedModel}) with status ({status})') - model = model.to(gpu) - model.eval() - if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE: - print("Avoiding DistributedDataParallel to reduce memory usage") - model.to(torch.float16) - else: - model = DistributedDataParallel(model, device_ids=[gpu]) if not cfg.debug else model - torch.cuda.empty_cache() - - # Where the input image and pose images come in test_list = cfg.test_list_path # num_videos = len(test_list) @@ -332,17 +321,6 @@ def is_libuv_supported(): # logging.info(f"[{idx}]/[{len(test_list)}] Begin to sample {ref_image_key}, pose sequence from {pose_seq_key} init seed {manual_seed} ...") print(f"Seed: {manual_seed}") - - # initialize reference_image, pose_sequence, frame_interval, max_frames, resolution_x, - vit_frame, video_data, misc_data, dwpose_data, random_ref_frame_data, random_ref_dwpose_data = load_video_frames(reference_image, ref_pose, pose_sequence, train_trans, vit_transforms, train_trans_pose, max_frames, frame_interval, resolution) - misc_data = misc_data.unsqueeze(0).to(gpu) - vit_frame = vit_frame.unsqueeze(0).to(gpu) - dwpose_data = dwpose_data.unsqueeze(0).to(gpu) - random_ref_frame_data = random_ref_frame_data.unsqueeze(0).to(gpu) - random_ref_dwpose_data = random_ref_dwpose_data.unsqueeze(0).to(gpu) - - - ### save for visualization misc_backups = copy(misc_data) frames_num = misc_data.shape[1] @@ -398,12 +376,7 @@ def is_libuv_supported(): dwpose_data = rearrange(dwpose_data, 'b f c h w -> b c f h w', b = bs_vd_local) - y_visual = [] - if 'image' in cfg.video_compositions: - with torch.no_grad(): - vit_frame = vit_frame.squeeze(1) - y_visual = clip_encoder.encode_image(vit_frame).unsqueeze(1) # [60, 1024] - y_visual0 = y_visual.clone() + # print(torch.get_default_dtype()) @@ -488,13 +461,42 @@ def is_libuv_supported(): noise_one = noise if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE: - clip_encoder.cpu() # add this line - del clip_encoder # Delete this object to free memory autoencoder.cpu() # add this line torch.cuda.empty_cache() # add this line - import gc gc.collect() + # [Model] UNet + if "config" in cfg.UNet: + cfg.UNet["config"] = cfg + cfg.UNet["zero_y"] = zero_y + model = MODEL.build(cfg.UNet) + # Here comes the UniAnimate model + # inferences folder + current_directory = os.path.dirname(os.path.abspath(__file__)) + # tools folder + parent_directory = os.path.dirname(current_directory) + # uniAnimate folder + root_directory = os.path.dirname(parent_directory) + unifiedModel = os.path.join(root_directory, 'checkpoints/unianimate_16f_32f_non_ema_223000.pth') + state_dict = torch.load(unifiedModel, map_location='cpu', mmap = True) + if 'state_dict' in state_dict: + state_dict = state_dict['state_dict'] + if 'step' in state_dict: + resume_step = state_dict['step'] + else: + resume_step = 0 + status = model.load_state_dict(state_dict, strict=True) + # logging.info('Load model from {} with status {}'.format(unifiedModel, status)) + print(f'Load model from ({unifiedModel}) with status ({status})') + model = model.to(torch.float16).to(gpu) + model.eval() + if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE: + print("Avoiding DistributedDataParallel to reduce memory usage") + #model.to(torch.float16) + else: + model = DistributedDataParallel(model, device_ids=[gpu]) if not cfg.debug else model + torch.cuda.empty_cache() + # print(f' noise_one is ({noise_one})') print(f"noise: {noise.shape}")