Skip to content

Commit

Permalink
use controlnet_model_name_or_path
Browse files Browse the repository at this point in the history
  • Loading branch information
sdbds committed Dec 4, 2024
1 parent b19210a commit 9146bdd
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 9 deletions.
2 changes: 1 addition & 1 deletion flux_train_control_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def train(args):
# load controlnet
controlnet_dtype = torch.float32 if args.deepspeed else weight_dtype
controlnet = flux_utils.load_controlnet(
args.controlnet_model_path, is_schnell, controlnet_dtype, accelerator.device, args.disable_mmap_load_safetensors
args.controlnet_model_name_or_path, is_schnell, controlnet_dtype, accelerator.device, args.disable_mmap_load_safetensors
)
controlnet.train()

Expand Down
2 changes: 1 addition & 1 deletion library/flux_train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,7 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser):
)
parser.add_argument("--ae", type=str, help="path to ae (*.sft or *.safetensors) / aeのパス(*.sftまたは*.safetensors)")
parser.add_argument(
"--controlnet_model_path",
"--controlnet_model_name_or_path",
type=str,
default=None,
help="path to controlnet (*.sft or *.safetensors) / controlnetのパス(*.sftまたは*.safetensors)"
Expand Down
8 changes: 4 additions & 4 deletions sdxl_train_control_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,12 +184,12 @@ def unwrap_model(model):

# make control net
logger.info("make ControlNet")
if args.controlnet_model_path:
if args.controlnet_model_name_or_path:
with init_empty_weights():
control_net = SdxlControlNet()

logger.info(f"load ControlNet from {args.controlnet_model_path}")
filename = args.controlnet_model_path
logger.info(f"load ControlNet from {args.controlnet_model_name_or_path}")
filename = args.controlnet_model_name_or_path
if os.path.splitext(filename)[1] == ".safetensors":
state_dict = load_file(filename)
else:
Expand Down Expand Up @@ -675,7 +675,7 @@ def setup_parser() -> argparse.ArgumentParser:
sdxl_train_util.add_sdxl_training_arguments(parser)

parser.add_argument(
"--controlnet_model_path",
"--controlnet_model_name_or_path",
type=str,
default=None,
help="controlnet model name or path / controlnetのモデル名またはパス",
Expand Down
6 changes: 3 additions & 3 deletions train_control_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,8 @@ def __contains__(self, name):

controlnet = ControlNetModel.from_unet(unet)

if args.controlnet_model_path:
filename = args.controlnet_model_path
if args.controlnet_model_name_or_path:
filename = args.controlnet_model_name_or_path
if os.path.isfile(filename):
if os.path.splitext(filename)[1] == ".safetensors":
state_dict = load_file(filename)
Expand Down Expand Up @@ -644,7 +644,7 @@ def setup_parser() -> argparse.ArgumentParser:
help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)",
)
parser.add_argument(
"--controlnet_model_path",
"--controlnet_model_name_or_path",
type=str,
default=None,
help="controlnet model name or path / controlnetのモデル名またはパス",
Expand Down

0 comments on commit 9146bdd

Please sign in to comment.