Skip to content

Commit

Permalink
Update inference_unianimate_entrance.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Isi-dev authored Sep 7, 2024
1 parent 50daa52 commit 3c62a5c
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions tools/inferences/inference_unianimate_entrance.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,9 @@ def worker(gpu, seed, steps, useFirstFrame, reference_image, ref_pose, pose_sequ
noise = torch.randn([1, 4, frames_num, int(cfg.resolution[1]/cfg.scale), int(cfg.resolution[0]/cfg.scale)])
noise = noise.to(gpu)
# print(f"noise: {noise.shape}")

# add a noise prior
noise = diffusion.q_sample(random_ref_frame.clone(), getattr(cfg, "noise_prior_value", 949), noise=noise)


if hasattr(cfg.Diffusion, "noise_strength"):
Expand All @@ -417,8 +420,7 @@ def worker(gpu, seed, steps, useFirstFrame, reference_image, ref_pose, pose_sequ



# add a noise prior
noise = diffusion.q_sample(random_ref_frame.clone(), getattr(cfg, "noise_prior_value", 949), noise=noise)


# construct model inputs (CFG)
full_model_kwargs=[{
Expand Down Expand Up @@ -489,6 +491,8 @@ def worker(gpu, seed, steps, useFirstFrame, reference_image, ref_pose, pose_sequ
ddim_timesteps=steps,
eta=0.0)

# print(f"video_data dtype: {video_data.dtype}")

if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE:
# if run forward of autoencoder or clip_encoder second times, load them again
clip_encoder.cuda()
Expand All @@ -504,15 +508,12 @@ def worker(gpu, seed, steps, useFirstFrame, reference_image, ref_pose, pose_sequ
video_data = torch.cat(decode_data, dim=0)
video_data = rearrange(video_data, '(b f) c h w -> b c f h w', b = cfg.batch_size).float()

# Check sth

# print(f' video_data is of shape ({video_data.shape})')
# print(f' video_data is ({video_data})')

del model_kwargs_one_vis[0][list(model_kwargs_one_vis[0].keys())[0]]
del model_kwargs_one_vis[1][list(model_kwargs_one_vis[1].keys())[0]]

video_data = extract_image_tensors(video_data.cpu(), cfg.mean, cfg.std)


# synchronize to finish some processes
if not cfg.debug:
Expand Down

0 comments on commit 3c62a5c

Please sign in to comment.