From ad29991e40feb9bb980373d905efc578a1c633e8 Mon Sep 17 00:00:00 2001 From: Xin Ma <38418898+maxin-cn@users.noreply.github.com> Date: Tue, 21 Nov 2023 14:13:00 +1100 Subject: [PATCH] Create sample.py --- base/sample.py | 86 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 86 insertions(+) create mode 100644 base/sample.py diff --git a/base/sample.py b/base/sample.py new file mode 100644 index 00000000..422a73ac --- /dev/null +++ b/base/sample.py @@ -0,0 +1,86 @@ +import os +import torch +import argparse +import torchvision +import sys + +sys.path.append(os.path.split(sys.path[0])[0]) +from pipelines.pipeline_videogen import VideoGenPipeline + +from download import find_model +from diffusers.schedulers import DDIMScheduler, DDPMScheduler, PNDMScheduler, EulerDiscreteScheduler +from diffusers.models import AutoencoderKL +from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection +from omegaconf import OmegaConf +from models import get_models +import imageio + +def main(args): + torch.manual_seed(args.seed) + torch.set_grad_enabled(False) + device = "cuda" if torch.cuda.is_available() else "cpu" + + sd_path = args.pretrained_path + "/stable-diffusion-v1-4" + unet = get_models(args, sd_path).to(device, dtype=torch.float16) + state_dict = find_model(args.pretrained_path + "/lavie_base.pt") + unet.load_state_dict(state_dict) + + vae = AutoencoderKL.from_pretrained(sd_path, subfolder="vae", torch_dtype=torch.float16).to(device) + tokenizer_one = CLIPTokenizer.from_pretrained(sd_path, subfolder="tokenizer") + text_encoder_one = CLIPTextModel.from_pretrained(sd_path, subfolder="text_encoder", torch_dtype=torch.float16).to(device) # huge + + # set eval mode + unet.eval() + vae.eval() + text_encoder_one.eval() + + if args.sample_method == 'ddim': + scheduler = DDIMScheduler.from_pretrained(sd_path, + subfolder="scheduler", + beta_start=args.beta_start, + beta_end=args.beta_end, + beta_schedule=args.beta_schedule) + elif args.sample_method == 'eulerdiscrete': + scheduler = EulerDiscreteScheduler.from_pretrained(sd_path, + subfolder="scheduler", + beta_start=args.beta_start, + beta_end=args.beta_end, + beta_schedule=args.beta_schedule) + elif args.sample_method == 'ddpm': + scheduler = DDPMScheduler.from_pretrained(sd_path, + subfolder="scheduler", + beta_start=args.beta_start, + beta_end=args.beta_end, + beta_schedule=args.beta_schedule) + else: + raise NotImplementedError + + videogen_pipeline = VideoGenPipeline(vae=vae, + text_encoder=text_encoder_one, + tokenizer=tokenizer_one, + scheduler=scheduler, + unet=unet).to(device) + videogen_pipeline.enable_xformers_memory_efficient_attention() + + if not os.path.exists(args.output_folder): + os.makedirs(args.output_folder) + + video_grids = [] + for prompt in args.text_prompt: + print('Processing the ({}) prompt'.format(prompt)) + videos = videogen_pipeline(prompt, + video_length=args.video_length, + height=args.image_size[0], + width=args.image_size[1], + num_inference_steps=args.num_sampling_steps, + guidance_scale=args.guidance_scale).video + imageio.mimwrite(args.output_folder + prompt.replace(' ', '_') + '.mp4', videos[0], fps=8, quality=9) # highest quality is 10, lowest is 0 + + print('save path {}'.format(args.output_folder)) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, default="") + args = parser.parse_args() + + main(OmegaConf.load(args.config))