diff --git a/tools/inferences/inference_unianimate_long_entrance.py b/tools/inferences/inference_unianimate_long_entrance.py index 5e841f5..7da0013 100644 --- a/tools/inferences/inference_unianimate_long_entrance.py +++ b/tools/inferences/inference_unianimate_long_entrance.py @@ -49,11 +49,10 @@ from ...tools.modules.autoencoder import get_first_stage_encoding from ...utils.registry_class import INFER_ENGINE, MODEL, EMBEDDER, AUTO_ENCODER, DIFFUSION from copy import copy -import cv2 -@INFER_ENGINE.register_function() -def inference_unianimate_long_entrance(cfg_update, **kwargs): +# @INFER_ENGINE.register_function() +def inference_unianimate_long_entrance(seed, steps, useFirstFrame, reference_image, refPose, pose_sequence, frame_interval, context_size, context_stride, context_overlap, max_frames, resolution, cfg_update, **kwargs): for k, v in cfg_update.items(): if isinstance(v, dict) and k in cfg: cfg[k].update(v) @@ -74,9 +73,9 @@ def inference_unianimate_long_entrance(cfg_update, **kwargs): cfg.world_size = cfg.pmi_world_size * cfg.gpus_per_machine if cfg.world_size == 1: - worker(0, cfg, cfg_update) + return worker(0, seed, steps, useFirstFrame, reference_image, refPose, pose_sequence, frame_interval, context_size, context_stride, context_overlap, max_frames, resolution, cfg, cfg_update) else: - mp.spawn(worker, nprocs=cfg.gpus_per_machine, args=(cfg, cfg_update)) + return mp.spawn(worker, nprocs=cfg.gpus_per_machine, args=(cfg, cfg_update)) return cfg @@ -87,105 +86,116 @@ def make_masked_images(imgs, masks): masked_imgs.append(torch.cat([imgs[i] * (1 - mask), (1 - mask)], dim=1)) return torch.stack(masked_imgs, dim=0) -def load_video_frames(ref_image_path, pose_file_path, train_trans, vit_transforms, train_trans_pose, max_frames=32, frame_interval = 1, resolution=[512, 768], get_first_frame=True, vit_resolution=[224, 224]): - +def load_video_frames(ref_image_tensor, ref_pose_tensor, pose_tensors, train_trans, vit_transforms, train_trans_pose, max_frames=32, frame_interval=1, resolution=[512, 768], get_first_frame=True, vit_resolution=[224, 224]): for _ in range(5): try: - dwpose_all = {} - frames_all = {} - for ii_index in sorted(os.listdir(pose_file_path)): - if ii_index != "ref_pose.jpg": - dwpose_all[ii_index] = Image.open(pose_file_path+"/"+ii_index) - frames_all[ii_index] = Image.fromarray(cv2.cvtColor(cv2.imread(ref_image_path),cv2.COLOR_BGR2RGB)) - # frames_all[ii_index] = Image.open(ref_image_path) + num_poses = len(pose_tensors) + numpyFrames = [] + numpyPoses = [] - pose_ref = Image.open(os.path.join(pose_file_path, "ref_pose.jpg")) - first_eq_ref = False + # Convert tensors to numpy arrays and prepare lists + for i in range(num_poses): + frame = ref_image_tensor.squeeze(0).cpu().numpy() # Convert to numpy array + # if i == 0: + # print(f'ref image is ({frame})') + numpyFrames.append(frame) + + pose = pose_tensors[i].squeeze(0).cpu().numpy() # Convert to numpy array + numpyPoses.append(pose) + + # Convert reference pose tensor to numpy array + pose_ref = ref_pose_tensor.squeeze(0).cpu().numpy() # Convert to numpy array - # sample max_frames poses for video generation + # Sample max_frames poses for video generation stride = frame_interval - _total_frame_num = len(frames_all) - if max_frames == "None": - max_frames = (_total_frame_num-1)//frame_interval + 1 - cover_frame_num = (stride * (max_frames-1)+1) - if _total_frame_num < cover_frame_num: - print('_total_frame_num is smaller than cover_frame_num, the sampled frame interval is changed') - start_frame = 0 # we set start_frame = 0 because the pose alignment is performed on the first frame - end_frame = _total_frame_num - stride = max((_total_frame_num-1//(max_frames-1)),1) - end_frame = stride*max_frames + total_frame_num = len(numpyFrames) + if max_frames == 1024000: + max_frames = (total_frame_num-1)//frame_interval + 1 + cover_frame_num = (stride * (max_frames - 1) + 1) + + if total_frame_num < cover_frame_num: + print(f'_total_frame_num ({total_frame_num}) is smaller than cover_frame_num ({cover_frame_num}), the sampled frame interval is changed') + start_frame = 0 + end_frame = total_frame_num + stride = max((total_frame_num - 1) // (max_frames - 1), 1) + end_frame = stride * max_frames else: - start_frame = 0 # we set start_frame = 0 because the pose alignment is performed on the first frame + start_frame = 0 end_frame = start_frame + cover_frame_num - + frame_list = [] dwpose_list = [] - random_ref_frame = frames_all[list(frames_all.keys())[0]] - if random_ref_frame.mode != 'RGB': - random_ref_frame = random_ref_frame.convert('RGB') - random_ref_dwpose = pose_ref - if random_ref_dwpose.mode != 'RGB': - random_ref_dwpose = random_ref_dwpose.convert('RGB') + + print(f'end_frame is ({end_frame})') + for i_index in range(start_frame, end_frame, stride): - if i_index == start_frame and first_eq_ref: - i_key = list(frames_all.keys())[i_index] - i_frame = frames_all[i_key] - - if i_frame.mode != 'RGB': - i_frame = i_frame.convert('RGB') - i_dwpose = frames_pose_ref - if i_dwpose.mode != 'RGB': - i_dwpose = i_dwpose.convert('RGB') - frame_list.append(i_frame) - dwpose_list.append(i_dwpose) - else: - # added - if first_eq_ref: - i_index = i_index - stride - - i_key = list(frames_all.keys())[i_index] - i_frame = frames_all[i_key] - if i_frame.mode != 'RGB': - i_frame = i_frame.convert('RGB') - i_dwpose = dwpose_all[i_key] - if i_dwpose.mode != 'RGB': - i_dwpose = i_dwpose.convert('RGB') + if i_index < len(numpyFrames): # Check index within bounds + i_frame = numpyFrames[i_index] + i_dwpose = numpyPoses[i_index] + + # Convert numpy arrays to PIL images + # i_frame = np.clip(i_frame, 0, 1) + i_frame = (i_frame - i_frame.min()) / (i_frame.max() - i_frame.min()) #Trying this in place of clip + i_frame = Image.fromarray((i_frame * 255).astype(np.uint8)) + i_frame = i_frame.convert('RGB') + # i_dwpose = np.clip(i_dwpose, 0, 1) + i_dwpose = (i_dwpose - i_dwpose.min()) / (i_dwpose.max() - i_dwpose.min()) #Trying this in place of clip + i_dwpose = Image.fromarray((i_dwpose * 255).astype(np.uint8)) + i_dwpose = i_dwpose.convert('RGB') + + # if i_index == 0: + # print(f'i_frame is ({np.array(i_frame)})') + frame_list.append(i_frame) dwpose_list.append(i_dwpose) - have_frames = len(frame_list)>0 - middle_indix = 0 - if have_frames: - ref_frame = frame_list[middle_indix] + + if frame_list: + # random_ref_frame = np.clip(numpyFrames[0], 0, 1) + random_ref_frame = (numpyFrames[0] - numpyFrames[0].min()) / (numpyFrames[0].max() - numpyFrames[0].min()) #Trying this in place of clip + random_ref_frame = Image.fromarray((random_ref_frame * 255).astype(np.uint8)) + if random_ref_frame.mode != 'RGB': + random_ref_frame = random_ref_frame.convert('RGB') + # random_ref_dwpose = np.clip(pose_ref, 0, 1) + random_ref_dwpose = (pose_ref - pose_ref.min()) / (pose_ref.max() - pose_ref.min()) #Trying this in place of clip + random_ref_dwpose = Image.fromarray((random_ref_dwpose * 255).astype(np.uint8)) + if random_ref_dwpose.mode != 'RGB': + random_ref_dwpose = random_ref_dwpose.convert('RGB') + + # Apply transforms + ref_frame = frame_list[0] vit_frame = vit_transforms(ref_frame) random_ref_frame_tmp = train_trans_pose(random_ref_frame) - random_ref_dwpose_tmp = train_trans_pose(random_ref_dwpose) + random_ref_dwpose_tmp = train_trans_pose(random_ref_dwpose) misc_data_tmp = torch.stack([train_trans_pose(ss) for ss in frame_list], dim=0) - video_data_tmp = torch.stack([train_trans(ss) for ss in frame_list], dim=0) + video_data_tmp = torch.stack([train_trans(ss) for ss in frame_list], dim=0) dwpose_data_tmp = torch.stack([train_trans_pose(ss) for ss in dwpose_list], dim=0) - - video_data = torch.zeros(max_frames, 3, resolution[1], resolution[0]) - dwpose_data = torch.zeros(max_frames, 3, resolution[1], resolution[0]) - misc_data = torch.zeros(max_frames, 3, resolution[1], resolution[0]) - random_ref_frame_data = torch.zeros(max_frames, 3, resolution[1], resolution[0]) # [32, 3, 512, 768] - random_ref_dwpose_data = torch.zeros(max_frames, 3, resolution[1], resolution[0]) - if have_frames: - video_data[:len(frame_list), ...] = video_data_tmp + + # Initialize tensors + video_data = torch.zeros(max_frames, 3, resolution[1], resolution[0]) + dwpose_data = torch.zeros(max_frames, 3, resolution[1], resolution[0]) + misc_data = torch.zeros(max_frames, 3, resolution[1], resolution[0]) + random_ref_frame_data = torch.zeros(max_frames, 3, resolution[1], resolution[0]) + random_ref_dwpose_data = torch.zeros(max_frames, 3, resolution[1], resolution[0]) + + # Copy data to tensors + video_data[:len(frame_list), ...] = video_data_tmp misc_data[:len(frame_list), ...] = misc_data_tmp dwpose_data[:len(frame_list), ...] = dwpose_data_tmp - random_ref_frame_data[:,...] = random_ref_frame_tmp - random_ref_dwpose_data[:,...] = random_ref_dwpose_tmp - - break - + random_ref_frame_data[:, ...] = random_ref_frame_tmp + random_ref_dwpose_data[:, ...] = random_ref_dwpose_tmp + + return vit_frame, video_data, misc_data, dwpose_data, random_ref_frame_data, random_ref_dwpose_data, max_frames + except Exception as e: - logging.info('{} read video frame failed with error: {}'.format(pose_file_path, e)) + # logging.info(f'Error reading video frame: {e}') + print(f'Error reading video frame: ({e})') continue - - return vit_frame, video_data, misc_data, dwpose_data, random_ref_frame_data, random_ref_dwpose_data, max_frames + + return None, None, None, None, None, None, None # Return default values if all attempts fail -def worker(gpu, cfg, cfg_update): +def worker(gpu, seed, steps, useFirstFrame, reference_image, ref_pose, pose_sequence, frame_interval, context_size, context_stride, context_overlap, max_frames, resolution, cfg, cfg_update): ''' Inference worker for each gpu ''' @@ -196,7 +206,7 @@ def worker(gpu, cfg, cfg_update): cfg[k] = v cfg.gpu = gpu - cfg.seed = int(cfg.seed) + cfg.seed = int(seed) cfg.rank = cfg.pmi_rank * cfg.gpus_per_machine + gpu setup_seed(cfg.seed + cfg.rank) @@ -205,39 +215,41 @@ def worker(gpu, cfg, cfg_update): torch.backends.cudnn.benchmark = True if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE: torch.backends.cudnn.benchmark = False - dist.init_process_group(backend='nccl', world_size=cfg.world_size, rank=cfg.rank) + if not dist.is_initialized(): + dist.init_process_group(backend='gloo', world_size=cfg.world_size, rank=cfg.rank) # [Log] Save logging and make log dir - log_dir = generalized_all_gather(cfg.log_dir)[0] - inf_name = osp.basename(cfg.cfg_file).split('.')[0] - test_model = osp.basename(cfg.test_model).split('.')[0].split('_')[-1] + # log_dir = generalized_all_gather(cfg.log_dir)[0] + # inf_name = osp.basename(cfg.cfg_file).split('.')[0] + # test_model = osp.basename(cfg.test_model).split('.')[0].split('_')[-1] - cfg.log_dir = osp.join(cfg.log_dir, '%s' % (inf_name)) - os.makedirs(cfg.log_dir, exist_ok=True) - log_file = osp.join(cfg.log_dir, 'log_%02d.txt' % (cfg.rank)) - cfg.log_file = log_file - reload(logging) - logging.basicConfig( - level=logging.INFO, - format='[%(asctime)s] %(levelname)s: %(message)s', - handlers=[ - logging.FileHandler(filename=log_file), - logging.StreamHandler(stream=sys.stdout)]) - logging.info(cfg) - logging.info(f"Running UniAnimate inference on gpu {gpu}") + # cfg.log_dir = osp.join(cfg.log_dir, '%s' % (inf_name)) + # os.makedirs(cfg.log_dir, exist_ok=True) + # log_file = osp.join(cfg.log_dir, 'log_%02d.txt' % (cfg.rank)) + # cfg.log_file = log_file + # reload(logging) + # logging.basicConfig( + # level=logging.INFO, + # format='[%(asctime)s] %(levelname)s: %(message)s', + # handlers=[ + # logging.FileHandler(filename=log_file), + # logging.StreamHandler(stream=sys.stdout)]) + # logging.info(cfg) + # logging.info(f"Running UniAnimate inference on gpu {gpu}") + print(f'Running UniAnimate inference on gpu ({gpu})') # [Diffusion] diffusion = DIFFUSION.build(cfg.Diffusion) # [Data] Data Transform train_trans = data.Compose([ - data.Resize(cfg.resolution), + data.Resize(resolution), data.ToTensor(), data.Normalize(mean=cfg.mean, std=cfg.std) ]) train_trans_pose = data.Compose([ - data.Resize(cfg.resolution), + data.Resize(resolution), data.ToTensor(), ] ) @@ -266,7 +278,12 @@ def worker(gpu, cfg, cfg_update): cfg.UNet["config"] = cfg cfg.UNet["zero_y"] = zero_y model = MODEL.build(cfg.UNet) - state_dict = torch.load(cfg.test_model, map_location='cpu') + + current_directory = os.path.dirname(os.path.abspath(__file__)) + parent_directory = os.path.dirname(current_directory) + root_directory = os.path.dirname(parent_directory) + unifiedModel = os.path.join(root_directory, 'checkpoints/unianimate_16f_32f_non_ema_223000.pth') + state_dict = torch.load(unifiedModel, map_location='cpu') if 'state_dict' in state_dict: state_dict = state_dict['state_dict'] if 'step' in state_dict: @@ -274,10 +291,11 @@ def worker(gpu, cfg, cfg_update): else: resume_step = 0 status = model.load_state_dict(state_dict, strict=True) - logging.info('Load model from {} with status {}'.format(cfg.test_model, status)) + print(f'Load model from ({unifiedModel}) with status ({status})') model = model.to(gpu) model.eval() if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE: + print("Avoiding DistributedDataParallel to reduce memory usage") model.to(torch.float16) else: model = DistributedDataParallel(model, device_ids=[gpu]) if not cfg.debug else model @@ -286,223 +304,238 @@ def worker(gpu, cfg, cfg_update): test_list = cfg.test_list_path - num_videos = len(test_list) - logging.info(f'There are {num_videos} videos. with {cfg.round} times') + # num_videos = len(test_list) + # logging.info(f'There are {num_videos} videos. with {cfg.round} times') test_list = [item for _ in range(cfg.round) for item in test_list] - for idx, file_path in enumerate(test_list): - cfg.frame_interval, ref_image_key, pose_seq_key = file_path[0], file_path[1], file_path[2] - - manual_seed = int(cfg.seed + cfg.rank + idx//num_videos) - setup_seed(manual_seed) - logging.info(f"[{idx}]/[{len(test_list)}] Begin to sample {ref_image_key}, pose sequence from {pose_seq_key} init seed {manual_seed} ...") + # for idx, file_path in enumerate(test_list): + # cfg.frame_interval, ref_image_key, pose_seq_key = file_path[0], file_path[1], file_path[2] + + manual_seed = int(cfg.seed + cfg.rank) + setup_seed(manual_seed) + # logging.info(f"[{idx}]/[{len(test_list)}] Begin to sample {ref_image_key}, pose sequence from {pose_seq_key} init seed {manual_seed} ...") + print(f"Seed: {manual_seed}") + + + vit_frame, video_data, misc_data, dwpose_data, random_ref_frame_data, random_ref_dwpose_data, max_frames = load_video_frames(reference_image, ref_pose, pose_sequence, train_trans, vit_transforms, train_trans_pose, max_frames, frame_interval, resolution) + cfg.max_frames_new = max_frames + misc_data = misc_data.unsqueeze(0).to(gpu) + vit_frame = vit_frame.unsqueeze(0).to(gpu) + dwpose_data = dwpose_data.unsqueeze(0).to(gpu) + random_ref_frame_data = random_ref_frame_data.unsqueeze(0).to(gpu) + random_ref_dwpose_data = random_ref_dwpose_data.unsqueeze(0).to(gpu) + + ### save for visualization + misc_backups = copy(misc_data) + frames_num = misc_data.shape[1] + misc_backups = rearrange(misc_backups, 'b f c h w -> b c f h w') + mv_data_video = [] + + + ### local image (first frame) + image_local = [] + if 'local_image' in cfg.video_compositions: + frames_num = misc_data.shape[1] + bs_vd_local = misc_data.shape[0] + image_local = misc_data[:,:1].clone().repeat(1,frames_num,1,1,1) + image_local_clone = rearrange(image_local, 'b f c h w -> b c f h w', b = bs_vd_local) + image_local = rearrange(image_local, 'b f c h w -> b c f h w', b = bs_vd_local) + if hasattr(cfg, "latent_local_image") and cfg.latent_local_image: + with torch.no_grad(): + temporal_length = frames_num + encoder_posterior = autoencoder.encode(video_data[:,0]) + local_image_data = get_first_stage_encoding(encoder_posterior).detach() + image_local = local_image_data.unsqueeze(1).repeat(1,temporal_length,1,1,1) # [10, 16, 4, 64, 40] + + + ### encode the video_data + # bs_vd = misc_data.shape[0] + misc_data = rearrange(misc_data, 'b f c h w -> (b f) c h w') + # misc_data_list = torch.chunk(misc_data, misc_data.shape[0]//cfg.chunk_size,dim=0) + + + with torch.no_grad(): - vit_frame, video_data, misc_data, dwpose_data, random_ref_frame_data, random_ref_dwpose_data, max_frames = load_video_frames(ref_image_key, pose_seq_key, train_trans, vit_transforms, train_trans_pose, max_frames=cfg.max_frames, frame_interval =cfg.frame_interval, resolution=cfg.resolution) - cfg.max_frames_new = max_frames - misc_data = misc_data.unsqueeze(0).to(gpu) - vit_frame = vit_frame.unsqueeze(0).to(gpu) - dwpose_data = dwpose_data.unsqueeze(0).to(gpu) - random_ref_frame_data = random_ref_frame_data.unsqueeze(0).to(gpu) - random_ref_dwpose_data = random_ref_dwpose_data.unsqueeze(0).to(gpu) - - ### save for visualization - misc_backups = copy(misc_data) - frames_num = misc_data.shape[1] - misc_backups = rearrange(misc_backups, 'b f c h w -> b c f h w') - mv_data_video = [] + random_ref_frame = [] + if 'randomref' in cfg.video_compositions: + random_ref_frame_clone = rearrange(random_ref_frame_data, 'b f c h w -> b c f h w') + if hasattr(cfg, "latent_random_ref") and cfg.latent_random_ref: + + temporal_length = random_ref_frame_data.shape[1] + encoder_posterior = autoencoder.encode(random_ref_frame_data[:,0].sub(0.5).div_(0.5)) + random_ref_frame_data = get_first_stage_encoding(encoder_posterior).detach() + random_ref_frame_data = random_ref_frame_data.unsqueeze(1).repeat(1,temporal_length,1,1,1) # [10, 16, 4, 64, 40] + + random_ref_frame = rearrange(random_ref_frame_data, 'b f c h w -> b c f h w') + + + if 'dwpose' in cfg.video_compositions: + bs_vd_local = dwpose_data.shape[0] + dwpose_data_clone = rearrange(dwpose_data.clone(), 'b f c h w -> b c f h w', b = bs_vd_local) + if 'randomref_pose' in cfg.video_compositions: + dwpose_data = torch.cat([random_ref_dwpose_data[:,:1], dwpose_data], dim=1) + dwpose_data = rearrange(dwpose_data, 'b f c h w -> b c f h w', b = bs_vd_local) + + y_visual = [] + if 'image' in cfg.video_compositions: + with torch.no_grad(): + vit_frame = vit_frame.squeeze(1) + y_visual = clip_encoder.encode_image(vit_frame).unsqueeze(1) # [60, 1024] + y_visual0 = y_visual.clone() + - ### local image (first frame) - image_local = [] - if 'local_image' in cfg.video_compositions: - frames_num = misc_data.shape[1] - bs_vd_local = misc_data.shape[0] - image_local = misc_data[:,:1].clone().repeat(1,frames_num,1,1,1) - image_local_clone = rearrange(image_local, 'b f c h w -> b c f h w', b = bs_vd_local) - image_local = rearrange(image_local, 'b f c h w -> b c f h w', b = bs_vd_local) - if hasattr(cfg, "latent_local_image") and cfg.latent_local_image: - with torch.no_grad(): - temporal_length = frames_num - encoder_posterior = autoencoder.encode(video_data[:,0]) - local_image_data = get_first_stage_encoding(encoder_posterior).detach() - image_local = local_image_data.unsqueeze(1).repeat(1,temporal_length,1,1,1) # [10, 16, 4, 64, 40] - + with amp.autocast(enabled=True): + # pynvml.nvmlInit() + # handle=pynvml.nvmlDeviceGetHandleByIndex(0) + # meminfo=pynvml.nvmlDeviceGetMemoryInfo(handle) + cur_seed = torch.initial_seed() + # logging.info(f"Current seed {cur_seed} ..., cfg.max_frames_new: {cfg.max_frames_new} ....") + print(f"Number of frames to denoise: {frames_num}") + noise = torch.randn([1, 4, cfg.max_frames_new, int(resolution[1]/cfg.scale), int(resolution[0]/cfg.scale)]) + noise = noise.to(gpu) + + # add a noise prior + noise = diffusion.q_sample(random_ref_frame.clone(), getattr(cfg, "noise_prior_value", 939), noise=noise) - ### encode the video_data - bs_vd = misc_data.shape[0] - misc_data = rearrange(misc_data, 'b f c h w -> (b f) c h w') - misc_data_list = torch.chunk(misc_data, misc_data.shape[0]//cfg.chunk_size,dim=0) + if hasattr(cfg.Diffusion, "noise_strength"): + b, c, f, _, _= noise.shape + offset_noise = torch.randn(b, c, f, 1, 1, device=noise.device) + noise = noise + cfg.Diffusion.noise_strength * offset_noise + # construct model inputs (CFG) + full_model_kwargs=[{ + 'y': None, + "local_image": None if len(image_local) == 0 else image_local[:], + 'image': None if len(y_visual) == 0 else y_visual0[:], + 'dwpose': None if len(dwpose_data) == 0 else dwpose_data[:], + 'randomref': None if len(random_ref_frame) == 0 else random_ref_frame[:], + }, + { + 'y': None, + "local_image": None, + 'image': None, + 'randomref': None, + 'dwpose': None, + }] + + # for visualization + full_model_kwargs_vis =[{ + 'y': None, + "local_image": None if len(image_local) == 0 else image_local_clone[:], + 'image': None, + 'dwpose': None if len(dwpose_data_clone) == 0 else dwpose_data_clone[:], + 'randomref': None if len(random_ref_frame) == 0 else random_ref_frame_clone[:, :3], + }, + { + 'y': None, + "local_image": None, + 'image': None, + 'randomref': None, + 'dwpose': None, + }] - with torch.no_grad(): - - random_ref_frame = [] - if 'randomref' in cfg.video_compositions: - random_ref_frame_clone = rearrange(random_ref_frame_data, 'b f c h w -> b c f h w') - if hasattr(cfg, "latent_random_ref") and cfg.latent_random_ref: - - temporal_length = random_ref_frame_data.shape[1] - encoder_posterior = autoencoder.encode(random_ref_frame_data[:,0].sub(0.5).div_(0.5)) - random_ref_frame_data = get_first_stage_encoding(encoder_posterior).detach() - random_ref_frame_data = random_ref_frame_data.unsqueeze(1).repeat(1,temporal_length,1,1,1) # [10, 16, 4, 64, 40] - - random_ref_frame = rearrange(random_ref_frame_data, 'b f c h w -> b c f h w') - - - if 'dwpose' in cfg.video_compositions: - bs_vd_local = dwpose_data.shape[0] - dwpose_data_clone = rearrange(dwpose_data.clone(), 'b f c h w -> b c f h w', b = bs_vd_local) - if 'randomref_pose' in cfg.video_compositions: - dwpose_data = torch.cat([random_ref_dwpose_data[:,:1], dwpose_data], dim=1) - dwpose_data = rearrange(dwpose_data, 'b f c h w -> b c f h w', b = bs_vd_local) + + partial_keys = [ + ['image', 'randomref', "dwpose"], + ] + # if hasattr(cfg, "partial_keys") and cfg.partial_keys: + # partial_keys = cfg.partial_keys + if useFirstFrame: + partial_keys = [ + ['image', 'local_image', "dwpose"], + ] + print('Using First Frame Conditioning!') + + for partial_keys_one in partial_keys: + model_kwargs_one = prepare_model_kwargs(partial_keys = partial_keys_one, + full_model_kwargs = full_model_kwargs, + use_fps_condition = cfg.use_fps_condition) + model_kwargs_one_vis = prepare_model_kwargs(partial_keys = partial_keys_one, + full_model_kwargs = full_model_kwargs_vis, + use_fps_condition = cfg.use_fps_condition) + noise_one = noise - y_visual = [] - if 'image' in cfg.video_compositions: - with torch.no_grad(): - vit_frame = vit_frame.squeeze(1) - y_visual = clip_encoder.encode_image(vit_frame).unsqueeze(1) # [60, 1024] - y_visual0 = y_visual.clone() - - - with amp.autocast(enabled=True): - # pynvml.nvmlInit() - # handle=pynvml.nvmlDeviceGetHandleByIndex(0) - # meminfo=pynvml.nvmlDeviceGetMemoryInfo(handle) - cur_seed = torch.initial_seed() - logging.info(f"Current seed {cur_seed} ..., cfg.max_frames_new: {cfg.max_frames_new} ....") - - noise = torch.randn([1, 4, cfg.max_frames_new, int(cfg.resolution[1]/cfg.scale), int(cfg.resolution[0]/cfg.scale)]) - noise = noise.to(gpu) - - # add a noise prior - noise = diffusion.q_sample(random_ref_frame.clone(), getattr(cfg, "noise_prior_value", 939), noise=noise) - - if hasattr(cfg.Diffusion, "noise_strength"): - b, c, f, _, _= noise.shape - offset_noise = torch.randn(b, c, f, 1, 1, device=noise.device) - noise = noise + cfg.Diffusion.noise_strength * offset_noise + if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE: + clip_encoder.cpu() # add this line + autoencoder.cpu() # add this line + torch.cuda.empty_cache() # add this line + + video_data = diffusion.ddim_sample_loop( + noise=noise_one, + context_size=context_size, + context_stride=context_stride, + context_overlap=context_overlap, + model=model.eval(), + model_kwargs=model_kwargs_one, + guide_scale=cfg.guide_scale, + ddim_timesteps=steps, + eta=0.0, + context_batch_size=getattr(cfg, "context_batch_size", 1) + ) - # construct model inputs (CFG) - full_model_kwargs=[{ - 'y': None, - "local_image": None if len(image_local) == 0 else image_local[:], - 'image': None if len(y_visual) == 0 else y_visual0[:], - 'dwpose': None if len(dwpose_data) == 0 else dwpose_data[:], - 'randomref': None if len(random_ref_frame) == 0 else random_ref_frame[:], - }, - { - 'y': None, - "local_image": None, - 'image': None, - 'randomref': None, - 'dwpose': None, - }] - - # for visualization - full_model_kwargs_vis =[{ - 'y': None, - "local_image": None if len(image_local) == 0 else image_local_clone[:], - 'image': None, - 'dwpose': None if len(dwpose_data_clone) == 0 else dwpose_data_clone[:], - 'randomref': None if len(random_ref_frame) == 0 else random_ref_frame_clone[:, :3], - }, - { - 'y': None, - "local_image": None, - 'image': None, - 'randomref': None, - 'dwpose': None, - }] + # print(f"video_data dtype: {video_data.dtype}") - - partial_keys = [ - ['image', 'randomref', "dwpose"], - ] - if hasattr(cfg, "partial_keys") and cfg.partial_keys: - partial_keys = cfg.partial_keys - - for partial_keys_one in partial_keys: - model_kwargs_one = prepare_model_kwargs(partial_keys = partial_keys_one, - full_model_kwargs = full_model_kwargs, - use_fps_condition = cfg.use_fps_condition) - model_kwargs_one_vis = prepare_model_kwargs(partial_keys = partial_keys_one, - full_model_kwargs = full_model_kwargs_vis, - use_fps_condition = cfg.use_fps_condition) - noise_one = noise - - if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE: - clip_encoder.cpu() # add this line - autoencoder.cpu() # add this line - torch.cuda.empty_cache() # add this line - - video_data = diffusion.ddim_sample_loop( - noise=noise_one, - context_size=cfg.context_size, - context_stride=cfg.context_stride, - context_overlap=cfg.context_overlap, - model=model.eval(), - model_kwargs=model_kwargs_one, - guide_scale=cfg.guide_scale, - ddim_timesteps=cfg.ddim_timesteps, - eta=0.0, - context_batch_size=getattr(cfg, "context_batch_size", 1) - ) - - if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE: - # if run forward of autoencoder or clip_encoder second times, load them again - clip_encoder.cuda() - autoencoder.cuda() - - - video_data = 1. / cfg.scale_factor * video_data # [1, 4, h, w] - video_data = rearrange(video_data, 'b c f h w -> (b f) c h w') - chunk_size = min(cfg.decoder_bs, video_data.shape[0]) - video_data_list = torch.chunk(video_data, video_data.shape[0]//chunk_size, dim=0) - decode_data = [] - for vd_data in video_data_list: - gen_frames = autoencoder.decode(vd_data) - decode_data.append(gen_frames) - video_data = torch.cat(decode_data, dim=0) - video_data = rearrange(video_data, '(b f) c h w -> b c f h w', b = cfg.batch_size).float() - - text_size = cfg.resolution[-1] - cap_name = re.sub(r'[^\w\s]', '', ref_image_key.split("/")[-1].split('.')[0]) # .replace(' ', '_') - name = f'seed_{cur_seed}' - for ii in partial_keys_one: - name = name + "_" + ii - file_name = f'rank_{cfg.world_size:02d}_{cfg.rank:02d}_{idx:02d}_{name}_{cap_name}_{cfg.resolution[1]}x{cfg.resolution[0]}.mp4' - local_path = os.path.join(cfg.log_dir, f'{file_name}') - os.makedirs(os.path.dirname(local_path), exist_ok=True) - captions = "human" - del model_kwargs_one_vis[0][list(model_kwargs_one_vis[0].keys())[0]] - del model_kwargs_one_vis[1][list(model_kwargs_one_vis[1].keys())[0]] - - save_video_multiple_conditions_not_gif_horizontal_3col(local_path, video_data.cpu(), model_kwargs_one_vis, misc_backups, - cfg.mean, cfg.std, nrow=1, save_fps=cfg.save_fps) + if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE: + # if run forward of autoencoder or clip_encoder second times, load them again + clip_encoder.cuda() + autoencoder.cuda() - # try: - # save_t2vhigen_video_safe(local_path, video_data.cpu(), captions, cfg.mean, cfg.std, text_size) - # logging.info('Save video to dir %s:' % (local_path)) - # except Exception as e: - # logging.info(f'Step: save text or video error with {e}') - logging.info('Congratulations! The inference is completed!') - # synchronize to finish some processes - if not cfg.debug: - torch.cuda.synchronize() - dist.barrier() + video_data = 1. / cfg.scale_factor * video_data # [1, 4, h, w] + video_data = rearrange(video_data, 'b c f h w -> (b f) c h w') + chunk_size = min(cfg.decoder_bs, video_data.shape[0]) + video_data_list = torch.chunk(video_data, video_data.shape[0]//chunk_size, dim=0) + decode_data = [] + for vd_data in video_data_list: + gen_frames = autoencoder.decode(vd_data) + decode_data.append(gen_frames) + video_data = torch.cat(decode_data, dim=0) + video_data = rearrange(video_data, '(b f) c h w -> b c f h w', b = cfg.batch_size).float() + + # print(f' video_data is of shape ({video_data.shape})') + + del model_kwargs_one_vis[0][list(model_kwargs_one_vis[0].keys())[0]] + del model_kwargs_one_vis[1][list(model_kwargs_one_vis[1].keys())[0]] + + video_data = extract_image_tensors(video_data.cpu(), cfg.mean, cfg.std) + + # synchronize to finish some processes + if not cfg.debug: + torch.cuda.synchronize() + dist.barrier() + + return video_data + +@torch.no_grad() +def extract_image_tensors(video_tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]): + # Unnormalize the video tensor + mean = torch.tensor(mean, device=video_tensor.device).view(1, -1, 1, 1, 1) # ncfhw + std = torch.tensor(std, device=video_tensor.device).view(1, -1, 1, 1, 1) # ncfhw + video_tensor = video_tensor.mul_(std).add_(mean) # unnormalize back to [0,1] + video_tensor.clamp_(0, 1) + + images = rearrange(video_tensor, 'b c f h w -> b f h w c') + images = images.squeeze(0) + images_t = [] + for img in images: + img_array = np.array(img) # Convert PIL Image to numpy array + img_tensor = torch.from_numpy(img_array).permute(2, 0, 1).unsqueeze(0).float() # Convert to tensor and CHW format + img_tensor = img_tensor.permute(0, 2, 3, 1) + images_t.append(img_tensor) + + print('Inference completed!') + images_t = torch.cat(images_t, dim=0) + return images_t def prepare_model_kwargs(partial_keys, full_model_kwargs, use_fps_condition=False): - if use_fps_condition is True: partial_keys.append('fps') - partial_model_kwargs = [{}, {}] for partial_key in partial_keys: partial_model_kwargs[0][partial_key] = full_model_kwargs[0][partial_key] partial_model_kwargs[1][partial_key] = full_model_kwargs[1][partial_key] - return partial_model_kwargs