diff --git a/library/train_util.py b/library/train_util.py index 31b3149..aac198a 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5410,7 +5410,7 @@ def save_and_remove_state_on_epoch_end(args: argparse.Namespace, accelerator, ep shutil.rmtree(state_dir_old) -def save_and_remove_state_stepwise(args: argparse.Namespace, accelerator, step_no): +def save_and_remove_state_stepwise(args: argparse.Namespace, accelerator, step_no, loss_recorder): model_name = default_if_none(args.output_name, DEFAULT_STEP_NAME) logger.info("") @@ -5419,6 +5419,13 @@ def save_and_remove_state_stepwise(args: argparse.Namespace, accelerator, step_n state_dir = os.path.join(args.output_dir, STEP_STATE_NAME.format(model_name, step_no)) accelerator.save_state(state_dir) + # save loss list + lossfile = os.path.join(state_dir, "losslist.json") + savepointfile = os.path.join(state_dir, "savepoints.json") + with open(lossfile, 'w') as f: + json.dump(loss_recorder.global_loss_list, f, indent=2) + with open(savepointfile, 'w') as f: + json.dump(loss_recorder.savepoints, f, indent=2) if args.save_state_to_huggingface: logger.info("uploading state to huggingface.") huggingface_util.upload(args, state_dir, "/" + STEP_STATE_NAME.format(model_name, step_no)) @@ -5445,7 +5452,13 @@ def save_state_on_train_end(args: argparse.Namespace, accelerator): state_dir = os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)) accelerator.save_state(state_dir) - + # save loss list + lossfile = os.path.join(state_dir, "losslist.json") + savepointfile = os.path.join(state_dir, "savepoints.json") + with open(lossfile, 'w') as f: + json.dump(loss_recorder.global_loss_list, f, indent=2) + with open(savepointfile, 'w') as f: + json.dump(loss_recorder.savepoints, f, indent=2) if args.save_state_to_huggingface: logger.info("uploading last state to huggingface.") huggingface_util.upload(args, state_dir, "/" + LAST_STATE_NAME.format(model_name)) @@ -6030,6 +6043,7 @@ def __init__(self): self.loss_list: List[float] = [] self.global_loss_list: List[float] = [] self.loss_total: float = 0.0 + self.savepoints: List[int] = [] def add(self, *, epoch: int, step: int, global_step: int, loss: float) -> None: if epoch == 0: @@ -6043,6 +6057,10 @@ def add(self, *, epoch: int, step: int, global_step: int, loss: float) -> None: self.global_loss_list.append(loss) self.loss_total += loss + def addsavepoint(self, step: int) -> None: + self.savepoints.append(step) + + @property def moving_average(self) -> float: return self.loss_total / len(self.loss_list) diff --git a/nodes.py b/nodes.py index 79d9c93..69a4a7a 100644 --- a/nodes.py +++ b/nodes.py @@ -392,7 +392,6 @@ def INPUT_TYPES(s): "gradient_dtype": (["fp32", "fp16", "bf16"], {"default": "fp32", "tooltip": "the actual dtype training uses"}), "save_dtype": (["fp32", "fp16", "bf16", "fp8_e4m3fn", "fp8_e5m2"], {"default": "bf16", "tooltip": "the dtype to save checkpoints as"}), "attention_mode": (["sdpa", "xformers", "disabled"], {"default": "sdpa", "tooltip": "memory efficient attention mode"}), - "sample_prompts": ("STRING", {"multiline": True, "default": "illustration of a kitten | photograph of a turtle", "tooltip": "validation sample prompts, for multiple prompts, separate by `|`"}), }, "optional": { "additional_args": ("STRING", {"multiline": True, "default": "", "tooltip": "additional args to pass to the training command"}), @@ -401,6 +400,7 @@ def INPUT_TYPES(s): "clip_l_lr": ("FLOAT", {"default": 0, "min": 0.0, "max": 10.0, "step": 0.000001, "tooltip": "text encoder learning rate"}), "T5_lr": ("FLOAT", {"default": 0, "min": 0.0, "max": 10.0, "step": 0.000001, "tooltip": "text encoder learning rate"}), "block_args": ("ARGS", {"default": "", "tooltip": "limit the blocks used in the LoRA"}), + "sample_prompts": ("STRING", {"multiline": True, "default": "illustration of a kitten | photograph of a turtle", "tooltip": "validation sample prompts, for multiple prompts, separate by `|`"}), "gradient_checkpointing": (["enabled", "enabled_with_cpu_offloading", "disabled"], {"default": "enabled", "tooltip": "use gradient checkpointing"}), }, "hidden": { @@ -413,9 +413,9 @@ def INPUT_TYPES(s): FUNCTION = "init_training" CATEGORY = "FluxTrainer" - def init_training(self, flux_models, dataset, optimizer_settings, sample_prompts, output_name, attention_mode, + def init_training(self, flux_models, dataset, optimizer_settings, output_name, attention_mode, gradient_dtype, save_dtype, split_mode, additional_args=None, resume_args=None, train_text_encoder='disabled', - block_args=None, gradient_checkpointing="enabled", prompt=None, extra_pnginfo=None, clip_l_lr=0, T5_lr=0, **kwargs,): + block_args=None, sample_prompts="", gradient_checkpointing="enabled", prompt=None, extra_pnginfo=None, clip_l_lr=0, T5_lr=0, **kwargs,): mm.soft_empty_cache() output_dir = os.path.abspath(kwargs.get("output_dir")) @@ -930,6 +930,67 @@ def save(self, network_trainer): network_trainer.optimizer_train_fn() print("Saving at step:", network_trainer.global_step) +class FluxTrainAndSaveLoop: + @classmethod + def INPUT_TYPES(cls): + return {"required": { + "network_trainer": ("NETWORKTRAINER",), + "save_at_steps": ("INT", {"default": 250, "min": 1, "max": 10000, "step": 1, "tooltip": "the step point in training to save"}), + "save_state": ("BOOLEAN", {"default": True, "tooltip": "backup the training state of the model"}), + }, + } + + RETURN_TYPES = ("NETWORKTRAINER", "INT",) + RETURN_NAMES = ("network_trainer", "steps",) + FUNCTION = "train" + CATEGORY = "FluxTrainer" + + def train(self, network_trainer, save_at_steps, save_state): + with torch.inference_mode(False): + training_loop = network_trainer["training_loop"] + network_trainer = network_trainer["network_trainer"] + + target_global_step = network_trainer.args.max_train_steps + comfy_pbar = comfy.utils.ProgressBar(target_global_step) + network_trainer.comfy_pbar = comfy_pbar + + network_trainer.optimizer_train_fn() + + while network_trainer.global_step < target_global_step: + next_save_step = ((network_trainer.global_step // save_at_steps) + 1) * save_at_steps + + # set current epoch to start epoch on resume + if network_trainer.current_epoch.value < network_trainer.epoch_to_start: + network_trainer.current_epoch.value = network_trainer.epoch_to_start + steps_done = training_loop( + break_at_steps=next_save_step, + epoch=network_trainer.current_epoch.value, + ) + + # Check if we need to save + if network_trainer.global_step % save_at_steps == 0: + self.save(network_trainer, save_state) + + # Also break if the global steps have reached the max train steps + if network_trainer.global_step >= network_trainer.args.max_train_steps: + break + + trainer = { + "network_trainer": network_trainer, + "training_loop": training_loop, + } + return (trainer, network_trainer.global_step) + + def save(self, network_trainer, save_state): + ckpt_name = train_util.get_step_ckpt_name(network_trainer.args, "." + network_trainer.args.save_model_as, network_trainer.global_step) + network_trainer.optimizer_eval_fn() + network_trainer.loss_recorder.addsavepoint(network_trainer.global_step) + if save_state: + train_util.save_and_remove_state_stepwise(network_trainer.args, network_trainer.accelerator, network_trainer.global_step, network_trainer.loss_recorder) + network_trainer.save_model(ckpt_name, network_trainer.accelerator.unwrap_model(network_trainer.network), network_trainer.global_step, network_trainer.current_epoch.value + 1) + network_trainer.optimizer_train_fn() + print("Saving at step:", network_trainer.global_step) + class FluxTrainSave: @classmethod def INPUT_TYPES(s): @@ -1043,6 +1104,8 @@ def endtrain(self, network_trainer, save_state): network_trainer.accelerator.end_training() network_trainer.optimizer_eval_fn() + network_trainer.loss_recorder.addsavepoint(network_trainer.global_step) + if save_state: train_util.save_state_on_train_end(network_trainer.args, network_trainer.accelerator) @@ -1227,6 +1290,8 @@ def INPUT_TYPES(s): def draw(self, network_trainer, window_size, plot_style, normalize_y, width, height, log_scale): import numpy as np loss_values = network_trainer["network_trainer"].loss_recorder.global_loss_list + savepoints = network_trainer["network_trainer"].loss_recorder.savepoints + del savepoints[-1] # Apply moving average def moving_average(values, window_size): @@ -1242,9 +1307,14 @@ def moving_average(values, window_size): # Create a plot fig, ax = plt.subplots(figsize=(width_inches, height_inches)) - ax.plot(loss_values, label='Training Loss') + ax.plot(range(window_size, len(loss_values) + window_size), loss_values, label='Training Loss') + plt.xlim(0,len(loss_values) + window_size - 1) ax.set_xlabel('Step') ax.set_ylabel('Loss') + #ax.set_xticks(savepoints, minor=False) + #ax.xaxis.grid(True, which='major') + for xpoint in savepoints: + ax.axvline(xpoint, linestyle=':', color='r') if normalize_y: plt.ylim(bottom=0) if log_scale: @@ -1703,6 +1773,7 @@ def extract(self, original_model, finetuned_model, output_path, dim, save_dtype, "FluxTrainBlockSelect": FluxTrainBlockSelect, "TrainDatasetRegularization": TrainDatasetRegularization, "FluxTrainAndValidateLoop": FluxTrainAndValidateLoop, + "FluxTrainAndSaveLoop": FluxTrainAndSaveLoop, "OptimizerConfigProdigyPlusScheduleFree": OptimizerConfigProdigyPlusScheduleFree, } NODE_DISPLAY_NAME_MAPPINGS = { @@ -1728,5 +1799,6 @@ def extract(self, original_model, finetuned_model, output_path, dim, save_dtype, "FluxTrainBlockSelect": "Flux Train Block Select", "TrainDatasetRegularization": "Train Dataset Regularization", "FluxTrainAndValidateLoop": "Flux Train And Validate Loop", + "FluxTrainAndSaveLoop": "Flux Train And Save Loop", "OptimizerConfigProdigyPlusScheduleFree": "Optimizer Config ProdigyPlusScheduleFree", } diff --git a/train_network.py b/train_network.py index 0094fd7..4c89066 100644 --- a/train_network.py +++ b/train_network.py @@ -1006,6 +1006,17 @@ def load_model_hook(models, input_dir): ) self.loss_recorder = train_util.LossRecorder() + if args.resume: + lossfile = os.path.join(args.resume, "losslist.json") + if os.path.isfile(lossfile): + with open(lossfile, 'r') as f: + self.loss_recorder.global_loss_list = json.load(f) + accelerator.print("losslist loaded") + savepointfile = os.path.join(args.resume, "savepoints.json") + if os.path.isfile(savepointfile): + with open(savepointfile, 'r') as f: + self.loss_recorder.savepoints = json.load(f) + accelerator.print("savepointlist loaded") del train_dataset_group pbar.update(1) @@ -1087,7 +1098,7 @@ def remove_model(old_ckpt_name): self.remove_model = remove_model self.comfy_pbar = None - progress_bar = tqdm(range(args.max_train_steps - initial_step), smoothing=0, disable=False, desc="steps") + progress_bar = tqdm(range(args.max_train_steps - self.global_step), smoothing=0, disable=False, desc="steps") def training_loop(break_at_steps, epoch): steps_done = 0