From 49af7cb3d1de51f6c19b7ff8836d07c91d79b2cf Mon Sep 17 00:00:00 2001 From: xumingw Date: Thu, 4 Jul 2024 10:35:40 +0800 Subject: [PATCH] fix: bug #124 add temp unet in log validation add face emb validation in extract meta info function --- scripts/extract_meta_info_stage1.py | 6 ++++++ scripts/train_stage1.py | 3 +++ 2 files changed, 9 insertions(+) diff --git a/scripts/extract_meta_info_stage1.py b/scripts/extract_meta_info_stage1.py index d25123e1..936cb06c 100644 --- a/scripts/extract_meta_info_stage1.py +++ b/scripts/extract_meta_info_stage1.py @@ -21,6 +21,8 @@ import os from pathlib import Path +import torch + def collect_video_folder_paths(root_path: Path) -> list: """ @@ -52,6 +54,10 @@ def construct_meta_info(frames_dir_path: Path) -> dict: print(f"Mask path not found: {mask_path}") return None + if torch.load(face_emb_path) is None: + print(f"Face emb is None: {face_emb_path}") + return None + return { "image_path": str(frames_dir_path), "mask_path": mask_path, diff --git a/scripts/train_stage1.py b/scripts/train_stage1.py index 9c6265fa..e9e7e847 100644 --- a/scripts/train_stage1.py +++ b/scripts/train_stage1.py @@ -16,6 +16,7 @@ """ import argparse +import copy import logging import math import os @@ -211,6 +212,7 @@ def log_validation( logger.info("Running validation... ") ori_net = accelerator.unwrap_model(net) + ori_net = copy.deepcopy(ori_net) reference_unet = ori_net.reference_unet denoising_unet = ori_net.denoising_unet face_locator = ori_net.face_locator @@ -278,6 +280,7 @@ def log_validation( canvas.save(out_file) del pipe + del ori_net torch.cuda.empty_cache() return pil_images