diff --git a/.gitignore b/.gitignore index 0904a2a41..e492b1add 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,5 @@ wd14_tagger_model venv *.egg-info build -.vscode \ No newline at end of file +.vscode +wandb diff --git a/fine_tune.py b/fine_tune.py index 61f6c1919..4de57b452 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -260,7 +260,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): ) if accelerator.is_main_process: - accelerator.init_trackers("finetuning") + accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name) for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") diff --git a/library/train_util.py b/library/train_util.py index 01d9343e2..8e91de05e 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2067,7 +2067,15 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: default=None, help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する", ) + parser.add_argument( + "--log_with", + type=str, + default=None, + choices=["tensorboard", "wandb", "all"], + help="what logging tool(s) to use (if 'all', TensorBoard and WandB are both used) / ログ出力に使用するツール (allを指定するとTensorBoardとWandBの両方が使用される)", + ) parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列") + parser.add_argument("--log_tracker_name", type=str, default=None, help="name of tracker to use for logging / ログ出力に使用するtrackerの名前") parser.add_argument( "--noise_offset", type=float, @@ -2732,13 +2740,25 @@ def load_tokenizer(args: argparse.Namespace): def prepare_accelerator(args: argparse.Namespace): if args.logging_dir is None: - log_with = None logging_dir = None else: - log_with = "tensorboard" log_prefix = "" if args.log_prefix is None else args.log_prefix logging_dir = args.logging_dir + "/" + log_prefix + time.strftime("%Y%m%d%H%M%S", time.localtime()) + if args.log_with is not None: + log_with = "tensorboard" if args.log_with is None else args.log_with + if log_with in ["tensorboard", "all"]: + if logging_dir is None: + raise ValueError("logging_dir is required when log_with is tensorboard / Tensorboardを使う場合、logging_dirを指定してください") + if log_with in ["wandb", "all"]: + try: + import wandb + except ImportError: + raise ImportError("No wandb / wandb がインストールされていないようです") + if logging_dir is not None: + os.makedirs(logging_dir) + os.environ["WANDB_DIR"] = logging_dir + accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, @@ -3197,6 +3217,20 @@ def sample_images( image.save(os.path.join(save_dir, img_filename)) + # wandb有効時のみログを送信 + try: + wandb_tracker = accelerator.get_tracker("wandb") + try: + import wandb + except ImportError: # 事前に一度確認するのでここはエラー出ないはず + raise ImportError("No wandb / wandb がインストールされていないようです") + + wandb_tracker.log({f"sample_{i}": wandb.Image(image)}) + except: # wandb 無効時 + pass + + + # clear pipeline and cache to reduce vram usage del pipeline torch.cuda.empty_cache() @@ -3205,7 +3239,6 @@ def sample_images( torch.cuda.set_rng_state(cuda_rng_state) vae.to(org_vae_device) - # endregion # region 前処理用 diff --git a/train_db.py b/train_db.py index eddf8f686..5c4202a6b 100644 --- a/train_db.py +++ b/train_db.py @@ -231,7 +231,7 @@ def train(args): ) if accelerator.is_main_process: - accelerator.init_trackers("dreambooth") + accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name) loss_list = [] loss_total = 0.0 diff --git a/train_network.py b/train_network.py index 658138b70..8b6f2c83c 100644 --- a/train_network.py +++ b/train_network.py @@ -538,7 +538,7 @@ def train(args): beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False ) if accelerator.is_main_process: - accelerator.init_trackers("network_train") + accelerator.init_trackers("network_train" if args.log_tracker_name is None else args.log_tracker_name) loss_list = [] loss_total = 0.0 diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 611adff71..2042a6188 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -337,7 +337,7 @@ def train(args): ) if accelerator.is_main_process: - accelerator.init_trackers("textual_inversion") + accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name) for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 54c4b4e56..c2ebf7cbc 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -371,7 +371,7 @@ def train(args): ) if accelerator.is_main_process: - accelerator.init_trackers("textual_inversion") + accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name) for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}")