Skip to content

Commit

Permalink
Merge pull request kohya-ss#428 from p1atdev/dev
Browse files Browse the repository at this point in the history
Add WandB logging support
  • Loading branch information
kohya-ss authored Apr 22, 2023
2 parents 9f8f27f + a69b24a commit c430cf4
Show file tree
Hide file tree
Showing 7 changed files with 43 additions and 9 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ wd14_tagger_model
venv
*.egg-info
build
.vscode
.vscode
wandb
2 changes: 1 addition & 1 deletion fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
)

if accelerator.is_main_process:
accelerator.init_trackers("finetuning")
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name)

for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")
Expand Down
39 changes: 36 additions & 3 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2067,7 +2067,15 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
default=None,
help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する",
)
parser.add_argument(
"--log_with",
type=str,
default=None,
choices=["tensorboard", "wandb", "all"],
help="what logging tool(s) to use (if 'all', TensorBoard and WandB are both used) / ログ出力に使用するツール (allを指定するとTensorBoardとWandBの両方が使用される)",
)
parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列")
parser.add_argument("--log_tracker_name", type=str, default=None, help="name of tracker to use for logging / ログ出力に使用するtrackerの名前")
parser.add_argument(
"--noise_offset",
type=float,
Expand Down Expand Up @@ -2732,13 +2740,25 @@ def load_tokenizer(args: argparse.Namespace):

def prepare_accelerator(args: argparse.Namespace):
if args.logging_dir is None:
log_with = None
logging_dir = None
else:
log_with = "tensorboard"
log_prefix = "" if args.log_prefix is None else args.log_prefix
logging_dir = args.logging_dir + "/" + log_prefix + time.strftime("%Y%m%d%H%M%S", time.localtime())

if args.log_with is not None:
log_with = "tensorboard" if args.log_with is None else args.log_with
if log_with in ["tensorboard", "all"]:
if logging_dir is None:
raise ValueError("logging_dir is required when log_with is tensorboard / Tensorboardを使う場合、logging_dirを指定してください")
if log_with in ["wandb", "all"]:
try:
import wandb
except ImportError:
raise ImportError("No wandb / wandb がインストールされていないようです")
if logging_dir is not None:
os.makedirs(logging_dir)
os.environ["WANDB_DIR"] = logging_dir

accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
Expand Down Expand Up @@ -3197,6 +3217,20 @@ def sample_images(

image.save(os.path.join(save_dir, img_filename))

# wandb有効時のみログを送信
try:
wandb_tracker = accelerator.get_tracker("wandb")
try:
import wandb
except ImportError: # 事前に一度確認するのでここはエラー出ないはず
raise ImportError("No wandb / wandb がインストールされていないようです")

wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
except: # wandb 無効時
pass



# clear pipeline and cache to reduce vram usage
del pipeline
torch.cuda.empty_cache()
Expand All @@ -3205,7 +3239,6 @@ def sample_images(
torch.cuda.set_rng_state(cuda_rng_state)
vae.to(org_vae_device)


# endregion

# region 前処理用
Expand Down
2 changes: 1 addition & 1 deletion train_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def train(args):
)

if accelerator.is_main_process:
accelerator.init_trackers("dreambooth")
accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name)

loss_list = []
loss_total = 0.0
Expand Down
2 changes: 1 addition & 1 deletion train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ def train(args):
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
)
if accelerator.is_main_process:
accelerator.init_trackers("network_train")
accelerator.init_trackers("network_train" if args.log_tracker_name is None else args.log_tracker_name)

loss_list = []
loss_total = 0.0
Expand Down
2 changes: 1 addition & 1 deletion train_textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def train(args):
)

if accelerator.is_main_process:
accelerator.init_trackers("textual_inversion")
accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name)

for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")
Expand Down
2 changes: 1 addition & 1 deletion train_textual_inversion_XTI.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def train(args):
)

if accelerator.is_main_process:
accelerator.init_trackers("textual_inversion")
accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name)

for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")
Expand Down

0 comments on commit c430cf4

Please sign in to comment.