From 9146bdd551d4905f43e464e663ef572efba3adb8 Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Wed, 4 Dec 2024 22:30:50 +0800 Subject: [PATCH] use controlnet_model_name_or_path --- flux_train_control_net.py | 2 +- library/flux_train_utils.py | 2 +- sdxl_train_control_net.py | 8 ++++---- train_control_net.py | 6 +++--- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/flux_train_control_net.py b/flux_train_control_net.py index 744c265b5..9d36a41d3 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -265,7 +265,7 @@ def train(args): # load controlnet controlnet_dtype = torch.float32 if args.deepspeed else weight_dtype controlnet = flux_utils.load_controlnet( - args.controlnet_model_path, is_schnell, controlnet_dtype, accelerator.device, args.disable_mmap_load_safetensors + args.controlnet_model_name_or_path, is_schnell, controlnet_dtype, accelerator.device, args.disable_mmap_load_safetensors ) controlnet.train() diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 8e2da5c65..f7f06c5cf 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -564,7 +564,7 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser): ) parser.add_argument("--ae", type=str, help="path to ae (*.sft or *.safetensors) / aeのパス(*.sftまたは*.safetensors)") parser.add_argument( - "--controlnet_model_path", + "--controlnet_model_name_or_path", type=str, default=None, help="path to controlnet (*.sft or *.safetensors) / controlnetのパス(*.sftまたは*.safetensors)" diff --git a/sdxl_train_control_net.py b/sdxl_train_control_net.py index 01387409a..ffbf03cab 100644 --- a/sdxl_train_control_net.py +++ b/sdxl_train_control_net.py @@ -184,12 +184,12 @@ def unwrap_model(model): # make control net logger.info("make ControlNet") - if args.controlnet_model_path: + if args.controlnet_model_name_or_path: with init_empty_weights(): control_net = SdxlControlNet() - logger.info(f"load ControlNet from {args.controlnet_model_path}") - filename = args.controlnet_model_path + logger.info(f"load ControlNet from {args.controlnet_model_name_or_path}") + filename = args.controlnet_model_name_or_path if os.path.splitext(filename)[1] == ".safetensors": state_dict = load_file(filename) else: @@ -675,7 +675,7 @@ def setup_parser() -> argparse.ArgumentParser: sdxl_train_util.add_sdxl_training_arguments(parser) parser.add_argument( - "--controlnet_model_path", + "--controlnet_model_name_or_path", type=str, default=None, help="controlnet model name or path / controlnetのモデル名またはパス", diff --git a/train_control_net.py b/train_control_net.py index bd2d6c47e..177d2b11f 100644 --- a/train_control_net.py +++ b/train_control_net.py @@ -221,8 +221,8 @@ def __contains__(self, name): controlnet = ControlNetModel.from_unet(unet) - if args.controlnet_model_path: - filename = args.controlnet_model_path + if args.controlnet_model_name_or_path: + filename = args.controlnet_model_name_or_path if os.path.isfile(filename): if os.path.splitext(filename)[1] == ".safetensors": state_dict = load_file(filename) @@ -644,7 +644,7 @@ def setup_parser() -> argparse.ArgumentParser: help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)", ) parser.add_argument( - "--controlnet_model_path", + "--controlnet_model_name_or_path", type=str, default=None, help="controlnet model name or path / controlnetのモデル名またはパス",