Skip to content

Commit

Permalink
Change model loading sequence to reduce VRAM cost.
Browse files Browse the repository at this point in the history
Use mmap to reduce RAM usage during main model loading.
Quantize model to bf16 before loading to GPU to save VRAM.
Changes to be committed:
	modified:   configs/UniAnimate_infer.yaml
	modified:   tools/inferences/inference_unianimate_entrance.py
  • Loading branch information
able2608 committed Dec 4, 2024
1 parent ebb9bdf commit cebff67
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 54 deletions.
2 changes: 1 addition & 1 deletion configs/UniAnimate_infer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use_fp16: True
batch_size: 1
latent_random_ref: True
chunk_size: 2
decoder_bs: 2
decoder_bs: 1
scale: 8
use_fps_condition: False
test_model: checkpoints/unianimate_16f_32f_non_ema_223000.pth
Expand Down
108 changes: 55 additions & 53 deletions tools/inferences/inference_unianimate_entrance.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,28 @@ def is_libuv_supported():
clip_encoder.model.to(gpu)
with torch.no_grad():
_, _, zero_y = clip_encoder(text="")

# initialize reference_image, pose_sequence, frame_interval, max_frames, resolution_x,
vit_frame, video_data, misc_data, dwpose_data, random_ref_frame_data, random_ref_dwpose_data = load_video_frames(reference_image, ref_pose, pose_sequence, train_trans, vit_transforms, train_trans_pose, max_frames, frame_interval, resolution)
misc_data = misc_data.unsqueeze(0).to(gpu)
vit_frame = vit_frame.unsqueeze(0).to(gpu)
dwpose_data = dwpose_data.unsqueeze(0).to(gpu)
random_ref_frame_data = random_ref_frame_data.unsqueeze(0).to(gpu)
random_ref_dwpose_data = random_ref_dwpose_data.unsqueeze(0).to(gpu)

with torch.no_grad():
y_visual = []
if 'image' in cfg.video_compositions:
with torch.no_grad():
vit_frame = vit_frame.squeeze(1)
y_visual = clip_encoder.encode_image(vit_frame).unsqueeze(1) # [60, 1024]
y_visual0 = y_visual.clone()

if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE:
clip_encoder.cpu() # add this line
del clip_encoder # Delete this object to free memory
import gc
gc.collect()


# [Model] auotoencoder
Expand All @@ -281,39 +303,6 @@ def is_libuv_supported():
param.requires_grad = False
autoencoder.cuda()

# [Model] UNet
if "config" in cfg.UNet:
cfg.UNet["config"] = cfg
cfg.UNet["zero_y"] = zero_y
model = MODEL.build(cfg.UNet)
# Here comes the UniAnimate model
# inferences folder
current_directory = os.path.dirname(os.path.abspath(__file__))
# tools folder
parent_directory = os.path.dirname(current_directory)
# uniAnimate folder
root_directory = os.path.dirname(parent_directory)
unifiedModel = os.path.join(root_directory, 'checkpoints/unianimate_16f_32f_non_ema_223000.pth')
state_dict = torch.load(unifiedModel, map_location='cpu')
if 'state_dict' in state_dict:
state_dict = state_dict['state_dict']
if 'step' in state_dict:
resume_step = state_dict['step']
else:
resume_step = 0
status = model.load_state_dict(state_dict, strict=True)
# logging.info('Load model from {} with status {}'.format(unifiedModel, status))
print(f'Load model from ({unifiedModel}) with status ({status})')
model = model.to(gpu)
model.eval()
if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE:
print("Avoiding DistributedDataParallel to reduce memory usage")
model.to(torch.float16)
else:
model = DistributedDataParallel(model, device_ids=[gpu]) if not cfg.debug else model
torch.cuda.empty_cache()


# Where the input image and pose images come in
test_list = cfg.test_list_path
# num_videos = len(test_list)
Expand All @@ -332,17 +321,6 @@ def is_libuv_supported():
# logging.info(f"[{idx}]/[{len(test_list)}] Begin to sample {ref_image_key}, pose sequence from {pose_seq_key} init seed {manual_seed} ...")
print(f"Seed: {manual_seed}")


# initialize reference_image, pose_sequence, frame_interval, max_frames, resolution_x,
vit_frame, video_data, misc_data, dwpose_data, random_ref_frame_data, random_ref_dwpose_data = load_video_frames(reference_image, ref_pose, pose_sequence, train_trans, vit_transforms, train_trans_pose, max_frames, frame_interval, resolution)
misc_data = misc_data.unsqueeze(0).to(gpu)
vit_frame = vit_frame.unsqueeze(0).to(gpu)
dwpose_data = dwpose_data.unsqueeze(0).to(gpu)
random_ref_frame_data = random_ref_frame_data.unsqueeze(0).to(gpu)
random_ref_dwpose_data = random_ref_dwpose_data.unsqueeze(0).to(gpu)



### save for visualization
misc_backups = copy(misc_data)
frames_num = misc_data.shape[1]
Expand Down Expand Up @@ -398,12 +376,7 @@ def is_libuv_supported():
dwpose_data = rearrange(dwpose_data, 'b f c h w -> b c f h w', b = bs_vd_local)


y_visual = []
if 'image' in cfg.video_compositions:
with torch.no_grad():
vit_frame = vit_frame.squeeze(1)
y_visual = clip_encoder.encode_image(vit_frame).unsqueeze(1) # [60, 1024]
y_visual0 = y_visual.clone()


# print(torch.get_default_dtype())

Expand Down Expand Up @@ -488,13 +461,42 @@ def is_libuv_supported():
noise_one = noise

if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE:
clip_encoder.cpu() # add this line
del clip_encoder # Delete this object to free memory
autoencoder.cpu() # add this line
torch.cuda.empty_cache() # add this line
import gc
gc.collect()

# [Model] UNet
if "config" in cfg.UNet:
cfg.UNet["config"] = cfg
cfg.UNet["zero_y"] = zero_y
model = MODEL.build(cfg.UNet)
# Here comes the UniAnimate model
# inferences folder
current_directory = os.path.dirname(os.path.abspath(__file__))
# tools folder
parent_directory = os.path.dirname(current_directory)
# uniAnimate folder
root_directory = os.path.dirname(parent_directory)
unifiedModel = os.path.join(root_directory, 'checkpoints/unianimate_16f_32f_non_ema_223000.pth')
state_dict = torch.load(unifiedModel, map_location='cpu', mmap = True)
if 'state_dict' in state_dict:
state_dict = state_dict['state_dict']
if 'step' in state_dict:
resume_step = state_dict['step']
else:
resume_step = 0
status = model.load_state_dict(state_dict, strict=True)
# logging.info('Load model from {} with status {}'.format(unifiedModel, status))
print(f'Load model from ({unifiedModel}) with status ({status})')
model = model.to(torch.float16).to(gpu)
model.eval()
if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE:
print("Avoiding DistributedDataParallel to reduce memory usage")
#model.to(torch.float16)
else:
model = DistributedDataParallel(model, device_ids=[gpu]) if not cfg.debug else model
torch.cuda.empty_cache()

# print(f' noise_one is ({noise_one})')
print(f"noise: {noise.shape}")

Expand Down

0 comments on commit cebff67

Please sign in to comment.