Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvement of Visualize Loss, backup the loss list and create a new node without validation #113

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5410,7 +5410,7 @@ def save_and_remove_state_on_epoch_end(args: argparse.Namespace, accelerator, ep
shutil.rmtree(state_dir_old)


def save_and_remove_state_stepwise(args: argparse.Namespace, accelerator, step_no):
def save_and_remove_state_stepwise(args: argparse.Namespace, accelerator, step_no, loss_recorder):
model_name = default_if_none(args.output_name, DEFAULT_STEP_NAME)

logger.info("")
Expand All @@ -5419,6 +5419,13 @@ def save_and_remove_state_stepwise(args: argparse.Namespace, accelerator, step_n

state_dir = os.path.join(args.output_dir, STEP_STATE_NAME.format(model_name, step_no))
accelerator.save_state(state_dir)
# save loss list
lossfile = os.path.join(state_dir, "losslist.json")
savepointfile = os.path.join(state_dir, "savepoints.json")
with open(lossfile, 'w') as f:
json.dump(loss_recorder.global_loss_list, f, indent=2)
with open(savepointfile, 'w') as f:
json.dump(loss_recorder.savepoints, f, indent=2)
if args.save_state_to_huggingface:
logger.info("uploading state to huggingface.")
huggingface_util.upload(args, state_dir, "/" + STEP_STATE_NAME.format(model_name, step_no))
Expand All @@ -5445,7 +5452,13 @@ def save_state_on_train_end(args: argparse.Namespace, accelerator):

state_dir = os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name))
accelerator.save_state(state_dir)

# save loss list
lossfile = os.path.join(state_dir, "losslist.json")
savepointfile = os.path.join(state_dir, "savepoints.json")
with open(lossfile, 'w') as f:
json.dump(loss_recorder.global_loss_list, f, indent=2)
with open(savepointfile, 'w') as f:
json.dump(loss_recorder.savepoints, f, indent=2)
if args.save_state_to_huggingface:
logger.info("uploading last state to huggingface.")
huggingface_util.upload(args, state_dir, "/" + LAST_STATE_NAME.format(model_name))
Expand Down Expand Up @@ -6030,6 +6043,7 @@ def __init__(self):
self.loss_list: List[float] = []
self.global_loss_list: List[float] = []
self.loss_total: float = 0.0
self.savepoints: List[int] = []

def add(self, *, epoch: int, step: int, global_step: int, loss: float) -> None:
if epoch == 0:
Expand All @@ -6043,6 +6057,10 @@ def add(self, *, epoch: int, step: int, global_step: int, loss: float) -> None:
self.global_loss_list.append(loss)
self.loss_total += loss

def addsavepoint(self, step: int) -> None:
self.savepoints.append(step)


@property
def moving_average(self) -> float:
return self.loss_total / len(self.loss_list)
80 changes: 76 additions & 4 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,6 @@ def INPUT_TYPES(s):
"gradient_dtype": (["fp32", "fp16", "bf16"], {"default": "fp32", "tooltip": "the actual dtype training uses"}),
"save_dtype": (["fp32", "fp16", "bf16", "fp8_e4m3fn", "fp8_e5m2"], {"default": "bf16", "tooltip": "the dtype to save checkpoints as"}),
"attention_mode": (["sdpa", "xformers", "disabled"], {"default": "sdpa", "tooltip": "memory efficient attention mode"}),
"sample_prompts": ("STRING", {"multiline": True, "default": "illustration of a kitten | photograph of a turtle", "tooltip": "validation sample prompts, for multiple prompts, separate by `|`"}),
},
"optional": {
"additional_args": ("STRING", {"multiline": True, "default": "", "tooltip": "additional args to pass to the training command"}),
Expand All @@ -401,6 +400,7 @@ def INPUT_TYPES(s):
"clip_l_lr": ("FLOAT", {"default": 0, "min": 0.0, "max": 10.0, "step": 0.000001, "tooltip": "text encoder learning rate"}),
"T5_lr": ("FLOAT", {"default": 0, "min": 0.0, "max": 10.0, "step": 0.000001, "tooltip": "text encoder learning rate"}),
"block_args": ("ARGS", {"default": "", "tooltip": "limit the blocks used in the LoRA"}),
"sample_prompts": ("STRING", {"multiline": True, "default": "illustration of a kitten | photograph of a turtle", "tooltip": "validation sample prompts, for multiple prompts, separate by `|`"}),
"gradient_checkpointing": (["enabled", "enabled_with_cpu_offloading", "disabled"], {"default": "enabled", "tooltip": "use gradient checkpointing"}),
},
"hidden": {
Expand All @@ -413,9 +413,9 @@ def INPUT_TYPES(s):
FUNCTION = "init_training"
CATEGORY = "FluxTrainer"

def init_training(self, flux_models, dataset, optimizer_settings, sample_prompts, output_name, attention_mode,
def init_training(self, flux_models, dataset, optimizer_settings, output_name, attention_mode,
gradient_dtype, save_dtype, split_mode, additional_args=None, resume_args=None, train_text_encoder='disabled',
block_args=None, gradient_checkpointing="enabled", prompt=None, extra_pnginfo=None, clip_l_lr=0, T5_lr=0, **kwargs,):
block_args=None, sample_prompts="", gradient_checkpointing="enabled", prompt=None, extra_pnginfo=None, clip_l_lr=0, T5_lr=0, **kwargs,):
mm.soft_empty_cache()

output_dir = os.path.abspath(kwargs.get("output_dir"))
Expand Down Expand Up @@ -930,6 +930,67 @@ def save(self, network_trainer):
network_trainer.optimizer_train_fn()
print("Saving at step:", network_trainer.global_step)

class FluxTrainAndSaveLoop:
@classmethod
def INPUT_TYPES(cls):
return {"required": {
"network_trainer": ("NETWORKTRAINER",),
"save_at_steps": ("INT", {"default": 250, "min": 1, "max": 10000, "step": 1, "tooltip": "the step point in training to save"}),
"save_state": ("BOOLEAN", {"default": True, "tooltip": "backup the training state of the model"}),
},
}

RETURN_TYPES = ("NETWORKTRAINER", "INT",)
RETURN_NAMES = ("network_trainer", "steps",)
FUNCTION = "train"
CATEGORY = "FluxTrainer"

def train(self, network_trainer, save_at_steps, save_state):
with torch.inference_mode(False):
training_loop = network_trainer["training_loop"]
network_trainer = network_trainer["network_trainer"]

target_global_step = network_trainer.args.max_train_steps
comfy_pbar = comfy.utils.ProgressBar(target_global_step)
network_trainer.comfy_pbar = comfy_pbar

network_trainer.optimizer_train_fn()

while network_trainer.global_step < target_global_step:
next_save_step = ((network_trainer.global_step // save_at_steps) + 1) * save_at_steps

# set current epoch to start epoch on resume
if network_trainer.current_epoch.value < network_trainer.epoch_to_start:
network_trainer.current_epoch.value = network_trainer.epoch_to_start
steps_done = training_loop(
break_at_steps=next_save_step,
epoch=network_trainer.current_epoch.value,
)

# Check if we need to save
if network_trainer.global_step % save_at_steps == 0:
self.save(network_trainer, save_state)

# Also break if the global steps have reached the max train steps
if network_trainer.global_step >= network_trainer.args.max_train_steps:
break

trainer = {
"network_trainer": network_trainer,
"training_loop": training_loop,
}
return (trainer, network_trainer.global_step)

def save(self, network_trainer, save_state):
ckpt_name = train_util.get_step_ckpt_name(network_trainer.args, "." + network_trainer.args.save_model_as, network_trainer.global_step)
network_trainer.optimizer_eval_fn()
network_trainer.loss_recorder.addsavepoint(network_trainer.global_step)
if save_state:
train_util.save_and_remove_state_stepwise(network_trainer.args, network_trainer.accelerator, network_trainer.global_step, network_trainer.loss_recorder)
network_trainer.save_model(ckpt_name, network_trainer.accelerator.unwrap_model(network_trainer.network), network_trainer.global_step, network_trainer.current_epoch.value + 1)
network_trainer.optimizer_train_fn()
print("Saving at step:", network_trainer.global_step)

class FluxTrainSave:
@classmethod
def INPUT_TYPES(s):
Expand Down Expand Up @@ -1043,6 +1104,8 @@ def endtrain(self, network_trainer, save_state):
network_trainer.accelerator.end_training()
network_trainer.optimizer_eval_fn()

network_trainer.loss_recorder.addsavepoint(network_trainer.global_step)

if save_state:
train_util.save_state_on_train_end(network_trainer.args, network_trainer.accelerator)

Expand Down Expand Up @@ -1227,6 +1290,8 @@ def INPUT_TYPES(s):
def draw(self, network_trainer, window_size, plot_style, normalize_y, width, height, log_scale):
import numpy as np
loss_values = network_trainer["network_trainer"].loss_recorder.global_loss_list
savepoints = network_trainer["network_trainer"].loss_recorder.savepoints
del savepoints[-1]

# Apply moving average
def moving_average(values, window_size):
Expand All @@ -1242,9 +1307,14 @@ def moving_average(values, window_size):

# Create a plot
fig, ax = plt.subplots(figsize=(width_inches, height_inches))
ax.plot(loss_values, label='Training Loss')
ax.plot(range(window_size, len(loss_values) + window_size), loss_values, label='Training Loss')
plt.xlim(0,len(loss_values) + window_size - 1)
ax.set_xlabel('Step')
ax.set_ylabel('Loss')
#ax.set_xticks(savepoints, minor=False)
#ax.xaxis.grid(True, which='major')
for xpoint in savepoints:
ax.axvline(xpoint, linestyle=':', color='r')
if normalize_y:
plt.ylim(bottom=0)
if log_scale:
Expand Down Expand Up @@ -1703,6 +1773,7 @@ def extract(self, original_model, finetuned_model, output_path, dim, save_dtype,
"FluxTrainBlockSelect": FluxTrainBlockSelect,
"TrainDatasetRegularization": TrainDatasetRegularization,
"FluxTrainAndValidateLoop": FluxTrainAndValidateLoop,
"FluxTrainAndSaveLoop": FluxTrainAndSaveLoop,
"OptimizerConfigProdigyPlusScheduleFree": OptimizerConfigProdigyPlusScheduleFree,
}
NODE_DISPLAY_NAME_MAPPINGS = {
Expand All @@ -1728,5 +1799,6 @@ def extract(self, original_model, finetuned_model, output_path, dim, save_dtype,
"FluxTrainBlockSelect": "Flux Train Block Select",
"TrainDatasetRegularization": "Train Dataset Regularization",
"FluxTrainAndValidateLoop": "Flux Train And Validate Loop",
"FluxTrainAndSaveLoop": "Flux Train And Save Loop",
"OptimizerConfigProdigyPlusScheduleFree": "Optimizer Config ProdigyPlusScheduleFree",
}
13 changes: 12 additions & 1 deletion train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,6 +1006,17 @@ def load_model_hook(models, input_dir):
)

self.loss_recorder = train_util.LossRecorder()
if args.resume:
lossfile = os.path.join(args.resume, "losslist.json")
if os.path.isfile(lossfile):
with open(lossfile, 'r') as f:
self.loss_recorder.global_loss_list = json.load(f)
accelerator.print("losslist loaded")
savepointfile = os.path.join(args.resume, "savepoints.json")
if os.path.isfile(savepointfile):
with open(savepointfile, 'r') as f:
self.loss_recorder.savepoints = json.load(f)
accelerator.print("savepointlist loaded")
del train_dataset_group

pbar.update(1)
Expand Down Expand Up @@ -1087,7 +1098,7 @@ def remove_model(old_ckpt_name):
self.remove_model = remove_model
self.comfy_pbar = None

progress_bar = tqdm(range(args.max_train_steps - initial_step), smoothing=0, disable=False, desc="steps")
progress_bar = tqdm(range(args.max_train_steps - self.global_step), smoothing=0, disable=False, desc="steps")

def training_loop(break_at_steps, epoch):
steps_done = 0
Expand Down