From 684954695d7b4566a29951c11a8d8304be75af22 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 23 Nov 2023 22:17:49 +0900 Subject: [PATCH 001/103] add gradual latent --- README.md | 32 +++++ gen_img_diffusers.py | 247 ++++++++++++++++++++++++++++++++- sdxl_gen_img.py | 318 ++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 595 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 0edaca25f..3cfbb84b5 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,35 @@ +# Gradual Latent について + +latentのサイズを徐々に大きくしていくHires fixです。`sdxl_gen_img.py` に以下のオプションが追加されています。 + +- `--gradual_latent_timesteps` : latentのサイズを大きくし始めるタイムステップを指定します。デフォルトは None で、Gradual Latentを使用しません。 +- `--gradual_latent_ratio` : latentの初期サイズを指定します。デフォルトは 0.5 で、デフォルトの latent サイズの半分のサイズから始めます。 +- `--gradual_latent_ratio_step`: latentのサイズを大きくする割合を指定します。デフォルトは 0.125 で、latentのサイズを 0.625, 0.75, 0.875, 1.0 と徐々に大きくします。 +- `--gradual_latent_ratio_every_n_steps`: latentのサイズを大きくする間隔を指定します。デフォルトは 3 で、3ステップごとに latent のサイズを大きくします。 + +それぞれのオプションは、プロンプトオプション、`--glt`、`--glr`、`--gls`、`--gle` でも指定できます。 + +__サンプラーに `euler_a` を指定してください。__ 他のサンプラーでは動作しません。 + +`gen_img_diffusers.py` にも同様のオプションが追加されていますが、試した範囲ではどうやっても乱れた画像しか生成できませんでした。 + +# About Gradual Latent + +Gradual Latent is a Hires fix that gradually increases the size of the latent. `sdxl_gen_img.py` has the following options added. + +- `--gradual_latent_timesteps`: Specifies the timestep to start increasing the size of the latent. The default is None, which means Gradual Latent is not used. +- `--gradual_latent_ratio`: Specifies the initial size of the latent. The default is 0.5, which means it starts with half the default latent size. +- `--gradual_latent_ratio_step`: Specifies the ratio to increase the size of the latent. The default is 0.125, which means the latent size is gradually increased to 0.625, 0.75, 0.875, 1.0. +- `--gradual_latent_ratio_every_n_steps`: Specifies the interval to increase the size of the latent. The default is 3, which means the latent size is increased every 3 steps. + +Each option can also be specified with prompt options, `--glt`, `--glr`, `--gls`, `--gle`. + +__Please specify `euler_a` for the sampler.__ It will not work with other samplers. + +`gen_img_diffusers.py` also has the same options, but in the range I tried, it only generated distorted images no matter what I did. + +--- + __SDXL is now supported. The sdxl branch has been merged into the main branch. If you update the repository, please follow the upgrade instructions. Also, the version of accelerate has been updated, so please run accelerate config again.__ The documentation for SDXL training is [here](./README.md#sdxl-training). This repository contains training, generation and utility scripts for Stable Diffusion. diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 7661538c6..c656e6c67 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -484,6 +484,14 @@ def add_token_replacement_XTI(self, target_token_id, rep_token_ids): def set_control_nets(self, ctrl_nets): self.control_nets = ctrl_nets + def set_gradual_latent(self, gradual_latent): + if gradual_latent is None: + print("gradual_latent is disabled") + self.gradual_latent = None + else: + print(f"gradual_latent is enabled: {gradual_latent}") + self.gradual_latent = gradual_latent # (ds_ratio, start_timesteps, every_n_steps, ratio_step) + # region xformersとか使う部分:独自に書き換えるので関係なし def enable_xformers_memory_efficient_attention(self): @@ -958,7 +966,41 @@ def __call__( else: text_emb_last = text_embeddings + enable_gradual_latent = False + if self.gradual_latent: + if not hasattr(self.scheduler, "set_resized_size"): + print("gradual_latent is not supported for this scheduler. Ignoring.") + print(self.scheduler.__class__.__name__) + else: + enable_gradual_latent = True + current_ratio, start_timesteps, every_n_steps, ratio_step = self.gradual_latent + step_elapsed = 1000 + + # first, we downscale the latents to the specified ratio / 最初に指定された比率にlatentsをダウンスケールする + height, width = latents.shape[-2:] + org_dtype = latents.dtype + if org_dtype == torch.bfloat16: + latents = latents.float() + latents = torch.nn.functional.interpolate( + latents, scale_factor=current_ratio, mode="bicubic", align_corners=False + ).to(org_dtype) + for i, t in enumerate(tqdm(timesteps)): + resized_size = None + if enable_gradual_latent: + # gradually upscale the latents / latentsを徐々にアップスケールする + if t < start_timesteps and current_ratio < 1.0 and step_elapsed >= every_n_steps: + print("upscale") + current_ratio = min(current_ratio + ratio_step, 1.0) + h = int(height * current_ratio) # // 8 * 8 + w = int(width * current_ratio) # // 8 * 8 + resized_size = (h, w) + self.scheduler.set_resized_size(resized_size) + step_elapsed = 0 + else: + self.scheduler.set_resized_size(None) + step_elapsed += 1 + # expand the latents if we are doing classifier free guidance latent_model_input = latents.repeat((num_latent_input, 1, 1, 1)) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -2112,6 +2154,133 @@ def replacer(): return prompts +# endregion + +# region Gradual Latent hires fix + +import diffusers.schedulers.scheduling_euler_ancestral_discrete +from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteSchedulerOutput + + +class EulerAncestralDiscreteSchedulerGL(EulerAncestralDiscreteScheduler): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.resized_size = None + + def set_resized_size(self, size): + self.resized_size = size + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`): + Whether or not to return a + [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple. + + Returns: + [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`: + If return_dict is `True`, + [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned, + otherwise a tuple is returned where the first element is the sample tensor. + + """ + + if isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) + + if not self.is_scale_input_called: + logger.warning( + "The `scale_model_input` function should be called before `step` to ensure correct denoising. " + "See `StableDiffusionPipeline` for a usage example." + ) + + if self.step_index is None: + self._init_step_index(timestep) + + sigma = self.sigmas[self.step_index] + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + if self.config.prediction_type == "epsilon": + pred_original_sample = sample - sigma * model_output + elif self.config.prediction_type == "v_prediction": + # * c_out + input * c_skip + pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) + elif self.config.prediction_type == "sample": + raise NotImplementedError("prediction_type not implemented yet: sample") + else: + raise ValueError(f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`") + + sigma_from = self.sigmas[self.step_index] + sigma_to = self.sigmas[self.step_index + 1] + sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5 + sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 + + # 2. Convert to an ODE derivative + derivative = (sample - pred_original_sample) / sigma + + dt = sigma_down - sigma + + prev_sample = sample + derivative * dt + + device = model_output.device + if self.resized_size is None: + noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor( + model_output.shape, dtype=model_output.dtype, device=device, generator=generator + ) + else: + print( + "resized_size", self.resized_size, "model_output.shape", model_output.shape, "prev_sample.shape", prev_sample.shape + ) + org_dtype = prev_sample.dtype + if org_dtype == torch.bfloat16: + prev_sample = prev_sample.float() + + prev_sample = torch.nn.functional.interpolate( + prev_sample.float(), size=self.resized_size, mode="bicubic", align_corners=False + ).to(dtype=org_dtype) + + noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor( + (model_output.shape[0], model_output.shape[1], self.resized_size[0], self.resized_size[1]), + dtype=model_output.dtype, + device=device, + generator=generator, + ) + + prev_sample = prev_sample + noise * sigma_up + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return EulerAncestralDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) + + # endregion @@ -2249,7 +2418,7 @@ def main(args): scheduler_cls = EulerDiscreteScheduler scheduler_module = diffusers.schedulers.scheduling_euler_discrete elif args.sampler == "euler_a" or args.sampler == "k_euler_a": - scheduler_cls = EulerAncestralDiscreteScheduler + scheduler_cls = EulerAncestralDiscreteSchedulerGL scheduler_module = diffusers.schedulers.scheduling_euler_ancestral_discrete elif args.sampler == "dpmsolver" or args.sampler == "dpmsolver++": scheduler_cls = DPMSolverMultistepScheduler @@ -2505,6 +2674,16 @@ def __getattr__(self, item): if args.ds_depth_1 is not None: unet.set_deep_shrink(args.ds_depth_1, args.ds_timesteps_1, args.ds_depth_2, args.ds_timesteps_2, args.ds_ratio) + # Gradual Latent + if args.gradual_latent_ratio is not None: + gradual_latent = ( + args.gradual_latent_ratio, + args.gradual_latent_timesteps, + args.gradual_latent_every_n_steps, + args.gradual_latent_ratio_step, + ) + pipe.set_gradual_latent(gradual_latent) + # Extended Textual Inversion および Textual Inversionを処理する if args.XTI_embeddings: diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI @@ -3096,6 +3275,12 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): ds_timesteps_2 = args.ds_timesteps_2 ds_ratio = args.ds_ratio + # Gradual Latent + gl_timesteps = None # means no override + gl_ratio = args.gradual_latent_ratio + gl_every_n_steps = args.gradual_latent_every_n_steps + gl_ratio_step = args.gradual_latent_ratio_step + prompt_args = raw_prompt.strip().split(" --") prompt = prompt_args[0] print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") @@ -3202,6 +3387,34 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): print(f"deep shrink ratio: {ds_ratio}") continue + # Gradual Latent + m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent timesteps + gl_timesteps = int(m.group(1)) + print(f"gradual latent timesteps: {gl_timesteps}") + continue + + m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio + gl_ratio = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + print(f"gradual latent ratio: {ds_ratio}") + continue + + m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent every n steps + gl_every_n_steps = int(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + print(f"gradual latent every n steps: {gl_every_n_steps}") + continue + + m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio step + gl_ratio_step = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + print(f"gradual latent ratio step: {gl_ratio_step}") + continue + except ValueError as ex: print(f"Exception in parsing / 解析エラー: {parg}") print(ex) @@ -3212,6 +3425,12 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): ds_depth_1 = args.ds_depth_1 or 3 unet.set_deep_shrink(ds_depth_1, ds_timesteps_1, ds_depth_2, ds_timesteps_2, ds_ratio) + # override Gradual Latent + if gl_ratio is not None: + if gl_timesteps is None: + gl_timesteps = args.gradual_latent_timesteps or 650 + pipe.set_gradual_latent((gl_ratio, gl_timesteps, gl_every_n_steps, gl_ratio_step)) + # prepare seed if seeds is not None: # given in prompt # 数が足りないなら前のをそのまま使う @@ -3585,6 +3804,32 @@ def setup_parser() -> argparse.ArgumentParser: "--ds_ratio", type=float, default=0.5, help="Deep Shrink ratio for downsampling / Deep Shrinkのdownsampling比率" ) + # gradual latent + parser.add_argument( + "--gradual_latent_timesteps", + type=int, + default=None, + help="enable Gradual Latent hires fix and apply upscaling from this timesteps / Gradual Latent hires fixをこのtimestepsで有効にし、このtimestepsからアップスケーリングを適用する", + ) + parser.add_argument( + "--gradual_latent_ratio", + type=float, + default=0.5, + help=" this size ratio, 0.5 means 1/2 / Gradual Latent hires fixをこのサイズ比率で有効にする、0.5は1/2を意味する", + ) + parser.add_argument( + "--gradual_latent_ratio_step", + type=float, + default=0.125, + help="step to increase ratio for Gradual Latent / Gradual Latentのratioをどのくらいずつ上げるか", + ) + parser.add_argument( + "--gradual_latent_every_n_steps", + type=int, + default=3, + help="steps to increase size of latents every this steps for Gradual Latent / Gradual Latentでlatentsのサイズをこのステップごとに上げる", + ) + return parser diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index a61fb7a89..8be046433 100755 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -345,6 +345,8 @@ def __init__( self.control_nets: List[ControlNetLLLite] = [] self.control_net_enabled = True # control_netsが空ならTrueでもFalseでもControlNetは動作しない + self.gradual_latent = None + # Textual Inversion def add_token_replacement(self, text_encoder_index, target_token_id, rep_token_ids): self.token_replacements_list[text_encoder_index][target_token_id] = rep_token_ids @@ -375,6 +377,14 @@ def replace_tokens(tokens): def set_control_nets(self, ctrl_nets): self.control_nets = ctrl_nets + def set_gradual_latent(self, gradual_latent): + if gradual_latent is None: + print("gradual_latent is disabled") + self.gradual_latent = None + else: + print(f"gradual_latent is enabled: {gradual_latent}") + self.gradual_latent = gradual_latent # (ds_ratio, start_timesteps, every_n_steps, ratio_step) + @torch.no_grad() def __call__( self, @@ -706,7 +716,108 @@ def __call__( control_net.set_cond_image(None) each_control_net_enabled = [self.control_net_enabled] * len(self.control_nets) + + # # first, we downscale the latents to the half of the size + # # 最初に1/2に縮小する + # height, width = latents.shape[-2:] + # # latents = torch.nn.functional.interpolate(latents.float(), scale_factor=0.5, mode="bicubic", align_corners=False).to( + # # latents.dtype + # # ) + # latents = latents[:, :, ::2, ::2] + # current_scale = 0.5 + + # # how much to increase the scale at each step: .125 seems to work well (because it's 1/8?) + # # 各ステップに拡大率をどのくらい増やすか:.125がよさそう(たぶん1/8なので) + # scale_step = 0.125 + + # # timesteps at which to start increasing the scale: 1000 seems to be enough + # # 拡大を開始するtimesteps: 1000で十分そうである + # start_timesteps = 1000 + + # # how many steps to wait before increasing the scale again + # # small values leads to blurry images (because the latents are blurry after the upscale, so some denoising might be needed) + # # large values leads to flat images + + # # 何ステップごとに拡大するか + # # 小さいとボケる(拡大後のlatentsはボケた感じになるので、そこから数stepのdenoiseが必要と思われる) + # # 大きすぎると細部が書き込まれずのっぺりした感じになる + # every_n_steps = 5 + + # scale_step = input("scale step:") + # scale_step = float(scale_step) + # start_timesteps = input("start timesteps:") + # start_timesteps = int(start_timesteps) + # every_n_steps = input("every n steps:") + # every_n_steps = int(every_n_steps) + + # # for i, t in enumerate(tqdm(timesteps)): + # i = 0 + # last_step = 0 + # while i < len(timesteps): + # t = timesteps[i] + # print(f"[{i}] t={t}") + + # print(i, t, current_scale, latents.shape) + # if t < start_timesteps and current_scale < 1.0 and i % every_n_steps == 0: + # if i == last_step: + # pass + # else: + # print("upscale") + # current_scale = min(current_scale + scale_step, 1.0) + + # h = int(height * current_scale) // 8 * 8 + # w = int(width * current_scale) // 8 * 8 + + # latents = torch.nn.functional.interpolate(latents.float(), size=(h, w), mode="bicubic", align_corners=False).to( + # latents.dtype + # ) + # last_step = i + # i = max(0, i - every_n_steps + 1) + + # diff = timesteps[i] - timesteps[last_step] + # # resized_init_noise = torch.nn.functional.interpolate( + # # init_noise.float(), size=(h, w), mode="bicubic", align_corners=False + # # ).to(latents.dtype) + # # latents = self.scheduler.add_noise(latents, resized_init_noise, diff) + # latents = self.scheduler.add_noise(latents, torch.randn_like(latents), diff * 4) + # # latents += torch.randn_like(latents) / 100 * diff + # continue + + enable_gradual_latent = False + if self.gradual_latent: + if not hasattr(self.scheduler, "set_resized_size"): + print("gradual_latent is not supported for this scheduler. Ignoring.") + print(self.scheduler.__class__.__name__) + else: + enable_gradual_latent = True + current_ratio, start_timesteps, every_n_steps, ratio_step = self.gradual_latent + step_elapsed = 1000 + + # first, we downscale the latents to the specified ratio / 最初に指定された比率にlatentsをダウンスケールする + height, width = latents.shape[-2:] + org_dtype = latents.dtype + if org_dtype == torch.bfloat16: + latents = latents.float() + latents = torch.nn.functional.interpolate( + latents, scale_factor=current_ratio, mode="bicubic", align_corners=False + ).to(org_dtype) + for i, t in enumerate(tqdm(timesteps)): + resized_size = None + if enable_gradual_latent: + # gradually upscale the latents / latentsを徐々にアップスケールする + if t < start_timesteps and current_ratio < 1.0 and step_elapsed >= every_n_steps: + print("upscale") + current_ratio = min(current_ratio + ratio_step, 1.0) + h = int(height * current_ratio) // 8 * 8 # make divisible by 8 because size of latents must be divisible at bottom of UNet + w = int(width * current_ratio) // 8 * 8 + resized_size = (h, w) + self.scheduler.set_resized_size(resized_size) + step_elapsed = 0 + else: + self.scheduler.set_resized_size(None) + step_elapsed += 1 + # expand the latents if we are doing classifier free guidance latent_model_input = latents.repeat((num_latent_input, 1, 1, 1)) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -775,6 +886,8 @@ def __call__( if is_cancelled_callback is not None and is_cancelled_callback(): return None + i += 1 + if return_latents: return latents @@ -1306,6 +1419,133 @@ def replacer(): return prompts +# endregion + +# region Gradual Latent hires fix + +import diffusers.schedulers.scheduling_euler_ancestral_discrete +from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteSchedulerOutput + + +class EulerAncestralDiscreteSchedulerGL(EulerAncestralDiscreteScheduler): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.resized_size = None + + def set_resized_size(self, size): + self.resized_size = size + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`): + Whether or not to return a + [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple. + + Returns: + [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`: + If return_dict is `True`, + [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned, + otherwise a tuple is returned where the first element is the sample tensor. + + """ + + if isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) + + if not self.is_scale_input_called: + logger.warning( + "The `scale_model_input` function should be called before `step` to ensure correct denoising. " + "See `StableDiffusionPipeline` for a usage example." + ) + + if self.step_index is None: + self._init_step_index(timestep) + + sigma = self.sigmas[self.step_index] + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + if self.config.prediction_type == "epsilon": + pred_original_sample = sample - sigma * model_output + elif self.config.prediction_type == "v_prediction": + # * c_out + input * c_skip + pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) + elif self.config.prediction_type == "sample": + raise NotImplementedError("prediction_type not implemented yet: sample") + else: + raise ValueError(f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`") + + sigma_from = self.sigmas[self.step_index] + sigma_to = self.sigmas[self.step_index + 1] + sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5 + sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 + + # 2. Convert to an ODE derivative + derivative = (sample - pred_original_sample) / sigma + + dt = sigma_down - sigma + + prev_sample = sample + derivative * dt + + device = model_output.device + if self.resized_size is None: + noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor( + model_output.shape, dtype=model_output.dtype, device=device, generator=generator + ) + else: + print( + "resized_size", self.resized_size, "model_output.shape", model_output.shape, "prev_sample.shape", prev_sample.shape + ) + org_dtype = prev_sample.dtype + if org_dtype == torch.bfloat16: + prev_sample = prev_sample.float() + + prev_sample = torch.nn.functional.interpolate( + prev_sample.float(), size=self.resized_size, mode="bicubic", align_corners=False + ).to(dtype=org_dtype) + + noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor( + (model_output.shape[0], model_output.shape[1], self.resized_size[0], self.resized_size[1]), + dtype=model_output.dtype, + device=device, + generator=generator, + ) + + prev_sample = prev_sample + noise * sigma_up + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return EulerAncestralDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) + + # endregion @@ -1407,7 +1647,7 @@ def main(args): scheduler_module = diffusers.schedulers.scheduling_euler_discrete has_clip_sample = False elif args.sampler == "euler_a" or args.sampler == "k_euler_a": - scheduler_cls = EulerAncestralDiscreteScheduler + scheduler_cls = EulerAncestralDiscreteSchedulerGL scheduler_module = diffusers.schedulers.scheduling_euler_ancestral_discrete has_clip_sample = False elif args.sampler == "dpmsolver" or args.sampler == "dpmsolver++": @@ -1700,6 +1940,16 @@ def __getattr__(self, item): if args.ds_depth_1 is not None: unet.set_deep_shrink(args.ds_depth_1, args.ds_timesteps_1, args.ds_depth_2, args.ds_timesteps_2, args.ds_ratio) + # Gradual Latent + if args.gradual_latent_ratio is not None: + gradual_latent = ( + args.gradual_latent_ratio, + args.gradual_latent_timesteps, + args.gradual_latent_every_n_steps, + args.gradual_latent_ratio_step, + ) + pipe.set_gradual_latent(gradual_latent) + # Textual Inversionを処理する if args.textual_inversion_embeddings: token_ids_embeds1 = [] @@ -2297,6 +2547,12 @@ def scale_and_round(x): ds_timesteps_2 = args.ds_timesteps_2 ds_ratio = args.ds_ratio + # Gradual Latent + gl_timesteps = None # means no override + gl_ratio = args.gradual_latent_ratio + gl_every_n_steps = args.gradual_latent_every_n_steps + gl_ratio_step = args.gradual_latent_ratio_step + prompt_args = raw_prompt.strip().split(" --") prompt = prompt_args[0] print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") @@ -2439,6 +2695,34 @@ def scale_and_round(x): print(f"deep shrink ratio: {ds_ratio}") continue + # Gradual Latent + m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent timesteps + gl_timesteps = int(m.group(1)) + print(f"gradual latent timesteps: {gl_timesteps}") + continue + + m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio + gl_ratio = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + print(f"gradual latent ratio: {ds_ratio}") + continue + + m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent every n steps + gl_every_n_steps = int(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + print(f"gradual latent every n steps: {gl_every_n_steps}") + continue + + m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio step + gl_ratio_step = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + print(f"gradual latent ratio step: {gl_ratio_step}") + continue + except ValueError as ex: print(f"Exception in parsing / 解析エラー: {parg}") print(ex) @@ -2449,6 +2733,12 @@ def scale_and_round(x): ds_depth_1 = args.ds_depth_1 or 3 unet.set_deep_shrink(ds_depth_1, ds_timesteps_1, ds_depth_2, ds_timesteps_2, ds_ratio) + # override Gradual Latent + if gl_ratio is not None: + if gl_timesteps is None: + gl_timesteps = args.gradual_latent_timesteps or 650 + pipe.set_gradual_latent((gl_ratio, gl_timesteps, gl_every_n_steps, gl_ratio_step)) + # prepare seed if seeds is not None: # given in prompt # 数が足りないなら前のをそのまま使う @@ -2811,6 +3101,32 @@ def setup_parser() -> argparse.ArgumentParser: "--ds_ratio", type=float, default=0.5, help="Deep Shrink ratio for downsampling / Deep Shrinkのdownsampling比率" ) + # gradual latent + parser.add_argument( + "--gradual_latent_timesteps", + type=int, + default=None, + help="enable Gradual Latent hires fix and apply upscaling from this timesteps / Gradual Latent hires fixをこのtimestepsで有効にし、このtimestepsからアップスケーリングを適用する", + ) + parser.add_argument( + "--gradual_latent_ratio", + type=float, + default=0.5, + help=" this size ratio, 0.5 means 1/2 / Gradual Latent hires fixをこのサイズ比率で有効にする、0.5は1/2を意味する", + ) + parser.add_argument( + "--gradual_latent_ratio_step", + type=float, + default=0.125, + help="step to increase ratio for Gradual Latent / Gradual Latentのratioをどのくらいずつ上げるか", + ) + parser.add_argument( + "--gradual_latent_every_n_steps", + type=int, + default=3, + help="steps to increase size of latents every this steps for Gradual Latent / Gradual Latentでlatentsのサイズをこのステップごとに上げる", + ) + # # parser.add_argument( # "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像" # ) From 610566fbb922390c3d33898219eb85b630477dd2 Mon Sep 17 00:00:00 2001 From: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Thu, 23 Nov 2023 22:22:36 +0900 Subject: [PATCH 002/103] Update README.md --- README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 3cfbb84b5..f8a269c73 100644 --- a/README.md +++ b/README.md @@ -2,14 +2,14 @@ latentのサイズを徐々に大きくしていくHires fixです。`sdxl_gen_img.py` に以下のオプションが追加されています。 -- `--gradual_latent_timesteps` : latentのサイズを大きくし始めるタイムステップを指定します。デフォルトは None で、Gradual Latentを使用しません。 +- `--gradual_latent_timesteps` : latentのサイズを大きくし始めるタイムステップを指定します。デフォルトは None で、Gradual Latentを使用しません。750 くらいから始めてみてください。 - `--gradual_latent_ratio` : latentの初期サイズを指定します。デフォルトは 0.5 で、デフォルトの latent サイズの半分のサイズから始めます。 - `--gradual_latent_ratio_step`: latentのサイズを大きくする割合を指定します。デフォルトは 0.125 で、latentのサイズを 0.625, 0.75, 0.875, 1.0 と徐々に大きくします。 - `--gradual_latent_ratio_every_n_steps`: latentのサイズを大きくする間隔を指定します。デフォルトは 3 で、3ステップごとに latent のサイズを大きくします。 それぞれのオプションは、プロンプトオプション、`--glt`、`--glr`、`--gls`、`--gle` でも指定できます。 -__サンプラーに `euler_a` を指定してください。__ 他のサンプラーでは動作しません。 +サンプラーに手を加えているため、__サンプラーに `euler_a` を指定してください。__ 他のサンプラーでは動作しません。 `gen_img_diffusers.py` にも同様のオプションが追加されていますが、試した範囲ではどうやっても乱れた画像しか生成できませんでした。 @@ -17,14 +17,14 @@ __サンプラーに `euler_a` を指定してください。__ 他のサンプ Gradual Latent is a Hires fix that gradually increases the size of the latent. `sdxl_gen_img.py` has the following options added. -- `--gradual_latent_timesteps`: Specifies the timestep to start increasing the size of the latent. The default is None, which means Gradual Latent is not used. +- `--gradual_latent_timesteps`: Specifies the timestep to start increasing the size of the latent. The default is None, which means Gradual Latent is not used. Please try around 750 at first. - `--gradual_latent_ratio`: Specifies the initial size of the latent. The default is 0.5, which means it starts with half the default latent size. - `--gradual_latent_ratio_step`: Specifies the ratio to increase the size of the latent. The default is 0.125, which means the latent size is gradually increased to 0.625, 0.75, 0.875, 1.0. - `--gradual_latent_ratio_every_n_steps`: Specifies the interval to increase the size of the latent. The default is 3, which means the latent size is increased every 3 steps. Each option can also be specified with prompt options, `--glt`, `--glr`, `--gls`, `--gle`. -__Please specify `euler_a` for the sampler.__ It will not work with other samplers. +__Please specify `euler_a` for the sampler.__ Because the source code of the sampler is modified. It will not work with other samplers. `gen_img_diffusers.py` also has the same options, but in the range I tried, it only generated distorted images no matter what I did. From 298c6c2343f7a2d1eff7bc495f2a732b82344ab2 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 26 Nov 2023 21:48:36 +0900 Subject: [PATCH 003/103] fix gradual latent cannot be disabled --- sdxl_gen_img.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index c8bd38dd3..bab164f53 100755 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -809,7 +809,9 @@ def __call__( if t < start_timesteps and current_ratio < 1.0 and step_elapsed >= every_n_steps: print("upscale") current_ratio = min(current_ratio + ratio_step, 1.0) - h = int(height * current_ratio) // 8 * 8 # make divisible by 8 because size of latents must be divisible at bottom of UNet + h = ( + int(height * current_ratio) // 8 * 8 + ) # make divisible by 8 because size of latents must be divisible at bottom of UNet w = int(width * current_ratio) // 8 * 8 resized_size = (h, w) self.scheduler.set_resized_size(resized_size) @@ -1946,7 +1948,7 @@ def __getattr__(self, item): unet.set_deep_shrink(args.ds_depth_1, args.ds_timesteps_1, args.ds_depth_2, args.ds_timesteps_2, args.ds_ratio) # Gradual Latent - if args.gradual_latent_ratio is not None: + if args.gradual_latent_timesteps is not None: gradual_latent = ( args.gradual_latent_ratio, args.gradual_latent_timesteps, @@ -2739,8 +2741,8 @@ def scale_and_round(x): unet.set_deep_shrink(ds_depth_1, ds_timesteps_1, ds_depth_2, ds_timesteps_2, ds_ratio) # override Gradual Latent - if gl_ratio is not None: - if gl_timesteps is None: + if gl_timesteps is not None: + if gl_timesteps < 0: gl_timesteps = args.gradual_latent_timesteps or 650 pipe.set_gradual_latent((gl_ratio, gl_timesteps, gl_every_n_steps, gl_ratio_step)) From 2c50ea0403c02fc54d45a8636828bbe30363d8e7 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 27 Nov 2023 23:50:21 +0900 Subject: [PATCH 004/103] apply unsharp mask --- sdxl_gen_img.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index bab164f53..9e94d4f6c 100755 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -802,6 +802,10 @@ def __call__( latents, scale_factor=current_ratio, mode="bicubic", align_corners=False ).to(org_dtype) + # apply unsharp mask / アンシャープマスクを適用する + blurred = torchvision.transforms.transforms.GaussianBlur(3, sigma=(0.5, 0.5))(latents) + latents = latents + (latents - blurred) * 0.5 + for i, t in enumerate(tqdm(timesteps)): resized_size = None if enable_gradual_latent: @@ -1434,8 +1438,9 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.resized_size = None - def set_resized_size(self, size): + def set_resized_size(self, size, s_noise=0.5): self.resized_size = size + self.s_noise = s_noise def step( self, @@ -1518,6 +1523,7 @@ def step( noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor( model_output.shape, dtype=model_output.dtype, device=device, generator=generator ) + s_noise = 1.0 else: print( "resized_size", self.resized_size, "model_output.shape", model_output.shape, "prev_sample.shape", prev_sample.shape @@ -1530,14 +1536,19 @@ def step( prev_sample.float(), size=self.resized_size, mode="bicubic", align_corners=False ).to(dtype=org_dtype) + # apply unsharp mask / アンシャープマスクを適用する + blurred = torchvision.transforms.transforms.GaussianBlur(3, sigma=(0.5, 0.5))(prev_sample) + prev_sample = prev_sample + (prev_sample - blurred) * 0.5 + noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor( (model_output.shape[0], model_output.shape[1], self.resized_size[0], self.resized_size[1]), dtype=model_output.dtype, device=device, generator=generator, ) + s_noise = self.s_noise - prev_sample = prev_sample + noise * sigma_up + prev_sample = prev_sample + noise * sigma_up * s_noise # upon completion increase step index by one self._step_index += 1 From 29b6fa621280d4cc16d30200e147c35dd4bbbc45 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 28 Nov 2023 22:33:22 +0900 Subject: [PATCH 005/103] add unsharp mask --- gen_img_diffusers.py | 213 +++++++++++++++--------------------------- library/utils.py | 182 +++++++++++++++++++++++++++++++++++- sdxl_gen_img.py | 215 +++++++++++++------------------------------ 3 files changed, 318 insertions(+), 292 deletions(-) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index fb7866fc4..6428ef8af 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -107,6 +107,7 @@ from tools.original_control_net import ControlNetInfo from library.original_unet import UNet2DConditionModel, InferUNet2DConditionModel from library.original_unet import FlashAttentionFunction +from library.utils import GradualLatent, EulerAncestralDiscreteSchedulerGL from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI @@ -454,6 +455,8 @@ def __init__( self.control_nets: List[ControlNetInfo] = [] self.control_net_enabled = True # control_netsが空ならTrueでもFalseでもControlNetは動作しない + self.gradual_latent: GradualLatent = None + # Textual Inversion def add_token_replacement(self, target_token_id, rep_token_ids): self.token_replacements[target_token_id] = rep_token_ids @@ -968,13 +971,13 @@ def __call__( enable_gradual_latent = False if self.gradual_latent: - if not hasattr(self.scheduler, "set_resized_size"): + if not hasattr(self.scheduler, "set_gradual_latent_params"): print("gradual_latent is not supported for this scheduler. Ignoring.") print(self.scheduler.__class__.__name__) else: enable_gradual_latent = True - current_ratio, start_timesteps, every_n_steps, ratio_step = self.gradual_latent step_elapsed = 1000 + current_ratio = self.gradual_latent.ratio # first, we downscale the latents to the specified ratio / 最初に指定された比率にlatentsをダウンスケールする height, width = latents.shape[-2:] @@ -985,20 +988,28 @@ def __call__( latents, scale_factor=current_ratio, mode="bicubic", align_corners=False ).to(org_dtype) + # apply unsharp mask / アンシャープマスクを適用する + if self.gradual_latent.gaussian_blur_ksize: + latents = self.gradual_latent.apply_unshark_mask(latents) + for i, t in enumerate(tqdm(timesteps)): resized_size = None if enable_gradual_latent: # gradually upscale the latents / latentsを徐々にアップスケールする - if t < start_timesteps and current_ratio < 1.0 and step_elapsed >= every_n_steps: - print("upscale") - current_ratio = min(current_ratio + ratio_step, 1.0) - h = int(height * current_ratio) # // 8 * 8 - w = int(width * current_ratio) # // 8 * 8 + if ( + t < self.gradual_latent.start_timesteps + and current_ratio < 1.0 + and step_elapsed >= self.gradual_latent.every_n_steps + ): + current_ratio = min(current_ratio + self.gradual_latent.ratio_step, 1.0) + # make divisible by 8 because size of latents must be divisible at bottom of UNet + h = int(height * current_ratio) // 8 * 8 + w = int(width * current_ratio) // 8 * 8 resized_size = (h, w) - self.scheduler.set_resized_size(resized_size) + self.scheduler.set_gradual_latent_params(resized_size, self.gradual_latent) step_elapsed = 0 else: - self.scheduler.set_resized_size(None) + self.scheduler.set_gradual_latent_params(None, None) step_elapsed += 1 # expand the latents if we are doing classifier free guidance @@ -2154,133 +2165,6 @@ def replacer(): return prompts -# endregion - -# region Gradual Latent hires fix - -import diffusers.schedulers.scheduling_euler_ancestral_discrete -from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteSchedulerOutput - - -class EulerAncestralDiscreteSchedulerGL(EulerAncestralDiscreteScheduler): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.resized_size = None - - def set_resized_size(self, size): - self.resized_size = size - - def step( - self, - model_output: torch.FloatTensor, - timestep: Union[float, torch.FloatTensor], - sample: torch.FloatTensor, - generator: Optional[torch.Generator] = None, - return_dict: bool = True, - ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]: - """ - Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion - process from the learned model outputs (most often the predicted noise). - - Args: - model_output (`torch.FloatTensor`): - The direct output from learned diffusion model. - timestep (`float`): - The current discrete timestep in the diffusion chain. - sample (`torch.FloatTensor`): - A current instance of a sample created by the diffusion process. - generator (`torch.Generator`, *optional*): - A random number generator. - return_dict (`bool`): - Whether or not to return a - [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple. - - Returns: - [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`: - If return_dict is `True`, - [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned, - otherwise a tuple is returned where the first element is the sample tensor. - - """ - - if isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor): - raise ValueError( - ( - "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" - " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" - " one of the `scheduler.timesteps` as a timestep." - ), - ) - - if not self.is_scale_input_called: - logger.warning( - "The `scale_model_input` function should be called before `step` to ensure correct denoising. " - "See `StableDiffusionPipeline` for a usage example." - ) - - if self.step_index is None: - self._init_step_index(timestep) - - sigma = self.sigmas[self.step_index] - - # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise - if self.config.prediction_type == "epsilon": - pred_original_sample = sample - sigma * model_output - elif self.config.prediction_type == "v_prediction": - # * c_out + input * c_skip - pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) - elif self.config.prediction_type == "sample": - raise NotImplementedError("prediction_type not implemented yet: sample") - else: - raise ValueError(f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`") - - sigma_from = self.sigmas[self.step_index] - sigma_to = self.sigmas[self.step_index + 1] - sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5 - sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 - - # 2. Convert to an ODE derivative - derivative = (sample - pred_original_sample) / sigma - - dt = sigma_down - sigma - - prev_sample = sample + derivative * dt - - device = model_output.device - if self.resized_size is None: - noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor( - model_output.shape, dtype=model_output.dtype, device=device, generator=generator - ) - else: - print( - "resized_size", self.resized_size, "model_output.shape", model_output.shape, "prev_sample.shape", prev_sample.shape - ) - org_dtype = prev_sample.dtype - if org_dtype == torch.bfloat16: - prev_sample = prev_sample.float() - - prev_sample = torch.nn.functional.interpolate( - prev_sample.float(), size=self.resized_size, mode="bicubic", align_corners=False - ).to(dtype=org_dtype) - - noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor( - (model_output.shape[0], model_output.shape[1], self.resized_size[0], self.resized_size[1]), - dtype=model_output.dtype, - device=device, - generator=generator, - ) - - prev_sample = prev_sample + noise * sigma_up - - # upon completion increase step index by one - self._step_index += 1 - - if not return_dict: - return (prev_sample,) - - return EulerAncestralDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) - - # endregion @@ -2683,12 +2567,22 @@ def __getattr__(self, item): unet.set_deep_shrink(args.ds_depth_1, args.ds_timesteps_1, args.ds_depth_2, args.ds_timesteps_2, args.ds_ratio) # Gradual Latent - if args.gradual_latent_ratio is not None: - gradual_latent = ( + if args.gradual_latent_timesteps is not None: + if args.gradual_latent_unsharp_params: + ksize, sigma, strength = [float(v) for v in args.gradual_latent_unsharp_params.split(",")] + ksize = int(ksize) + else: + ksize, sigma, strength = None, None, None + + gradual_latent = GradualLatent( args.gradual_latent_ratio, args.gradual_latent_timesteps, args.gradual_latent_every_n_steps, args.gradual_latent_ratio_step, + args.gradual_latent_s_noise, + ksize, + sigma, + strength, ) pipe.set_gradual_latent(gradual_latent) @@ -3288,6 +3182,8 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): gl_ratio = args.gradual_latent_ratio gl_every_n_steps = args.gradual_latent_every_n_steps gl_ratio_step = args.gradual_latent_ratio_step + gl_s_noise = args.gradual_latent_s_noise + gl_unsharp_params = args.gradual_latent_unsharp_params prompt_args = raw_prompt.strip().split(" --") prompt = prompt_args[0] @@ -3423,6 +3319,20 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): print(f"gradual latent ratio step: {gl_ratio_step}") continue + m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent s noise + gl_s_noise = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + print(f"gradual latent s noise: {gl_s_noise}") + continue + + m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE) + if m: # gradual latent unsharp params + gl_unsharp_params = m.group(1) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + print(f"gradual latent unsharp params: {gl_unsharp_params}") + continue + except ValueError as ex: print(f"Exception in parsing / 解析エラー: {parg}") print(ex) @@ -3434,10 +3344,18 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): unet.set_deep_shrink(ds_depth_1, ds_timesteps_1, ds_depth_2, ds_timesteps_2, ds_ratio) # override Gradual Latent - if gl_ratio is not None: - if gl_timesteps is None: + if gl_timesteps is not None: + if gl_timesteps < 0: gl_timesteps = args.gradual_latent_timesteps or 650 - pipe.set_gradual_latent((gl_ratio, gl_timesteps, gl_every_n_steps, gl_ratio_step)) + if gl_unsharp_params is not None: + ksize, sigma, strength = [float(v) for v in gl_unsharp_params.split(",")] + ksize = int(ksize) + else: + ksize, sigma, strength = None, None, None + gradual_latent = GradualLatent( + gl_ratio, gl_timesteps, gl_every_n_steps, gl_ratio_step, gl_s_noise, ksize, sigma, strength + ) + pipe.set_gradual_latent(gradual_latent) # prepare seed if seeds is not None: # given in prompt @@ -3837,6 +3755,19 @@ def setup_parser() -> argparse.ArgumentParser: default=3, help="steps to increase size of latents every this steps for Gradual Latent / Gradual Latentでlatentsのサイズをこのステップごとに上げる", ) + parser.add_argument( + "--gradual_latent_s_noise", + type=float, + default=1.0, + help="s_noise for Gradual Latent / Gradual Latentのs_noise", + ) + parser.add_argument( + "--gradual_latent_unsharp_params", + type=str, + default=None, + help="unsharp mask parameters for Gradual Latent: ksize, sigma, strength. `3,0.5,0.5` is recommended /" + + " Gradual Latentのunsharp maskのパラメータ: ksize, sigma, strength. `3,0.5,0.5` が推奨", + ) return parser diff --git a/library/utils.py b/library/utils.py index 7d801a676..1ccb141c8 100644 --- a/library/utils.py +++ b/library/utils.py @@ -1,6 +1,186 @@ import threading +import torch +from torchvision import transforms from typing import * +from diffusers import EulerAncestralDiscreteScheduler +import diffusers.schedulers.scheduling_euler_ancestral_discrete +from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteSchedulerOutput def fire_in_thread(f, *args, **kwargs): - threading.Thread(target=f, args=args, kwargs=kwargs).start() \ No newline at end of file + threading.Thread(target=f, args=args, kwargs=kwargs).start() + + +# TODO make inf_utils.py + + +# region Gradual Latent hires fix + + +class GradualLatent: + def __init__( + self, + ratio, + start_timesteps, + every_n_steps, + ratio_step, + s_noise=1.0, + gaussian_blur_ksize=None, + gaussian_blur_sigma=0.5, + gaussian_blur_strength=0.5, + ): + self.ratio = ratio + self.start_timesteps = start_timesteps + self.every_n_steps = every_n_steps + self.ratio_step = ratio_step + self.s_noise = s_noise + self.gaussian_blur_ksize = gaussian_blur_ksize + self.gaussian_blur_sigma = gaussian_blur_sigma + self.gaussian_blur_strength = gaussian_blur_strength + + def __str__(self) -> str: + return ( + f"GradualLatent(ratio={self.ratio}, start_timesteps={self.start_timesteps}, " + + f"every_n_steps={self.every_n_steps}, ratio_step={self.ratio_step}, s_noise={self.s_noise}, " + + f"gaussian_blur_ksize={self.gaussian_blur_ksize}, gaussian_blur_sigma={self.gaussian_blur_sigma}, gaussian_blur_strength={self.gaussian_blur_strength})" + ) + + def apply_unshark_mask(self, x: torch.Tensor): + if self.gaussian_blur_ksize is None: + return x + blurred = transforms.functional.gaussian_blur(x, self.gaussian_blur_ksize, self.gaussian_blur_sigma) + # mask = torch.sigmoid((x - blurred) * self.gaussian_blur_strength) + mask = (x - blurred) * self.gaussian_blur_strength + sharpened = x + mask + return sharpened + + +class EulerAncestralDiscreteSchedulerGL(EulerAncestralDiscreteScheduler): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.resized_size = None + self.gradual_latent = None + + def set_gradual_latent_params(self, size, gradual_latent: GradualLatent): + self.resized_size = size + self.gradual_latent = gradual_latent + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`): + Whether or not to return a + [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple. + + Returns: + [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`: + If return_dict is `True`, + [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned, + otherwise a tuple is returned where the first element is the sample tensor. + + """ + + if isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) + + if not self.is_scale_input_called: + # logger.warning( + print( + "The `scale_model_input` function should be called before `step` to ensure correct denoising. " + "See `StableDiffusionPipeline` for a usage example." + ) + + if self.step_index is None: + self._init_step_index(timestep) + + sigma = self.sigmas[self.step_index] + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + if self.config.prediction_type == "epsilon": + pred_original_sample = sample - sigma * model_output + elif self.config.prediction_type == "v_prediction": + # * c_out + input * c_skip + pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) + elif self.config.prediction_type == "sample": + raise NotImplementedError("prediction_type not implemented yet: sample") + else: + raise ValueError(f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`") + + sigma_from = self.sigmas[self.step_index] + sigma_to = self.sigmas[self.step_index + 1] + sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5 + sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 + + # 2. Convert to an ODE derivative + derivative = (sample - pred_original_sample) / sigma + + dt = sigma_down - sigma + + prev_sample = sample + derivative * dt + + device = model_output.device + if self.resized_size is None: + noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor( + model_output.shape, dtype=model_output.dtype, device=device, generator=generator + ) + s_noise = 1.0 + else: + print( + "resized_size", self.resized_size, "model_output.shape", model_output.shape, "prev_sample.shape", prev_sample.shape + ) + org_dtype = prev_sample.dtype + if org_dtype == torch.bfloat16: + prev_sample = prev_sample.float() + + prev_sample = torch.nn.functional.interpolate( + prev_sample.float(), size=self.resized_size, mode="bicubic", align_corners=False + ).to(dtype=org_dtype) + + # apply unsharp mask / アンシャープマスクを適用する + if self.gradual_latent.gaussian_blur_ksize: + prev_sample = self.gradual_latent.apply_unshark_mask(prev_sample) + + noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor( + (model_output.shape[0], model_output.shape[1], self.resized_size[0], self.resized_size[1]), + dtype=model_output.dtype, + device=device, + generator=generator, + ) + s_noise = self.gradual_latent.s_noise + + prev_sample = prev_sample + noise * sigma_up * s_noise + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return EulerAncestralDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) + + +# endregion diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index 9e94d4f6c..880318038 100755 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -60,6 +60,7 @@ from library.sdxl_original_unet import InferSdxlUNet2DConditionModel from library.original_unet import FlashAttentionFunction from networks.control_net_lllite import ControlNetLLLite +from library.utils import GradualLatent, EulerAncestralDiscreteSchedulerGL # scheduler: SCHEDULER_LINEAR_START = 0.00085 @@ -345,7 +346,7 @@ def __init__( self.control_nets: List[ControlNetLLLite] = [] self.control_net_enabled = True # control_netsが空ならTrueでもFalseでもControlNetは動作しない - self.gradual_latent = None + self.gradual_latent: GradualLatent = None # Textual Inversion def add_token_replacement(self, text_encoder_index, target_token_id, rep_token_ids): @@ -785,13 +786,13 @@ def __call__( enable_gradual_latent = False if self.gradual_latent: - if not hasattr(self.scheduler, "set_resized_size"): + if not hasattr(self.scheduler, "set_gradual_latent_params"): print("gradual_latent is not supported for this scheduler. Ignoring.") print(self.scheduler.__class__.__name__) else: enable_gradual_latent = True - current_ratio, start_timesteps, every_n_steps, ratio_step = self.gradual_latent step_elapsed = 1000 + current_ratio = self.gradual_latent.ratio # first, we downscale the latents to the specified ratio / 最初に指定された比率にlatentsをダウンスケールする height, width = latents.shape[-2:] @@ -803,25 +804,27 @@ def __call__( ).to(org_dtype) # apply unsharp mask / アンシャープマスクを適用する - blurred = torchvision.transforms.transforms.GaussianBlur(3, sigma=(0.5, 0.5))(latents) - latents = latents + (latents - blurred) * 0.5 + if self.gradual_latent.gaussian_blur_ksize: + latents = self.gradual_latent.apply_unshark_mask(latents) for i, t in enumerate(tqdm(timesteps)): resized_size = None if enable_gradual_latent: # gradually upscale the latents / latentsを徐々にアップスケールする - if t < start_timesteps and current_ratio < 1.0 and step_elapsed >= every_n_steps: - print("upscale") - current_ratio = min(current_ratio + ratio_step, 1.0) - h = ( - int(height * current_ratio) // 8 * 8 - ) # make divisible by 8 because size of latents must be divisible at bottom of UNet + if ( + t < self.gradual_latent.start_timesteps + and current_ratio < 1.0 + and step_elapsed >= self.gradual_latent.every_n_steps + ): + current_ratio = min(current_ratio + self.gradual_latent.ratio_step, 1.0) + # make divisible by 8 because size of latents must be divisible at bottom of UNet + h = int(height * current_ratio) // 8 * 8 w = int(width * current_ratio) // 8 * 8 resized_size = (h, w) - self.scheduler.set_resized_size(resized_size) + self.scheduler.set_gradual_latent_params(resized_size, self.gradual_latent) step_elapsed = 0 else: - self.scheduler.set_resized_size(None) + self.scheduler.set_gradual_latent_params(None, None) step_elapsed += 1 # expand the latents if we are doing classifier free guidance @@ -1427,141 +1430,6 @@ def replacer(): # endregion -# region Gradual Latent hires fix - -import diffusers.schedulers.scheduling_euler_ancestral_discrete -from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteSchedulerOutput - - -class EulerAncestralDiscreteSchedulerGL(EulerAncestralDiscreteScheduler): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.resized_size = None - - def set_resized_size(self, size, s_noise=0.5): - self.resized_size = size - self.s_noise = s_noise - - def step( - self, - model_output: torch.FloatTensor, - timestep: Union[float, torch.FloatTensor], - sample: torch.FloatTensor, - generator: Optional[torch.Generator] = None, - return_dict: bool = True, - ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]: - """ - Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion - process from the learned model outputs (most often the predicted noise). - - Args: - model_output (`torch.FloatTensor`): - The direct output from learned diffusion model. - timestep (`float`): - The current discrete timestep in the diffusion chain. - sample (`torch.FloatTensor`): - A current instance of a sample created by the diffusion process. - generator (`torch.Generator`, *optional*): - A random number generator. - return_dict (`bool`): - Whether or not to return a - [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple. - - Returns: - [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`: - If return_dict is `True`, - [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned, - otherwise a tuple is returned where the first element is the sample tensor. - - """ - - if isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor): - raise ValueError( - ( - "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" - " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" - " one of the `scheduler.timesteps` as a timestep." - ), - ) - - if not self.is_scale_input_called: - logger.warning( - "The `scale_model_input` function should be called before `step` to ensure correct denoising. " - "See `StableDiffusionPipeline` for a usage example." - ) - - if self.step_index is None: - self._init_step_index(timestep) - - sigma = self.sigmas[self.step_index] - - # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise - if self.config.prediction_type == "epsilon": - pred_original_sample = sample - sigma * model_output - elif self.config.prediction_type == "v_prediction": - # * c_out + input * c_skip - pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) - elif self.config.prediction_type == "sample": - raise NotImplementedError("prediction_type not implemented yet: sample") - else: - raise ValueError(f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`") - - sigma_from = self.sigmas[self.step_index] - sigma_to = self.sigmas[self.step_index + 1] - sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5 - sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 - - # 2. Convert to an ODE derivative - derivative = (sample - pred_original_sample) / sigma - - dt = sigma_down - sigma - - prev_sample = sample + derivative * dt - - device = model_output.device - if self.resized_size is None: - noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor( - model_output.shape, dtype=model_output.dtype, device=device, generator=generator - ) - s_noise = 1.0 - else: - print( - "resized_size", self.resized_size, "model_output.shape", model_output.shape, "prev_sample.shape", prev_sample.shape - ) - org_dtype = prev_sample.dtype - if org_dtype == torch.bfloat16: - prev_sample = prev_sample.float() - - prev_sample = torch.nn.functional.interpolate( - prev_sample.float(), size=self.resized_size, mode="bicubic", align_corners=False - ).to(dtype=org_dtype) - - # apply unsharp mask / アンシャープマスクを適用する - blurred = torchvision.transforms.transforms.GaussianBlur(3, sigma=(0.5, 0.5))(prev_sample) - prev_sample = prev_sample + (prev_sample - blurred) * 0.5 - - noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor( - (model_output.shape[0], model_output.shape[1], self.resized_size[0], self.resized_size[1]), - dtype=model_output.dtype, - device=device, - generator=generator, - ) - s_noise = self.s_noise - - prev_sample = prev_sample + noise * sigma_up * s_noise - - # upon completion increase step index by one - self._step_index += 1 - - if not return_dict: - return (prev_sample,) - - return EulerAncestralDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) - - -# endregion - - # def load_clip_l14_336(dtype): # print(f"loading CLIP: {CLIP_ID_L14_336}") # text_encoder = CLIPTextModel.from_pretrained(CLIP_ID_L14_336, torch_dtype=dtype) @@ -1960,11 +1828,21 @@ def __getattr__(self, item): # Gradual Latent if args.gradual_latent_timesteps is not None: - gradual_latent = ( + if args.gradual_latent_unsharp_params: + ksize, sigma, strength = [float(v) for v in args.gradual_latent_unsharp_params.split(",")] + ksize = int(ksize) + else: + ksize, sigma, strength = None, None, None + + gradual_latent = GradualLatent( args.gradual_latent_ratio, args.gradual_latent_timesteps, args.gradual_latent_every_n_steps, args.gradual_latent_ratio_step, + args.gradual_latent_s_noise, + ksize, + sigma, + strength, ) pipe.set_gradual_latent(gradual_latent) @@ -2570,6 +2448,8 @@ def scale_and_round(x): gl_ratio = args.gradual_latent_ratio gl_every_n_steps = args.gradual_latent_every_n_steps gl_ratio_step = args.gradual_latent_ratio_step + gl_s_noise = args.gradual_latent_s_noise + gl_unsharp_params = args.gradual_latent_unsharp_params prompt_args = raw_prompt.strip().split(" --") prompt = prompt_args[0] @@ -2741,6 +2621,20 @@ def scale_and_round(x): print(f"gradual latent ratio step: {gl_ratio_step}") continue + m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent s noise + gl_s_noise = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + print(f"gradual latent s noise: {gl_s_noise}") + continue + + m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE) + if m: # gradual latent unsharp params + gl_unsharp_params = m.group(1) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + print(f"gradual latent unsharp params: {gl_unsharp_params}") + continue + except ValueError as ex: print(f"Exception in parsing / 解析エラー: {parg}") print(ex) @@ -2755,7 +2649,15 @@ def scale_and_round(x): if gl_timesteps is not None: if gl_timesteps < 0: gl_timesteps = args.gradual_latent_timesteps or 650 - pipe.set_gradual_latent((gl_ratio, gl_timesteps, gl_every_n_steps, gl_ratio_step)) + if gl_unsharp_params is not None: + ksize, sigma, strength = [float(v) for v in gl_unsharp_params.split(",")] + ksize = int(ksize) + else: + ksize, sigma, strength = None, None, None + gradual_latent = GradualLatent( + gl_ratio, gl_timesteps, gl_every_n_steps, gl_ratio_step, gl_s_noise, ksize, sigma, strength + ) + pipe.set_gradual_latent(gradual_latent) # prepare seed if seeds is not None: # given in prompt @@ -3144,6 +3046,19 @@ def setup_parser() -> argparse.ArgumentParser: default=3, help="steps to increase size of latents every this steps for Gradual Latent / Gradual Latentでlatentsのサイズをこのステップごとに上げる", ) + parser.add_argument( + "--gradual_latent_s_noise", + type=float, + default=1.0, + help="s_noise for Gradual Latent / Gradual Latentのs_noise", + ) + parser.add_argument( + "--gradual_latent_unsharp_params", + type=str, + default=None, + help="unsharp mask parameters for Gradual Latent: ksize, sigma, strength. `3,0.5,0.5` is recommended /" + + " Gradual Latentのunsharp maskのパラメータ: ksize, sigma, strength. `3,0.5,0.5` が推奨", + ) # # parser.add_argument( # "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像" From 2952bca5205fb020ab415c27fc7556b2de5a8042 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 1 Dec 2023 21:56:08 +0900 Subject: [PATCH 006/103] fix strength error --- sdxl_gen_img.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index 880318038..35d6575e9 100755 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -1829,10 +1829,10 @@ def __getattr__(self, item): # Gradual Latent if args.gradual_latent_timesteps is not None: if args.gradual_latent_unsharp_params: - ksize, sigma, strength = [float(v) for v in args.gradual_latent_unsharp_params.split(",")] - ksize = int(ksize) + us_ksize, us_sigma, us_strength = [float(v) for v in args.gradual_latent_unsharp_params.split(",")] + us_ksize = int(us_ksize) else: - ksize, sigma, strength = None, None, None + us_ksize, us_sigma, us_strength = None, None, None gradual_latent = GradualLatent( args.gradual_latent_ratio, @@ -1840,9 +1840,9 @@ def __getattr__(self, item): args.gradual_latent_every_n_steps, args.gradual_latent_ratio_step, args.gradual_latent_s_noise, - ksize, - sigma, - strength, + us_ksize, + us_sigma, + us_strength, ) pipe.set_gradual_latent(gradual_latent) @@ -2650,12 +2650,12 @@ def scale_and_round(x): if gl_timesteps < 0: gl_timesteps = args.gradual_latent_timesteps or 650 if gl_unsharp_params is not None: - ksize, sigma, strength = [float(v) for v in gl_unsharp_params.split(",")] - ksize = int(ksize) + us_ksize, us_sigma, us_strength = [float(v) for v in gl_unsharp_params.split(",")] + us_ksize = int(us_ksize) else: - ksize, sigma, strength = None, None, None + us_ksize, us_sigma, us_strength = None, None, None gradual_latent = GradualLatent( - gl_ratio, gl_timesteps, gl_every_n_steps, gl_ratio_step, gl_s_noise, ksize, sigma, strength + gl_ratio, gl_timesteps, gl_every_n_steps, gl_ratio_step, gl_s_noise, us_ksize, us_sigma, us_strength ) pipe.set_gradual_latent(gradual_latent) From 7a4e50705cc59411139c18ebd62176d6b120de16 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 3 Dec 2023 17:59:41 +0900 Subject: [PATCH 007/103] add target_x flag (not sure this impl is correct) --- gen_img_diffusers.py | 38 +++++++++++++++++++++++++------------ library/utils.py | 45 +++++++++++++++++++++++++++----------------- sdxl_gen_img.py | 27 +++++++++++++++++++------- 3 files changed, 74 insertions(+), 36 deletions(-) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 6428ef8af..74f3852de 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -2569,10 +2569,12 @@ def __getattr__(self, item): # Gradual Latent if args.gradual_latent_timesteps is not None: if args.gradual_latent_unsharp_params: - ksize, sigma, strength = [float(v) for v in args.gradual_latent_unsharp_params.split(",")] - ksize = int(ksize) + us_params = args.gradual_latent_unsharp_params.split(",") + us_ksize, us_sigma, us_strength = [float(v) for v in us_params[:3]] + us_target_x = True if len(us_params) <= 3 else bool(int(us_params[3])) + us_ksize = int(us_ksize) else: - ksize, sigma, strength = None, None, None + us_ksize, us_sigma, us_strength, us_target_x = None, None, None, None gradual_latent = GradualLatent( args.gradual_latent_ratio, @@ -2580,9 +2582,10 @@ def __getattr__(self, item): args.gradual_latent_every_n_steps, args.gradual_latent_ratio_step, args.gradual_latent_s_noise, - ksize, - sigma, - strength, + us_ksize, + us_sigma, + us_strength, + us_target_x, ) pipe.set_gradual_latent(gradual_latent) @@ -3348,12 +3351,23 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): if gl_timesteps < 0: gl_timesteps = args.gradual_latent_timesteps or 650 if gl_unsharp_params is not None: - ksize, sigma, strength = [float(v) for v in gl_unsharp_params.split(",")] - ksize = int(ksize) + unsharp_params = gl_unsharp_params.split(",") + us_ksize, us_sigma, us_strength = [float(v) for v in unsharp_params[:3]] + print(unsharp_params) + us_target_x = True if len(unsharp_params) < 4 else bool(int(unsharp_params[3])) + us_ksize = int(us_ksize) else: - ksize, sigma, strength = None, None, None + us_ksize, us_sigma, us_strength, us_target_x = None, None, None, None gradual_latent = GradualLatent( - gl_ratio, gl_timesteps, gl_every_n_steps, gl_ratio_step, gl_s_noise, ksize, sigma, strength + gl_ratio, + gl_timesteps, + gl_every_n_steps, + gl_ratio_step, + gl_s_noise, + us_ksize, + us_sigma, + us_strength, + us_target_x, ) pipe.set_gradual_latent(gradual_latent) @@ -3765,8 +3779,8 @@ def setup_parser() -> argparse.ArgumentParser: "--gradual_latent_unsharp_params", type=str, default=None, - help="unsharp mask parameters for Gradual Latent: ksize, sigma, strength. `3,0.5,0.5` is recommended /" - + " Gradual Latentのunsharp maskのパラメータ: ksize, sigma, strength. `3,0.5,0.5` が推奨", + help="unsharp mask parameters for Gradual Latent: ksize, sigma, strength, target-x (1 means True). `3,0.5,0.5,1` or `3,1.0,1.0,0` is recommended /" + + " Gradual Latentのunsharp maskのパラメータ: ksize, sigma, strength, target-x. `3,0.5,0.5,1` または `3,1.0,1.0,0` が推奨", ) return parser diff --git a/library/utils.py b/library/utils.py index 1ccb141c8..48fb40e2d 100644 --- a/library/utils.py +++ b/library/utils.py @@ -28,6 +28,7 @@ def __init__( gaussian_blur_ksize=None, gaussian_blur_sigma=0.5, gaussian_blur_strength=0.5, + unsharp_target_x=True, ): self.ratio = ratio self.start_timesteps = start_timesteps @@ -37,12 +38,14 @@ def __init__( self.gaussian_blur_ksize = gaussian_blur_ksize self.gaussian_blur_sigma = gaussian_blur_sigma self.gaussian_blur_strength = gaussian_blur_strength + self.unsharp_target_x = unsharp_target_x def __str__(self) -> str: return ( f"GradualLatent(ratio={self.ratio}, start_timesteps={self.start_timesteps}, " + f"every_n_steps={self.every_n_steps}, ratio_step={self.ratio_step}, s_noise={self.s_noise}, " - + f"gaussian_blur_ksize={self.gaussian_blur_ksize}, gaussian_blur_sigma={self.gaussian_blur_sigma}, gaussian_blur_strength={self.gaussian_blur_strength})" + + f"gaussian_blur_ksize={self.gaussian_blur_ksize}, gaussian_blur_sigma={self.gaussian_blur_sigma}, gaussian_blur_strength={self.gaussian_blur_strength}, " + + f"unsharp_target_x={self.unsharp_target_x})" ) def apply_unshark_mask(self, x: torch.Tensor): @@ -54,6 +57,19 @@ def apply_unshark_mask(self, x: torch.Tensor): sharpened = x + mask return sharpened + def interpolate(self, x: torch.Tensor, resized_size, unsharp=True): + org_dtype = x.dtype + if org_dtype == torch.bfloat16: + x = x.float() + + x = torch.nn.functional.interpolate(x, size=resized_size, mode="bicubic", align_corners=False).to(dtype=org_dtype) + + # apply unsharp mask / アンシャープマスクを適用する + if unsharp and self.gaussian_blur_ksize: + x = self.apply_unshark_mask(x) + + return x + class EulerAncestralDiscreteSchedulerGL(EulerAncestralDiscreteScheduler): def __init__(self, *args, **kwargs): @@ -140,29 +156,25 @@ def step( dt = sigma_down - sigma - prev_sample = sample + derivative * dt - device = model_output.device if self.resized_size is None: + prev_sample = sample + derivative * dt + noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor( model_output.shape, dtype=model_output.dtype, device=device, generator=generator ) s_noise = 1.0 else: - print( - "resized_size", self.resized_size, "model_output.shape", model_output.shape, "prev_sample.shape", prev_sample.shape - ) - org_dtype = prev_sample.dtype - if org_dtype == torch.bfloat16: - prev_sample = prev_sample.float() - - prev_sample = torch.nn.functional.interpolate( - prev_sample.float(), size=self.resized_size, mode="bicubic", align_corners=False - ).to(dtype=org_dtype) + print("resized_size", self.resized_size, "model_output.shape", model_output.shape, "sample.shape", sample.shape) + s_noise = self.gradual_latent.s_noise - # apply unsharp mask / アンシャープマスクを適用する - if self.gradual_latent.gaussian_blur_ksize: - prev_sample = self.gradual_latent.apply_unshark_mask(prev_sample) + if self.gradual_latent.unsharp_target_x: + prev_sample = sample + derivative * dt + prev_sample = self.gradual_latent.interpolate(prev_sample, self.resized_size) + else: + sample = self.gradual_latent.interpolate(sample, self.resized_size) + derivative = self.gradual_latent.interpolate(derivative, self.resized_size, unsharp=False) + prev_sample = sample + derivative * dt noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor( (model_output.shape[0], model_output.shape[1], self.resized_size[0], self.resized_size[1]), @@ -170,7 +182,6 @@ def step( device=device, generator=generator, ) - s_noise = self.gradual_latent.s_noise prev_sample = prev_sample + noise * sigma_up * s_noise diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index 35d6575e9..bfe2e5124 100755 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -1829,10 +1829,12 @@ def __getattr__(self, item): # Gradual Latent if args.gradual_latent_timesteps is not None: if args.gradual_latent_unsharp_params: - us_ksize, us_sigma, us_strength = [float(v) for v in args.gradual_latent_unsharp_params.split(",")] + us_params = args.gradual_latent_unsharp_params.split(",") + us_ksize, us_sigma, us_strength = [float(v) for v in us_params[:3]] + us_target_x = True if len(us_params) <= 3 else bool(int(us_params[3])) us_ksize = int(us_ksize) else: - us_ksize, us_sigma, us_strength = None, None, None + us_ksize, us_sigma, us_strength, us_target_x = None, None, None, None gradual_latent = GradualLatent( args.gradual_latent_ratio, @@ -1843,6 +1845,7 @@ def __getattr__(self, item): us_ksize, us_sigma, us_strength, + us_target_x, ) pipe.set_gradual_latent(gradual_latent) @@ -2650,12 +2653,22 @@ def scale_and_round(x): if gl_timesteps < 0: gl_timesteps = args.gradual_latent_timesteps or 650 if gl_unsharp_params is not None: - us_ksize, us_sigma, us_strength = [float(v) for v in gl_unsharp_params.split(",")] + unsharp_params = gl_unsharp_params.split(",") + us_ksize, us_sigma, us_strength = [float(v) for v in unsharp_params[:3]] + us_target_x = True if len(unsharp_params) < 4 else bool(int(unsharp_params[3])) us_ksize = int(us_ksize) else: - us_ksize, us_sigma, us_strength = None, None, None + us_ksize, us_sigma, us_strength, us_target_x = None, None, None, None gradual_latent = GradualLatent( - gl_ratio, gl_timesteps, gl_every_n_steps, gl_ratio_step, gl_s_noise, us_ksize, us_sigma, us_strength + gl_ratio, + gl_timesteps, + gl_every_n_steps, + gl_ratio_step, + gl_s_noise, + us_ksize, + us_sigma, + us_strength, + us_target_x, ) pipe.set_gradual_latent(gradual_latent) @@ -3056,8 +3069,8 @@ def setup_parser() -> argparse.ArgumentParser: "--gradual_latent_unsharp_params", type=str, default=None, - help="unsharp mask parameters for Gradual Latent: ksize, sigma, strength. `3,0.5,0.5` is recommended /" - + " Gradual Latentのunsharp maskのパラメータ: ksize, sigma, strength. `3,0.5,0.5` が推奨", + help="unsharp mask parameters for Gradual Latent: ksize, sigma, strength, target-x (1 means True). `3,0.5,0.5,1` or `3,1.0,1.0,0` is recommended /" + + " Gradual Latentのunsharp maskのパラメータ: ksize, sigma, strength, target-x. `3,0.5,0.5,1` または `3,1.0,1.0,0` が推奨", ) # # parser.add_argument( From 07ef03d3403d56a05a87afdabbe6e8964cc05a9d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 12 Dec 2023 08:03:27 +0900 Subject: [PATCH 008/103] fix controlnet to work with gradual latent --- tools/original_control_net.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/tools/original_control_net.py b/tools/original_control_net.py index cd47bd76a..5e1c074eb 100644 --- a/tools/original_control_net.py +++ b/tools/original_control_net.py @@ -179,8 +179,21 @@ def call_unet_and_control_net( return original_unet(sample, timestep, encoder_hidden_states) guided_hint = guided_hints[cnet_idx] + + # gradual latent support: match the size of guided_hint to the size of sample + if guided_hint.shape[-2:] != sample.shape[-2:]: + # print(f"guided_hint.shape={guided_hint.shape}, sample.shape={sample.shape}") + org_dtype = guided_hint.dtype + if org_dtype == torch.bfloat16: + guided_hint = guided_hint.to(torch.float32) + guided_hint = torch.nn.functional.interpolate(guided_hint, size=sample.shape[-2:], mode="bicubic") + if org_dtype == torch.bfloat16: + guided_hint = guided_hint.to(org_dtype) + guided_hint = guided_hint.repeat((num_latent_input, 1, 1, 1)) - outs = unet_forward(True, cnet_info.net, cnet_info.unet, guided_hint, None, sample, timestep, encoder_hidden_states_for_control_net) + outs = unet_forward( + True, cnet_info.net, cnet_info.unet, guided_hint, None, sample, timestep, encoder_hidden_states_for_control_net + ) outs = [o * cnet_info.weight for o in outs] # U-Net From d61ecb26fd0251a6cba313cb05771e4e2cee1bfc Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 12 Dec 2023 08:20:36 +0900 Subject: [PATCH 009/103] enable comment in prompt file, record raw prompt to metadata --- gen_img_diffusers.py | 23 +++++++++++++++++------ sdxl_gen_img.py | 21 +++++++++++++++------ 2 files changed, 32 insertions(+), 12 deletions(-) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 74f3852de..b20e71794 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -2184,6 +2184,7 @@ class BatchDataBase(NamedTuple): mask_image: Any clip_prompt: str guide_image: Any + raw_prompt: str class BatchDataExt(NamedTuple): @@ -2710,7 +2711,7 @@ def __getattr__(self, item): print(f"reading prompts from {args.from_file}") with open(args.from_file, "r", encoding="utf-8") as f: prompt_list = f.read().splitlines() - prompt_list = [d for d in prompt_list if len(d.strip()) > 0] + prompt_list = [d for d in prompt_list if len(d.strip()) > 0 and d[0] != "#"] elif args.prompt is not None: prompt_list = [args.prompt] else: @@ -2954,13 +2955,14 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): # このバッチの情報を取り出す ( return_latents, - (step_first, _, _, _, init_image, mask_image, _, guide_image), + (step_first, _, _, _, init_image, mask_image, _, guide_image, _), (width, height, steps, scale, negative_scale, strength, network_muls, num_sub_prompts), ) = batch[0] noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR) prompts = [] negative_prompts = [] + raw_prompts = [] start_code = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) noises = [ torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) @@ -2991,11 +2993,16 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): all_images_are_same = True all_masks_are_same = True all_guide_images_are_same = True - for i, (_, (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), _) in enumerate(batch): + for i, ( + _, + (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt), + _, + ) in enumerate(batch): prompts.append(prompt) negative_prompts.append(negative_prompt) seeds.append(seed) clip_prompts.append(clip_prompt) + raw_prompts.append(raw_prompt) if init_image is not None: init_images.append(init_image) @@ -3087,8 +3094,8 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): # save image highres_prefix = ("0" if highres_1st else "1") if highres_fix else "" ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) - for i, (image, prompt, negative_prompts, seed, clip_prompt) in enumerate( - zip(images, prompts, negative_prompts, seeds, clip_prompts) + for i, (image, prompt, negative_prompts, seed, clip_prompt, raw_prompt) in enumerate( + zip(images, prompts, negative_prompts, seeds, clip_prompts, raw_prompts) ): if highres_fix: seed -= 1 # record original seed @@ -3104,6 +3111,8 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): metadata.add_text("negative-scale", str(negative_scale)) if clip_prompt is not None: metadata.add_text("clip-prompt", clip_prompt) + if raw_prompt is not None: + metadata.add_text("raw-prompt", raw_prompt) if args.use_original_file_name and init_images is not None: if type(init_images) is list: @@ -3438,7 +3447,9 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): b1 = BatchData( False, - BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), + BatchDataBase( + global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt + ), BatchDataExt( width, height, diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index 297268219..37109e50d 100755 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -1449,6 +1449,7 @@ class BatchDataBase(NamedTuple): mask_image: Any clip_prompt: str guide_image: Any + raw_prompt: str class BatchDataExt(NamedTuple): @@ -1918,7 +1919,7 @@ def __getattr__(self, item): print(f"reading prompts from {args.from_file}") with open(args.from_file, "r", encoding="utf-8") as f: prompt_list = f.read().splitlines() - prompt_list = [d for d in prompt_list if len(d.strip()) > 0] + prompt_list = [d for d in prompt_list if len(d.strip()) > 0 and d[0] != "#"] elif args.prompt is not None: prompt_list = [args.prompt] else: @@ -2190,7 +2191,7 @@ def scale_and_round(x): # このバッチの情報を取り出す ( return_latents, - (step_first, _, _, _, init_image, mask_image, _, guide_image), + (step_first, _, _, _, init_image, mask_image, _, guide_image, _), ( width, height, @@ -2212,6 +2213,7 @@ def scale_and_round(x): prompts = [] negative_prompts = [] + raw_prompts = [] start_code = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) noises = [ torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) @@ -2242,11 +2244,16 @@ def scale_and_round(x): all_images_are_same = True all_masks_are_same = True all_guide_images_are_same = True - for i, (_, (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), _) in enumerate(batch): + for i, ( + _, + (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt), + _, + ) in enumerate(batch): prompts.append(prompt) negative_prompts.append(negative_prompt) seeds.append(seed) clip_prompts.append(clip_prompt) + raw_prompts.append(raw_prompt) if init_image is not None: init_images.append(init_image) @@ -2344,8 +2351,8 @@ def scale_and_round(x): # save image highres_prefix = ("0" if highres_1st else "1") if highres_fix else "" ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) - for i, (image, prompt, negative_prompts, seed, clip_prompt) in enumerate( - zip(images, prompts, negative_prompts, seeds, clip_prompts) + for i, (image, prompt, negative_prompts, seed, clip_prompt, raw_prompt) in enumerate( + zip(images, prompts, negative_prompts, seeds, clip_prompts, raw_prompts ) ): if highres_fix: seed -= 1 # record original seed @@ -2361,6 +2368,8 @@ def scale_and_round(x): metadata.add_text("negative-scale", str(negative_scale)) if clip_prompt is not None: metadata.add_text("clip-prompt", clip_prompt) + if raw_prompt is not None: + metadata.add_text("raw-prompt", raw_prompt) metadata.add_text("original-height", str(original_height)) metadata.add_text("original-width", str(original_width)) metadata.add_text("original-height-negative", str(original_height_negative)) @@ -2736,7 +2745,7 @@ def scale_and_round(x): b1 = BatchData( False, - BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), + BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt), BatchDataExt( width, height, From 976d092c689b86e2e9f0ff76f8db24ce869ec33f Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 17 Jan 2024 21:31:50 +0900 Subject: [PATCH 010/103] fix text encodes are on gpu even when not trained --- sdxl_train_network.py | 4 ++-- train_network.py | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/sdxl_train_network.py b/sdxl_train_network.py index a35779d00..d810ce7d4 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -95,8 +95,8 @@ def cache_text_encoder_outputs_if_needed( unet.to(org_unet_device) else: # Text Encoderから毎回出力を取得するので、GPUに乗せておく - text_encoders[0].to(accelerator.device) - text_encoders[1].to(accelerator.device) + text_encoders[0].to(accelerator.device, dtype=weight_dtype) + text_encoders[1].to(accelerator.device, dtype=weight_dtype) def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype): if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: diff --git a/train_network.py b/train_network.py index a75299cda..c2b7fbdef 100644 --- a/train_network.py +++ b/train_network.py @@ -117,7 +117,7 @@ def cache_text_encoder_outputs_if_needed( self, args, accelerator, unet, vae, tokenizers, text_encoders, data_loader, weight_dtype ): for t_enc in text_encoders: - t_enc.to(accelerator.device) + t_enc.to(accelerator.device, dtype=weight_dtype) def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype): input_ids = batch["input_ids"].to(accelerator.device) @@ -278,6 +278,7 @@ def train(self, args): accelerator.wait_for_everyone() # 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される + # cache text encoder outputs if needed: Text Encoder is moved to cpu or gpu self.cache_text_encoder_outputs_if_needed( args, accelerator, unet, vae, tokenizers, text_encoders, train_dataset_group, weight_dtype ) @@ -394,8 +395,7 @@ def train(self, args): for t_enc in text_encoders: t_enc.requires_grad_(False) - # acceleratorがなんかよろしくやってくれるらしい - # TODO めちゃくちゃ冗長なのでコードを整理する + # acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good if train_unet: unet = accelerator.prepare(unet) else: @@ -407,8 +407,8 @@ def train(self, args): text_encoder = accelerator.prepare(text_encoder) text_encoders = [text_encoder] else: - for t_enc in text_encoders: - t_enc.to(accelerator.device, dtype=weight_dtype) + pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set + network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler) if args.gradient_checkpointing: @@ -685,7 +685,7 @@ def train(self, args): if accelerator.is_main_process: init_kwargs = {} if args.wandb_run_name: - init_kwargs['wandb'] = {'name': args.wandb_run_name} + init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( From 987d4a969dad5807b9a9bea445698fd1718414be Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 17 Jan 2024 21:38:49 +0900 Subject: [PATCH 011/103] update readme --- README.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/README.md b/README.md index 03ff5ab11..c62f9ec59 100644 --- a/README.md +++ b/README.md @@ -249,6 +249,16 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum ## Change History +### Jan 17, 2024 / 2024/1/17: v0.8.1 + +- Fixed a bug that the VRAM usage without Text Encoder training is larger than before in training scripts for LoRA etc (`train_network.py`, `sdxl_train_network.py`). + - Text Encoders were not moved to CPU. +- Fixed typos. Thanks to akx! [PR #1053](https://github.com/kohya-ss/sd-scripts/pull/1053) + +- LoRA 等の学習スクリプト(`train_network.py`、`sdxl_train_network.py`)で、Text Encoder を学習しない場合の VRAM 使用量が以前に比べて大きくなっていた不具合を修正しました。 + - Text Encoder が GPU に保持されたままになっていました。 +- 誤字が修正されました。 [PR #1053](https://github.com/kohya-ss/sd-scripts/pull/1053) akx 氏に感謝します。 + ### Jan 15, 2024 / 2024/1/15: v0.8.0 - Diffusers, Accelerate, Transformers and other related libraries have been updated. Please update the libraries with [Upgrade](#upgrade). From 6f3f701d3d6a4c1386ae770a3bbe997a14bbdef6 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Thu, 18 Jan 2024 18:23:21 +0200 Subject: [PATCH 012/103] Deduplicate ipex initialization code --- XTI_hijack.py | 10 +++------- fine_tune.py | 9 ++------- gen_img_diffusers.py | 9 ++------- library/ipex_interop.py | 24 ++++++++++++++++++++++++ library/model_util.py | 10 ++-------- sdxl_gen_img.py | 9 ++------- sdxl_minimal_inference.py | 12 +++++------- sdxl_train.py | 9 ++------- sdxl_train_control_net_lllite.py | 12 +++++------- sdxl_train_control_net_lllite_old.py | 12 +++++------- sdxl_train_network.py | 9 ++------- sdxl_train_textual_inversion.py | 10 +++------- train_controlnet.py | 9 ++------- train_db.py | 9 ++------- train_network.py | 9 ++------- train_textual_inversion.py | 9 ++------- train_textual_inversion_XTI.py | 12 +++++------- 17 files changed, 70 insertions(+), 113 deletions(-) create mode 100644 library/ipex_interop.py diff --git a/XTI_hijack.py b/XTI_hijack.py index ec0849455..1dbc263ac 100644 --- a/XTI_hijack.py +++ b/XTI_hijack.py @@ -1,11 +1,7 @@ import torch -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() -except Exception: - pass +from library.ipex_interop import init_ipex + +init_ipex() from typing import Union, List, Optional, Dict, Any, Tuple from diffusers.models.unet_2d_condition import UNet2DConditionOutput diff --git a/fine_tune.py b/fine_tune.py index be61b3d16..982dc8aec 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -11,15 +11,10 @@ from tqdm import tqdm import torch -try: - import intel_extension_for_pytorch as ipex +from library.ipex_interop import init_ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init +init_ipex() - ipex_init() -except Exception: - pass from accelerate.utils import set_seed from diffusers import DDPMScheduler diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index be43847a6..a207ad5a1 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -66,15 +66,10 @@ import numpy as np import torch -try: - import intel_extension_for_pytorch as ipex +from library.ipex_interop import init_ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init +init_ipex() - ipex_init() -except Exception: - pass import torchvision from diffusers import ( AutoencoderKL, diff --git a/library/ipex_interop.py b/library/ipex_interop.py new file mode 100644 index 000000000..6fe320c57 --- /dev/null +++ b/library/ipex_interop.py @@ -0,0 +1,24 @@ +import torch + + +def init_ipex(): + """ + Try to import `intel_extension_for_pytorch`, and apply + the hijacks using `library.ipex.ipex_init`. + + If IPEX is not installed, this function does nothing. + """ + try: + import intel_extension_for_pytorch as ipex # noqa + except ImportError: + return + + try: + from library.ipex import ipex_init + + if torch.xpu.is_available(): + is_initialized, error_message = ipex_init() + if not is_initialized: + print("failed to initialize ipex:", error_message) + except Exception as e: + print("failed to initialize ipex:", e) diff --git a/library/model_util.py b/library/model_util.py index 1f40ce324..4361b4994 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -5,15 +5,9 @@ import os import torch -try: - import intel_extension_for_pytorch as ipex +from library.ipex_interop import init_ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - - ipex_init() -except Exception: - pass +init_ipex() import diffusers from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index ab5399842..0db9e340e 100755 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -18,15 +18,10 @@ import numpy as np import torch -try: - import intel_extension_for_pytorch as ipex +from library.ipex_interop import init_ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init +init_ipex() - ipex_init() -except Exception: - pass import torchvision from diffusers import ( AutoencoderKL, diff --git a/sdxl_minimal_inference.py b/sdxl_minimal_inference.py index 45b9edd65..15a70678f 100644 --- a/sdxl_minimal_inference.py +++ b/sdxl_minimal_inference.py @@ -9,13 +9,11 @@ from einops import repeat import numpy as np import torch -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() -except Exception: - pass + +from library.ipex_interop import init_ipex + +init_ipex() + from tqdm import tqdm from transformers import CLIPTokenizer from diffusers import EulerDiscreteScheduler diff --git a/sdxl_train.py b/sdxl_train.py index b4ce2770e..a3f6f3a17 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -11,15 +11,10 @@ from tqdm import tqdm import torch -try: - import intel_extension_for_pytorch as ipex +from library.ipex_interop import init_ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init +init_ipex() - ipex_init() -except Exception: - pass from accelerate.utils import set_seed from diffusers import DDPMScheduler from library import sdxl_model_util diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 4436dd3cd..7a88feb84 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -14,13 +14,11 @@ from tqdm import tqdm import torch -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() -except Exception: - pass + +from library.ipex_interop import init_ipex + +init_ipex() + from torch.nn.parallel import DistributedDataParallel as DDP from accelerate.utils import set_seed import accelerate diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 6ae5377ba..b94bf5c1b 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -11,13 +11,11 @@ from tqdm import tqdm import torch -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() -except Exception: - pass + +from library.ipex_interop import init_ipex + +init_ipex() + from torch.nn.parallel import DistributedDataParallel as DDP from accelerate.utils import set_seed from diffusers import DDPMScheduler, ControlNetModel diff --git a/sdxl_train_network.py b/sdxl_train_network.py index d810ce7d4..5d363280d 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -1,15 +1,10 @@ import argparse import torch -try: - import intel_extension_for_pytorch as ipex +from library.ipex_interop import init_ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init +init_ipex() - ipex_init() -except Exception: - pass from library import sdxl_model_util, sdxl_train_util, train_util import train_network diff --git a/sdxl_train_textual_inversion.py b/sdxl_train_textual_inversion.py index f8a1d7bce..df3937135 100644 --- a/sdxl_train_textual_inversion.py +++ b/sdxl_train_textual_inversion.py @@ -3,13 +3,9 @@ import regex import torch -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() -except Exception: - pass +from library.ipex_interop import init_ipex + +init_ipex() import open_clip from library import sdxl_model_util, sdxl_train_util, train_util diff --git a/train_controlnet.py b/train_controlnet.py index cc0eaab7a..7b0b2bbfe 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -12,15 +12,10 @@ from tqdm import tqdm import torch -try: - import intel_extension_for_pytorch as ipex +from library.ipex_interop import init_ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init +init_ipex() - ipex_init() -except Exception: - pass from torch.nn.parallel import DistributedDataParallel as DDP from accelerate.utils import set_seed from diffusers import DDPMScheduler, ControlNetModel diff --git a/train_db.py b/train_db.py index 14d9dff13..888cad25e 100644 --- a/train_db.py +++ b/train_db.py @@ -12,15 +12,10 @@ from tqdm import tqdm import torch -try: - import intel_extension_for_pytorch as ipex +from library.ipex_interop import init_ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init +init_ipex() - ipex_init() -except Exception: - pass from accelerate.utils import set_seed from diffusers import DDPMScheduler diff --git a/train_network.py b/train_network.py index c2b7fbdef..e1ff20c33 100644 --- a/train_network.py +++ b/train_network.py @@ -14,15 +14,10 @@ import torch from torch.nn.parallel import DistributedDataParallel as DDP -try: - import intel_extension_for_pytorch as ipex +from library.ipex_interop import init_ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init +init_ipex() - ipex_init() -except Exception: - pass from accelerate.utils import set_seed from diffusers import DDPMScheduler from library import model_util diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 0e3912b1d..ccf7596d7 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -8,15 +8,10 @@ from tqdm import tqdm import torch -try: - import intel_extension_for_pytorch as ipex +from library.ipex_interop import init_ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init +init_ipex() - ipex_init() -except Exception: - pass from accelerate.utils import set_seed from diffusers import DDPMScheduler from transformers import CLIPTokenizer diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 71b43549d..7046a4808 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -8,13 +8,11 @@ from tqdm import tqdm import torch -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() -except Exception: - pass + +from library.ipex_interop import init_ipex + +init_ipex() + from accelerate.utils import set_seed import diffusers from diffusers import DDPMScheduler From 9cfa68c92fc45fe6e78745d1ec5544b3d7b5cf5f Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sat, 20 Jan 2024 08:46:53 +0800 Subject: [PATCH 013/103] [Experimental Feature] FP8 weight dtype for base model when running train_network (or sdxl_train_network) (#1057) * Add fp8 support * remove some debug prints * Better implementation for te * Fix some misunderstanding * as same as unet, add explicit convert * better impl for convert TE to fp8 * fp8 for not only unet * Better cache TE and TE lr * match arg name * Fix with list * Add timeout settings * Fix arg style * Add custom seperator * Fix typo * Fix typo again * Fix dtype error * Fix gradient problem * Fix req grad * fix merge * Fix merge * Resolve merge * arrangement and document * Resolve merge error * Add assert for mixed precision --- library/train_util.py | 3 +++ train_network.py | 36 ++++++++++++++++++++++++++++++------ 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index ff161feab..21e7638da 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2904,6 +2904,9 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument( "--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する" ) # TODO move to SDXL training, because it is not supported by SD1/2 + parser.add_argument( + "--fp8_base", action="store_true", help="use fp8 for base model / base modelにfp8を使う" + ) parser.add_argument( "--ddp_timeout", type=int, diff --git a/train_network.py b/train_network.py index c2b7fbdef..5f28a5e0d 100644 --- a/train_network.py +++ b/train_network.py @@ -390,16 +390,36 @@ def train(self, args): accelerator.print("enable full bf16 training.") network.to(weight_dtype) + unet_weight_dtype = te_weight_dtype = weight_dtype + # Experimental Feature: Put base model into fp8 to save vram + if args.fp8_base: + assert ( + torch.__version__ >= '2.1.0' + ), "fp8_base requires torch>=2.1.0 / fp8を使う場合はtorch>=2.1.0が必要です。" + assert ( + args.mixed_precision != 'no' + ), "fp8_base requires mixed precision='fp16' or 'bf16' / fp8を使う場合はmixed_precision='fp16'または'bf16'が必要です。" + accelerator.print("enable fp8 training.") + unet_weight_dtype = torch.float8_e4m3fn + te_weight_dtype = torch.float8_e4m3fn + unet.requires_grad_(False) - unet.to(dtype=weight_dtype) + unet.to(dtype=unet_weight_dtype) for t_enc in text_encoders: t_enc.requires_grad_(False) + t_enc.to(dtype=te_weight_dtype) + # nn.Embedding not support FP8 + t_enc.text_model.embeddings.to(dtype=( + weight_dtype + if te_weight_dtype == torch.float8_e4m3fn + else te_weight_dtype + )) # acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good if train_unet: unet = accelerator.prepare(unet) else: - unet.to(accelerator.device, dtype=weight_dtype) # move to device because unet is not prepared by accelerator + unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator if train_text_encoder: if len(text_encoders) > 1: text_encoder = text_encoders = [accelerator.prepare(t_enc) for t_enc in text_encoders] @@ -421,9 +441,6 @@ def train(self, args): if train_text_encoder: t_enc.text_model.embeddings.requires_grad_(True) - # set top parameter requires_grad = True for gradient checkpointing works - if not train_text_encoder: # train U-Net only - unet.parameters().__next__().requires_grad_(True) else: unet.eval() for t_enc in text_encoders: @@ -778,10 +795,17 @@ def remove_model(old_ckpt_name): args, noise_scheduler, latents ) + # ensure the hidden state will require grad + if args.gradient_checkpointing: + for x in noisy_latents: + x.requires_grad_(True) + for t in text_encoder_conds: + t.requires_grad_(True) + # Predict the noise residual with accelerator.autocast(): noise_pred = self.call_unet( - args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype + args, accelerator, unet, noisy_latents.requires_grad_(train_unet), timesteps, text_encoder_conds, batch, weight_dtype ) if args.v_parameterization: From a7ef6422b658660dd4c4685397f9b56e24996eb2 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 20 Jan 2024 10:00:30 +0900 Subject: [PATCH 014/103] fix to work with torch 2.0 --- train_network.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/train_network.py b/train_network.py index 5f28a5e0d..b1291ed10 100644 --- a/train_network.py +++ b/train_network.py @@ -393,11 +393,9 @@ def train(self, args): unet_weight_dtype = te_weight_dtype = weight_dtype # Experimental Feature: Put base model into fp8 to save vram if args.fp8_base: + assert torch.__version__ >= "2.1.0", "fp8_base requires torch>=2.1.0 / fp8を使う場合はtorch>=2.1.0が必要です。" assert ( - torch.__version__ >= '2.1.0' - ), "fp8_base requires torch>=2.1.0 / fp8を使う場合はtorch>=2.1.0が必要です。" - assert ( - args.mixed_precision != 'no' + args.mixed_precision != "no" ), "fp8_base requires mixed precision='fp16' or 'bf16' / fp8を使う場合はmixed_precision='fp16'または'bf16'が必要です。" accelerator.print("enable fp8 training.") unet_weight_dtype = torch.float8_e4m3fn @@ -407,13 +405,12 @@ def train(self, args): unet.to(dtype=unet_weight_dtype) for t_enc in text_encoders: t_enc.requires_grad_(False) - t_enc.to(dtype=te_weight_dtype) - # nn.Embedding not support FP8 - t_enc.text_model.embeddings.to(dtype=( - weight_dtype - if te_weight_dtype == torch.float8_e4m3fn - else te_weight_dtype - )) + + # in case of cpu, dtype is already set to fp32 because cpu does not support fp8/fp16/bf16 + if t_enc.device.type != "cpu": + t_enc.to(dtype=te_weight_dtype) + # nn.Embedding not support FP8 + t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype)) # acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good if train_unet: @@ -805,7 +802,14 @@ def remove_model(old_ckpt_name): # Predict the noise residual with accelerator.autocast(): noise_pred = self.call_unet( - args, accelerator, unet, noisy_latents.requires_grad_(train_unet), timesteps, text_encoder_conds, batch, weight_dtype + args, + accelerator, + unet, + noisy_latents.requires_grad_(train_unet), + timesteps, + text_encoder_conds, + batch, + weight_dtype, ) if args.v_parameterization: From 1f77bb6e73573d0bde67b94c76b549590833ece9 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 20 Jan 2024 10:57:42 +0900 Subject: [PATCH 015/103] fix to work sample generation in fp8 ref #1057 --- library/sdxl_lpw_stable_diffusion.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/library/sdxl_lpw_stable_diffusion.py b/library/sdxl_lpw_stable_diffusion.py index e03ee4056..0562e88ad 100644 --- a/library/sdxl_lpw_stable_diffusion.py +++ b/library/sdxl_lpw_stable_diffusion.py @@ -923,7 +923,11 @@ def __call__( if up1 is not None: uncond_pool = up1 - dtype = self.unet.dtype + unet_dtype = self.unet.dtype + dtype = unet_dtype + if dtype.itemsize == 1: # fp8 + dtype = torch.float16 + self.unet.to(dtype) # 4. Preprocess image and mask if isinstance(image, PIL.Image.Image): @@ -1028,6 +1032,7 @@ def __call__( if is_cancelled_callback is not None and is_cancelled_callback(): return None + self.unet.to(unet_dtype) return latents def latents_to_image(self, latents): From 2a0f45aea951d8d2649917938f216705e4fc632a Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 20 Jan 2024 11:08:20 +0900 Subject: [PATCH 016/103] update readme --- README.md | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/README.md b/README.md index a505c0b34..663f52e86 100644 --- a/README.md +++ b/README.md @@ -249,6 +249,20 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum ## Change History +### 作業中の内容 / Work in progress + +- [Experimental] The `--fp8_base` option is added to the training scripts for LoRA etc. The base model (U-Net, and Text Encoder when training modules for Text Encoder) can be trained with fp8. PR [#1057](https://github.com/kohya-ss/sd-scripts/pull/1057) Thanks to KohakuBlueleaf! + - Please specify `--fp8_base` in `train_network.py` or `sdxl_train_network.py`. + - PyTorch 2.1 or later is required. + - If you use xformers with PyTorch 2.1, please see [xformers repository](https://github.com/facebookresearch/xformers) and install the appropriate version according to your CUDA version. + - The sample image generation during training consumes a lot of memory. It is recommended to turn it off. + +- (実験的) LoRA等の学習スクリプトで、ベースモデル(U-Net、および Text Encoder のモジュール学習時は Text Encoder も)の重みを fp8 にして学習するオプションが追加されました。 PR [#1057](https://github.com/kohya-ss/sd-scripts/pull/1057) KohakuBlueleaf 氏に感謝します。 + - `train_network.py` または `sdxl_train_network.py` で `--fp8_base` を指定してください。 + - PyTorch 2.1 以降が必要です。 + - PyTorch 2.1 で xformers を使用する場合は、[xformers のリポジトリ](https://github.com/facebookresearch/xformers) を参照し、CUDA バージョンに応じて適切なバージョンをインストールしてください。 + - 学習中のサンプル画像生成はメモリを大量に消費するため、オフにすることをお勧めします。 + ### Jan 17, 2024 / 2024/1/17: v0.8.1 - Fixed a bug that the VRAM usage without Text Encoder training is larger than before in training scripts for LoRA etc (`train_network.py`, `sdxl_train_network.py`). From 5a1ebc4c7c21600ac773c22b7f1b7d20d383e318 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 20 Jan 2024 13:10:45 +0900 Subject: [PATCH 017/103] format by black --- library/config_util.py | 1035 +++++++++++++++++++++------------------- 1 file changed, 549 insertions(+), 486 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index 47868f3ba..716cecff8 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -1,462 +1,496 @@ import argparse from dataclasses import ( - asdict, - dataclass, + asdict, + dataclass, ) import functools import random from textwrap import dedent, indent import json from pathlib import Path + # from toolz import curry from typing import ( - List, - Optional, - Sequence, - Tuple, - Union, + List, + Optional, + Sequence, + Tuple, + Union, ) import toml import voluptuous from voluptuous import ( - Any, - ExactSequence, - MultipleInvalid, - Object, - Required, - Schema, + Any, + ExactSequence, + MultipleInvalid, + Object, + Required, + Schema, ) from transformers import CLIPTokenizer from . import train_util from .train_util import ( - DreamBoothSubset, - FineTuningSubset, - ControlNetSubset, - DreamBoothDataset, - FineTuningDataset, - ControlNetDataset, - DatasetGroup, + DreamBoothSubset, + FineTuningSubset, + ControlNetSubset, + DreamBoothDataset, + FineTuningDataset, + ControlNetDataset, + DatasetGroup, ) def add_config_arguments(parser: argparse.ArgumentParser): - parser.add_argument("--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル") + parser.add_argument("--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル") + # TODO: inherit Params class in Subset, Dataset + @dataclass class BaseSubsetParams: - image_dir: Optional[str] = None - num_repeats: int = 1 - shuffle_caption: bool = False - caption_separator: str = ',', - keep_tokens: int = 0 - keep_tokens_separator: str = None, - color_aug: bool = False - flip_aug: bool = False - face_crop_aug_range: Optional[Tuple[float, float]] = None - random_crop: bool = False - caption_prefix: Optional[str] = None - caption_suffix: Optional[str] = None - caption_dropout_rate: float = 0.0 - caption_dropout_every_n_epochs: int = 0 - caption_tag_dropout_rate: float = 0.0 - token_warmup_min: int = 1 - token_warmup_step: float = 0 + image_dir: Optional[str] = None + num_repeats: int = 1 + shuffle_caption: bool = False + caption_separator: str = (",",) + keep_tokens: int = 0 + keep_tokens_separator: str = (None,) + color_aug: bool = False + flip_aug: bool = False + face_crop_aug_range: Optional[Tuple[float, float]] = None + random_crop: bool = False + caption_prefix: Optional[str] = None + caption_suffix: Optional[str] = None + caption_dropout_rate: float = 0.0 + caption_dropout_every_n_epochs: int = 0 + caption_tag_dropout_rate: float = 0.0 + token_warmup_min: int = 1 + token_warmup_step: float = 0 + @dataclass class DreamBoothSubsetParams(BaseSubsetParams): - is_reg: bool = False - class_tokens: Optional[str] = None - caption_extension: str = ".caption" + is_reg: bool = False + class_tokens: Optional[str] = None + caption_extension: str = ".caption" + @dataclass class FineTuningSubsetParams(BaseSubsetParams): - metadata_file: Optional[str] = None + metadata_file: Optional[str] = None + @dataclass class ControlNetSubsetParams(BaseSubsetParams): - conditioning_data_dir: str = None - caption_extension: str = ".caption" + conditioning_data_dir: str = None + caption_extension: str = ".caption" + @dataclass class BaseDatasetParams: - tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None - max_token_length: int = None - resolution: Optional[Tuple[int, int]] = None - debug_dataset: bool = False + tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None + max_token_length: int = None + resolution: Optional[Tuple[int, int]] = None + debug_dataset: bool = False + @dataclass class DreamBoothDatasetParams(BaseDatasetParams): - batch_size: int = 1 - enable_bucket: bool = False - min_bucket_reso: int = 256 - max_bucket_reso: int = 1024 - bucket_reso_steps: int = 64 - bucket_no_upscale: bool = False - prior_loss_weight: float = 1.0 + batch_size: int = 1 + enable_bucket: bool = False + min_bucket_reso: int = 256 + max_bucket_reso: int = 1024 + bucket_reso_steps: int = 64 + bucket_no_upscale: bool = False + prior_loss_weight: float = 1.0 + @dataclass class FineTuningDatasetParams(BaseDatasetParams): - batch_size: int = 1 - enable_bucket: bool = False - min_bucket_reso: int = 256 - max_bucket_reso: int = 1024 - bucket_reso_steps: int = 64 - bucket_no_upscale: bool = False + batch_size: int = 1 + enable_bucket: bool = False + min_bucket_reso: int = 256 + max_bucket_reso: int = 1024 + bucket_reso_steps: int = 64 + bucket_no_upscale: bool = False + @dataclass class ControlNetDatasetParams(BaseDatasetParams): - batch_size: int = 1 - enable_bucket: bool = False - min_bucket_reso: int = 256 - max_bucket_reso: int = 1024 - bucket_reso_steps: int = 64 - bucket_no_upscale: bool = False + batch_size: int = 1 + enable_bucket: bool = False + min_bucket_reso: int = 256 + max_bucket_reso: int = 1024 + bucket_reso_steps: int = 64 + bucket_no_upscale: bool = False + @dataclass class SubsetBlueprint: - params: Union[DreamBoothSubsetParams, FineTuningSubsetParams] + params: Union[DreamBoothSubsetParams, FineTuningSubsetParams] + @dataclass class DatasetBlueprint: - is_dreambooth: bool - is_controlnet: bool - params: Union[DreamBoothDatasetParams, FineTuningDatasetParams] - subsets: Sequence[SubsetBlueprint] + is_dreambooth: bool + is_controlnet: bool + params: Union[DreamBoothDatasetParams, FineTuningDatasetParams] + subsets: Sequence[SubsetBlueprint] + @dataclass class DatasetGroupBlueprint: - datasets: Sequence[DatasetBlueprint] + datasets: Sequence[DatasetBlueprint] + + @dataclass class Blueprint: - dataset_group: DatasetGroupBlueprint + dataset_group: DatasetGroupBlueprint class ConfigSanitizer: - # @curry - @staticmethod - def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple: - Schema(ExactSequence([klass, klass]))(value) - return tuple(value) - - # @curry - @staticmethod - def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]) -> Tuple: - Schema(Any(klass, ExactSequence([klass, klass])))(value) - try: - Schema(klass)(value) - return (value, value) - except: - return ConfigSanitizer.__validate_and_convert_twodim(klass, value) - - # subset schema - SUBSET_ASCENDABLE_SCHEMA = { - "color_aug": bool, - "face_crop_aug_range": functools.partial(__validate_and_convert_twodim.__func__, float), - "flip_aug": bool, - "num_repeats": int, - "random_crop": bool, - "shuffle_caption": bool, - "keep_tokens": int, - "keep_tokens_separator": str, - "token_warmup_min": int, - "token_warmup_step": Any(float,int), - "caption_prefix": str, - "caption_suffix": str, - } - # DO means DropOut - DO_SUBSET_ASCENDABLE_SCHEMA = { - "caption_dropout_every_n_epochs": int, - "caption_dropout_rate": Any(float, int), - "caption_tag_dropout_rate": Any(float, int), - } - # DB means DreamBooth - DB_SUBSET_ASCENDABLE_SCHEMA = { - "caption_extension": str, - "class_tokens": str, - } - DB_SUBSET_DISTINCT_SCHEMA = { - Required("image_dir"): str, - "is_reg": bool, - } - # FT means FineTuning - FT_SUBSET_DISTINCT_SCHEMA = { - Required("metadata_file"): str, - "image_dir": str, - } - CN_SUBSET_ASCENDABLE_SCHEMA = { - "caption_extension": str, - } - CN_SUBSET_DISTINCT_SCHEMA = { - Required("image_dir"): str, - Required("conditioning_data_dir"): str, - } - - # datasets schema - DATASET_ASCENDABLE_SCHEMA = { - "batch_size": int, - "bucket_no_upscale": bool, - "bucket_reso_steps": int, - "enable_bucket": bool, - "max_bucket_reso": int, - "min_bucket_reso": int, - "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int), - } - - # options handled by argparse but not handled by user config - ARGPARSE_SPECIFIC_SCHEMA = { - "debug_dataset": bool, - "max_token_length": Any(None, int), - "prior_loss_weight": Any(float, int), - } - # for handling default None value of argparse - ARGPARSE_NULLABLE_OPTNAMES = [ - "face_crop_aug_range", - "resolution", - ] - # prepare map because option name may differ among argparse and user config - ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME = { - "train_batch_size": "batch_size", - "dataset_repeats": "num_repeats", - } - - def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_controlnet: bool, support_dropout: bool) -> None: - assert support_dreambooth or support_finetuning or support_controlnet, "Neither DreamBooth mode nor fine tuning mode specified. Please specify one mode or more. / DreamBooth モードか fine tuning モードのどちらも指定されていません。1つ以上指定してください。" - - self.db_subset_schema = self.__merge_dict( - self.SUBSET_ASCENDABLE_SCHEMA, - self.DB_SUBSET_DISTINCT_SCHEMA, - self.DB_SUBSET_ASCENDABLE_SCHEMA, - self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, - ) - - self.ft_subset_schema = self.__merge_dict( - self.SUBSET_ASCENDABLE_SCHEMA, - self.FT_SUBSET_DISTINCT_SCHEMA, - self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, - ) - - self.cn_subset_schema = self.__merge_dict( - self.SUBSET_ASCENDABLE_SCHEMA, - self.CN_SUBSET_DISTINCT_SCHEMA, - self.CN_SUBSET_ASCENDABLE_SCHEMA, - self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, - ) - - self.db_dataset_schema = self.__merge_dict( - self.DATASET_ASCENDABLE_SCHEMA, - self.SUBSET_ASCENDABLE_SCHEMA, - self.DB_SUBSET_ASCENDABLE_SCHEMA, - self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, - {"subsets": [self.db_subset_schema]}, - ) - - self.ft_dataset_schema = self.__merge_dict( - self.DATASET_ASCENDABLE_SCHEMA, - self.SUBSET_ASCENDABLE_SCHEMA, - self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, - {"subsets": [self.ft_subset_schema]}, - ) - - self.cn_dataset_schema = self.__merge_dict( - self.DATASET_ASCENDABLE_SCHEMA, - self.SUBSET_ASCENDABLE_SCHEMA, - self.CN_SUBSET_ASCENDABLE_SCHEMA, - self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, - {"subsets": [self.cn_subset_schema]}, - ) - - if support_dreambooth and support_finetuning: - def validate_flex_dataset(dataset_config: dict): - subsets_config = dataset_config.get("subsets", []) - - if support_controlnet and all(["conditioning_data_dir" in subset for subset in subsets_config]): - return Schema(self.cn_dataset_schema)(dataset_config) - # check dataset meets FT style - # NOTE: all FT subsets should have "metadata_file" - elif all(["metadata_file" in subset for subset in subsets_config]): - return Schema(self.ft_dataset_schema)(dataset_config) - # check dataset meets DB style - # NOTE: all DB subsets should have no "metadata_file" - elif all(["metadata_file" not in subset for subset in subsets_config]): - return Schema(self.db_dataset_schema)(dataset_config) - else: - raise voluptuous.Invalid("DreamBooth subset and fine tuning subset cannot be mixed in the same dataset. Please split them into separate datasets. / DreamBoothのサブセットとfine tuninのサブセットを同一のデータセットに混在させることはできません。別々のデータセットに分割してください。") - - self.dataset_schema = validate_flex_dataset - elif support_dreambooth: - self.dataset_schema = self.db_dataset_schema - elif support_finetuning: - self.dataset_schema = self.ft_dataset_schema - elif support_controlnet: - self.dataset_schema = self.cn_dataset_schema - - self.general_schema = self.__merge_dict( - self.DATASET_ASCENDABLE_SCHEMA, - self.SUBSET_ASCENDABLE_SCHEMA, - self.DB_SUBSET_ASCENDABLE_SCHEMA if support_dreambooth else {}, - self.CN_SUBSET_ASCENDABLE_SCHEMA if support_controlnet else {}, - self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, - ) - - self.user_config_validator = Schema({ - "general": self.general_schema, - "datasets": [self.dataset_schema], - }) - - self.argparse_schema = self.__merge_dict( - self.general_schema, - self.ARGPARSE_SPECIFIC_SCHEMA, - {optname: Any(None, self.general_schema[optname]) for optname in self.ARGPARSE_NULLABLE_OPTNAMES}, - {a_name: self.general_schema[c_name] for a_name, c_name in self.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME.items()}, - ) - - self.argparse_config_validator = Schema(Object(self.argparse_schema), extra=voluptuous.ALLOW_EXTRA) - - def sanitize_user_config(self, user_config: dict) -> dict: - try: - return self.user_config_validator(user_config) - except MultipleInvalid: - # TODO: エラー発生時のメッセージをわかりやすくする - print("Invalid user config / ユーザ設定の形式が正しくないようです") - raise - - # NOTE: In nature, argument parser result is not needed to be sanitize - # However this will help us to detect program bug - def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> argparse.Namespace: - try: - return self.argparse_config_validator(argparse_namespace) - except MultipleInvalid: - # XXX: this should be a bug - print("Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。") - raise - - # NOTE: value would be overwritten by latter dict if there is already the same key - @staticmethod - def __merge_dict(*dict_list: dict) -> dict: - merged = {} - for schema in dict_list: - # merged |= schema - for k, v in schema.items(): - merged[k] = v - return merged + # @curry + @staticmethod + def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple: + Schema(ExactSequence([klass, klass]))(value) + return tuple(value) + + # @curry + @staticmethod + def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]) -> Tuple: + Schema(Any(klass, ExactSequence([klass, klass])))(value) + try: + Schema(klass)(value) + return (value, value) + except: + return ConfigSanitizer.__validate_and_convert_twodim(klass, value) + + # subset schema + SUBSET_ASCENDABLE_SCHEMA = { + "color_aug": bool, + "face_crop_aug_range": functools.partial(__validate_and_convert_twodim.__func__, float), + "flip_aug": bool, + "num_repeats": int, + "random_crop": bool, + "shuffle_caption": bool, + "keep_tokens": int, + "keep_tokens_separator": str, + "token_warmup_min": int, + "token_warmup_step": Any(float, int), + "caption_prefix": str, + "caption_suffix": str, + } + # DO means DropOut + DO_SUBSET_ASCENDABLE_SCHEMA = { + "caption_dropout_every_n_epochs": int, + "caption_dropout_rate": Any(float, int), + "caption_tag_dropout_rate": Any(float, int), + } + # DB means DreamBooth + DB_SUBSET_ASCENDABLE_SCHEMA = { + "caption_extension": str, + "class_tokens": str, + } + DB_SUBSET_DISTINCT_SCHEMA = { + Required("image_dir"): str, + "is_reg": bool, + } + # FT means FineTuning + FT_SUBSET_DISTINCT_SCHEMA = { + Required("metadata_file"): str, + "image_dir": str, + } + CN_SUBSET_ASCENDABLE_SCHEMA = { + "caption_extension": str, + } + CN_SUBSET_DISTINCT_SCHEMA = { + Required("image_dir"): str, + Required("conditioning_data_dir"): str, + } + + # datasets schema + DATASET_ASCENDABLE_SCHEMA = { + "batch_size": int, + "bucket_no_upscale": bool, + "bucket_reso_steps": int, + "enable_bucket": bool, + "max_bucket_reso": int, + "min_bucket_reso": int, + "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int), + } + + # options handled by argparse but not handled by user config + ARGPARSE_SPECIFIC_SCHEMA = { + "debug_dataset": bool, + "max_token_length": Any(None, int), + "prior_loss_weight": Any(float, int), + } + # for handling default None value of argparse + ARGPARSE_NULLABLE_OPTNAMES = [ + "face_crop_aug_range", + "resolution", + ] + # prepare map because option name may differ among argparse and user config + ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME = { + "train_batch_size": "batch_size", + "dataset_repeats": "num_repeats", + } + + def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_controlnet: bool, support_dropout: bool) -> None: + assert ( + support_dreambooth or support_finetuning or support_controlnet + ), "Neither DreamBooth mode nor fine tuning mode specified. Please specify one mode or more. / DreamBooth モードか fine tuning モードのどちらも指定されていません。1つ以上指定してください。" + + self.db_subset_schema = self.__merge_dict( + self.SUBSET_ASCENDABLE_SCHEMA, + self.DB_SUBSET_DISTINCT_SCHEMA, + self.DB_SUBSET_ASCENDABLE_SCHEMA, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + ) + + self.ft_subset_schema = self.__merge_dict( + self.SUBSET_ASCENDABLE_SCHEMA, + self.FT_SUBSET_DISTINCT_SCHEMA, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + ) + + self.cn_subset_schema = self.__merge_dict( + self.SUBSET_ASCENDABLE_SCHEMA, + self.CN_SUBSET_DISTINCT_SCHEMA, + self.CN_SUBSET_ASCENDABLE_SCHEMA, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + ) + + self.db_dataset_schema = self.__merge_dict( + self.DATASET_ASCENDABLE_SCHEMA, + self.SUBSET_ASCENDABLE_SCHEMA, + self.DB_SUBSET_ASCENDABLE_SCHEMA, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + {"subsets": [self.db_subset_schema]}, + ) + + self.ft_dataset_schema = self.__merge_dict( + self.DATASET_ASCENDABLE_SCHEMA, + self.SUBSET_ASCENDABLE_SCHEMA, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + {"subsets": [self.ft_subset_schema]}, + ) + + self.cn_dataset_schema = self.__merge_dict( + self.DATASET_ASCENDABLE_SCHEMA, + self.SUBSET_ASCENDABLE_SCHEMA, + self.CN_SUBSET_ASCENDABLE_SCHEMA, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + {"subsets": [self.cn_subset_schema]}, + ) + + if support_dreambooth and support_finetuning: + + def validate_flex_dataset(dataset_config: dict): + subsets_config = dataset_config.get("subsets", []) + + if support_controlnet and all(["conditioning_data_dir" in subset for subset in subsets_config]): + return Schema(self.cn_dataset_schema)(dataset_config) + # check dataset meets FT style + # NOTE: all FT subsets should have "metadata_file" + elif all(["metadata_file" in subset for subset in subsets_config]): + return Schema(self.ft_dataset_schema)(dataset_config) + # check dataset meets DB style + # NOTE: all DB subsets should have no "metadata_file" + elif all(["metadata_file" not in subset for subset in subsets_config]): + return Schema(self.db_dataset_schema)(dataset_config) + else: + raise voluptuous.Invalid( + "DreamBooth subset and fine tuning subset cannot be mixed in the same dataset. Please split them into separate datasets. / DreamBoothのサブセットとfine tuninのサブセットを同一のデータセットに混在させることはできません。別々のデータセットに分割してください。" + ) + + self.dataset_schema = validate_flex_dataset + elif support_dreambooth: + self.dataset_schema = self.db_dataset_schema + elif support_finetuning: + self.dataset_schema = self.ft_dataset_schema + elif support_controlnet: + self.dataset_schema = self.cn_dataset_schema + + self.general_schema = self.__merge_dict( + self.DATASET_ASCENDABLE_SCHEMA, + self.SUBSET_ASCENDABLE_SCHEMA, + self.DB_SUBSET_ASCENDABLE_SCHEMA if support_dreambooth else {}, + self.CN_SUBSET_ASCENDABLE_SCHEMA if support_controlnet else {}, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + ) + + self.user_config_validator = Schema( + { + "general": self.general_schema, + "datasets": [self.dataset_schema], + } + ) + + self.argparse_schema = self.__merge_dict( + self.general_schema, + self.ARGPARSE_SPECIFIC_SCHEMA, + {optname: Any(None, self.general_schema[optname]) for optname in self.ARGPARSE_NULLABLE_OPTNAMES}, + {a_name: self.general_schema[c_name] for a_name, c_name in self.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME.items()}, + ) + + self.argparse_config_validator = Schema(Object(self.argparse_schema), extra=voluptuous.ALLOW_EXTRA) + + def sanitize_user_config(self, user_config: dict) -> dict: + try: + return self.user_config_validator(user_config) + except MultipleInvalid: + # TODO: エラー発生時のメッセージをわかりやすくする + print("Invalid user config / ユーザ設定の形式が正しくないようです") + raise + + # NOTE: In nature, argument parser result is not needed to be sanitize + # However this will help us to detect program bug + def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> argparse.Namespace: + try: + return self.argparse_config_validator(argparse_namespace) + except MultipleInvalid: + # XXX: this should be a bug + print("Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。") + raise + + # NOTE: value would be overwritten by latter dict if there is already the same key + @staticmethod + def __merge_dict(*dict_list: dict) -> dict: + merged = {} + for schema in dict_list: + # merged |= schema + for k, v in schema.items(): + merged[k] = v + return merged class BlueprintGenerator: - BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME = { - } - - def __init__(self, sanitizer: ConfigSanitizer): - self.sanitizer = sanitizer - - # runtime_params is for parameters which is only configurable on runtime, such as tokenizer - def generate(self, user_config: dict, argparse_namespace: argparse.Namespace, **runtime_params) -> Blueprint: - sanitized_user_config = self.sanitizer.sanitize_user_config(user_config) - sanitized_argparse_namespace = self.sanitizer.sanitize_argparse_namespace(argparse_namespace) - - # convert argparse namespace to dict like config - # NOTE: it is ok to have extra entries in dict - optname_map = self.sanitizer.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME - argparse_config = {optname_map.get(optname, optname): value for optname, value in vars(sanitized_argparse_namespace).items()} - - general_config = sanitized_user_config.get("general", {}) - - dataset_blueprints = [] - for dataset_config in sanitized_user_config.get("datasets", []): - # NOTE: if subsets have no "metadata_file", these are DreamBooth datasets/subsets - subsets = dataset_config.get("subsets", []) - is_dreambooth = all(["metadata_file" not in subset for subset in subsets]) - is_controlnet = all(["conditioning_data_dir" in subset for subset in subsets]) - if is_controlnet: - subset_params_klass = ControlNetSubsetParams - dataset_params_klass = ControlNetDatasetParams - elif is_dreambooth: - subset_params_klass = DreamBoothSubsetParams - dataset_params_klass = DreamBoothDatasetParams - else: - subset_params_klass = FineTuningSubsetParams - dataset_params_klass = FineTuningDatasetParams - - subset_blueprints = [] - for subset_config in subsets: - params = self.generate_params_by_fallbacks(subset_params_klass, - [subset_config, dataset_config, general_config, argparse_config, runtime_params]) - subset_blueprints.append(SubsetBlueprint(params)) - - params = self.generate_params_by_fallbacks(dataset_params_klass, - [dataset_config, general_config, argparse_config, runtime_params]) - dataset_blueprints.append(DatasetBlueprint(is_dreambooth, is_controlnet, params, subset_blueprints)) - - dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints) - - return Blueprint(dataset_group_blueprint) - - @staticmethod - def generate_params_by_fallbacks(param_klass, fallbacks: Sequence[dict]): - name_map = BlueprintGenerator.BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME - search_value = BlueprintGenerator.search_value - default_params = asdict(param_klass()) - param_names = default_params.keys() - - params = {name: search_value(name_map.get(name, name), fallbacks, default_params.get(name)) for name in param_names} - - return param_klass(**params) - - @staticmethod - def search_value(key: str, fallbacks: Sequence[dict], default_value = None): - for cand in fallbacks: - value = cand.get(key) - if value is not None: - return value - - return default_value + BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME = {} + + def __init__(self, sanitizer: ConfigSanitizer): + self.sanitizer = sanitizer + + # runtime_params is for parameters which is only configurable on runtime, such as tokenizer + def generate(self, user_config: dict, argparse_namespace: argparse.Namespace, **runtime_params) -> Blueprint: + sanitized_user_config = self.sanitizer.sanitize_user_config(user_config) + sanitized_argparse_namespace = self.sanitizer.sanitize_argparse_namespace(argparse_namespace) + + # convert argparse namespace to dict like config + # NOTE: it is ok to have extra entries in dict + optname_map = self.sanitizer.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME + argparse_config = { + optname_map.get(optname, optname): value for optname, value in vars(sanitized_argparse_namespace).items() + } + + general_config = sanitized_user_config.get("general", {}) + + dataset_blueprints = [] + for dataset_config in sanitized_user_config.get("datasets", []): + # NOTE: if subsets have no "metadata_file", these are DreamBooth datasets/subsets + subsets = dataset_config.get("subsets", []) + is_dreambooth = all(["metadata_file" not in subset for subset in subsets]) + is_controlnet = all(["conditioning_data_dir" in subset for subset in subsets]) + if is_controlnet: + subset_params_klass = ControlNetSubsetParams + dataset_params_klass = ControlNetDatasetParams + elif is_dreambooth: + subset_params_klass = DreamBoothSubsetParams + dataset_params_klass = DreamBoothDatasetParams + else: + subset_params_klass = FineTuningSubsetParams + dataset_params_klass = FineTuningDatasetParams + + subset_blueprints = [] + for subset_config in subsets: + params = self.generate_params_by_fallbacks( + subset_params_klass, [subset_config, dataset_config, general_config, argparse_config, runtime_params] + ) + subset_blueprints.append(SubsetBlueprint(params)) + + params = self.generate_params_by_fallbacks( + dataset_params_klass, [dataset_config, general_config, argparse_config, runtime_params] + ) + dataset_blueprints.append(DatasetBlueprint(is_dreambooth, is_controlnet, params, subset_blueprints)) + + dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints) + + return Blueprint(dataset_group_blueprint) + + @staticmethod + def generate_params_by_fallbacks(param_klass, fallbacks: Sequence[dict]): + name_map = BlueprintGenerator.BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME + search_value = BlueprintGenerator.search_value + default_params = asdict(param_klass()) + param_names = default_params.keys() + + params = {name: search_value(name_map.get(name, name), fallbacks, default_params.get(name)) for name in param_names} + + return param_klass(**params) + + @staticmethod + def search_value(key: str, fallbacks: Sequence[dict], default_value=None): + for cand in fallbacks: + value = cand.get(key) + if value is not None: + return value + + return default_value def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint): - datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] - - for dataset_blueprint in dataset_group_blueprint.datasets: - if dataset_blueprint.is_controlnet: - subset_klass = ControlNetSubset - dataset_klass = ControlNetDataset - elif dataset_blueprint.is_dreambooth: - subset_klass = DreamBoothSubset - dataset_klass = DreamBoothDataset - else: - subset_klass = FineTuningSubset - dataset_klass = FineTuningDataset - - subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] - dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params)) - datasets.append(dataset) - - # print info - info = "" - for i, dataset in enumerate(datasets): - is_dreambooth = isinstance(dataset, DreamBoothDataset) - is_controlnet = isinstance(dataset, ControlNetDataset) - info += dedent(f"""\ + datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] + + for dataset_blueprint in dataset_group_blueprint.datasets: + if dataset_blueprint.is_controlnet: + subset_klass = ControlNetSubset + dataset_klass = ControlNetDataset + elif dataset_blueprint.is_dreambooth: + subset_klass = DreamBoothSubset + dataset_klass = DreamBoothDataset + else: + subset_klass = FineTuningSubset + dataset_klass = FineTuningDataset + + subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] + dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params)) + datasets.append(dataset) + + # print info + info = "" + for i, dataset in enumerate(datasets): + is_dreambooth = isinstance(dataset, DreamBoothDataset) + is_controlnet = isinstance(dataset, ControlNetDataset) + info += dedent( + f"""\ [Dataset {i}] batch_size: {dataset.batch_size} resolution: {(dataset.width, dataset.height)} enable_bucket: {dataset.enable_bucket} - """) + """ + ) - if dataset.enable_bucket: - info += indent(dedent(f"""\ + if dataset.enable_bucket: + info += indent( + dedent( + f"""\ min_bucket_reso: {dataset.min_bucket_reso} max_bucket_reso: {dataset.max_bucket_reso} bucket_reso_steps: {dataset.bucket_reso_steps} bucket_no_upscale: {dataset.bucket_no_upscale} - \n"""), " ") - else: - info += "\n" + \n""" + ), + " ", + ) + else: + info += "\n" - for j, subset in enumerate(dataset.subsets): - info += indent(dedent(f"""\ + for j, subset in enumerate(dataset.subsets): + info += indent( + dedent( + f"""\ [Subset {j} of Dataset {i}] image_dir: "{subset.image_dir}" image_count: {subset.img_count} @@ -475,147 +509,176 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu random_crop: {subset.random_crop} token_warmup_min: {subset.token_warmup_min}, token_warmup_step: {subset.token_warmup_step}, - """), " ") - - if is_dreambooth: - info += indent(dedent(f"""\ + """ + ), + " ", + ) + + if is_dreambooth: + info += indent( + dedent( + f"""\ is_reg: {subset.is_reg} class_tokens: {subset.class_tokens} caption_extension: {subset.caption_extension} - \n"""), " ") - elif not is_controlnet: - info += indent(dedent(f"""\ + \n""" + ), + " ", + ) + elif not is_controlnet: + info += indent( + dedent( + f"""\ metadata_file: {subset.metadata_file} - \n"""), " ") + \n""" + ), + " ", + ) - print(info) + print(info) - # make buckets first because it determines the length of dataset - # and set the same seed for all datasets - seed = random.randint(0, 2**31) # actual seed is seed + epoch_no - for i, dataset in enumerate(datasets): - print(f"[Dataset {i}]") - dataset.make_buckets() - dataset.set_seed(seed) + # make buckets first because it determines the length of dataset + # and set the same seed for all datasets + seed = random.randint(0, 2**31) # actual seed is seed + epoch_no + for i, dataset in enumerate(datasets): + print(f"[Dataset {i}]") + dataset.make_buckets() + dataset.set_seed(seed) - return DatasetGroup(datasets) + return DatasetGroup(datasets) def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None): - def extract_dreambooth_params(name: str) -> Tuple[int, str]: - tokens = name.split('_') - try: - n_repeats = int(tokens[0]) - except ValueError as e: - print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {name}") - return 0, "" - caption_by_folder = '_'.join(tokens[1:]) - return n_repeats, caption_by_folder - - def generate(base_dir: Optional[str], is_reg: bool): - if base_dir is None: - return [] - - base_dir: Path = Path(base_dir) - if not base_dir.is_dir(): - return [] + def extract_dreambooth_params(name: str) -> Tuple[int, str]: + tokens = name.split("_") + try: + n_repeats = int(tokens[0]) + except ValueError as e: + print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {name}") + return 0, "" + caption_by_folder = "_".join(tokens[1:]) + return n_repeats, caption_by_folder + + def generate(base_dir: Optional[str], is_reg: bool): + if base_dir is None: + return [] + + base_dir: Path = Path(base_dir) + if not base_dir.is_dir(): + return [] + + subsets_config = [] + for subdir in base_dir.iterdir(): + if not subdir.is_dir(): + continue + + num_repeats, class_tokens = extract_dreambooth_params(subdir.name) + if num_repeats < 1: + continue + + subset_config = {"image_dir": str(subdir), "num_repeats": num_repeats, "is_reg": is_reg, "class_tokens": class_tokens} + subsets_config.append(subset_config) + + return subsets_config subsets_config = [] - for subdir in base_dir.iterdir(): - if not subdir.is_dir(): - continue - - num_repeats, class_tokens = extract_dreambooth_params(subdir.name) - if num_repeats < 1: - continue - - subset_config = {"image_dir": str(subdir), "num_repeats": num_repeats, "is_reg": is_reg, "class_tokens": class_tokens} - subsets_config.append(subset_config) + subsets_config += generate(train_data_dir, False) + subsets_config += generate(reg_data_dir, True) return subsets_config - subsets_config = [] - subsets_config += generate(train_data_dir, False) - subsets_config += generate(reg_data_dir, True) - return subsets_config +def generate_controlnet_subsets_config_by_subdirs( + train_data_dir: Optional[str] = None, conditioning_data_dir: Optional[str] = None, caption_extension: str = ".txt" +): + def generate(base_dir: Optional[str]): + if base_dir is None: + return [] + base_dir: Path = Path(base_dir) + if not base_dir.is_dir(): + return [] -def generate_controlnet_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, conditioning_data_dir: Optional[str] = None, caption_extension: str = ".txt"): - def generate(base_dir: Optional[str]): - if base_dir is None: - return [] + subsets_config = [] + subset_config = { + "image_dir": train_data_dir, + "conditioning_data_dir": conditioning_data_dir, + "caption_extension": caption_extension, + "num_repeats": 1, + } + subsets_config.append(subset_config) - base_dir: Path = Path(base_dir) - if not base_dir.is_dir(): - return [] + return subsets_config subsets_config = [] - subset_config = {"image_dir": train_data_dir, "conditioning_data_dir": conditioning_data_dir, "caption_extension": caption_extension, "num_repeats": 1} - subsets_config.append(subset_config) + subsets_config += generate(train_data_dir) return subsets_config - subsets_config = [] - subsets_config += generate(train_data_dir) - return subsets_config +def load_user_config(file: str) -> dict: + file: Path = Path(file) + if not file.is_file(): + raise ValueError(f"file not found / ファイルが見つかりません: {file}") + + if file.name.lower().endswith(".json"): + try: + with open(file, "r") as f: + config = json.load(f) + except Exception: + print( + f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}" + ) + raise + elif file.name.lower().endswith(".toml"): + try: + config = toml.load(file) + except Exception: + print( + f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}" + ) + raise + else: + raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}") + return config -def load_user_config(file: str) -> dict: - file: Path = Path(file) - if not file.is_file(): - raise ValueError(f"file not found / ファイルが見つかりません: {file}") - - if file.name.lower().endswith('.json'): - try: - with open(file, 'r') as f: - config = json.load(f) - except Exception: - print(f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}") - raise - elif file.name.lower().endswith('.toml'): - try: - config = toml.load(file) - except Exception: - print(f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}") - raise - else: - raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}") - - return config # for config test if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--support_dreambooth", action="store_true") - parser.add_argument("--support_finetuning", action="store_true") - parser.add_argument("--support_controlnet", action="store_true") - parser.add_argument("--support_dropout", action="store_true") - parser.add_argument("dataset_config") - config_args, remain = parser.parse_known_args() - - parser = argparse.ArgumentParser() - train_util.add_dataset_arguments(parser, config_args.support_dreambooth, config_args.support_finetuning, config_args.support_dropout) - train_util.add_training_arguments(parser, config_args.support_dreambooth) - argparse_namespace = parser.parse_args(remain) - train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning) + parser = argparse.ArgumentParser() + parser.add_argument("--support_dreambooth", action="store_true") + parser.add_argument("--support_finetuning", action="store_true") + parser.add_argument("--support_controlnet", action="store_true") + parser.add_argument("--support_dropout", action="store_true") + parser.add_argument("dataset_config") + config_args, remain = parser.parse_known_args() + + parser = argparse.ArgumentParser() + train_util.add_dataset_arguments( + parser, config_args.support_dreambooth, config_args.support_finetuning, config_args.support_dropout + ) + train_util.add_training_arguments(parser, config_args.support_dreambooth) + argparse_namespace = parser.parse_args(remain) + train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning) - print("[argparse_namespace]") - print(vars(argparse_namespace)) + print("[argparse_namespace]") + print(vars(argparse_namespace)) - user_config = load_user_config(config_args.dataset_config) + user_config = load_user_config(config_args.dataset_config) - print("\n[user_config]") - print(user_config) + print("\n[user_config]") + print(user_config) - sanitizer = ConfigSanitizer(config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout) - sanitized_user_config = sanitizer.sanitize_user_config(user_config) + sanitizer = ConfigSanitizer( + config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout + ) + sanitized_user_config = sanitizer.sanitize_user_config(user_config) - print("\n[sanitized_user_config]") - print(sanitized_user_config) + print("\n[sanitized_user_config]") + print(sanitized_user_config) - blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace) + blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace) - print("\n[blueprint]") - print(blueprint) + print("\n[blueprint]") + print(blueprint) From fef172966fff02a5f918840843eb613ee1a6d50e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 20 Jan 2024 16:24:43 +0900 Subject: [PATCH 018/103] Add network_multiplier for dataset and train LoRA --- README.md | 31 +++++++++++++++++++++++++++++ library/config_util.py | 3 +++ library/train_util.py | 45 +++++++++++++++++++++++++----------------- train_network.py | 13 +++++++++++- 4 files changed, 73 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index 663f52e86..c1043783c 100644 --- a/README.md +++ b/README.md @@ -257,11 +257,42 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum - If you use xformers with PyTorch 2.1, please see [xformers repository](https://github.com/facebookresearch/xformers) and install the appropriate version according to your CUDA version. - The sample image generation during training consumes a lot of memory. It is recommended to turn it off. +- [Experimental] The network multiplier can be specified for each dataset in the training scripts for LoRA etc. + - This is an experimental option and may be removed or changed in the future. + - For example, if you train with state A as `1.0` and state B as `-1.0`, you may be able to generate by switching between state A and B depending on the LoRA application rate. + - Also, if you prepare five states and train them as `0.2`, `0.4`, `0.6`, `0.8`, and `1.0`, you may be able to generate by switching the states smoothly depending on the application rate. + - Please specify `network_multiplier` in `[[datasets]]` in `.toml` file. + - (実験的) LoRA等の学習スクリプトで、ベースモデル(U-Net、および Text Encoder のモジュール学習時は Text Encoder も)の重みを fp8 にして学習するオプションが追加されました。 PR [#1057](https://github.com/kohya-ss/sd-scripts/pull/1057) KohakuBlueleaf 氏に感謝します。 - `train_network.py` または `sdxl_train_network.py` で `--fp8_base` を指定してください。 - PyTorch 2.1 以降が必要です。 - PyTorch 2.1 で xformers を使用する場合は、[xformers のリポジトリ](https://github.com/facebookresearch/xformers) を参照し、CUDA バージョンに応じて適切なバージョンをインストールしてください。 - 学習中のサンプル画像生成はメモリを大量に消費するため、オフにすることをお勧めします。 +- (実験的) LoRA 等の学習で、データセットごとに異なるネットワーク適用率を指定できるようになりました。 + - 実験的オプションのため、将来的に削除または仕様変更される可能性があります。 + - たとえば状態 A を `1.0`、状態 B を `-1.0` として学習すると、LoRA の適用率に応じて状態 A と B を切り替えつつ生成できるかもしれません。 + - また、五段階の状態を用意し、それぞれ `0.2`、`0.4`、`0.6`、`0.8`、`1.0` として学習すると、適用率でなめらかに状態を切り替えて生成できるかもしれません。 + - `.toml` ファイルで `[[datasets]]` に `network_multiplier` を指定してください。 + +- `.toml` example for network multiplier / ネットワーク適用率の `.toml` の記述例 + +```toml +[general] +[[datasets]] +resolution = 512 +batch_size = 8 +network_multiplier = 1.0 + +... subset settings ... + +[[datasets]] +resolution = 512 +batch_size = 8 +network_multiplier = -1.0 + +... subset settings ... +``` + ### Jan 17, 2024 / 2024/1/17: v0.8.1 diff --git a/library/config_util.py b/library/config_util.py index 716cecff8..a98c2b90d 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -93,6 +93,7 @@ class BaseDatasetParams: tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None max_token_length: int = None resolution: Optional[Tuple[int, int]] = None + network_multiplier: float = 1.0 debug_dataset: bool = False @@ -219,6 +220,7 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence] "max_bucket_reso": int, "min_bucket_reso": int, "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int), + "network_multiplier": float, } # options handled by argparse but not handled by user config @@ -469,6 +471,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu batch_size: {dataset.batch_size} resolution: {(dataset.width, dataset.height)} enable_bucket: {dataset.enable_bucket} + network_multiplier: {dataset.network_multiplier} """ ) diff --git a/library/train_util.py b/library/train_util.py index 21e7638da..4ac6728b0 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -558,6 +558,7 @@ def __init__( tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]], max_token_length: int, resolution: Optional[Tuple[int, int]], + network_multiplier: float, debug_dataset: bool, ) -> None: super().__init__() @@ -567,6 +568,7 @@ def __init__( self.max_token_length = max_token_length # width/height is used when enable_bucket==False self.width, self.height = (None, None) if resolution is None else resolution + self.network_multiplier = network_multiplier self.debug_dataset = debug_dataset self.subsets: List[Union[DreamBoothSubset, FineTuningSubset]] = [] @@ -1106,7 +1108,9 @@ def __getitem__(self, index): for image_key in bucket[image_index : image_index + bucket_batch_size]: image_info = self.image_data[image_key] subset = self.image_to_subset[image_key] - loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0) + loss_weights.append( + self.prior_loss_weight if image_info.is_reg else 1.0 + ) # in case of fine tuning, is_reg is always False flipped = subset.flip_aug and random.random() < 0.5 # not flipped or flipped with 50% chance @@ -1272,6 +1276,8 @@ def __getitem__(self, index): example["target_sizes_hw"] = torch.stack([torch.LongTensor(x) for x in target_sizes_hw]) example["flippeds"] = flippeds + example["network_multipliers"] = torch.FloatTensor([self.network_multiplier] * len(captions)) + if self.debug_dataset: example["image_keys"] = bucket[image_index : image_index + self.batch_size] return example @@ -1346,15 +1352,16 @@ def __init__( tokenizer, max_token_length, resolution, + network_multiplier: float, enable_bucket: bool, min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, prior_loss_weight: float, - debug_dataset, + debug_dataset: bool, ) -> None: - super().__init__(tokenizer, max_token_length, resolution, debug_dataset) + super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です" @@ -1520,14 +1527,15 @@ def __init__( tokenizer, max_token_length, resolution, + network_multiplier: float, enable_bucket: bool, min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, - debug_dataset, + debug_dataset: bool, ) -> None: - super().__init__(tokenizer, max_token_length, resolution, debug_dataset) + super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) self.batch_size = batch_size @@ -1724,14 +1732,15 @@ def __init__( tokenizer, max_token_length, resolution, + network_multiplier: float, enable_bucket: bool, min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, - debug_dataset, + debug_dataset: float, ) -> None: - super().__init__(tokenizer, max_token_length, resolution, debug_dataset) + super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) db_subsets = [] for subset in subsets: @@ -2039,6 +2048,8 @@ def debug_dataset(train_dataset, show_input_ids=False): print( f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}", original size: {orgsz}, crop top left: {crptl}, target size: {trgsz}, flipped: {flpdz}' ) + if "network_multipliers" in example: + print(f"network multiplier: {example['network_multipliers'][j]}") if show_input_ids: print(f"input ids: {iid}") @@ -2105,8 +2116,8 @@ def glob_images_pathlib(dir_path, recursive): class MinimalDataset(BaseDataset): - def __init__(self, tokenizer, max_token_length, resolution, debug_dataset=False): - super().__init__(tokenizer, max_token_length, resolution, debug_dataset) + def __init__(self, tokenizer, max_token_length, resolution, network_multiplier, debug_dataset=False): + super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) self.num_train_images = 0 # update in subclass self.num_reg_images = 0 # update in subclass @@ -2850,14 +2861,14 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: ) parser.add_argument("--torch_compile", action="store_true", help="use torch.compile (requires PyTorch 2.0) / torch.compile を使う") parser.add_argument( - "--dynamo_backend", - type=str, - default="inductor", + "--dynamo_backend", + type=str, + default="inductor", # available backends: # https://github.com/huggingface/accelerate/blob/d1abd59114ada8ba673e1214218cb2878c13b82d/src/accelerate/utils/dataclasses.py#L376-L388C5 # https://pytorch.org/docs/stable/torch.compiler.html - choices=["eager", "aot_eager", "inductor", "aot_ts_nvfuser", "nvprims_nvfuser", "cudagraphs", "ofi", "fx2trt", "onnxrt"], - help="dynamo backend type (default is inductor) / dynamoのbackendの種類(デフォルトは inductor)" + choices=["eager", "aot_eager", "inductor", "aot_ts_nvfuser", "nvprims_nvfuser", "cudagraphs", "ofi", "fx2trt", "onnxrt"], + help="dynamo backend type (default is inductor) / dynamoのbackendの種類(デフォルトは inductor)", ) parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う") parser.add_argument( @@ -2904,9 +2915,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument( "--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する" ) # TODO move to SDXL training, because it is not supported by SD1/2 - parser.add_argument( - "--fp8_base", action="store_true", help="use fp8 for base model / base modelにfp8を使う" - ) + parser.add_argument("--fp8_base", action="store_true", help="use fp8 for base model / base modelにfp8を使う") parser.add_argument( "--ddp_timeout", type=int, @@ -3889,7 +3898,7 @@ def prepare_accelerator(args: argparse.Namespace): os.environ["WANDB_DIR"] = logging_dir if args.wandb_api_key is not None: wandb.login(key=args.wandb_api_key) - + # torch.compile のオプション。 NO の場合は torch.compile は使わない dynamo_backend = "NO" if args.torch_compile: diff --git a/train_network.py b/train_network.py index b1291ed10..ef7d41973 100644 --- a/train_network.py +++ b/train_network.py @@ -310,6 +310,7 @@ def train(self, args): ) if network is None: return + network_has_multiplier = hasattr(network, "set_multiplier") if hasattr(network, "prepare_network"): network.prepare_network(args) @@ -768,7 +769,17 @@ def remove_model(old_ckpt_name): accelerator.print("NaN found in latents, replacing with zeros") latents = torch.nan_to_num(latents, 0, out=latents) latents = latents * self.vae_scale_factor - b_size = latents.shape[0] + + # get multiplier for each sample + if network_has_multiplier: + multipliers = batch["network_multipliers"] + # if all multipliers are same, use single multiplier + if torch.all(multipliers == multipliers[0]): + multipliers = multipliers[0].item() + else: + raise NotImplementedError("multipliers for each sample is not supported yet") + # print(f"set multiplier: {multipliers}") + network.set_multiplier(multipliers) with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): # Get the text embedding for conditioning From c59249a664a9202d8c7e2300587480d0a59a4c9b Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 20 Jan 2024 18:45:54 +0900 Subject: [PATCH 019/103] Add options to reduce memory usage in extract_lora_from_models.py closes #1059 --- README.md | 7 +++ networks/extract_lora_from_models.py | 83 ++++++++++++++++++++++++---- 2 files changed, 79 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index c1043783c..be6df5883 100644 --- a/README.md +++ b/README.md @@ -262,6 +262,9 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum - For example, if you train with state A as `1.0` and state B as `-1.0`, you may be able to generate by switching between state A and B depending on the LoRA application rate. - Also, if you prepare five states and train them as `0.2`, `0.4`, `0.6`, `0.8`, and `1.0`, you may be able to generate by switching the states smoothly depending on the application rate. - Please specify `network_multiplier` in `[[datasets]]` in `.toml` file. +- Some options are added to `networks/extract_lora_from_models.py` to reduce the memory usage. + - `--load_precision` option can be used to specify the precision when loading the model. If the model is saved in fp16, you can reduce the memory usage by specifying `--load_precision fp16` without losing precision. + - `--load_original_model_to` option can be used to specify the device to load the original model. `--load_tuned_model_to` option can be used to specify the device to load the derived model. The default is `cpu` for both options, but you can specify `cuda` etc. You can reduce the memory usage by loading one of them to GPU. - (実験的) LoRA等の学習スクリプトで、ベースモデル(U-Net、および Text Encoder のモジュール学習時は Text Encoder も)の重みを fp8 にして学習するオプションが追加されました。 PR [#1057](https://github.com/kohya-ss/sd-scripts/pull/1057) KohakuBlueleaf 氏に感謝します。 - `train_network.py` または `sdxl_train_network.py` で `--fp8_base` を指定してください。 @@ -273,6 +276,10 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum - たとえば状態 A を `1.0`、状態 B を `-1.0` として学習すると、LoRA の適用率に応じて状態 A と B を切り替えつつ生成できるかもしれません。 - また、五段階の状態を用意し、それぞれ `0.2`、`0.4`、`0.6`、`0.8`、`1.0` として学習すると、適用率でなめらかに状態を切り替えて生成できるかもしれません。 - `.toml` ファイルで `[[datasets]]` に `network_multiplier` を指定してください。 +- `networks/extract_lora_from_models.py` に使用メモリ量を削減するいくつかのオプションを追加しました。 + - `--load_precision` で読み込み時の精度を指定できます。モデルが fp16 で保存されている場合は `--load_precision fp16` を指定して精度を変えずにメモリ量を削減できます。 + - `--load_original_model_to` で元モデルを読み込むデバイスを、`--load_tuned_model_to` で派生モデルを読み込むデバイスを指定できます。デフォルトは両方とも `cpu` ですがそれぞれ `cuda` 等を指定できます。片方を GPU に読み込むことでメモリ量を削減できます。 + - `.toml` example for network multiplier / ネットワーク適用率の `.toml` の記述例 diff --git a/networks/extract_lora_from_models.py b/networks/extract_lora_from_models.py index 6357df55d..b9027adba 100644 --- a/networks/extract_lora_from_models.py +++ b/networks/extract_lora_from_models.py @@ -43,6 +43,9 @@ def svd( clamp_quantile=0.99, min_diff=0.01, no_metadata=False, + load_precision=None, + load_original_model_to=None, + load_tuned_model_to=None, ): def str_to_dtype(p): if p == "float": @@ -57,28 +60,51 @@ def str_to_dtype(p): if v_parameterization is None: v_parameterization = v2 + load_dtype = str_to_dtype(load_precision) if load_precision else None save_dtype = str_to_dtype(save_precision) + work_device = "cpu" # load models if not sdxl: print(f"loading original SD model : {model_org}") text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_org) text_encoders_o = [text_encoder_o] + if load_dtype is not None: + text_encoder_o = text_encoder_o.to(load_dtype) + unet_o = unet_o.to(load_dtype) + print(f"loading tuned SD model : {model_tuned}") text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_tuned) text_encoders_t = [text_encoder_t] + if load_dtype is not None: + text_encoder_t = text_encoder_t.to(load_dtype) + unet_t = unet_t.to(load_dtype) + model_version = model_util.get_model_version_str_for_sd1_sd2(v2, v_parameterization) else: + device_org = load_original_model_to if load_original_model_to else "cpu" + device_tuned = load_tuned_model_to if load_tuned_model_to else "cpu" + print(f"loading original SDXL model : {model_org}") text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint( - sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, "cpu" + sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, device_org ) text_encoders_o = [text_encoder_o1, text_encoder_o2] + if load_dtype is not None: + text_encoder_o1 = text_encoder_o1.to(load_dtype) + text_encoder_o2 = text_encoder_o2.to(load_dtype) + unet_o = unet_o.to(load_dtype) + print(f"loading original SDXL model : {model_tuned}") text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint( - sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, "cpu" + sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, device_tuned ) text_encoders_t = [text_encoder_t1, text_encoder_t2] + if load_dtype is not None: + text_encoder_t1 = text_encoder_t1.to(load_dtype) + text_encoder_t2 = text_encoder_t2.to(load_dtype) + unet_t = unet_t.to(load_dtype) + model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0 # create LoRA network to extract weights: Use dim (rank) as alpha @@ -100,38 +126,54 @@ def str_to_dtype(p): lora_name = lora_o.lora_name module_o = lora_o.org_module module_t = lora_t.org_module - diff = module_t.weight - module_o.weight + diff = module_t.weight.to(work_device) - module_o.weight.to(work_device) + + # clear weight to save memory + module_o.weight = None + module_t.weight = None # Text Encoder might be same if not text_encoder_different and torch.max(torch.abs(diff)) > min_diff: text_encoder_different = True print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {min_diff}") - diff = diff.float() diffs[lora_name] = diff + # clear target Text Encoder to save memory + for text_encoder in text_encoders_t: + del text_encoder + if not text_encoder_different: print("Text encoder is same. Extract U-Net only.") lora_network_o.text_encoder_loras = [] - diffs = {} + diffs = {} # clear diffs for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.unet_loras, lora_network_t.unet_loras)): lora_name = lora_o.lora_name module_o = lora_o.org_module module_t = lora_t.org_module - diff = module_t.weight - module_o.weight - diff = diff.float() + diff = module_t.weight.to(work_device) - module_o.weight.to(work_device) - if args.device: - diff = diff.to(args.device) + # clear weight to save memory + module_o.weight = None + module_t.weight = None diffs[lora_name] = diff + # clear LoRA network, target U-Net to save memory + del lora_network_o + del lora_network_t + del unet_t + # make LoRA with svd print("calculating by svd") lora_weights = {} with torch.no_grad(): for lora_name, mat in tqdm(list(diffs.items())): + if args.device: + mat = mat.to(args.device) + mat = mat.to(torch.float) # calc by float + # if conv_dim is None, diffs do not include LoRAs for conv2d-3x3 conv2d = len(mat.size()) == 4 kernel_size = None if not conv2d else mat.size()[2:4] @@ -171,8 +213,8 @@ def str_to_dtype(p): U = U.reshape(out_dim, rank, 1, 1) Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1]) - U = U.to("cpu").contiguous() - Vh = Vh.to("cpu").contiguous() + U = U.to(work_device, dtype=save_dtype).contiguous() + Vh = Vh.to(work_device, dtype=save_dtype).contiguous() lora_weights[lora_name] = (U, Vh) @@ -230,6 +272,13 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument( "--sdxl", action="store_true", help="load Stable Diffusion SDXL base model / Stable Diffusion SDXL baseのモデルを読み込む" ) + parser.add_argument( + "--load_precision", + type=str, + default=None, + choices=[None, "float", "fp16", "bf16"], + help="precision in loading, model default if omitted / 読み込み時に精度を変更して読み込む、省略時はモデルファイルによる" + ) parser.add_argument( "--save_precision", type=str, @@ -285,6 +334,18 @@ def setup_parser() -> argparse.ArgumentParser: help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / " + "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)", ) + parser.add_argument( + "--load_original_model_to", + type=str, + default=None, + help="location to load original model, cpu or cuda, cuda:0, etc, default is cpu, only for SDXL / 元モデル読み込み先、cpuまたはcuda、cuda:0など、省略時はcpu、SDXLのみ有効", + ) + parser.add_argument( + "--load_tuned_model_to", + type=str, + default=None, + help="location to load tuned model, cpu or cuda, cuda:0, etc, default is cpu, only for SDXL / 派生モデル読み込み先、cpuまたはcuda、cuda:0など、省略時はcpu、SDXLのみ有効", + ) return parser From e0a3c69223d09f38a2fc0d86ed289378cf0f7958 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 20 Jan 2024 18:47:10 +0900 Subject: [PATCH 020/103] update readme --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index be6df5883..23108b78e 100644 --- a/README.md +++ b/README.md @@ -264,7 +264,7 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum - Please specify `network_multiplier` in `[[datasets]]` in `.toml` file. - Some options are added to `networks/extract_lora_from_models.py` to reduce the memory usage. - `--load_precision` option can be used to specify the precision when loading the model. If the model is saved in fp16, you can reduce the memory usage by specifying `--load_precision fp16` without losing precision. - - `--load_original_model_to` option can be used to specify the device to load the original model. `--load_tuned_model_to` option can be used to specify the device to load the derived model. The default is `cpu` for both options, but you can specify `cuda` etc. You can reduce the memory usage by loading one of them to GPU. + - `--load_original_model_to` option can be used to specify the device to load the original model. `--load_tuned_model_to` option can be used to specify the device to load the derived model. The default is `cpu` for both options, but you can specify `cuda` etc. You can reduce the memory usage by loading one of them to GPU. This option is available only for SDXL. - (実験的) LoRA等の学習スクリプトで、ベースモデル(U-Net、および Text Encoder のモジュール学習時は Text Encoder も)の重みを fp8 にして学習するオプションが追加されました。 PR [#1057](https://github.com/kohya-ss/sd-scripts/pull/1057) KohakuBlueleaf 氏に感謝します。 - `train_network.py` または `sdxl_train_network.py` で `--fp8_base` を指定してください。 @@ -278,7 +278,7 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum - `.toml` ファイルで `[[datasets]]` に `network_multiplier` を指定してください。 - `networks/extract_lora_from_models.py` に使用メモリ量を削減するいくつかのオプションを追加しました。 - `--load_precision` で読み込み時の精度を指定できます。モデルが fp16 で保存されている場合は `--load_precision fp16` を指定して精度を変えずにメモリ量を削減できます。 - - `--load_original_model_to` で元モデルを読み込むデバイスを、`--load_tuned_model_to` で派生モデルを読み込むデバイスを指定できます。デフォルトは両方とも `cpu` ですがそれぞれ `cuda` 等を指定できます。片方を GPU に読み込むことでメモリ量を削減できます。 + - `--load_original_model_to` で元モデルを読み込むデバイスを、`--load_tuned_model_to` で派生モデルを読み込むデバイスを指定できます。デフォルトは両方とも `cpu` ですがそれぞれ `cuda` 等を指定できます。片方を GPU に読み込むことでメモリ量を削減できます。SDXL の場合のみ有効です。 - `.toml` example for network multiplier / ネットワーク適用率の `.toml` の記述例 From 696dd7f66802091f973ef0ce5ffa1e8002e90789 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 22 Jan 2024 12:43:37 +0900 Subject: [PATCH 021/103] Fix dtype issue in PyTorch 2.0 for generating samples in training sdxl network --- library/sdxl_lpw_stable_diffusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/sdxl_lpw_stable_diffusion.py b/library/sdxl_lpw_stable_diffusion.py index 0562e88ad..03b182566 100644 --- a/library/sdxl_lpw_stable_diffusion.py +++ b/library/sdxl_lpw_stable_diffusion.py @@ -925,7 +925,7 @@ def __call__( unet_dtype = self.unet.dtype dtype = unet_dtype - if dtype.itemsize == 1: # fp8 + if hasattr(dtype, "itemsize") and dtype.itemsize == 1: # fp8 dtype = torch.float16 self.unet.to(dtype) From 711b40ccda0143cbf0014249ecdb9353231cbb79 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 23 Jan 2024 11:49:03 +0800 Subject: [PATCH 022/103] Avoid always sync --- train_network.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/train_network.py b/train_network.py index ef7d41973..9036f4862 100644 --- a/train_network.py +++ b/train_network.py @@ -847,10 +847,11 @@ def remove_model(old_ckpt_name): loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし accelerator.backward(loss) - self.all_reduce_network(accelerator, network) # sync DDP grad manually - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - params_to_clip = accelerator.unwrap_model(network).get_trainable_params() - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + if accelerator.sync_gradients: + self.all_reduce_network(accelerator, network) # sync DDP grad manually + if args.max_grad_norm != 0.0: + params_to_clip = accelerator.unwrap_model(network).get_trainable_params() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() lr_scheduler.step() From 6805cafa9be066e549b3aa2b9f53ec03a91ffda6 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 23 Jan 2024 20:17:19 +0900 Subject: [PATCH 023/103] fix TI training crashes in multigpu #1019 --- train_textual_inversion.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 0e3912b1d..f1cf6fbda 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -505,7 +505,7 @@ def train(self, args): if accelerator.is_main_process: init_kwargs = {} if args.wandb_run_name: - init_kwargs['wandb'] = {'name': args.wandb_run_name} + init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( @@ -730,14 +730,13 @@ def remove_model(old_ckpt_name): is_main_process = accelerator.is_main_process if is_main_process: text_encoder = accelerator.unwrap_model(text_encoder) + updated_embs = text_encoder.get_input_embeddings().weight[token_ids].data.detach().clone() accelerator.end_training() if args.save_state and is_main_process: train_util.save_state_on_train_end(args, accelerator) - updated_embs = text_encoder.get_input_embeddings().weight[token_ids].data.detach().clone() - if is_main_process: ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) save_model(ckpt_name, updated_embs_list, global_step, num_train_epochs, force_sync_upload=True) From 7cb44e45023295548788449167bcc8cd6afc5417 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 23 Jan 2024 21:02:40 +0900 Subject: [PATCH 024/103] update readme --- README.md | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 23108b78e..0de787b8f 100644 --- a/README.md +++ b/README.md @@ -249,7 +249,7 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum ## Change History -### 作業中の内容 / Work in progress +### Jan 23, 2024 / 2024/1/23: v0.8.2 - [Experimental] The `--fp8_base` option is added to the training scripts for LoRA etc. The base model (U-Net, and Text Encoder when training modules for Text Encoder) can be trained with fp8. PR [#1057](https://github.com/kohya-ss/sd-scripts/pull/1057) Thanks to KohakuBlueleaf! - Please specify `--fp8_base` in `train_network.py` or `sdxl_train_network.py`. @@ -266,6 +266,10 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum - `--load_precision` option can be used to specify the precision when loading the model. If the model is saved in fp16, you can reduce the memory usage by specifying `--load_precision fp16` without losing precision. - `--load_original_model_to` option can be used to specify the device to load the original model. `--load_tuned_model_to` option can be used to specify the device to load the derived model. The default is `cpu` for both options, but you can specify `cuda` etc. You can reduce the memory usage by loading one of them to GPU. This option is available only for SDXL. +- The gradient synchronization in LoRA training with multi-GPU is improved. PR [#1064](https://github.com/kohya-ss/sd-scripts/pull/1064) Thanks to KohakuBlueleaf! +- The code for Intel IPEX support is improved. PR [#1060](https://github.com/kohya-ss/sd-scripts/pull/1060) Thanks to akx! +- Fixed a bug in multi-GPU Textual Inversion training. + - (実験的) LoRA等の学習スクリプトで、ベースモデル(U-Net、および Text Encoder のモジュール学習時は Text Encoder も)の重みを fp8 にして学習するオプションが追加されました。 PR [#1057](https://github.com/kohya-ss/sd-scripts/pull/1057) KohakuBlueleaf 氏に感謝します。 - `train_network.py` または `sdxl_train_network.py` で `--fp8_base` を指定してください。 - PyTorch 2.1 以降が必要です。 @@ -279,7 +283,9 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum - `networks/extract_lora_from_models.py` に使用メモリ量を削減するいくつかのオプションを追加しました。 - `--load_precision` で読み込み時の精度を指定できます。モデルが fp16 で保存されている場合は `--load_precision fp16` を指定して精度を変えずにメモリ量を削減できます。 - `--load_original_model_to` で元モデルを読み込むデバイスを、`--load_tuned_model_to` で派生モデルを読み込むデバイスを指定できます。デフォルトは両方とも `cpu` ですがそれぞれ `cuda` 等を指定できます。片方を GPU に読み込むことでメモリ量を削減できます。SDXL の場合のみ有効です。 - +- マルチ GPU での LoRA 等の学習時に勾配の同期が改善されました。 PR [#1064](https://github.com/kohya-ss/sd-scripts/pull/1064) KohakuBlueleaf 氏に感謝します。 +- Intel IPEX サポートのコードが改善されました。PR [#1060](https://github.com/kohya-ss/sd-scripts/pull/1060) akx 氏に感謝します。 +- マルチ GPU での Textual Inversion 学習の不具合を修正しました。 - `.toml` example for network multiplier / ネットワーク適用率の `.toml` の記述例 From 2e4bee6f24c81836f917fe6057fdb1e5899108a3 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Tue, 16 Jan 2024 15:10:24 +0200 Subject: [PATCH 025/103] Log accelerator device --- library/train_util.py | 1 + 1 file changed, 1 insertion(+) diff --git a/library/train_util.py b/library/train_util.py index 4ac6728b0..320755d77 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3919,6 +3919,7 @@ def prepare_accelerator(args: argparse.Namespace): kwargs_handlers=kwargs_handlers, dynamo_backend=dynamo_backend, ) + print("accelerator device:", accelerator.device) return accelerator From afc38707d57a055577627ee4e17ade4581ed0140 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Tue, 16 Jan 2024 14:47:44 +0200 Subject: [PATCH 026/103] Refactor memory cleaning into a single function --- fine_tune.py | 6 ++---- gen_img_diffusers.py | 7 +++---- library/device_utils.py | 9 +++++++++ library/sdxl_train_util.py | 5 ++--- library/train_util.py | 10 ++++------ sdxl_gen_img.py | 10 ++++------ sdxl_train.py | 9 +++------ sdxl_train_control_net_lllite.py | 9 +++------ sdxl_train_control_net_lllite_old.py | 9 +++------ sdxl_train_network.py | 7 +++---- train_controlnet.py | 6 ++---- train_db.py | 6 ++---- train_network.py | 6 ++---- train_textual_inversion.py | 6 ++---- train_textual_inversion_XTI.py | 6 ++---- 15 files changed, 46 insertions(+), 65 deletions(-) create mode 100644 library/device_utils.py diff --git a/fine_tune.py b/fine_tune.py index 982dc8aec..11e94e560 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -2,7 +2,6 @@ # XXX dropped option: hypernetwork training import argparse -import gc import math import os from multiprocessing import Value @@ -11,6 +10,7 @@ from tqdm import tqdm import torch +from library.device_utils import clean_memory from library.ipex_interop import init_ipex init_ipex() @@ -158,9 +158,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() + clean_memory() accelerator.wait_for_everyone() diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index a207ad5a1..4d2a73082 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -66,6 +66,7 @@ import numpy as np import torch +from library.device_utils import clean_memory from library.ipex_interop import init_ipex init_ipex() @@ -888,8 +889,7 @@ def __call__( init_latent_dist = self.vae.encode(init_image).latent_dist init_latents = init_latent_dist.sample(generator=generator) else: - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory() init_latents = [] for i in tqdm(range(0, min(batch_size, len(init_image)), vae_batch_size)): init_latent_dist = self.vae.encode( @@ -1047,8 +1047,7 @@ def __call__( if vae_batch_size >= batch_size: image = self.vae.decode(latents).sample else: - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory() images = [] for i in tqdm(range(0, batch_size, vae_batch_size)): images.append( diff --git a/library/device_utils.py b/library/device_utils.py new file mode 100644 index 000000000..49af622bb --- /dev/null +++ b/library/device_utils.py @@ -0,0 +1,9 @@ +import gc + +import torch + + +def clean_memory(): + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index 5ad748d15..d2becad6f 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -1,5 +1,4 @@ import argparse -import gc import math import os from typing import Optional @@ -8,6 +7,7 @@ from tqdm import tqdm from transformers import CLIPTokenizer from library import model_util, sdxl_model_util, train_util, sdxl_original_unet +from library.device_utils import clean_memory from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline TOKENIZER1_PATH = "openai/clip-vit-large-patch14" @@ -47,8 +47,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype): unet.to(accelerator.device) vae.to(accelerator.device) - gc.collect() - torch.cuda.empty_cache() + clean_memory() accelerator.wait_for_everyone() return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info diff --git a/library/train_util.py b/library/train_util.py index 320755d77..8b5606df3 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -20,7 +20,6 @@ Union, ) from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs -import gc import glob import math import os @@ -67,6 +66,7 @@ # from library.attention_processors import FlashAttnProcessor # from library.hypernetwork import replace_attentions_for_hypernetwork +from library.device_utils import clean_memory from library.original_unet import UNet2DConditionModel # Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う @@ -2278,8 +2278,7 @@ def cache_batch_latents( info.latents_flipped = flipped_latent # FIXME this slows down caching a lot, specify this as an option - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory() def cache_batch_text_encoder_outputs( @@ -4006,8 +4005,7 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio unet.to(accelerator.device) vae.to(accelerator.device) - gc.collect() - torch.cuda.empty_cache() + clean_memory() accelerator.wait_for_everyone() return text_encoder, vae, unet, load_stable_diffusion_format @@ -4816,7 +4814,7 @@ def sample_images_common( # clear pipeline and cache to reduce vram usage del pipeline - torch.cuda.empty_cache() + clean_memory() torch.set_rng_state(rng_state) if cuda_rng_state is not None: diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index 0db9e340e..14b05bfb7 100755 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -18,6 +18,7 @@ import numpy as np import torch +from library.device_utils import clean_memory from library.ipex_interop import init_ipex init_ipex() @@ -640,8 +641,7 @@ def __call__( init_latent_dist = self.vae.encode(init_image.to(self.vae.dtype)).latent_dist init_latents = init_latent_dist.sample(generator=generator) else: - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory() init_latents = [] for i in tqdm(range(0, min(batch_size, len(init_image)), vae_batch_size)): init_latent_dist = self.vae.encode( @@ -780,8 +780,7 @@ def __call__( if vae_batch_size >= batch_size: image = self.vae.decode(latents.to(self.vae.dtype)).sample else: - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory() images = [] for i in tqdm(range(0, batch_size, vae_batch_size)): images.append( @@ -796,8 +795,7 @@ def __call__( # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory() if output_type == "pil": # image = self.numpy_to_pil(image) diff --git a/sdxl_train.py b/sdxl_train.py index a3f6f3a17..78cfaf495 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -1,7 +1,6 @@ # training with captions import argparse -import gc import math import os from multiprocessing import Value @@ -11,6 +10,7 @@ from tqdm import tqdm import torch +from library.device_utils import clean_memory from library.ipex_interop import init_ipex init_ipex() @@ -252,9 +252,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() + clean_memory() accelerator.wait_for_everyone() @@ -407,8 +405,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 text_encoder1.to("cpu", dtype=torch.float32) text_encoder2.to("cpu", dtype=torch.float32) - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory() else: # make sure Text Encoders are on GPU text_encoder1.to(accelerator.device) diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 7a88feb84..95b755f18 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -2,7 +2,6 @@ # training code for ControlNet-LLLite with passing cond_image to U-Net's forward import argparse -import gc import json import math import os @@ -15,6 +14,7 @@ from tqdm import tqdm import torch +from library.device_utils import clean_memory from library.ipex_interop import init_ipex init_ipex() @@ -164,9 +164,7 @@ def train(args): accelerator.is_main_process, ) vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() + clean_memory() accelerator.wait_for_everyone() @@ -291,8 +289,7 @@ def train(args): # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 text_encoder1.to("cpu", dtype=torch.float32) text_encoder2.to("cpu", dtype=torch.float32) - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory() else: # make sure Text Encoders are on GPU text_encoder1.to(accelerator.device) diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index b94bf5c1b..fd24898c4 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -1,5 +1,4 @@ import argparse -import gc import json import math import os @@ -12,6 +11,7 @@ from tqdm import tqdm import torch +from library.device_utils import clean_memory from library.ipex_interop import init_ipex init_ipex() @@ -163,9 +163,7 @@ def train(args): accelerator.is_main_process, ) vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() + clean_memory() accelerator.wait_for_everyone() @@ -264,8 +262,7 @@ def train(args): # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 text_encoder1.to("cpu", dtype=torch.float32) text_encoder2.to("cpu", dtype=torch.float32) - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory() else: # make sure Text Encoders are on GPU text_encoder1.to(accelerator.device) diff --git a/sdxl_train_network.py b/sdxl_train_network.py index 5d363280d..af0c8d1d7 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -1,6 +1,7 @@ import argparse import torch +from library.device_utils import clean_memory from library.ipex_interop import init_ipex init_ipex() @@ -65,8 +66,7 @@ def cache_text_encoder_outputs_if_needed( org_unet_device = unet.device vae.to("cpu") unet.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory() # When TE is not be trained, it will not be prepared so we need to use explicit autocast with accelerator.autocast(): @@ -81,8 +81,7 @@ def cache_text_encoder_outputs_if_needed( text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU text_encoders[1].to("cpu", dtype=torch.float32) - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory() if not args.lowram: print("move vae and unet back to original device") diff --git a/train_controlnet.py b/train_controlnet.py index 7b0b2bbfe..e6bea2c9f 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -1,5 +1,4 @@ import argparse -import gc import json import math import os @@ -12,6 +11,7 @@ from tqdm import tqdm import torch +from library.device_utils import clean_memory from library.ipex_interop import init_ipex init_ipex() @@ -219,9 +219,7 @@ def train(args): accelerator.is_main_process, ) vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() + clean_memory() accelerator.wait_for_everyone() diff --git a/train_db.py b/train_db.py index 888cad25e..daeb6d668 100644 --- a/train_db.py +++ b/train_db.py @@ -1,7 +1,6 @@ # DreamBooth training # XXX dropped option: fine_tune -import gc import argparse import itertools import math @@ -12,6 +11,7 @@ from tqdm import tqdm import torch +from library.device_utils import clean_memory from library.ipex_interop import init_ipex init_ipex() @@ -138,9 +138,7 @@ def train(args): with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() + clean_memory() accelerator.wait_for_everyone() diff --git a/train_network.py b/train_network.py index 8b6c395c5..f593e405d 100644 --- a/train_network.py +++ b/train_network.py @@ -1,6 +1,5 @@ import importlib import argparse -import gc import math import os import sys @@ -14,6 +13,7 @@ import torch from torch.nn.parallel import DistributedDataParallel as DDP +from library.device_utils import clean_memory from library.ipex_interop import init_ipex init_ipex() @@ -266,9 +266,7 @@ def train(self, args): with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() + clean_memory() accelerator.wait_for_everyone() diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 441c1e00b..821cfe786 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -1,5 +1,4 @@ import argparse -import gc import math import os from multiprocessing import Value @@ -8,6 +7,7 @@ from tqdm import tqdm import torch +from library.device_utils import clean_memory from library.ipex_interop import init_ipex init_ipex() @@ -363,9 +363,7 @@ def train(self, args): with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() + clean_memory() accelerator.wait_for_everyone() diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 7046a4808..ecd6d087f 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -1,6 +1,5 @@ import importlib import argparse -import gc import math import os import toml @@ -9,6 +8,7 @@ from tqdm import tqdm import torch +from library.device_utils import clean_memory from library.ipex_interop import init_ipex init_ipex() @@ -286,9 +286,7 @@ def train(args): with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() + clean_memory() accelerator.wait_for_everyone() From 478156b4f7fb681549410d33e89d439ba4bd12f9 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Tue, 16 Jan 2024 15:01:59 +0200 Subject: [PATCH 027/103] Refactor device determination to function; add MPS fallback --- finetune/make_captions.py | 3 ++- finetune/make_captions_by_git.py | 4 ++-- finetune/prepare_buckets_latents.py | 4 +++- gen_img_diffusers.py | 4 ++-- library/device_utils.py | 29 +++++++++++++++++++++++++++-- networks/lora_diffusers.py | 4 +++- networks/lora_interrogator.py | 3 ++- sdxl_gen_img.py | 4 ++-- sdxl_minimal_inference.py | 3 ++- tools/latent_upscaler.py | 4 +++- 10 files changed, 48 insertions(+), 14 deletions(-) diff --git a/finetune/make_captions.py b/finetune/make_captions.py index 074576bc2..524f80b9d 100644 --- a/finetune/make_captions.py +++ b/finetune/make_captions.py @@ -15,8 +15,9 @@ sys.path.append(os.path.dirname(__file__)) from blip.blip import blip_decoder, is_url import library.train_util as train_util +from library.device_utils import get_preferred_device -DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") +DEVICE = get_preferred_device() IMAGE_SIZE = 384 diff --git a/finetune/make_captions_by_git.py b/finetune/make_captions_by_git.py index b3c5cc423..2b650eb01 100644 --- a/finetune/make_captions_by_git.py +++ b/finetune/make_captions_by_git.py @@ -10,9 +10,9 @@ from transformers.generation.utils import GenerationMixin import library.train_util as train_util +from library.device_utils import get_preferred_device - -DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") +DEVICE = get_preferred_device() PATTERN_REPLACE = [ re.compile(r'(has|with|and) the (words?|letters?|name) (" ?[^"]*"|\w+)( ?(is )?(on|in) (the |her |their |him )?\w+)?'), diff --git a/finetune/prepare_buckets_latents.py b/finetune/prepare_buckets_latents.py index 1bccb1d3b..9d352dd6e 100644 --- a/finetune/prepare_buckets_latents.py +++ b/finetune/prepare_buckets_latents.py @@ -14,7 +14,9 @@ import library.model_util as model_util import library.train_util as train_util -DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") +from library.device_utils import get_preferred_device + +DEVICE = get_preferred_device() IMAGE_TRANSFORMS = transforms.Compose( [ diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 4d2a73082..38b1ceabf 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -66,7 +66,7 @@ import numpy as np import torch -from library.device_utils import clean_memory +from library.device_utils import clean_memory, get_preferred_device from library.ipex_interop import init_ipex init_ipex() @@ -2324,7 +2324,7 @@ def __getattr__(self, item): scheduler.config.clip_sample = True # deviceを決定する - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # "mps"を考量してない + device = get_preferred_device() # custom pipelineをコピったやつを生成する if args.vae_slices: diff --git a/library/device_utils.py b/library/device_utils.py index 49af622bb..353bfa9f3 100644 --- a/library/device_utils.py +++ b/library/device_utils.py @@ -1,9 +1,34 @@ +import functools import gc import torch +try: + HAS_CUDA = torch.cuda.is_available() +except Exception: + HAS_CUDA = False + +try: + HAS_MPS = torch.backends.mps.is_available() +except Exception: + HAS_MPS = False + def clean_memory(): - if torch.cuda.is_available(): - torch.cuda.empty_cache() gc.collect() + if HAS_CUDA: + torch.cuda.empty_cache() + if HAS_MPS: + torch.mps.empty_cache() + + +@functools.lru_cache(maxsize=None) +def get_preferred_device() -> torch.device: + if HAS_CUDA: + device = torch.device("cuda") + elif HAS_MPS: + device = torch.device("mps") + else: + device = torch.device("cpu") + print(f"get_preferred_device() -> {device}") + return device diff --git a/networks/lora_diffusers.py b/networks/lora_diffusers.py index 47d75ac4d..0056ac78e 100644 --- a/networks/lora_diffusers.py +++ b/networks/lora_diffusers.py @@ -11,6 +11,8 @@ from transformers import CLIPTextModel import torch +from library.device_utils import get_preferred_device + def make_unet_conversion_map() -> Dict[str, str]: unet_conversion_map_layer = [] @@ -476,7 +478,7 @@ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline import torch - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + device = get_preferred_device() parser = argparse.ArgumentParser() parser.add_argument("--model_id", type=str, default=None, help="model id for huggingface") diff --git a/networks/lora_interrogator.py b/networks/lora_interrogator.py index 0dc066fd1..83942d7ce 100644 --- a/networks/lora_interrogator.py +++ b/networks/lora_interrogator.py @@ -9,11 +9,12 @@ import library.model_util as model_util import lora +from library.device_utils import get_preferred_device TOKENIZER_PATH = "openai/clip-vit-large-patch14" V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う -DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +DEVICE = get_preferred_device() def interrogate(args): diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index 14b05bfb7..0722b93f4 100755 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -18,7 +18,7 @@ import numpy as np import torch -from library.device_utils import clean_memory +from library.device_utils import clean_memory, get_preferred_device from library.ipex_interop import init_ipex init_ipex() @@ -1495,7 +1495,7 @@ def __getattr__(self, item): # scheduler.config.clip_sample = True # deviceを決定する - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # "mps"を考量してない + device = get_preferred_device() # custom pipelineをコピったやつを生成する if args.vae_slices: diff --git a/sdxl_minimal_inference.py b/sdxl_minimal_inference.py index 15a70678f..3eae20441 100644 --- a/sdxl_minimal_inference.py +++ b/sdxl_minimal_inference.py @@ -10,6 +10,7 @@ import numpy as np import torch +from library.device_utils import get_preferred_device from library.ipex_interop import init_ipex init_ipex() @@ -85,7 +86,7 @@ def get_timestep_embedding(x, outdim): guidance_scale = 7 seed = None # 1 - DEVICE = "cuda" + DEVICE = get_preferred_device() DTYPE = torch.float16 # bfloat16 may work parser = argparse.ArgumentParser() diff --git a/tools/latent_upscaler.py b/tools/latent_upscaler.py index ab1fa3390..27d13ef61 100644 --- a/tools/latent_upscaler.py +++ b/tools/latent_upscaler.py @@ -15,6 +15,8 @@ from tqdm import tqdm from PIL import Image +from library.device_utils import get_preferred_device + class ResidualBlock(nn.Module): def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=1): @@ -255,7 +257,7 @@ def create_upscaler(**kwargs): # another interface: upscale images with a model for given images from command line def upscale_images(args: argparse.Namespace): - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + DEVICE = get_preferred_device() us_dtype = torch.float16 # TODO: support fp32/bf16 os.makedirs(args.output_dir, exist_ok=True) From c576f806397f7bca1085518616f8aaebfa2e98b4 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 25 Jan 2024 18:43:07 +0900 Subject: [PATCH 028/103] Fix ControlNetLLLite training issue #1069 --- library/train_util.py | 1 + 1 file changed, 1 insertion(+) diff --git a/library/train_util.py b/library/train_util.py index 4ac6728b0..ba428e508 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1774,6 +1774,7 @@ def __init__( tokenizer, max_token_length, resolution, + network_multiplier, enable_bucket, min_bucket_reso, max_bucket_reso, From 322ee52c7775bf09ff32913f2e0d14d85605680a Mon Sep 17 00:00:00 2001 From: feffy380 <114889020+feffy380@users.noreply.github.com> Date: Thu, 25 Jan 2024 19:15:53 +0100 Subject: [PATCH 029/103] Update requirements.txt Update safetensors to fix a crash when using `--fp8_base --save_state` --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 8517d95ac..c09e62b7f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,7 +8,7 @@ einops==0.6.1 pytorch-lightning==1.9.0 # bitsandbytes==0.39.1 tensorboard==2.10.1 -safetensors==0.3.1 +safetensors==0.4.2 # gradio==3.16.2 altair==4.2.2 easygui==0.98.3 From 85bc371ebc3ab311b2e6b75a98bdf9fb310516f7 Mon Sep 17 00:00:00 2001 From: DukeG Date: Fri, 26 Jan 2024 18:58:47 +0800 Subject: [PATCH 030/103] test --- train_network.py | 1 + 1 file changed, 1 insertion(+) diff --git a/train_network.py b/train_network.py index 8b6c395c5..2e06262f7 100644 --- a/train_network.py +++ b/train_network.py @@ -774,6 +774,7 @@ def remove_model(old_ckpt_name): else: raise NotImplementedError("multipliers for each sample is not supported yet") # print(f"set multiplier: {multipliers}") + print(network) network.set_multiplier(multipliers) with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): From 50f631c768eb3db578c98224106236cf3abc03f9 Mon Sep 17 00:00:00 2001 From: DukeG Date: Fri, 26 Jan 2024 20:02:48 +0800 Subject: [PATCH 031/103] test --- train_network.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index 2e06262f7..d64d7a1bc 100644 --- a/train_network.py +++ b/train_network.py @@ -774,7 +774,8 @@ def remove_model(old_ckpt_name): else: raise NotImplementedError("multipliers for each sample is not supported yet") # print(f"set multiplier: {multipliers}") - print(network) + print(type(network)) + network.set_multiplier(multipliers) with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): From 4e67fb8444b54deb0d867cc3d622f694a62b5c2b Mon Sep 17 00:00:00 2001 From: DukeG Date: Fri, 26 Jan 2024 20:22:49 +0800 Subject: [PATCH 032/103] test --- train_network.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/train_network.py b/train_network.py index d64d7a1bc..8d102ae8f 100644 --- a/train_network.py +++ b/train_network.py @@ -774,9 +774,7 @@ def remove_model(old_ckpt_name): else: raise NotImplementedError("multipliers for each sample is not supported yet") # print(f"set multiplier: {multipliers}") - print(type(network)) - - network.set_multiplier(multipliers) + accelerator.unwrap_model(network).set_multiplier(multipliers) with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): # Get the text embedding for conditioning From 736365bdd5033697e925c5ece377ef5fe7f6f902 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 27 Jan 2024 18:31:01 +0900 Subject: [PATCH 033/103] update README.md --- README.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/README.md b/README.md index 0de787b8f..e2c606708 100644 --- a/README.md +++ b/README.md @@ -249,6 +249,18 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum ## Change History +### Jan 27, 2024 / 2024/1/27: v0.8.3 + +- Fixed a bug that the training crashes when `--fp8_base` is specified with `--save_state`. PR [#1079](https://github.com/kohya-ss/sd-scripts/pull/1079) Thanks to feffy380! + - `safetensors` is updated. Please see [Upgrade](#upgrade) and update the library. +- Fixed a bug that the training crashes when `network_multiplier` is specified with multi-GPU training. PR [#1084](https://github.com/kohya-ss/sd-scripts/pull/1084) Thanks to fireicewolf! +- Fixed a bug that the training crashes when training ControlNet-LLLite. + +- `--fp8_base` 指定時に `--save_state` での保存がエラーになる不具合が修正されました。 PR [#1079](https://github.com/kohya-ss/sd-scripts/pull/1079) feffy380 氏に感謝します。 + - `safetensors` がバージョンアップされていますので、[Upgrade](#upgrade) を参照し更新をお願いします。 +- 複数 GPU での学習時に `network_multiplier` を指定するとクラッシュする不具合が修正されました。 PR [#1084](https://github.com/kohya-ss/sd-scripts/pull/1084) fireicewolf 氏に感謝します。 +- ControlNet-LLLite の学習がエラーになる不具合を修正しました。 + ### Jan 23, 2024 / 2024/1/23: v0.8.2 - [Experimental] The `--fp8_base` option is added to the training scripts for LoRA etc. The base model (U-Net, and Text Encoder when training modules for Text Encoder) can be trained with fp8. PR [#1057](https://github.com/kohya-ss/sd-scripts/pull/1057) Thanks to KohakuBlueleaf! From ccc3a481e7cce678a226552e64c906628d0cad32 Mon Sep 17 00:00:00 2001 From: Disty0 Date: Sun, 28 Jan 2024 14:14:31 +0300 Subject: [PATCH 034/103] Update IPEX Libs --- library/ipex/__init__.py | 15 ++++++---- library/ipex/attention.py | 2 ++ library/ipex/diffusers.py | 2 ++ library/ipex/hijacks.py | 58 +++++++++++++++++++++++++++++++++------ 4 files changed, 63 insertions(+), 14 deletions(-) diff --git a/library/ipex/__init__.py b/library/ipex/__init__.py index 333504935..9f2e7c410 100644 --- a/library/ipex/__init__.py +++ b/library/ipex/__init__.py @@ -125,9 +125,13 @@ def ipex_init(): # pylint: disable=too-many-statements # AMP: torch.cuda.amp = torch.xpu.amp + torch.is_autocast_enabled = torch.xpu.is_autocast_xpu_enabled + torch.get_autocast_gpu_dtype = torch.xpu.get_autocast_xpu_dtype + if not hasattr(torch.cuda.amp, "common"): torch.cuda.amp.common = contextlib.nullcontext() torch.cuda.amp.common.amp_definitely_not_available = lambda: False + try: torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler except Exception: # pylint: disable=broad-exception-caught @@ -151,15 +155,16 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.has_half = True torch.cuda.is_bf16_supported = lambda *args, **kwargs: True torch.cuda.is_fp16_supported = lambda *args, **kwargs: True - torch.version.cuda = "11.7" - torch.cuda.get_device_capability = lambda *args, **kwargs: [11,7] - torch.cuda.get_device_properties.major = 11 - torch.cuda.get_device_properties.minor = 7 + torch.backends.cuda.is_built = lambda *args, **kwargs: True + torch.version.cuda = "12.1" + torch.cuda.get_device_capability = lambda *args, **kwargs: [12,1] + torch.cuda.get_device_properties.major = 12 + torch.cuda.get_device_properties.minor = 1 torch.cuda.ipc_collect = lambda *args, **kwargs: None torch.cuda.utilization = lambda *args, **kwargs: 0 ipex_hijacks() - if not torch.xpu.has_fp64_dtype(): + if not torch.xpu.has_fp64_dtype() or os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is not None: try: from .diffusers import ipex_diffusers ipex_diffusers() diff --git a/library/ipex/attention.py b/library/ipex/attention.py index e98807a84..8253c5b17 100644 --- a/library/ipex/attention.py +++ b/library/ipex/attention.py @@ -124,6 +124,7 @@ def torch_bmm_32_bit(input, mat2, *, out=None): ) else: return original_torch_bmm(input, mat2, out=out) + torch.xpu.synchronize(input.device) return hidden_states original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention @@ -172,4 +173,5 @@ def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropo ) else: return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal) + torch.xpu.synchronize(query.device) return hidden_states diff --git a/library/ipex/diffusers.py b/library/ipex/diffusers.py index 47b0375ae..732a18568 100644 --- a/library/ipex/diffusers.py +++ b/library/ipex/diffusers.py @@ -149,6 +149,7 @@ def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice del attn_slice + torch.xpu.synchronize(query.device) else: query_slice = query[start_idx:end_idx] key_slice = key[start_idx:end_idx] @@ -283,6 +284,7 @@ def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, hidden_states[start_idx:end_idx] = attn_slice del attn_slice + torch.xpu.synchronize(query.device) else: attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) diff --git a/library/ipex/hijacks.py b/library/ipex/hijacks.py index b6d246dd2..51de31ad0 100644 --- a/library/ipex/hijacks.py +++ b/library/ipex/hijacks.py @@ -1,6 +1,9 @@ -import contextlib +import os +from functools import wraps +from contextlib import nullcontext import torch import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import +import numpy as np # pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return @@ -11,7 +14,7 @@ def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: return module.to("xpu") def return_null_context(*args, **kwargs): # pylint: disable=unused-argument - return contextlib.nullcontext() + return nullcontext() @property def is_cuda(self): @@ -25,15 +28,17 @@ def return_xpu(device): # Autocast -original_autocast = torch.autocast -def ipex_autocast(*args, **kwargs): - if len(args) > 0 and args[0] == "cuda": - return original_autocast("xpu", *args[1:], **kwargs) +original_autocast_init = torch.amp.autocast_mode.autocast.__init__ +@wraps(torch.amp.autocast_mode.autocast.__init__) +def autocast_init(self, device_type, dtype=None, enabled=True, cache_enabled=None): + if device_type == "cuda": + return original_autocast_init(self, device_type="xpu", dtype=dtype, enabled=enabled, cache_enabled=cache_enabled) else: - return original_autocast(*args, **kwargs) + return original_autocast_init(self, device_type=device_type, dtype=dtype, enabled=enabled, cache_enabled=cache_enabled) # Latent Antialias CPU Offload: original_interpolate = torch.nn.functional.interpolate +@wraps(torch.nn.functional.interpolate) def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments if antialias or align_corners is not None: return_device = tensor.device @@ -46,13 +51,25 @@ def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corn # Diffusers Float64 (Alchemist GPUs doesn't support 64 bit): original_from_numpy = torch.from_numpy +@wraps(torch.from_numpy) def from_numpy(ndarray): if ndarray.dtype == float: return original_from_numpy(ndarray.astype('float32')) else: return original_from_numpy(ndarray) -if torch.xpu.has_fp64_dtype(): +original_as_tensor = torch.as_tensor +@wraps(torch.as_tensor) +def as_tensor(data, dtype=None, device=None): + if check_device(device): + device = return_xpu(device) + if isinstance(data, np.ndarray) and data.dtype == float and not ( + (isinstance(device, torch.device) and device.type == "cpu") or (isinstance(device, str) and "cpu" in device)): + return original_as_tensor(data, dtype=torch.float32, device=device) + else: + return original_as_tensor(data, dtype=dtype, device=device) + +if torch.xpu.has_fp64_dtype() and os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is None: original_torch_bmm = torch.bmm original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention else: @@ -66,20 +83,25 @@ def from_numpy(ndarray): # Data Type Errors: +@wraps(torch.bmm) def torch_bmm(input, mat2, *, out=None): if input.dtype != mat2.dtype: mat2 = mat2.to(input.dtype) return original_torch_bmm(input, mat2, out=out) +@wraps(torch.nn.functional.scaled_dot_product_attention) def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False): if query.dtype != key.dtype: key = key.to(dtype=query.dtype) if query.dtype != value.dtype: value = value.to(dtype=query.dtype) + if attn_mask is not None and query.dtype != attn_mask.dtype: + attn_mask = attn_mask.to(dtype=query.dtype) return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal) # A1111 FP16 original_functional_group_norm = torch.nn.functional.group_norm +@wraps(torch.nn.functional.group_norm) def functional_group_norm(input, num_groups, weight=None, bias=None, eps=1e-05): if weight is not None and input.dtype != weight.data.dtype: input = input.to(dtype=weight.data.dtype) @@ -89,6 +111,7 @@ def functional_group_norm(input, num_groups, weight=None, bias=None, eps=1e-05): # A1111 BF16 original_functional_layer_norm = torch.nn.functional.layer_norm +@wraps(torch.nn.functional.layer_norm) def functional_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05): if weight is not None and input.dtype != weight.data.dtype: input = input.to(dtype=weight.data.dtype) @@ -98,6 +121,7 @@ def functional_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1 # Training original_functional_linear = torch.nn.functional.linear +@wraps(torch.nn.functional.linear) def functional_linear(input, weight, bias=None): if input.dtype != weight.data.dtype: input = input.to(dtype=weight.data.dtype) @@ -106,6 +130,7 @@ def functional_linear(input, weight, bias=None): return original_functional_linear(input, weight, bias=bias) original_functional_conv2d = torch.nn.functional.conv2d +@wraps(torch.nn.functional.conv2d) def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): if input.dtype != weight.data.dtype: input = input.to(dtype=weight.data.dtype) @@ -115,6 +140,7 @@ def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, # A1111 Embedding BF16 original_torch_cat = torch.cat +@wraps(torch.cat) def torch_cat(tensor, *args, **kwargs): if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype): return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs) @@ -123,6 +149,7 @@ def torch_cat(tensor, *args, **kwargs): # SwinIR BF16: original_functional_pad = torch.nn.functional.pad +@wraps(torch.nn.functional.pad) def functional_pad(input, pad, mode='constant', value=None): if mode == 'reflect' and input.dtype == torch.bfloat16: return original_functional_pad(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16) @@ -131,6 +158,7 @@ def functional_pad(input, pad, mode='constant', value=None): original_torch_tensor = torch.tensor +@wraps(torch.tensor) def torch_tensor(*args, device=None, **kwargs): if check_device(device): return original_torch_tensor(*args, device=return_xpu(device), **kwargs) @@ -138,6 +166,7 @@ def torch_tensor(*args, device=None, **kwargs): return original_torch_tensor(*args, device=device, **kwargs) original_Tensor_to = torch.Tensor.to +@wraps(torch.Tensor.to) def Tensor_to(self, device=None, *args, **kwargs): if check_device(device): return original_Tensor_to(self, return_xpu(device), *args, **kwargs) @@ -145,6 +174,7 @@ def Tensor_to(self, device=None, *args, **kwargs): return original_Tensor_to(self, device, *args, **kwargs) original_Tensor_cuda = torch.Tensor.cuda +@wraps(torch.Tensor.cuda) def Tensor_cuda(self, device=None, *args, **kwargs): if check_device(device): return original_Tensor_cuda(self, return_xpu(device), *args, **kwargs) @@ -152,6 +182,7 @@ def Tensor_cuda(self, device=None, *args, **kwargs): return original_Tensor_cuda(self, device, *args, **kwargs) original_UntypedStorage_init = torch.UntypedStorage.__init__ +@wraps(torch.UntypedStorage.__init__) def UntypedStorage_init(*args, device=None, **kwargs): if check_device(device): return original_UntypedStorage_init(*args, device=return_xpu(device), **kwargs) @@ -159,6 +190,7 @@ def UntypedStorage_init(*args, device=None, **kwargs): return original_UntypedStorage_init(*args, device=device, **kwargs) original_UntypedStorage_cuda = torch.UntypedStorage.cuda +@wraps(torch.UntypedStorage.cuda) def UntypedStorage_cuda(self, device=None, *args, **kwargs): if check_device(device): return original_UntypedStorage_cuda(self, return_xpu(device), *args, **kwargs) @@ -166,6 +198,7 @@ def UntypedStorage_cuda(self, device=None, *args, **kwargs): return original_UntypedStorage_cuda(self, device, *args, **kwargs) original_torch_empty = torch.empty +@wraps(torch.empty) def torch_empty(*args, device=None, **kwargs): if check_device(device): return original_torch_empty(*args, device=return_xpu(device), **kwargs) @@ -173,6 +206,7 @@ def torch_empty(*args, device=None, **kwargs): return original_torch_empty(*args, device=device, **kwargs) original_torch_randn = torch.randn +@wraps(torch.randn) def torch_randn(*args, device=None, **kwargs): if check_device(device): return original_torch_randn(*args, device=return_xpu(device), **kwargs) @@ -180,6 +214,7 @@ def torch_randn(*args, device=None, **kwargs): return original_torch_randn(*args, device=device, **kwargs) original_torch_ones = torch.ones +@wraps(torch.ones) def torch_ones(*args, device=None, **kwargs): if check_device(device): return original_torch_ones(*args, device=return_xpu(device), **kwargs) @@ -187,6 +222,7 @@ def torch_ones(*args, device=None, **kwargs): return original_torch_ones(*args, device=device, **kwargs) original_torch_zeros = torch.zeros +@wraps(torch.zeros) def torch_zeros(*args, device=None, **kwargs): if check_device(device): return original_torch_zeros(*args, device=return_xpu(device), **kwargs) @@ -194,6 +230,7 @@ def torch_zeros(*args, device=None, **kwargs): return original_torch_zeros(*args, device=device, **kwargs) original_torch_linspace = torch.linspace +@wraps(torch.linspace) def torch_linspace(*args, device=None, **kwargs): if check_device(device): return original_torch_linspace(*args, device=return_xpu(device), **kwargs) @@ -201,6 +238,7 @@ def torch_linspace(*args, device=None, **kwargs): return original_torch_linspace(*args, device=device, **kwargs) original_torch_Generator = torch.Generator +@wraps(torch.Generator) def torch_Generator(device=None): if check_device(device): return original_torch_Generator(return_xpu(device)) @@ -208,6 +246,7 @@ def torch_Generator(device=None): return original_torch_Generator(device) original_torch_load = torch.load +@wraps(torch.load) def torch_load(f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs): if check_device(map_location): return original_torch_load(f, map_location=return_xpu(map_location), pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs) @@ -232,7 +271,7 @@ def ipex_hijacks(): torch.backends.cuda.sdp_kernel = return_null_context torch.nn.DataParallel = DummyDataParallel torch.UntypedStorage.is_cuda = is_cuda - torch.autocast = ipex_autocast + torch.amp.autocast_mode.autocast.__init__ = autocast_init torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention torch.nn.functional.group_norm = functional_group_norm @@ -246,3 +285,4 @@ def ipex_hijacks(): torch.cat = torch_cat if not torch.xpu.has_fp64_dtype(): torch.from_numpy = from_numpy + torch.as_tensor = as_tensor From d4b9568269d09f5e4f81d5f243ca0a9bd3e6dcd3 Mon Sep 17 00:00:00 2001 From: mgz <49577754+mgz-dev@users.noreply.github.com> Date: Sun, 28 Jan 2024 11:59:07 -0600 Subject: [PATCH 035/103] fix broken import in svd_merge_lora script remove missing import, and remove unused imports --- networks/svd_merge_lora.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/networks/svd_merge_lora.py b/networks/svd_merge_lora.py index 16e813b36..ee2dae300 100644 --- a/networks/svd_merge_lora.py +++ b/networks/svd_merge_lora.py @@ -1,4 +1,3 @@ -import math import argparse import os import time @@ -6,8 +5,6 @@ from safetensors.torch import load_file, save_file from tqdm import tqdm from library import sai_model_spec, train_util -import library.model_util as model_util -import lora CLAMP_QUANTILE = 0.99 From 988dee02b9feae8a5fba1dd12cd40aaa42905df8 Mon Sep 17 00:00:00 2001 From: Disty0 Date: Tue, 30 Jan 2024 01:52:32 +0300 Subject: [PATCH 036/103] IPEX torch.tensor FP64 workaround --- library/ipex/hijacks.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/library/ipex/hijacks.py b/library/ipex/hijacks.py index 51de31ad0..7b2d26d4a 100644 --- a/library/ipex/hijacks.py +++ b/library/ipex/hijacks.py @@ -5,6 +5,8 @@ import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import import numpy as np +device_supports_fp64 = torch.xpu.has_fp64_dtype() + # pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods @@ -49,6 +51,7 @@ def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corn return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode, align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias) + # Diffusers Float64 (Alchemist GPUs doesn't support 64 bit): original_from_numpy = torch.from_numpy @wraps(torch.from_numpy) @@ -69,7 +72,8 @@ def as_tensor(data, dtype=None, device=None): else: return original_as_tensor(data, dtype=dtype, device=device) -if torch.xpu.has_fp64_dtype() and os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is None: + +if device_supports_fp64 and os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is None: original_torch_bmm = torch.bmm original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention else: @@ -159,11 +163,16 @@ def functional_pad(input, pad, mode='constant', value=None): original_torch_tensor = torch.tensor @wraps(torch.tensor) -def torch_tensor(*args, device=None, **kwargs): +def torch_tensor(data, *args, dtype=None, device=None, **kwargs): if check_device(device): - return original_torch_tensor(*args, device=return_xpu(device), **kwargs) - else: - return original_torch_tensor(*args, device=device, **kwargs) + device = return_xpu(device) + if not device_supports_fp64: + if (isinstance(device, torch.device) and device.type == "xpu") or (isinstance(device, str) and "xpu" in device): + if dtype == torch.float64: + dtype = torch.float32 + elif dtype is None and (hasattr(data, "dtype") and (data.dtype == torch.float64 or data.dtype == float)): + dtype = torch.float32 + return original_torch_tensor(data, *args, dtype=dtype, device=device, **kwargs) original_Tensor_to = torch.Tensor.to @wraps(torch.Tensor.to) @@ -253,6 +262,7 @@ def torch_load(f, map_location=None, pickle_module=None, *, weights_only=False, else: return original_torch_load(f, map_location=map_location, pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs) + # Hijack Functions: def ipex_hijacks(): torch.tensor = torch_tensor @@ -283,6 +293,6 @@ def ipex_hijacks(): torch.bmm = torch_bmm torch.cat = torch_cat - if not torch.xpu.has_fp64_dtype(): + if not device_supports_fp64: torch.from_numpy = from_numpy torch.as_tensor = as_tensor From a6a2b5a867594635d8b1bcdbd30008fdf4d60858 Mon Sep 17 00:00:00 2001 From: Disty0 Date: Wed, 31 Jan 2024 17:32:37 +0300 Subject: [PATCH 037/103] Fix IPEX support and add XPU device to device_utils --- XTI_hijack.py | 4 +- fine_tune.py | 6 +- finetune/make_captions.py | 5 +- finetune/make_captions_by_git.py | 5 +- finetune/prepare_buckets_latents.py | 6 +- gen_img_diffusers.py | 6 +- library/device_utils.py | 29 +++ library/ipex/__init__.py | 304 ++++++++++++++------------- library/ipex_interop.py | 24 --- library/model_util.py | 6 +- library/sdxl_train_util.py | 5 +- library/train_util.py | 5 +- networks/lora_diffusers.py | 5 +- networks/lora_interrogator.py | 4 +- sdxl_gen_img.py | 6 +- sdxl_minimal_inference.py | 6 +- sdxl_train.py | 6 +- sdxl_train_control_net_lllite.py | 6 +- sdxl_train_control_net_lllite_old.py | 6 +- sdxl_train_network.py | 6 +- sdxl_train_textual_inversion.py | 5 +- tools/latent_upscaler.py | 5 +- train_controlnet.py | 6 +- train_db.py | 6 +- train_network.py | 9 +- train_textual_inversion.py | 6 +- train_textual_inversion_XTI.py | 6 +- 27 files changed, 248 insertions(+), 245 deletions(-) delete mode 100644 library/ipex_interop.py diff --git a/XTI_hijack.py b/XTI_hijack.py index 1dbc263ac..93bc1c0b1 100644 --- a/XTI_hijack.py +++ b/XTI_hijack.py @@ -1,7 +1,7 @@ import torch -from library.ipex_interop import init_ipex - +from library.device_utils import init_ipex init_ipex() + from typing import Union, List, Optional, Dict, Any, Tuple from diffusers.models.unet_2d_condition import UNet2DConditionOutput diff --git a/fine_tune.py b/fine_tune.py index 11e94e560..a6a5c1e25 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -8,11 +8,9 @@ import toml from tqdm import tqdm -import torch - -from library.device_utils import clean_memory -from library.ipex_interop import init_ipex +import torch +from library.device_utils import init_ipex, clean_memory init_ipex() from accelerate.utils import set_seed diff --git a/finetune/make_captions.py b/finetune/make_captions.py index 524f80b9d..e7590faa6 100644 --- a/finetune/make_captions.py +++ b/finetune/make_captions.py @@ -9,13 +9,16 @@ from PIL import Image from tqdm import tqdm import numpy as np + import torch +from library.device_utils import init_ipex, get_preferred_device +init_ipex() + from torchvision import transforms from torchvision.transforms.functional import InterpolationMode sys.path.append(os.path.dirname(__file__)) from blip.blip import blip_decoder, is_url import library.train_util as train_util -from library.device_utils import get_preferred_device DEVICE = get_preferred_device() diff --git a/finetune/make_captions_by_git.py b/finetune/make_captions_by_git.py index 2b650eb01..549c477a3 100644 --- a/finetune/make_captions_by_git.py +++ b/finetune/make_captions_by_git.py @@ -5,12 +5,15 @@ from pathlib import Path from PIL import Image from tqdm import tqdm + import torch +from library.device_utils import init_ipex, get_preferred_device +init_ipex() + from transformers import AutoProcessor, AutoModelForCausalLM from transformers.generation.utils import GenerationMixin import library.train_util as train_util -from library.device_utils import get_preferred_device DEVICE = get_preferred_device() diff --git a/finetune/prepare_buckets_latents.py b/finetune/prepare_buckets_latents.py index 9d352dd6e..29710ca92 100644 --- a/finetune/prepare_buckets_latents.py +++ b/finetune/prepare_buckets_latents.py @@ -8,14 +8,16 @@ import numpy as np from PIL import Image import cv2 + import torch +from library.device_utils import init_ipex, get_preferred_device +init_ipex() + from torchvision import transforms import library.model_util as model_util import library.train_util as train_util -from library.device_utils import get_preferred_device - DEVICE = get_preferred_device() IMAGE_TRANSFORMS = transforms.Compose( diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 38b1ceabf..ae79e2f9b 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -64,11 +64,9 @@ import diffusers import numpy as np -import torch - -from library.device_utils import clean_memory, get_preferred_device -from library.ipex_interop import init_ipex +import torch +from library.device_utils import init_ipex, clean_memory, get_preferred_device init_ipex() import torchvision diff --git a/library/device_utils.py b/library/device_utils.py index 353bfa9f3..93371ca64 100644 --- a/library/device_utils.py +++ b/library/device_utils.py @@ -13,11 +13,19 @@ except Exception: HAS_MPS = False +try: + import intel_extension_for_pytorch as ipex # noqa + HAS_XPU = torch.xpu.is_available() +except Exception: + HAS_XPU = False + def clean_memory(): gc.collect() if HAS_CUDA: torch.cuda.empty_cache() + if HAS_XPU: + torch.xpu.empty_cache() if HAS_MPS: torch.mps.empty_cache() @@ -26,9 +34,30 @@ def clean_memory(): def get_preferred_device() -> torch.device: if HAS_CUDA: device = torch.device("cuda") + elif HAS_XPU: + device = torch.device("xpu") elif HAS_MPS: device = torch.device("mps") else: device = torch.device("cpu") print(f"get_preferred_device() -> {device}") return device + +def init_ipex(): + """ + Apply IPEX to CUDA hijacks using `library.ipex.ipex_init`. + + This function should run right after importing torch and before doing anything else. + + If IPEX is not available, this function does nothing. + """ + try: + if HAS_XPU: + from library.ipex import ipex_init + is_initialized, error_message = ipex_init() + if not is_initialized: + print("failed to initialize ipex:", error_message) + else: + return + except Exception as e: + print("failed to initialize ipex:", e) \ No newline at end of file diff --git a/library/ipex/__init__.py b/library/ipex/__init__.py index 9f2e7c410..972a3bf63 100644 --- a/library/ipex/__init__.py +++ b/library/ipex/__init__.py @@ -9,167 +9,171 @@ def ipex_init(): # pylint: disable=too-many-statements try: - # Replace cuda with xpu: - torch.cuda.current_device = torch.xpu.current_device - torch.cuda.current_stream = torch.xpu.current_stream - torch.cuda.device = torch.xpu.device - torch.cuda.device_count = torch.xpu.device_count - torch.cuda.device_of = torch.xpu.device_of - torch.cuda.get_device_name = torch.xpu.get_device_name - torch.cuda.get_device_properties = torch.xpu.get_device_properties - torch.cuda.init = torch.xpu.init - torch.cuda.is_available = torch.xpu.is_available - torch.cuda.is_initialized = torch.xpu.is_initialized - torch.cuda.is_current_stream_capturing = lambda: False - torch.cuda.set_device = torch.xpu.set_device - torch.cuda.stream = torch.xpu.stream - torch.cuda.synchronize = torch.xpu.synchronize - torch.cuda.Event = torch.xpu.Event - torch.cuda.Stream = torch.xpu.Stream - torch.cuda.FloatTensor = torch.xpu.FloatTensor - torch.Tensor.cuda = torch.Tensor.xpu - torch.Tensor.is_cuda = torch.Tensor.is_xpu - torch.UntypedStorage.cuda = torch.UntypedStorage.xpu - torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock - torch.cuda._initialized = torch.xpu.lazy_init._initialized - torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker - torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls - torch.cuda._tls = torch.xpu.lazy_init._tls - torch.cuda.threading = torch.xpu.lazy_init.threading - torch.cuda.traceback = torch.xpu.lazy_init.traceback - torch.cuda.Optional = torch.xpu.Optional - torch.cuda.__cached__ = torch.xpu.__cached__ - torch.cuda.__loader__ = torch.xpu.__loader__ - torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage - torch.cuda.Tuple = torch.xpu.Tuple - torch.cuda.streams = torch.xpu.streams - torch.cuda._lazy_new = torch.xpu._lazy_new - torch.cuda.FloatStorage = torch.xpu.FloatStorage - torch.cuda.Any = torch.xpu.Any - torch.cuda.__doc__ = torch.xpu.__doc__ - torch.cuda.default_generators = torch.xpu.default_generators - torch.cuda.HalfTensor = torch.xpu.HalfTensor - torch.cuda._get_device_index = torch.xpu._get_device_index - torch.cuda.__path__ = torch.xpu.__path__ - torch.cuda.Device = torch.xpu.Device - torch.cuda.IntTensor = torch.xpu.IntTensor - torch.cuda.ByteStorage = torch.xpu.ByteStorage - torch.cuda.set_stream = torch.xpu.set_stream - torch.cuda.BoolStorage = torch.xpu.BoolStorage - torch.cuda.os = torch.xpu.os - torch.cuda.torch = torch.xpu.torch - torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage - torch.cuda.Union = torch.xpu.Union - torch.cuda.DoubleTensor = torch.xpu.DoubleTensor - torch.cuda.ShortTensor = torch.xpu.ShortTensor - torch.cuda.LongTensor = torch.xpu.LongTensor - torch.cuda.IntStorage = torch.xpu.IntStorage - torch.cuda.LongStorage = torch.xpu.LongStorage - torch.cuda.__annotations__ = torch.xpu.__annotations__ - torch.cuda.__package__ = torch.xpu.__package__ - torch.cuda.__builtins__ = torch.xpu.__builtins__ - torch.cuda.CharTensor = torch.xpu.CharTensor - torch.cuda.List = torch.xpu.List - torch.cuda._lazy_init = torch.xpu._lazy_init - torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor - torch.cuda.DoubleStorage = torch.xpu.DoubleStorage - torch.cuda.ByteTensor = torch.xpu.ByteTensor - torch.cuda.StreamContext = torch.xpu.StreamContext - torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage - torch.cuda.ShortStorage = torch.xpu.ShortStorage - torch.cuda._lazy_call = torch.xpu._lazy_call - torch.cuda.HalfStorage = torch.xpu.HalfStorage - torch.cuda.random = torch.xpu.random - torch.cuda._device = torch.xpu._device - torch.cuda.classproperty = torch.xpu.classproperty - torch.cuda.__name__ = torch.xpu.__name__ - torch.cuda._device_t = torch.xpu._device_t - torch.cuda.warnings = torch.xpu.warnings - torch.cuda.__spec__ = torch.xpu.__spec__ - torch.cuda.BoolTensor = torch.xpu.BoolTensor - torch.cuda.CharStorage = torch.xpu.CharStorage - torch.cuda.__file__ = torch.xpu.__file__ - torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork - # torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing + if hasattr(torch, "cuda") and hasattr(torch.cuda, "is_xpu_hijacked") and torch.cuda.is_xpu_hijacked: + return True, "Skipping IPEX hijack" + else: + # Replace cuda with xpu: + torch.cuda.current_device = torch.xpu.current_device + torch.cuda.current_stream = torch.xpu.current_stream + torch.cuda.device = torch.xpu.device + torch.cuda.device_count = torch.xpu.device_count + torch.cuda.device_of = torch.xpu.device_of + torch.cuda.get_device_name = torch.xpu.get_device_name + torch.cuda.get_device_properties = torch.xpu.get_device_properties + torch.cuda.init = torch.xpu.init + torch.cuda.is_available = torch.xpu.is_available + torch.cuda.is_initialized = torch.xpu.is_initialized + torch.cuda.is_current_stream_capturing = lambda: False + torch.cuda.set_device = torch.xpu.set_device + torch.cuda.stream = torch.xpu.stream + torch.cuda.synchronize = torch.xpu.synchronize + torch.cuda.Event = torch.xpu.Event + torch.cuda.Stream = torch.xpu.Stream + torch.cuda.FloatTensor = torch.xpu.FloatTensor + torch.Tensor.cuda = torch.Tensor.xpu + torch.Tensor.is_cuda = torch.Tensor.is_xpu + torch.UntypedStorage.cuda = torch.UntypedStorage.xpu + torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock + torch.cuda._initialized = torch.xpu.lazy_init._initialized + torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker + torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls + torch.cuda._tls = torch.xpu.lazy_init._tls + torch.cuda.threading = torch.xpu.lazy_init.threading + torch.cuda.traceback = torch.xpu.lazy_init.traceback + torch.cuda.Optional = torch.xpu.Optional + torch.cuda.__cached__ = torch.xpu.__cached__ + torch.cuda.__loader__ = torch.xpu.__loader__ + torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage + torch.cuda.Tuple = torch.xpu.Tuple + torch.cuda.streams = torch.xpu.streams + torch.cuda._lazy_new = torch.xpu._lazy_new + torch.cuda.FloatStorage = torch.xpu.FloatStorage + torch.cuda.Any = torch.xpu.Any + torch.cuda.__doc__ = torch.xpu.__doc__ + torch.cuda.default_generators = torch.xpu.default_generators + torch.cuda.HalfTensor = torch.xpu.HalfTensor + torch.cuda._get_device_index = torch.xpu._get_device_index + torch.cuda.__path__ = torch.xpu.__path__ + torch.cuda.Device = torch.xpu.Device + torch.cuda.IntTensor = torch.xpu.IntTensor + torch.cuda.ByteStorage = torch.xpu.ByteStorage + torch.cuda.set_stream = torch.xpu.set_stream + torch.cuda.BoolStorage = torch.xpu.BoolStorage + torch.cuda.os = torch.xpu.os + torch.cuda.torch = torch.xpu.torch + torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage + torch.cuda.Union = torch.xpu.Union + torch.cuda.DoubleTensor = torch.xpu.DoubleTensor + torch.cuda.ShortTensor = torch.xpu.ShortTensor + torch.cuda.LongTensor = torch.xpu.LongTensor + torch.cuda.IntStorage = torch.xpu.IntStorage + torch.cuda.LongStorage = torch.xpu.LongStorage + torch.cuda.__annotations__ = torch.xpu.__annotations__ + torch.cuda.__package__ = torch.xpu.__package__ + torch.cuda.__builtins__ = torch.xpu.__builtins__ + torch.cuda.CharTensor = torch.xpu.CharTensor + torch.cuda.List = torch.xpu.List + torch.cuda._lazy_init = torch.xpu._lazy_init + torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor + torch.cuda.DoubleStorage = torch.xpu.DoubleStorage + torch.cuda.ByteTensor = torch.xpu.ByteTensor + torch.cuda.StreamContext = torch.xpu.StreamContext + torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage + torch.cuda.ShortStorage = torch.xpu.ShortStorage + torch.cuda._lazy_call = torch.xpu._lazy_call + torch.cuda.HalfStorage = torch.xpu.HalfStorage + torch.cuda.random = torch.xpu.random + torch.cuda._device = torch.xpu._device + torch.cuda.classproperty = torch.xpu.classproperty + torch.cuda.__name__ = torch.xpu.__name__ + torch.cuda._device_t = torch.xpu._device_t + torch.cuda.warnings = torch.xpu.warnings + torch.cuda.__spec__ = torch.xpu.__spec__ + torch.cuda.BoolTensor = torch.xpu.BoolTensor + torch.cuda.CharStorage = torch.xpu.CharStorage + torch.cuda.__file__ = torch.xpu.__file__ + torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork + # torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing - # Memory: - torch.cuda.memory = torch.xpu.memory - if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read(): - torch.xpu.empty_cache = lambda: None - torch.cuda.empty_cache = torch.xpu.empty_cache - torch.cuda.memory_stats = torch.xpu.memory_stats - torch.cuda.memory_summary = torch.xpu.memory_summary - torch.cuda.memory_snapshot = torch.xpu.memory_snapshot - torch.cuda.memory_allocated = torch.xpu.memory_allocated - torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated - torch.cuda.memory_reserved = torch.xpu.memory_reserved - torch.cuda.memory_cached = torch.xpu.memory_reserved - torch.cuda.max_memory_reserved = torch.xpu.max_memory_reserved - torch.cuda.max_memory_cached = torch.xpu.max_memory_reserved - torch.cuda.reset_peak_memory_stats = torch.xpu.reset_peak_memory_stats - torch.cuda.reset_max_memory_cached = torch.xpu.reset_peak_memory_stats - torch.cuda.reset_max_memory_allocated = torch.xpu.reset_peak_memory_stats - torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict - torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats + # Memory: + torch.cuda.memory = torch.xpu.memory + if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read(): + torch.xpu.empty_cache = lambda: None + torch.cuda.empty_cache = torch.xpu.empty_cache + torch.cuda.memory_stats = torch.xpu.memory_stats + torch.cuda.memory_summary = torch.xpu.memory_summary + torch.cuda.memory_snapshot = torch.xpu.memory_snapshot + torch.cuda.memory_allocated = torch.xpu.memory_allocated + torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated + torch.cuda.memory_reserved = torch.xpu.memory_reserved + torch.cuda.memory_cached = torch.xpu.memory_reserved + torch.cuda.max_memory_reserved = torch.xpu.max_memory_reserved + torch.cuda.max_memory_cached = torch.xpu.max_memory_reserved + torch.cuda.reset_peak_memory_stats = torch.xpu.reset_peak_memory_stats + torch.cuda.reset_max_memory_cached = torch.xpu.reset_peak_memory_stats + torch.cuda.reset_max_memory_allocated = torch.xpu.reset_peak_memory_stats + torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict + torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats - # RNG: - torch.cuda.get_rng_state = torch.xpu.get_rng_state - torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all - torch.cuda.set_rng_state = torch.xpu.set_rng_state - torch.cuda.set_rng_state_all = torch.xpu.set_rng_state_all - torch.cuda.manual_seed = torch.xpu.manual_seed - torch.cuda.manual_seed_all = torch.xpu.manual_seed_all - torch.cuda.seed = torch.xpu.seed - torch.cuda.seed_all = torch.xpu.seed_all - torch.cuda.initial_seed = torch.xpu.initial_seed + # RNG: + torch.cuda.get_rng_state = torch.xpu.get_rng_state + torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all + torch.cuda.set_rng_state = torch.xpu.set_rng_state + torch.cuda.set_rng_state_all = torch.xpu.set_rng_state_all + torch.cuda.manual_seed = torch.xpu.manual_seed + torch.cuda.manual_seed_all = torch.xpu.manual_seed_all + torch.cuda.seed = torch.xpu.seed + torch.cuda.seed_all = torch.xpu.seed_all + torch.cuda.initial_seed = torch.xpu.initial_seed - # AMP: - torch.cuda.amp = torch.xpu.amp - torch.is_autocast_enabled = torch.xpu.is_autocast_xpu_enabled - torch.get_autocast_gpu_dtype = torch.xpu.get_autocast_xpu_dtype + # AMP: + torch.cuda.amp = torch.xpu.amp + torch.is_autocast_enabled = torch.xpu.is_autocast_xpu_enabled + torch.get_autocast_gpu_dtype = torch.xpu.get_autocast_xpu_dtype - if not hasattr(torch.cuda.amp, "common"): - torch.cuda.amp.common = contextlib.nullcontext() - torch.cuda.amp.common.amp_definitely_not_available = lambda: False + if not hasattr(torch.cuda.amp, "common"): + torch.cuda.amp.common = contextlib.nullcontext() + torch.cuda.amp.common.amp_definitely_not_available = lambda: False - try: - torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler - except Exception: # pylint: disable=broad-exception-caught try: - from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error - gradscaler_init() torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler except Exception: # pylint: disable=broad-exception-caught - torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler + try: + from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error + gradscaler_init() + torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler + except Exception: # pylint: disable=broad-exception-caught + torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler - # C - torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream - ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_eu_count - ipex._C._DeviceProperties.major = 2023 - ipex._C._DeviceProperties.minor = 2 + # C + torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream + ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_eu_count + ipex._C._DeviceProperties.major = 2023 + ipex._C._DeviceProperties.minor = 2 - # Fix functions with ipex: - torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory] - torch._utils._get_available_device_type = lambda: "xpu" - torch.has_cuda = True - torch.cuda.has_half = True - torch.cuda.is_bf16_supported = lambda *args, **kwargs: True - torch.cuda.is_fp16_supported = lambda *args, **kwargs: True - torch.backends.cuda.is_built = lambda *args, **kwargs: True - torch.version.cuda = "12.1" - torch.cuda.get_device_capability = lambda *args, **kwargs: [12,1] - torch.cuda.get_device_properties.major = 12 - torch.cuda.get_device_properties.minor = 1 - torch.cuda.ipc_collect = lambda *args, **kwargs: None - torch.cuda.utilization = lambda *args, **kwargs: 0 + # Fix functions with ipex: + torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory] + torch._utils._get_available_device_type = lambda: "xpu" + torch.has_cuda = True + torch.cuda.has_half = True + torch.cuda.is_bf16_supported = lambda *args, **kwargs: True + torch.cuda.is_fp16_supported = lambda *args, **kwargs: True + torch.backends.cuda.is_built = lambda *args, **kwargs: True + torch.version.cuda = "12.1" + torch.cuda.get_device_capability = lambda *args, **kwargs: [12,1] + torch.cuda.get_device_properties.major = 12 + torch.cuda.get_device_properties.minor = 1 + torch.cuda.ipc_collect = lambda *args, **kwargs: None + torch.cuda.utilization = lambda *args, **kwargs: 0 - ipex_hijacks() - if not torch.xpu.has_fp64_dtype() or os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is not None: - try: - from .diffusers import ipex_diffusers - ipex_diffusers() - except Exception: # pylint: disable=broad-exception-caught - pass + ipex_hijacks() + if not torch.xpu.has_fp64_dtype() or os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is not None: + try: + from .diffusers import ipex_diffusers + ipex_diffusers() + except Exception: # pylint: disable=broad-exception-caught + pass + torch.cuda.is_xpu_hijacked = True except Exception as e: return False, e return True, None diff --git a/library/ipex_interop.py b/library/ipex_interop.py deleted file mode 100644 index 6fe320c57..000000000 --- a/library/ipex_interop.py +++ /dev/null @@ -1,24 +0,0 @@ -import torch - - -def init_ipex(): - """ - Try to import `intel_extension_for_pytorch`, and apply - the hijacks using `library.ipex.ipex_init`. - - If IPEX is not installed, this function does nothing. - """ - try: - import intel_extension_for_pytorch as ipex # noqa - except ImportError: - return - - try: - from library.ipex import ipex_init - - if torch.xpu.is_available(): - is_initialized, error_message = ipex_init() - if not is_initialized: - print("failed to initialize ipex:", error_message) - except Exception as e: - print("failed to initialize ipex:", e) diff --git a/library/model_util.py b/library/model_util.py index 4361b4994..6398e8a06 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -3,11 +3,11 @@ import math import os -import torch - -from library.ipex_interop import init_ipex +import torch +from library.device_utils import init_ipex init_ipex() + import diffusers from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index d2becad6f..0ecf4feb6 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -2,12 +2,15 @@ import math import os from typing import Optional + import torch +from library.device_utils import init_ipex, clean_memory +init_ipex() + from accelerate import init_empty_weights from tqdm import tqdm from transformers import CLIPTokenizer from library import model_util, sdxl_model_util, train_util, sdxl_original_unet -from library.device_utils import clean_memory from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline TOKENIZER1_PATH = "openai/clip-vit-large-patch14" diff --git a/library/train_util.py b/library/train_util.py index d59f42584..3b74ddcef 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -30,7 +30,11 @@ import toml from tqdm import tqdm + import torch +from library.device_utils import init_ipex, clean_memory +init_ipex() + from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer from torchvision import transforms @@ -66,7 +70,6 @@ # from library.attention_processors import FlashAttnProcessor # from library.hypernetwork import replace_attentions_for_hypernetwork -from library.device_utils import clean_memory from library.original_unet import UNet2DConditionModel # Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う diff --git a/networks/lora_diffusers.py b/networks/lora_diffusers.py index 0056ac78e..cbba44f73 100644 --- a/networks/lora_diffusers.py +++ b/networks/lora_diffusers.py @@ -9,9 +9,10 @@ import numpy as np from tqdm import tqdm from transformers import CLIPTextModel -import torch -from library.device_utils import get_preferred_device +import torch +from library.device_utils import init_ipex, get_preferred_device +init_ipex() def make_unet_conversion_map() -> Dict[str, str]: diff --git a/networks/lora_interrogator.py b/networks/lora_interrogator.py index 83942d7ce..6cda391bc 100644 --- a/networks/lora_interrogator.py +++ b/networks/lora_interrogator.py @@ -5,11 +5,13 @@ import library.train_util as train_util import argparse from transformers import CLIPTokenizer + import torch +from library.device_utils import init_ipex, get_preferred_device +init_ipex() import library.model_util as model_util import lora -from library.device_utils import get_preferred_device TOKENIZER_PATH = "openai/clip-vit-large-patch14" V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index 0722b93f4..8824a3a12 100755 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -16,11 +16,9 @@ import diffusers import numpy as np -import torch - -from library.device_utils import clean_memory, get_preferred_device -from library.ipex_interop import init_ipex +import torch +from library.device_utils import init_ipex, clean_memory, get_preferred_device init_ipex() import torchvision diff --git a/sdxl_minimal_inference.py b/sdxl_minimal_inference.py index 3eae20441..4f509b6d0 100644 --- a/sdxl_minimal_inference.py +++ b/sdxl_minimal_inference.py @@ -8,11 +8,9 @@ import random from einops import repeat import numpy as np -import torch - -from library.device_utils import get_preferred_device -from library.ipex_interop import init_ipex +import torch +from library.device_utils import init_ipex, get_preferred_device init_ipex() from tqdm import tqdm diff --git a/sdxl_train.py b/sdxl_train.py index 78cfaf495..b0dcdbe9d 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -8,11 +8,9 @@ import toml from tqdm import tqdm -import torch - -from library.device_utils import clean_memory -from library.ipex_interop import init_ipex +import torch +from library.device_utils import init_ipex, clean_memory init_ipex() from accelerate.utils import set_seed diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 95b755f18..1b0696247 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -12,11 +12,9 @@ import toml from tqdm import tqdm -import torch - -from library.device_utils import clean_memory -from library.ipex_interop import init_ipex +import torch +from library.device_utils import init_ipex, clean_memory init_ipex() from torch.nn.parallel import DistributedDataParallel as DDP diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index fd24898c4..b74e3b90b 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -9,11 +9,9 @@ import toml from tqdm import tqdm -import torch - -from library.device_utils import clean_memory -from library.ipex_interop import init_ipex +import torch +from library.device_utils import init_ipex, clean_memory init_ipex() from torch.nn.parallel import DistributedDataParallel as DDP diff --git a/sdxl_train_network.py b/sdxl_train_network.py index af0c8d1d7..205b526e8 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -1,9 +1,7 @@ import argparse -import torch - -from library.device_utils import clean_memory -from library.ipex_interop import init_ipex +import torch +from library.device_utils import init_ipex, clean_memory init_ipex() from library import sdxl_model_util, sdxl_train_util, train_util diff --git a/sdxl_train_textual_inversion.py b/sdxl_train_textual_inversion.py index df3937135..b9a948bb2 100644 --- a/sdxl_train_textual_inversion.py +++ b/sdxl_train_textual_inversion.py @@ -2,10 +2,11 @@ import os import regex -import torch -from library.ipex_interop import init_ipex +import torch +from library.device_utils import init_ipex init_ipex() + import open_clip from library import sdxl_model_util, sdxl_train_util, train_util diff --git a/tools/latent_upscaler.py b/tools/latent_upscaler.py index 27d13ef61..81f51f023 100644 --- a/tools/latent_upscaler.py +++ b/tools/latent_upscaler.py @@ -11,12 +11,13 @@ import numpy as np import torch +from library.device_utils import init_ipex, get_preferred_device +init_ipex() + from torch import nn from tqdm import tqdm from PIL import Image -from library.device_utils import get_preferred_device - class ResidualBlock(nn.Module): def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=1): diff --git a/train_controlnet.py b/train_controlnet.py index e6bea2c9f..e7a06ae14 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -9,11 +9,9 @@ import toml from tqdm import tqdm -import torch - -from library.device_utils import clean_memory -from library.ipex_interop import init_ipex +import torch +from library.device_utils import init_ipex, clean_memory init_ipex() from torch.nn.parallel import DistributedDataParallel as DDP diff --git a/train_db.py b/train_db.py index daeb6d668..f6795dcef 100644 --- a/train_db.py +++ b/train_db.py @@ -9,11 +9,9 @@ import toml from tqdm import tqdm -import torch - -from library.device_utils import clean_memory -from library.ipex_interop import init_ipex +import torch +from library.device_utils import init_ipex, clean_memory init_ipex() from accelerate.utils import set_seed diff --git a/train_network.py b/train_network.py index 9aabd4d7c..f8dd6ab23 100644 --- a/train_network.py +++ b/train_network.py @@ -10,14 +10,13 @@ import toml from tqdm import tqdm -import torch -from torch.nn.parallel import DistributedDataParallel as DDP - -from library.device_utils import clean_memory -from library.ipex_interop import init_ipex +import torch +from library.device_utils import init_ipex, clean_memory init_ipex() +from torch.nn.parallel import DistributedDataParallel as DDP + from accelerate.utils import set_seed from diffusers import DDPMScheduler from library import model_util diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 821cfe786..d18837bd4 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -5,11 +5,9 @@ import toml from tqdm import tqdm -import torch - -from library.device_utils import clean_memory -from library.ipex_interop import init_ipex +import torch +from library.device_utils import init_ipex, clean_memory init_ipex() from accelerate.utils import set_seed diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index ecd6d087f..9d4b0aef1 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -6,11 +6,9 @@ from multiprocessing import Value from tqdm import tqdm -import torch - -from library.device_utils import clean_memory -from library.ipex_interop import init_ipex +import torch +from library.device_utils import init_ipex, clean_memory init_ipex() from accelerate.utils import set_seed From 716a92cbed2623a24cc74026f5d2cbe2526668a8 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 1 Feb 2024 01:57:52 +0000 Subject: [PATCH 038/103] Bump crate-ci/typos from 1.16.26 to 1.17.2 Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.16.26 to 1.17.2. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.16.26...v1.17.2) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] --- .github/workflows/typos.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/typos.yml b/.github/workflows/typos.yml index bd4ef334e..c9edf2650 100644 --- a/.github/workflows/typos.yml +++ b/.github/workflows/typos.yml @@ -18,4 +18,4 @@ jobs: - uses: actions/checkout@v4 - name: typos-action - uses: crate-ci/typos@v1.16.26 + uses: crate-ci/typos@v1.17.2 From 5cca1fdc40a5457c40dd40be4a4812f90032ab2d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 1 Feb 2024 21:55:55 +0900 Subject: [PATCH 039/103] add highvram option and do not clear cache in caching latents --- library/train_util.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 3b74ddcef..21e0d1919 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -33,6 +33,7 @@ import torch from library.device_utils import init_ipex, clean_memory + init_ipex() from torch.nn.parallel import DistributedDataParallel as DDP @@ -76,6 +77,8 @@ TOKENIZER_PATH = "openai/clip-vit-large-patch14" V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ +HIGH_VRAM = False + # checkpointファイル名 EPOCH_STATE_NAME = "{}-{:06d}-state" EPOCH_FILE_NAME = "{}-{:06d}" @@ -2281,8 +2284,8 @@ def cache_batch_latents( if flip_aug: info.latents_flipped = flipped_latent - # FIXME this slows down caching a lot, specify this as an option - clean_memory() + if not HIGH_VRAM: + clean_memory() def cache_batch_text_encoder_outputs( @@ -3037,7 +3040,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument( "--lowram", action="store_true", - help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込むなど(ColabやKaggleなどRAMに比べてVRAMが多い環境向け)", + help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込む等(ColabやKaggleなどRAMに比べてVRAMが多い環境向け)", + ) + parser.add_argument( + "--highvram", + action="store_true", + help="disable low VRAM optimization. e.g. do not clear CUDA cache after each latent caching (for machines which have bigger VRAM) / VRAMが少ない環境向け最適化を無効にする。たとえば各latentのキャッシュ後のCUDAキャッシュを行わない等(VRAMが多い環境向け)", ) parser.add_argument( @@ -3128,6 +3136,15 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: def verify_training_args(args: argparse.Namespace): + r""" + Verify training arguments. Also reflect highvram option to global variable + 学習用引数を検証する。あわせて highvram オプションの指定をグローバル変数に反映する + """ + if args.highvram: + print("highvram is enabled / highvramが有効です") + global HIGH_VRAM + HIGH_VRAM = True + if args.v_parameterization and not args.v2: print("v_parameterization should be with v2 not v1 or sdxl / v1やsdxlでv_parameterizationを使用することは想定されていません") if args.v2 and args.clip_skip is not None: From 1567ce1e1777c169bd59237f1f722b2e9722bbf3 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 3 Feb 2024 20:46:31 +0800 Subject: [PATCH 040/103] Enable distributed sample image generation on multi-GPU enviroment (#1061) * Update train_util.py Modifying to attempt enable multi GPU inference * Update train_util.py additional VRAM checking, refactor check_vram_usage to return string for use with accelerator.print * Update train_network.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py remove sample image debug outputs * Update train_util.py * Update train_util.py * Update train_network.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_network.py * Update train_util.py * Update train_network.py * Update train_network.py * Update train_network.py * Cleanup of debugging outputs * adopt more elegant coding Co-authored-by: Aarni Koskela * Update train_util.py Fix leftover debugging code attempt to refactor inference into separate function * refactor in function generate_per_device_prompt_list() generation of distributed prompt list * Clean up missing variables * fix syntax error * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * true random sample image generation update code to reinitialize random seed to true random if seed was set * true random sample image generation * simplify per process prompt * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_network.py * Update train_network.py * Update train_network.py --------- Co-authored-by: Aarni Koskela --- library/train_util.py | 208 +++++++++++++++++++++++------------------- 1 file changed, 115 insertions(+), 93 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index ba428e508..3e6125f08 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -19,7 +19,7 @@ Tuple, Union, ) -from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs +from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState import gc import glob import math @@ -4636,7 +4636,6 @@ def line_to_prompt_dict(line: str) -> dict: return prompt_dict - def sample_images_common( pipe_class, accelerator: Accelerator, @@ -4654,6 +4653,7 @@ def sample_images_common( """ StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した """ + if steps == 0: if not args.sample_at_first: return @@ -4668,13 +4668,15 @@ def sample_images_common( if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch return + distributed_state = PartialState() #testing implementation of multi gpu distributed inference + print(f"\ngenerating sample images at step / サンプル画像生成 ステップ: {steps}") if not os.path.isfile(args.sample_prompts): print(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") return org_vae_device = vae.device # CPUにいるはず - vae.to(device) + vae.to(distributed_state.device) # unwrap unet and text_encoder(s) unet = accelerator.unwrap_model(unet) @@ -4700,12 +4702,11 @@ def sample_images_common( with open(args.sample_prompts, "r", encoding="utf-8") as f: prompts = json.load(f) - schedulers: dict = {} + # schedulers: dict = {} cannot find where this is used default_scheduler = get_my_scheduler( sample_sampler=args.sample_sampler, v_parameterization=args.v_parameterization, ) - schedulers[args.sample_sampler] = default_scheduler pipeline = pipe_class( text_encoder=text_encoder, @@ -4718,114 +4719,135 @@ def sample_images_common( requires_safety_checker=False, clip_skip=args.clip_skip, ) - pipeline.to(device) - + pipeline.to(distributed_state.device) save_dir = args.output_dir + "/sample" os.makedirs(save_dir, exist_ok=True) + + # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processess available (number of devices available) + # prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical. + per_process_prompts = generate_per_device_prompt_list(prompts, num_of_processes = distributed_state.num_processes, prompt_replacement = prompt_replacement) rng_state = torch.get_rng_state() cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None + # True random sample image generation + torch.seed() + torch.cuda.seed() with torch.no_grad(): - # with accelerator.autocast(): - for i, prompt_dict in enumerate(prompts): - if not accelerator.is_main_process: - continue - - if isinstance(prompt_dict, str): - prompt_dict = line_to_prompt_dict(prompt_dict) - - assert isinstance(prompt_dict, dict) - negative_prompt = prompt_dict.get("negative_prompt") - sample_steps = prompt_dict.get("sample_steps", 30) - width = prompt_dict.get("width", 512) - height = prompt_dict.get("height", 512) - scale = prompt_dict.get("scale", 7.5) - seed = prompt_dict.get("seed") - controlnet_image = prompt_dict.get("controlnet_image") - prompt: str = prompt_dict.get("prompt", "") - sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler) - - if seed is not None: - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - - scheduler = schedulers.get(sampler_name) - if scheduler is None: - scheduler = get_my_scheduler( - sample_sampler=sampler_name, - v_parameterization=args.v_parameterization, - ) - schedulers[sampler_name] = scheduler - pipeline.scheduler = scheduler - - if prompt_replacement is not None: - prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) - if negative_prompt is not None: - negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) - - if controlnet_image is not None: - controlnet_image = Image.open(controlnet_image).convert("RGB") - controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS) - - height = max(64, height - height % 8) # round to divisible by 8 - width = max(64, width - width % 8) # round to divisible by 8 - print(f"prompt: {prompt}") - print(f"negative_prompt: {negative_prompt}") - print(f"height: {height}") - print(f"width: {width}") - print(f"sample_steps: {sample_steps}") - print(f"scale: {scale}") - print(f"sample_sampler: {sampler_name}") - if seed is not None: - print(f"seed: {seed}") - with accelerator.autocast(): - latents = pipeline( - prompt=prompt, - height=height, - width=width, - num_inference_steps=sample_steps, - guidance_scale=scale, - negative_prompt=negative_prompt, - controlnet=controlnet, - controlnet_image=controlnet_image, - ) + with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists: + for prompt_dict in prompt_dict_lists[0]: + sample_image_inference(accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, controlnet=controlnet) - image = pipeline.latents_to_image(latents)[0] - - ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) - num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}" - seed_suffix = "" if seed is None else f"_{seed}" - img_filename = ( - f"{'' if args.output_name is None else args.output_name + '_'}{ts_str}_{num_suffix}_{i:02d}{seed_suffix}.png" - ) - - image.save(os.path.join(save_dir, img_filename)) - - # wandb有効時のみログを送信 - try: - wandb_tracker = accelerator.get_tracker("wandb") - try: - import wandb - except ImportError: # 事前に一度確認するのでここはエラー出ないはず - raise ImportError("No wandb / wandb がインストールされていないようです") - - wandb_tracker.log({f"sample_{i}": wandb.Image(image)}) - except: # wandb 無効時 - pass # clear pipeline and cache to reduce vram usage del pipeline - torch.cuda.empty_cache() + with torch.cuda.device(torch.cuda.current_device()): + torch.cuda.empty_cache() + torch.set_rng_state(rng_state) if cuda_rng_state is not None: torch.cuda.set_rng_state(cuda_rng_state) vae.to(org_vae_device) +def generate_per_device_prompt_list(prompts, num_of_processes, prompt_replacement=None): + + # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processess available (number of devices available) + # prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical. + per_process_prompts = [[] for i in range(num_of_processes)] + for i, prompt in enumerate(prompts): + if isinstance(prompt, str): + prompt = line_to_prompt_dict(prompt) + assert isinstance(prompt, dict) + prompt.pop("subset", None) # Clean up subset key + prompt["enum"] = i + # Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict. + if prompt_replacement is not None: + prompt["prompt"] = prompt["prompt"].replace(prompt_replacement[0], prompt_replacement[1]) + if prompt["negative_prompt"] is not None: + prompt["negative_prompt"] = prompt["negative_prompt"].replace(prompt_replacement[0], prompt_replacement[1]) + # Refactor prompt replacement to here in order to simplify sample_image_inference function. + per_process_prompts[i % num_of_processes].append(prompt) + return per_process_prompts + +def sample_image_inference(accelerator: Accelerator, args: argparse.Namespace, pipeline, save_dir, prompt_dict, epoch, steps, controlnet=None): + assert isinstance(prompt_dict, dict) + negative_prompt = prompt_dict.get("negative_prompt") + sample_steps = prompt_dict.get("sample_steps", 30) + width = prompt_dict.get("width", 512) + height = prompt_dict.get("height", 512) + scale = prompt_dict.get("scale", 7.5) + seed = prompt_dict.get("seed") + controlnet_image = prompt_dict.get("controlnet_image") + prompt: str = prompt_dict.get("prompt", "") + sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler) + + if seed is not None: + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + scheduler = get_my_scheduler( + sample_sampler=sampler_name, + v_parameterization=args.v_parameterization, + ) + pipeline.scheduler = scheduler + + if controlnet_image is not None: + controlnet_image = Image.open(controlnet_image).convert("RGB") + controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS) + + height = max(64, height - height % 8) # round to divisible by 8 + width = max(64, width - width % 8) # round to divisible by 8 + print(f"\nprompt: {prompt}") + print(f"negative_prompt: {negative_prompt}") + print(f"height: {height}") + print(f"width: {width}") + print(f"sample_steps: {sample_steps}") + print(f"scale: {scale}") + print(f"sample_sampler: {sampler_name}") + if seed is not None: + print(f"seed: {seed}") + with accelerator.autocast(): + latents = pipeline( + prompt=prompt, + height=height, + width=width, + num_inference_steps=sample_steps, + guidance_scale=scale, + negative_prompt=negative_prompt, + controlnet=controlnet, + controlnet_image=controlnet_image, + ) + image = pipeline.latents_to_image(latents)[0] + # adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list + ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) + num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}" + seed_suffix = "" if seed is None else f"_{seed}" + i: int = prompt_dict["enum"] + img_filename = ( + f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png" + ) + + image.save(os.path.join(save_dir, img_filename)) + if seed is not None: + torch.seed() + torch.cuda.seed() + # wandb有効時のみログを送信 + try: + wandb_tracker = accelerator.get_tracker("wandb") + try: + import wandb + except ImportError: # 事前に一度確認するのでここはエラー出ないはず + raise ImportError("No wandb / wandb がインストールされていないようです") + wandb_tracker.log({f"sample_{i}": wandb.Image(image)}) + except: # wandb 無効時 + pass # endregion + + + # region 前処理用 From 11aced35005221c05920e33658ceba46fc0e4272 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 3 Feb 2024 22:25:29 +0900 Subject: [PATCH 041/103] simplify multi-GPU sample generation --- library/train_util.py | 92 ++++++++++++++++++++++--------------------- 1 file changed, 47 insertions(+), 45 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 3e6125f08..177fae551 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4668,13 +4668,13 @@ def sample_images_common( if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch return - distributed_state = PartialState() #testing implementation of multi gpu distributed inference - print(f"\ngenerating sample images at step / サンプル画像生成 ステップ: {steps}") if not os.path.isfile(args.sample_prompts): print(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") return + distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here + org_vae_device = vae.device # CPUにいるはず vae.to(distributed_state.device) @@ -4686,10 +4686,6 @@ def sample_images_common( text_encoder = accelerator.unwrap_model(text_encoder) # read prompts - - # with open(args.sample_prompts, "rt", encoding="utf-8") as f: - # prompts = f.readlines() - if args.sample_prompts.endswith(".txt"): with open(args.sample_prompts, "r", encoding="utf-8") as f: lines = f.readlines() @@ -4722,22 +4718,39 @@ def sample_images_common( pipeline.to(distributed_state.device) save_dir = args.output_dir + "/sample" os.makedirs(save_dir, exist_ok=True) - - # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processess available (number of devices available) - # prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical. - per_process_prompts = generate_per_device_prompt_list(prompts, num_of_processes = distributed_state.num_processes, prompt_replacement = prompt_replacement) - rng_state = torch.get_rng_state() - cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None - # True random sample image generation - torch.seed() - torch.cuda.seed() + # preprocess prompts + for i in range(len(prompts)): + prompt_dict = prompts[i] + if isinstance(prompt_dict, str): + prompt_dict = line_to_prompt_dict(prompt_dict) + prompts[i] = prompt_dict + assert isinstance(prompt_dict, dict) - with torch.no_grad(): - with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists: - for prompt_dict in prompt_dict_lists[0]: - sample_image_inference(accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, controlnet=controlnet) + # Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict. + prompt_dict["enum"] = i + prompt_dict.pop("subset", None) + # save random state to restore later + rng_state = torch.get_rng_state() + cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None # TODO mps etc. support + + if distributed_state.num_processes <= 1: + # If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts. + with torch.no_grad(): + for prompt_dict in prompts: + sample_image_inference(accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet) + else: + # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processess available (number of devices available) + # prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical. + per_process_prompts = [] # list of lists + for i in range(distributed_state.num_processes): + per_process_prompts.append(prompts[i::distributed_state.num_processes]) + + with torch.no_grad(): + with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists: + for prompt_dict in prompt_dict_lists[0]: + sample_image_inference(accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet) # clear pipeline and cache to reduce vram usage del pipeline @@ -4750,27 +4763,7 @@ def sample_images_common( torch.cuda.set_rng_state(cuda_rng_state) vae.to(org_vae_device) -def generate_per_device_prompt_list(prompts, num_of_processes, prompt_replacement=None): - - # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processess available (number of devices available) - # prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical. - per_process_prompts = [[] for i in range(num_of_processes)] - for i, prompt in enumerate(prompts): - if isinstance(prompt, str): - prompt = line_to_prompt_dict(prompt) - assert isinstance(prompt, dict) - prompt.pop("subset", None) # Clean up subset key - prompt["enum"] = i - # Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict. - if prompt_replacement is not None: - prompt["prompt"] = prompt["prompt"].replace(prompt_replacement[0], prompt_replacement[1]) - if prompt["negative_prompt"] is not None: - prompt["negative_prompt"] = prompt["negative_prompt"].replace(prompt_replacement[0], prompt_replacement[1]) - # Refactor prompt replacement to here in order to simplify sample_image_inference function. - per_process_prompts[i % num_of_processes].append(prompt) - return per_process_prompts - -def sample_image_inference(accelerator: Accelerator, args: argparse.Namespace, pipeline, save_dir, prompt_dict, epoch, steps, controlnet=None): +def sample_image_inference(accelerator: Accelerator, args: argparse.Namespace, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=None): assert isinstance(prompt_dict, dict) negative_prompt = prompt_dict.get("negative_prompt") sample_steps = prompt_dict.get("sample_steps", 30) @@ -4781,10 +4774,19 @@ def sample_image_inference(accelerator: Accelerator, args: argparse.Namespace, p controlnet_image = prompt_dict.get("controlnet_image") prompt: str = prompt_dict.get("prompt", "") sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler) - + + if prompt_replacement is not None: + prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) + if negative_prompt is not None: + negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) + if seed is not None: torch.manual_seed(seed) torch.cuda.manual_seed(seed) + else: + # True random sample image generation + torch.seed() + torch.cuda.seed() scheduler = get_my_scheduler( sample_sampler=sampler_name, @@ -4819,7 +4821,10 @@ def sample_image_inference(accelerator: Accelerator, args: argparse.Namespace, p controlnet_image=controlnet_image, ) image = pipeline.latents_to_image(latents)[0] + # adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list + # but adding 'enum' to the filename should be enough + ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}" seed_suffix = "" if seed is None else f"_{seed}" @@ -4827,11 +4832,8 @@ def sample_image_inference(accelerator: Accelerator, args: argparse.Namespace, p img_filename = ( f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png" ) - image.save(os.path.join(save_dir, img_filename)) - if seed is not None: - torch.seed() - torch.cuda.seed() + # wandb有効時のみログを送信 try: wandb_tracker = accelerator.get_tracker("wandb") From 2f9a34429729c8b44c6f04cb27656d578ecb9420 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 3 Feb 2024 23:26:57 +0900 Subject: [PATCH 042/103] fix typo --- library/train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 177fae551..1377997c9 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4741,7 +4741,7 @@ def sample_images_common( for prompt_dict in prompts: sample_image_inference(accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet) else: - # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processess available (number of devices available) + # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available) # prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical. per_process_prompts = [] # list of lists for i in range(distributed_state.num_processes): From 6269682c568e7743d0889726e0fdcd6bd8d51bbc Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 3 Feb 2024 23:33:48 +0900 Subject: [PATCH 043/103] unificaition of gen scripts for SD and SDXL, work in progress --- gen_img.py | 3234 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 3234 insertions(+) create mode 100644 gen_img.py diff --git a/gen_img.py b/gen_img.py new file mode 100644 index 000000000..f5a740c3e --- /dev/null +++ b/gen_img.py @@ -0,0 +1,3234 @@ +import itertools +import json +from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable +import glob +import importlib +import inspect +import time +import zipfile +from diffusers.utils import deprecate +from diffusers.configuration_utils import FrozenDict +import argparse +import math +import os +import random +import re + +import diffusers +import numpy as np +import torch + +from library.ipex_interop import init_ipex + +init_ipex() + +import torchvision +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + DPMSolverSinglestepScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + DDIMScheduler, + EulerDiscreteScheduler, + HeunDiscreteScheduler, + KDPM2DiscreteScheduler, + KDPM2AncestralDiscreteScheduler, + # UNet2DConditionModel, + StableDiffusionPipeline, +) +from einops import rearrange +from tqdm import tqdm +from torchvision import transforms +from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection, CLIPImageProcessor +import PIL +from PIL import Image +from PIL.PngImagePlugin import PngInfo + +import library.model_util as model_util +import library.train_util as train_util +import library.sdxl_model_util as sdxl_model_util +import library.sdxl_train_util as sdxl_train_util +from networks.lora import LoRANetwork +import tools.original_control_net as original_control_net +from tools.original_control_net import ControlNetInfo +from library.original_unet import UNet2DConditionModel, InferUNet2DConditionModel +from library.sdxl_original_unet import InferSdxlUNet2DConditionModel +from library.original_unet import FlashAttentionFunction +from networks.control_net_lllite import ControlNetLLLite +from library.utils import GradualLatent, EulerAncestralDiscreteSchedulerGL + +# scheduler: +SCHEDULER_LINEAR_START = 0.00085 +SCHEDULER_LINEAR_END = 0.0120 +SCHEDULER_TIMESTEPS = 1000 +SCHEDLER_SCHEDULE = "scaled_linear" + +# その他の設定 +LATENT_CHANNELS = 4 +DOWNSAMPLING_FACTOR = 8 + +CLIP_VISION_MODEL = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" + +# region モジュール入れ替え部 +""" +高速化のためのモジュール入れ替え +""" + + +def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa): + if mem_eff_attn: + print("Enable memory efficient attention for U-Net") + + # これはDiffusersのU-Netではなく自前のU-Netなので置き換えなくても良い + unet.set_use_memory_efficient_attention(False, True) + elif xformers: + print("Enable xformers for U-Net") + try: + import xformers.ops + except ImportError: + raise ImportError("No xformers / xformersがインストールされていないようです") + + unet.set_use_memory_efficient_attention(True, False) + elif sdpa: + print("Enable SDPA for U-Net") + unet.set_use_memory_efficient_attention(False, False) + unet.set_use_sdpa(True) + + +# TODO common train_util.py +def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers, sdpa): + if mem_eff_attn: + replace_vae_attn_to_memory_efficient() + elif xformers: + # replace_vae_attn_to_xformers() # 解像度によってxformersがエラーを出す? + vae.set_use_memory_efficient_attention_xformers(True) # とりあえずこっちを使う + elif sdpa: + replace_vae_attn_to_sdpa() + + +def replace_vae_attn_to_memory_efficient(): + print("VAE Attention.forward has been replaced to FlashAttention (not xformers)") + flash_func = FlashAttentionFunction + + def forward_flash_attn(self, hidden_states, **kwargs): + q_bucket_size = 512 + k_bucket_size = 1024 + + residual = hidden_states + batch, channel, height, width = hidden_states.shape + + # norm + hidden_states = self.group_norm(hidden_states) + + hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) + + # proj to q, k, v + query_proj = self.to_q(hidden_states) + key_proj = self.to_k(hidden_states) + value_proj = self.to_v(hidden_states) + + query_proj, key_proj, value_proj = map( + lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (query_proj, key_proj, value_proj) + ) + + out = flash_func.apply(query_proj, key_proj, value_proj, None, False, q_bucket_size, k_bucket_size) + + out = rearrange(out, "b h n d -> b n (h d)") + + # compute next hidden_states + # linear proj + hidden_states = self.to_out[0](hidden_states) + # dropout + hidden_states = self.to_out[1](hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) + + # res connect and rescale + hidden_states = (hidden_states + residual) / self.rescale_output_factor + return hidden_states + + def forward_flash_attn_0_14(self, hidden_states, **kwargs): + if not hasattr(self, "to_q"): + self.to_q = self.query + self.to_k = self.key + self.to_v = self.value + self.to_out = [self.proj_attn, torch.nn.Identity()] + self.heads = self.num_heads + return forward_flash_attn(self, hidden_states, **kwargs) + + if diffusers.__version__ < "0.15.0": + diffusers.models.attention.AttentionBlock.forward = forward_flash_attn_0_14 + else: + diffusers.models.attention_processor.Attention.forward = forward_flash_attn + + +def replace_vae_attn_to_xformers(): + print("VAE: Attention.forward has been replaced to xformers") + import xformers.ops + + def forward_xformers(self, hidden_states, **kwargs): + residual = hidden_states + batch, channel, height, width = hidden_states.shape + + # norm + hidden_states = self.group_norm(hidden_states) + + hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) + + # proj to q, k, v + query_proj = self.to_q(hidden_states) + key_proj = self.to_k(hidden_states) + value_proj = self.to_v(hidden_states) + + query_proj, key_proj, value_proj = map( + lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (query_proj, key_proj, value_proj) + ) + + query_proj = query_proj.contiguous() + key_proj = key_proj.contiguous() + value_proj = value_proj.contiguous() + out = xformers.ops.memory_efficient_attention(query_proj, key_proj, value_proj, attn_bias=None) + + out = rearrange(out, "b h n d -> b n (h d)") + + # compute next hidden_states + # linear proj + hidden_states = self.to_out[0](hidden_states) + # dropout + hidden_states = self.to_out[1](hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) + + # res connect and rescale + hidden_states = (hidden_states + residual) / self.rescale_output_factor + return hidden_states + + def forward_xformers_0_14(self, hidden_states, **kwargs): + if not hasattr(self, "to_q"): + self.to_q = self.query + self.to_k = self.key + self.to_v = self.value + self.to_out = [self.proj_attn, torch.nn.Identity()] + self.heads = self.num_heads + return forward_xformers(self, hidden_states, **kwargs) + + if diffusers.__version__ < "0.15.0": + diffusers.models.attention.AttentionBlock.forward = forward_xformers_0_14 + else: + diffusers.models.attention_processor.Attention.forward = forward_xformers + + +def replace_vae_attn_to_sdpa(): + print("VAE: Attention.forward has been replaced to sdpa") + + def forward_sdpa(self, hidden_states, **kwargs): + residual = hidden_states + batch, channel, height, width = hidden_states.shape + + # norm + hidden_states = self.group_norm(hidden_states) + + hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) + + # proj to q, k, v + query_proj = self.to_q(hidden_states) + key_proj = self.to_k(hidden_states) + value_proj = self.to_v(hidden_states) + + query_proj, key_proj, value_proj = map( + lambda t: rearrange(t, "b n (h d) -> b n h d", h=self.heads), (query_proj, key_proj, value_proj) + ) + + out = torch.nn.functional.scaled_dot_product_attention( + query_proj, key_proj, value_proj, attn_mask=None, dropout_p=0.0, is_causal=False + ) + + out = rearrange(out, "b n h d -> b n (h d)") + + # compute next hidden_states + # linear proj + hidden_states = self.to_out[0](hidden_states) + # dropout + hidden_states = self.to_out[1](hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) + + # res connect and rescale + hidden_states = (hidden_states + residual) / self.rescale_output_factor + return hidden_states + + def forward_sdpa_0_14(self, hidden_states, **kwargs): + if not hasattr(self, "to_q"): + self.to_q = self.query + self.to_k = self.key + self.to_v = self.value + self.to_out = [self.proj_attn, torch.nn.Identity()] + self.heads = self.num_heads + return forward_sdpa(self, hidden_states, **kwargs) + + if diffusers.__version__ < "0.15.0": + diffusers.models.attention.AttentionBlock.forward = forward_sdpa_0_14 + else: + diffusers.models.attention_processor.Attention.forward = forward_sdpa + + +# endregion + +# region 画像生成の本体:lpw_stable_diffusion.py (ASL)からコピーして修正 +# https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion.py +# Pipelineだけ独立して使えないのと機能追加するのとでコピーして修正 + + +class PipelineLike: + def __init__( + self, + is_sdxl, + device, + vae: AutoencoderKL, + text_encoders: List[CLIPTextModel], + tokenizers: List[CLIPTokenizer], + unet: InferSdxlUNet2DConditionModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + clip_skip: int, + ): + super().__init__() + self.is_sdxl = is_sdxl + self.device = device + self.clip_skip = clip_skip + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + self.vae = vae + self.text_encoders = text_encoders + self.tokenizers = tokenizers + self.unet: Union[InferUNet2DConditionModel, InferSdxlUNet2DConditionModel] = unet + self.scheduler = scheduler + self.safety_checker = None + + # Textual Inversion + self.token_replacements_list = [] + for _ in range(len(self.text_encoders)): + self.token_replacements_list.append({}) + + # ControlNet + self.control_nets: List[ControlNetInfo] = [] # only for SD 1.5 + self.control_net_lllites: List[ControlNetLLLite] = [] + self.control_net_enabled = True # control_netsが空ならTrueでもFalseでもControlNetは動作しない + + self.gradual_latent: GradualLatent = None + + # Textual Inversion + def add_token_replacement(self, text_encoder_index, target_token_id, rep_token_ids): + self.token_replacements_list[text_encoder_index][target_token_id] = rep_token_ids + + def set_enable_control_net(self, en: bool): + self.control_net_enabled = en + + def get_token_replacer(self, tokenizer): + tokenizer_index = self.tokenizers.index(tokenizer) + token_replacements = self.token_replacements_list[tokenizer_index] + + def replace_tokens(tokens): + # print("replace_tokens", tokens, "=>", token_replacements) + if isinstance(tokens, torch.Tensor): + tokens = tokens.tolist() + + new_tokens = [] + for token in tokens: + if token in token_replacements: + replacement = token_replacements[token] + new_tokens.extend(replacement) + else: + new_tokens.append(token) + return new_tokens + + return replace_tokens + + def set_control_nets(self, ctrl_nets): + self.control_nets = ctrl_nets + + def set_control_net_lllites(self, ctrl_net_lllites): + self.control_net_lllites = ctrl_net_lllites + + def set_gradual_latent(self, gradual_latent): + if gradual_latent is None: + print("gradual_latent is disabled") + self.gradual_latent = None + else: + print(f"gradual_latent is enabled: {gradual_latent}") + self.gradual_latent = gradual_latent # (ds_ratio, start_timesteps, every_n_steps, ratio_step) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + init_image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None, + mask_image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None, + height: int = 1024, + width: int = 1024, + original_height: int = None, + original_width: int = None, + original_height_negative: int = None, + original_width_negative: int = None, + crop_top: int = 0, + crop_left: int = 0, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_scale: float = None, + strength: float = 0.8, + # num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + max_embeddings_multiples: Optional[int] = 3, + output_type: Optional[str] = "pil", + vae_batch_size: float = None, + return_latents: bool = False, + # return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + is_cancelled_callback: Optional[Callable[[], bool]] = None, + callback_steps: Optional[int] = 1, + img2img_noise=None, + clip_guide_images=None, + **kwargs, + ): + # TODO support secondary prompt + num_images_per_prompt = 1 # fixed because already prompt is repeated + + if isinstance(prompt, str): + batch_size = 1 + prompt = [prompt] + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + regional_network = " AND " in prompt[0] + + vae_batch_size = ( + batch_size + if vae_batch_size is None + else (int(vae_batch_size) if vae_batch_size >= 1 else max(1, int(batch_size * vae_batch_size))) + ) + + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." + ) + + # get prompt text embeddings + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + if not do_classifier_free_guidance and negative_scale is not None: + print(f"negative_scale is ignored if guidance scalle <= 1.0") + negative_scale = None + + # get unconditional embeddings for classifier free guidance + if negative_prompt is None: + negative_prompt = [""] * batch_size + elif isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] * batch_size + if batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + tes_text_embs = [] + tes_uncond_embs = [] + tes_real_uncond_embs = [] + + for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders): + token_replacer = self.get_token_replacer(tokenizer) + + # use last text_pool, because it is from text encoder 2 + text_embeddings, text_pool, uncond_embeddings, uncond_pool, _ = get_weighted_text_embeddings( + self.is_sdxl, + tokenizer, + text_encoder, + prompt=prompt, + uncond_prompt=negative_prompt if do_classifier_free_guidance else None, + max_embeddings_multiples=max_embeddings_multiples, + clip_skip=self.clip_skip, + token_replacer=token_replacer, + device=self.device, + **kwargs, + ) + tes_text_embs.append(text_embeddings) + tes_uncond_embs.append(uncond_embeddings) + + if negative_scale is not None: + _, real_uncond_embeddings, _ = get_weighted_text_embeddings( + self.is_sdxl, + token_replacer, + prompt=prompt, # こちらのトークン長に合わせてuncondを作るので75トークン超で必須 + uncond_prompt=[""] * batch_size, + max_embeddings_multiples=max_embeddings_multiples, + clip_skip=self.clip_skip, + token_replacer=token_replacer, + device=self.device, + **kwargs, + ) + tes_real_uncond_embs.append(real_uncond_embeddings) + + # concat text encoder outputs + text_embeddings = tes_text_embs[0] + uncond_embeddings = tes_uncond_embs[0] + for i in range(1, len(tes_text_embs)): + text_embeddings = torch.cat([text_embeddings, tes_text_embs[i]], dim=2) # n,77,2048 + if do_classifier_free_guidance: + uncond_embeddings = torch.cat([uncond_embeddings, tes_uncond_embs[i]], dim=2) # n,77,2048 + + if do_classifier_free_guidance: + if negative_scale is None: + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + else: + text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings]) + + if self.control_net_lllites: + # ControlNetのhintにguide imageを流用する。ControlNetの場合はControlNet側で行う + if isinstance(clip_guide_images, PIL.Image.Image): + clip_guide_images = [clip_guide_images] + if isinstance(clip_guide_images[0], PIL.Image.Image): + clip_guide_images = [preprocess_image(im) for im in clip_guide_images] + clip_guide_images = torch.cat(clip_guide_images) + if isinstance(clip_guide_images, list): + clip_guide_images = torch.stack(clip_guide_images) + + clip_guide_images = clip_guide_images.to(self.device, dtype=text_embeddings.dtype) + + # create size embs + if original_height is None: + original_height = height + if original_width is None: + original_width = width + if original_height_negative is None: + original_height_negative = original_height + if original_width_negative is None: + original_width_negative = original_width + if crop_top is None: + crop_top = 0 + if crop_left is None: + crop_left = 0 + if self.is_sdxl: + emb1 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([original_height, original_width]).unsqueeze(0), 256) + uc_emb1 = sdxl_train_util.get_timestep_embedding( + torch.FloatTensor([original_height_negative, original_width_negative]).unsqueeze(0), 256 + ) + emb2 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([crop_top, crop_left]).unsqueeze(0), 256) + emb3 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([height, width]).unsqueeze(0), 256) + c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(self.device, dtype=text_embeddings.dtype).repeat(batch_size, 1) + uc_vector = torch.cat([uc_emb1, emb2, emb3], dim=1).to(self.device, dtype=text_embeddings.dtype).repeat(batch_size, 1) + + if regional_network: + # use last pool for conditioning + num_sub_prompts = len(text_pool) // batch_size + text_pool = text_pool[num_sub_prompts - 1 :: num_sub_prompts] # last subprompt + + if init_image is not None and self.clip_vision_model is not None: + print(f"encode by clip_vision_model and apply clip_vision_strength={self.clip_vision_strength}") + vision_input = self.clip_vision_processor(init_image, return_tensors="pt", device=self.device) + pixel_values = vision_input["pixel_values"].to(self.device, dtype=text_embeddings.dtype) + + clip_vision_embeddings = self.clip_vision_model( + pixel_values=pixel_values, output_hidden_states=True, return_dict=True + ) + clip_vision_embeddings = clip_vision_embeddings.image_embeds + + if len(clip_vision_embeddings) == 1 and batch_size > 1: + clip_vision_embeddings = clip_vision_embeddings.repeat((batch_size, 1)) + + clip_vision_embeddings = clip_vision_embeddings * self.clip_vision_strength + assert clip_vision_embeddings.shape == text_pool.shape, f"{clip_vision_embeddings.shape} != {text_pool.shape}" + text_pool = clip_vision_embeddings # replace: same as ComfyUI (?) + + c_vector = torch.cat([text_pool, c_vector], dim=1) + if do_classifier_free_guidance: + uc_vector = torch.cat([uncond_pool, uc_vector], dim=1) + vector_embeddings = torch.cat([uc_vector, c_vector]) + else: + vector_embeddings = c_vector + + # set timesteps + self.scheduler.set_timesteps(num_inference_steps, self.device) + + latents_dtype = text_embeddings.dtype + init_latents_orig = None + mask = None + + if init_image is None: + # get the initial random noise unless the user supplied it + + # Unlike in other pipelines, latents need to be generated in the target device + # for 1-to-1 results reproducibility with the CompVis implementation. + # However this currently doesn't work in `mps`. + latents_shape = ( + batch_size * num_images_per_prompt, + self.unet.in_channels, + height // 8, + width // 8, + ) + + if latents is None: + if self.device.type == "mps": + # randn does not exist on mps + latents = torch.randn( + latents_shape, + generator=generator, + device="cpu", + dtype=latents_dtype, + ).to(self.device) + else: + latents = torch.randn( + latents_shape, + generator=generator, + device=self.device, + dtype=latents_dtype, + ) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + latents = latents.to(self.device) + + timesteps = self.scheduler.timesteps.to(self.device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + else: + # image to tensor + if isinstance(init_image, PIL.Image.Image): + init_image = [init_image] + if isinstance(init_image[0], PIL.Image.Image): + init_image = [preprocess_image(im) for im in init_image] + init_image = torch.cat(init_image) + if isinstance(init_image, list): + init_image = torch.stack(init_image) + + # mask image to tensor + if mask_image is not None: + if isinstance(mask_image, PIL.Image.Image): + mask_image = [mask_image] + if isinstance(mask_image[0], PIL.Image.Image): + mask_image = torch.cat([preprocess_mask(im) for im in mask_image]) # H*W, 0 for repaint + + # encode the init image into latents and scale the latents + init_image = init_image.to(device=self.device, dtype=latents_dtype) + if init_image.size()[-2:] == (height // 8, width // 8): + init_latents = init_image + else: + if vae_batch_size >= batch_size: + init_latent_dist = self.vae.encode(init_image.to(self.vae.dtype)).latent_dist + init_latents = init_latent_dist.sample(generator=generator) + else: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + init_latents = [] + for i in tqdm(range(0, min(batch_size, len(init_image)), vae_batch_size)): + init_latent_dist = self.vae.encode( + (init_image[i : i + vae_batch_size] if vae_batch_size > 1 else init_image[i].unsqueeze(0)).to( + self.vae.dtype + ) + ).latent_dist + init_latents.append(init_latent_dist.sample(generator=generator)) + init_latents = torch.cat(init_latents) + + init_latents = (sdxl_model_util.VAE_SCALE_FACTOR if self.is_sdxl else 0.18215) * init_latents + + if len(init_latents) == 1: + init_latents = init_latents.repeat((batch_size, 1, 1, 1)) + init_latents_orig = init_latents + + # preprocess mask + if mask_image is not None: + mask = mask_image.to(device=self.device, dtype=latents_dtype) + if len(mask) == 1: + mask = mask.repeat((batch_size, 1, 1, 1)) + + # check sizes + if not mask.shape == init_latents.shape: + raise ValueError("The mask and init_image should be the same size!") + + # get the original timestep using init_timestep + offset = self.scheduler.config.get("steps_offset", 0) + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + + timesteps = self.scheduler.timesteps[-init_timestep] + timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device) + + # add noise to latents using the timesteps + latents = self.scheduler.add_noise(init_latents, img2img_noise, timesteps) + + t_start = max(num_inference_steps - init_timestep + offset, 0) + timesteps = self.scheduler.timesteps[t_start:].to(self.device) + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1 + + if self.control_nets: + guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images) + each_control_net_enabled = [self.control_net_enabled] * len(self.control_nets) + + if self.control_net_lllites: + # guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images) + if self.control_net_enabled: + for control_net, _ in self.control_net_lllites: + with torch.no_grad(): + control_net.set_cond_image(clip_guide_images) + else: + for control_net, _ in self.control_net_lllites: + control_net.set_cond_image(None) + + each_control_net_enabled = [self.control_net_enabled] * len(self.control_net_lllites) + + enable_gradual_latent = False + if self.gradual_latent: + if not hasattr(self.scheduler, "set_gradual_latent_params"): + print("gradual_latent is not supported for this scheduler. Ignoring.") + print(self.scheduler.__class__.__name__) + else: + enable_gradual_latent = True + step_elapsed = 1000 + current_ratio = self.gradual_latent.ratio + + # first, we downscale the latents to the specified ratio / 最初に指定された比率にlatentsをダウンスケールする + height, width = latents.shape[-2:] + org_dtype = latents.dtype + if org_dtype == torch.bfloat16: + latents = latents.float() + latents = torch.nn.functional.interpolate( + latents, scale_factor=current_ratio, mode="bicubic", align_corners=False + ).to(org_dtype) + + # apply unsharp mask / アンシャープマスクを適用する + if self.gradual_latent.gaussian_blur_ksize: + latents = self.gradual_latent.apply_unshark_mask(latents) + + for i, t in enumerate(tqdm(timesteps)): + resized_size = None + if enable_gradual_latent: + # gradually upscale the latents / latentsを徐々にアップスケールする + if ( + t < self.gradual_latent.start_timesteps + and current_ratio < 1.0 + and step_elapsed >= self.gradual_latent.every_n_steps + ): + current_ratio = min(current_ratio + self.gradual_latent.ratio_step, 1.0) + # make divisible by 8 because size of latents must be divisible at bottom of UNet + h = int(height * current_ratio) // 8 * 8 + w = int(width * current_ratio) // 8 * 8 + resized_size = (h, w) + self.scheduler.set_gradual_latent_params(resized_size, self.gradual_latent) + step_elapsed = 0 + else: + self.scheduler.set_gradual_latent_params(None, None) + step_elapsed += 1 + + # expand the latents if we are doing classifier free guidance + latent_model_input = latents.repeat((num_latent_input, 1, 1, 1)) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # disable ControlNet-LLLite if ratio is set. ControlNet is disabled in ControlNetInfo + if self.control_net_lllites: + for j, ((control_net, ratio), enabled) in enumerate(zip(self.control_net_lllites, each_control_net_enabled)): + if not enabled or ratio >= 1.0: + continue + if ratio < i / len(timesteps): + print(f"ControlNetLLLite {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})") + control_net.set_cond_image(None) + each_control_net_enabled[j] = False + + # predict the noise residual + if self.control_nets and self.control_net_enabled: + if regional_network: + num_sub_and_neg_prompts = len(text_embeddings) // batch_size + text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2 :: num_sub_and_neg_prompts] # last subprompt + else: + text_emb_last = text_embeddings + + noise_pred = original_control_net.call_unet_and_control_net( + i, + num_latent_input, + self.unet, + self.control_nets, + guided_hints, + i / len(timesteps), + latent_model_input, + t, + text_embeddings, + text_emb_last, + ).sample + elif self.is_sdxl: + noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings) + else: + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + if negative_scale is None: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(num_latent_input) # uncond by negative prompt + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + else: + noise_pred_negative, noise_pred_text, noise_pred_uncond = noise_pred.chunk( + num_latent_input + ) # uncond is real uncond + noise_pred = ( + noise_pred_uncond + + guidance_scale * (noise_pred_text - noise_pred_uncond) + - negative_scale * (noise_pred_negative - noise_pred_uncond) + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + if mask is not None: + # masking + init_latents_proper = self.scheduler.add_noise(init_latents_orig, img2img_noise, torch.tensor([t])) + latents = (init_latents_proper * mask) + (latents * (1 - mask)) + + # call the callback, if provided + if i % callback_steps == 0: + if callback is not None: + callback(i, t, latents) + if is_cancelled_callback is not None and is_cancelled_callback(): + return None + + if return_latents: + return latents + + latents = 1 / (sdxl_model_util.VAE_SCALE_FACTOR if self.is_sdxl else 0.18215) * latents + if vae_batch_size >= batch_size: + image = self.vae.decode(latents.to(self.vae.dtype)).sample + else: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + images = [] + for i in tqdm(range(0, batch_size, vae_batch_size)): + images.append( + self.vae.decode( + (latents[i : i + vae_batch_size] if vae_batch_size > 1 else latents[i].unsqueeze(0)).to(self.vae.dtype) + ).sample + ) + image = torch.cat(images) + + image = (image / 2 + 0.5).clamp(0, 1) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + if output_type == "pil": + # image = self.numpy_to_pil(image) + image = (image * 255).round().astype("uint8") + image = [Image.fromarray(im) for im in image] + + return image + + # return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + +re_attention = re.compile( + r""" +\\\(| +\\\)| +\\\[| +\\]| +\\\\| +\\| +\(| +\[| +:([+-]?[.\d]+)\)| +\)| +]| +[^\\()\[\]:]+| +: +""", + re.X, +) + + +def parse_prompt_attention(text): + """ + Parses a string with attention tokens and returns a list of pairs: text and its associated weight. + Accepted tokens are: + (abc) - increases attention to abc by a multiplier of 1.1 + (abc:3.12) - increases attention to abc by a multiplier of 3.12 + [abc] - decreases attention to abc by a multiplier of 1.1 + \( - literal character '(' + \[ - literal character '[' + \) - literal character ')' + \] - literal character ']' + \\ - literal character '\' + anything else - just text + >>> parse_prompt_attention('normal text') + [['normal text', 1.0]] + >>> parse_prompt_attention('an (important) word') + [['an ', 1.0], ['important', 1.1], [' word', 1.0]] + >>> parse_prompt_attention('(unbalanced') + [['unbalanced', 1.1]] + >>> parse_prompt_attention('\(literal\]') + [['(literal]', 1.0]] + >>> parse_prompt_attention('(unnecessary)(parens)') + [['unnecessaryparens', 1.1]] + >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') + [['a ', 1.0], + ['house', 1.5730000000000004], + [' ', 1.1], + ['on', 1.0], + [' a ', 1.1], + ['hill', 0.55], + [', sun, ', 1.1], + ['sky', 1.4641000000000006], + ['.', 1.1]] + """ + + res = [] + round_brackets = [] + square_brackets = [] + + round_bracket_multiplier = 1.1 + square_bracket_multiplier = 1 / 1.1 + + def multiply_range(start_position, multiplier): + for p in range(start_position, len(res)): + res[p][1] *= multiplier + + # keep break as separate token + text = text.replace("BREAK", "\\BREAK\\") + + for m in re_attention.finditer(text): + text = m.group(0) + weight = m.group(1) + + if text.startswith("\\"): + res.append([text[1:], 1.0]) + elif text == "(": + round_brackets.append(len(res)) + elif text == "[": + square_brackets.append(len(res)) + elif weight is not None and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), float(weight)) + elif text == ")" and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), round_bracket_multiplier) + elif text == "]" and len(square_brackets) > 0: + multiply_range(square_brackets.pop(), square_bracket_multiplier) + else: + res.append([text, 1.0]) + + for pos in round_brackets: + multiply_range(pos, round_bracket_multiplier) + + for pos in square_brackets: + multiply_range(pos, square_bracket_multiplier) + + if len(res) == 0: + res = [["", 1.0]] + + # merge runs of identical weights + i = 0 + while i + 1 < len(res): + if res[i][1] == res[i + 1][1] and res[i][0].strip() != "BREAK" and res[i + 1][0].strip() != "BREAK": + res[i][0] += res[i + 1][0] + res.pop(i + 1) + else: + i += 1 + + return res + + +def get_prompts_with_weights(tokenizer: CLIPTokenizer, token_replacer, prompt: List[str], max_length: int): + r""" + Tokenize a list of prompts and return its tokens with weights of each token. + No padding, starting or ending token is included. + """ + tokens = [] + weights = [] + truncated = False + + for text in prompt: + texts_and_weights = parse_prompt_attention(text) + text_token = [] + text_weight = [] + for word, weight in texts_and_weights: + if word.strip() == "BREAK": + # pad until next multiple of tokenizer's max token length + pad_len = tokenizer.model_max_length - (len(text_token) % tokenizer.model_max_length) + print(f"BREAK pad_len: {pad_len}") + for i in range(pad_len): + # v2のときEOSをつけるべきかどうかわからないぜ + # if i == 0: + # text_token.append(tokenizer.eos_token_id) + # else: + text_token.append(tokenizer.pad_token_id) + text_weight.append(1.0) + continue + + # tokenize and discard the starting and the ending token + token = tokenizer(word).input_ids[1:-1] + + token = token_replacer(token) # for Textual Inversion + + text_token += token + # copy the weight by length of token + text_weight += [weight] * len(token) + # stop if the text is too long (longer than truncation limit) + if len(text_token) > max_length: + truncated = True + break + # truncate + if len(text_token) > max_length: + truncated = True + text_token = text_token[:max_length] + text_weight = text_weight[:max_length] + tokens.append(text_token) + weights.append(text_weight) + if truncated: + print("warning: Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") + return tokens, weights + + +def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77): + r""" + Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. + """ + max_embeddings_multiples = (max_length - 2) // (chunk_length - 2) + weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length + for i in range(len(tokens)): + tokens[i] = [bos] + tokens[i] + [eos] + [pad] * (max_length - 2 - len(tokens[i])) + if no_boseos_middle: + weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i])) + else: + w = [] + if len(weights[i]) == 0: + w = [1.0] * weights_length + else: + for j in range(max_embeddings_multiples): + w.append(1.0) # weight for starting token in this chunk + w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))] + w.append(1.0) # weight for ending token in this chunk + w += [1.0] * (weights_length - len(w)) + weights[i] = w[:] + + return tokens, weights + + +def get_unweighted_text_embeddings( + is_sdxl: bool, + text_encoder: CLIPTextModel, + text_input: torch.Tensor, + chunk_length: int, + clip_skip: int, + eos: int, + pad: int, + no_boseos_middle: Optional[bool] = True, +): + """ + When the length of tokens is a multiple of the capacity of the text encoder, + it should be split into chunks and sent to the text encoder individually. + """ + max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2) + if max_embeddings_multiples > 1: + text_embeddings = [] + pool = None + for i in range(max_embeddings_multiples): + # extract the i-th chunk + text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone() + + # cover the head and the tail by the starting and the ending tokens + text_input_chunk[:, 0] = text_input[0, 0] + if pad == eos: # v1 + text_input_chunk[:, -1] = text_input[0, -1] + else: # v2 + for j in range(len(text_input_chunk)): + if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある + text_input_chunk[j, -1] = eos + if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD + text_input_chunk[j, 1] = eos + + # in sdxl, value of clip_skip is same for Text Encoder 1 and 2 + enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True) + text_embedding = enc_out["hidden_states"][-clip_skip] + if not is_sdxl: # SD 1.5 requires final_layer_norm + text_embedding = text_encoder.text_model.final_layer_norm(text_embedding) + if pool is None: + pool = enc_out.get("text_embeds", None) # use 1st chunk, if provided + if pool is not None: + pool = train_util.pool_workaround(text_encoder, enc_out["last_hidden_state"], text_input_chunk, eos) + + if no_boseos_middle: + if i == 0: + # discard the ending token + text_embedding = text_embedding[:, :-1] + elif i == max_embeddings_multiples - 1: + # discard the starting token + text_embedding = text_embedding[:, 1:] + else: + # discard both starting and ending tokens + text_embedding = text_embedding[:, 1:-1] + + text_embeddings.append(text_embedding) + text_embeddings = torch.concat(text_embeddings, axis=1) + else: + enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True) + text_embeddings = enc_out["hidden_states"][-clip_skip] + if not is_sdxl: # SD 1.5 requires final_layer_norm + text_embeddings = text_encoder.text_model.final_layer_norm(text_embeddings) + pool = enc_out.get("text_embeds", None) # text encoder 1 doesn't return this + if pool is not None: + pool = train_util.pool_workaround(text_encoder, enc_out["last_hidden_state"], text_input, eos) + return text_embeddings, pool + + +def get_weighted_text_embeddings( + is_sdxl: bool, + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModel, + prompt: Union[str, List[str]], + uncond_prompt: Optional[Union[str, List[str]]] = None, + max_embeddings_multiples: Optional[int] = 1, + no_boseos_middle: Optional[bool] = False, + skip_parsing: Optional[bool] = False, + skip_weighting: Optional[bool] = False, + clip_skip: int = 1, + token_replacer=None, + device=None, + **kwargs, +): + max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 + if isinstance(prompt, str): + prompt = [prompt] + + # split the prompts with "AND". each prompt must have the same number of splits + new_prompts = [] + for p in prompt: + new_prompts.extend(p.split(" AND ")) + prompt = new_prompts + + if not skip_parsing: + prompt_tokens, prompt_weights = get_prompts_with_weights(tokenizer, token_replacer, prompt, max_length - 2) + if uncond_prompt is not None: + if isinstance(uncond_prompt, str): + uncond_prompt = [uncond_prompt] + uncond_tokens, uncond_weights = get_prompts_with_weights(tokenizer, token_replacer, uncond_prompt, max_length - 2) + else: + prompt_tokens = [token[1:-1] for token in tokenizer(prompt, max_length=max_length, truncation=True).input_ids] + prompt_weights = [[1.0] * len(token) for token in prompt_tokens] + if uncond_prompt is not None: + if isinstance(uncond_prompt, str): + uncond_prompt = [uncond_prompt] + uncond_tokens = [token[1:-1] for token in tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids] + uncond_weights = [[1.0] * len(token) for token in uncond_tokens] + + # round up the longest length of tokens to a multiple of (model_max_length - 2) + max_length = max([len(token) for token in prompt_tokens]) + if uncond_prompt is not None: + max_length = max(max_length, max([len(token) for token in uncond_tokens])) + + max_embeddings_multiples = min( + max_embeddings_multiples, + (max_length - 1) // (tokenizer.model_max_length - 2) + 1, + ) + max_embeddings_multiples = max(1, max_embeddings_multiples) + max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 + + # pad the length of tokens and weights + bos = tokenizer.bos_token_id + eos = tokenizer.eos_token_id + pad = tokenizer.pad_token_id + prompt_tokens, prompt_weights = pad_tokens_and_weights( + prompt_tokens, + prompt_weights, + max_length, + bos, + eos, + pad, + no_boseos_middle=no_boseos_middle, + chunk_length=tokenizer.model_max_length, + ) + prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=device) + if uncond_prompt is not None: + uncond_tokens, uncond_weights = pad_tokens_and_weights( + uncond_tokens, + uncond_weights, + max_length, + bos, + eos, + pad, + no_boseos_middle=no_boseos_middle, + chunk_length=tokenizer.model_max_length, + ) + uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=device) + + # get the embeddings + text_embeddings, text_pool = get_unweighted_text_embeddings( + is_sdxl, + text_encoder, + prompt_tokens, + tokenizer.model_max_length, + clip_skip, + eos, + pad, + no_boseos_middle=no_boseos_middle, + ) + + prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=device) + if uncond_prompt is not None: + uncond_embeddings, uncond_pool = get_unweighted_text_embeddings( + is_sdxl, + text_encoder, + uncond_tokens, + tokenizer.model_max_length, + clip_skip, + eos, + pad, + no_boseos_middle=no_boseos_middle, + ) + uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=device) + + # assign weights to the prompts and normalize in the sense of mean + # TODO: should we normalize by chunk or in a whole (current implementation)? + # →全体でいいんじゃないかな + if (not skip_parsing) and (not skip_weighting): + previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings *= prompt_weights.unsqueeze(-1) + current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + if uncond_prompt is not None: + previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) + uncond_embeddings *= uncond_weights.unsqueeze(-1) + current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) + uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + + if uncond_prompt is not None: + return text_embeddings, text_pool, uncond_embeddings, uncond_pool, prompt_tokens + return text_embeddings, text_pool, None, None, prompt_tokens + + +def preprocess_image(image): + w, h = image.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL.Image.LANCZOS) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return 2.0 * image - 1.0 + + +def preprocess_mask(mask): + mask = mask.convert("L") + w, h = mask.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + mask = mask.resize((w // 8, h // 8), resample=PIL.Image.BILINEAR) # LANCZOS) + mask = np.array(mask).astype(np.float32) / 255.0 + mask = np.tile(mask, (4, 1, 1)) + mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? + mask = 1 - mask # repaint white, keep black + mask = torch.from_numpy(mask) + return mask + + +# regular expression for dynamic prompt: +# starts and ends with "{" and "}" +# contains at least one variant divided by "|" +# optional framgments divided by "$$" at start +# if the first fragment is "E" or "e", enumerate all variants +# if the second fragment is a number or two numbers, repeat the variants in the range +# if the third fragment is a string, use it as a separator + +RE_DYNAMIC_PROMPT = re.compile(r"\{((e|E)\$\$)?(([\d\-]+)\$\$)?(([^\|\}]+?)\$\$)?(.+?((\|).+?)*?)\}") + + +def handle_dynamic_prompt_variants(prompt, repeat_count): + founds = list(RE_DYNAMIC_PROMPT.finditer(prompt)) + if not founds: + return [prompt] + + # make each replacement for each variant + enumerating = False + replacers = [] + for found in founds: + # if "e$$" is found, enumerate all variants + found_enumerating = found.group(2) is not None + enumerating = enumerating or found_enumerating + + separator = ", " if found.group(6) is None else found.group(6) + variants = found.group(7).split("|") + + # parse count range + count_range = found.group(4) + if count_range is None: + count_range = [1, 1] + else: + count_range = count_range.split("-") + if len(count_range) == 1: + count_range = [int(count_range[0]), int(count_range[0])] + elif len(count_range) == 2: + count_range = [int(count_range[0]), int(count_range[1])] + else: + print(f"invalid count range: {count_range}") + count_range = [1, 1] + if count_range[0] > count_range[1]: + count_range = [count_range[1], count_range[0]] + if count_range[0] < 0: + count_range[0] = 0 + if count_range[1] > len(variants): + count_range[1] = len(variants) + + if found_enumerating: + # make function to enumerate all combinations + def make_replacer_enum(vari, cr, sep): + def replacer(): + values = [] + for count in range(cr[0], cr[1] + 1): + for comb in itertools.combinations(vari, count): + values.append(sep.join(comb)) + return values + + return replacer + + replacers.append(make_replacer_enum(variants, count_range, separator)) + else: + # make function to choose random combinations + def make_replacer_single(vari, cr, sep): + def replacer(): + count = random.randint(cr[0], cr[1]) + comb = random.sample(vari, count) + return [sep.join(comb)] + + return replacer + + replacers.append(make_replacer_single(variants, count_range, separator)) + + # make each prompt + if not enumerating: + # if not enumerating, repeat the prompt, replace each variant randomly + prompts = [] + for _ in range(repeat_count): + current = prompt + for found, replacer in zip(founds, replacers): + current = current.replace(found.group(0), replacer()[0], 1) + prompts.append(current) + else: + # if enumerating, iterate all combinations for previous prompts + prompts = [prompt] + + for found, replacer in zip(founds, replacers): + if found.group(2) is not None: + # make all combinations for existing prompts + new_prompts = [] + for current in prompts: + replecements = replacer() + for replecement in replecements: + new_prompts.append(current.replace(found.group(0), replecement, 1)) + prompts = new_prompts + + for found, replacer in zip(founds, replacers): + # make random selection for existing prompts + if found.group(2) is None: + for i in range(len(prompts)): + prompts[i] = prompts[i].replace(found.group(0), replacer()[0], 1) + + return prompts + + +# endregion + +# def load_clip_l14_336(dtype): +# print(f"loading CLIP: {CLIP_ID_L14_336}") +# text_encoder = CLIPTextModel.from_pretrained(CLIP_ID_L14_336, torch_dtype=dtype) +# return text_encoder + + +class BatchDataBase(NamedTuple): + # バッチ分割が必要ないデータ + step: int + prompt: str + negative_prompt: str + seed: int + init_image: Any + mask_image: Any + clip_prompt: str + guide_image: Any + raw_prompt: str + + +class BatchDataExt(NamedTuple): + # バッチ分割が必要なデータ + width: int + height: int + original_width: int + original_height: int + original_width_negative: int + original_height_negative: int + crop_left: int + crop_top: int + steps: int + scale: float + negative_scale: float + strength: float + network_muls: Tuple[float] + num_sub_prompts: int + + +class BatchData(NamedTuple): + return_latents: bool + base: BatchDataBase + ext: BatchDataExt + + +def main(args): + if args.fp16: + dtype = torch.float16 + elif args.bf16: + dtype = torch.bfloat16 + else: + dtype = torch.float32 + + highres_fix = args.highres_fix_scale is not None + # assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません" + + if args.v_parameterization and not args.v2: + print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません") + if args.v2 and args.clip_skip is not None: + print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") + + # モデルを読み込む + if not os.path.exists(args.ckpt): # ファイルがないならパターンで探し、一つだけ該当すればそれを使う + files = glob.glob(args.ckpt) + if len(files) == 1: + args.ckpt = files[0] + + name_or_path = os.readlink(args.ckpt) if os.path.islink(args.ckpt) else args.ckpt + use_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers + + # SDXLかどうかを判定する + is_sdxl = args.sdxl + if not is_sdxl and not args.v1 and not args.v2: # どれも指定されていない場合は自動で判定する + if use_stable_diffusion_format: + # if file size > 5.5GB, sdxl + is_sdxl = os.path.getsize(name_or_path) > 5.5 * 1024**3 + else: + # if `text_encoder_2` subdirectory exists, sdxl + is_sdxl = os.path.isdir(os.path.join(name_or_path, "text_encoder_2")) + print(f"SDXL: {is_sdxl}") + + if is_sdxl: + if args.clip_skip is None: + args.clip_skip = 2 + + (_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model( + args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype + ) + unet: InferSdxlUNet2DConditionModel = InferSdxlUNet2DConditionModel(unet) + text_encoders = [text_encoder1, text_encoder2] + else: + if args.clip_skip is None: + args.clip_skip = 2 if args.v2 else 1 + + if use_stable_diffusion_format: + print("load StableDiffusion checkpoint") + text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.ckpt) + else: + print("load Diffusers pretrained models") + loading_pipe = StableDiffusionPipeline.from_pretrained(args.ckpt, safety_checker=None, torch_dtype=dtype) + text_encoder = loading_pipe.text_encoder + vae = loading_pipe.vae + unet = loading_pipe.unet + tokenizer = loading_pipe.tokenizer + del loading_pipe + + # Diffusers U-Net to original U-Net + original_unet = UNet2DConditionModel( + unet.config.sample_size, + unet.config.attention_head_dim, + unet.config.cross_attention_dim, + unet.config.use_linear_projection, + unet.config.upcast_attention, + ) + original_unet.load_state_dict(unet.state_dict()) + unet = original_unet + unet: InferUNet2DConditionModel = InferUNet2DConditionModel(unet) + text_encoders = [text_encoder] + + # VAEを読み込む + if args.vae is not None: + vae = model_util.load_vae(args.vae, dtype) + print("additional VAE loaded") + + # xformers、Hypernetwork対応 + if not args.diffusers_xformers: + mem_eff = not (args.xformers or args.sdpa) + replace_unet_modules(unet, mem_eff, args.xformers, args.sdpa) + replace_vae_modules(vae, mem_eff, args.xformers, args.sdpa) + + # tokenizerを読み込む + print("loading tokenizer") + if is_sdxl: + tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) + tokenizers = [tokenizer1, tokenizer2] + else: + if use_stable_diffusion_format: + tokenizer = train_util.load_tokenizer(args) + tokenizers = [tokenizer] + + # schedulerを用意する + sched_init_args = {} + has_steps_offset = True + has_clip_sample = True + scheduler_num_noises_per_step = 1 + + if args.sampler == "ddim": + scheduler_cls = DDIMScheduler + scheduler_module = diffusers.schedulers.scheduling_ddim + elif args.sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある + scheduler_cls = DDPMScheduler + scheduler_module = diffusers.schedulers.scheduling_ddpm + elif args.sampler == "pndm": + scheduler_cls = PNDMScheduler + scheduler_module = diffusers.schedulers.scheduling_pndm + has_clip_sample = False + elif args.sampler == "lms" or args.sampler == "k_lms": + scheduler_cls = LMSDiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_lms_discrete + has_clip_sample = False + elif args.sampler == "euler" or args.sampler == "k_euler": + scheduler_cls = EulerDiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_euler_discrete + has_clip_sample = False + elif args.sampler == "euler_a" or args.sampler == "k_euler_a": + scheduler_cls = EulerAncestralDiscreteSchedulerGL + scheduler_module = diffusers.schedulers.scheduling_euler_ancestral_discrete + has_clip_sample = False + elif args.sampler == "dpmsolver" or args.sampler == "dpmsolver++": + scheduler_cls = DPMSolverMultistepScheduler + sched_init_args["algorithm_type"] = args.sampler + scheduler_module = diffusers.schedulers.scheduling_dpmsolver_multistep + has_clip_sample = False + elif args.sampler == "dpmsingle": + scheduler_cls = DPMSolverSinglestepScheduler + scheduler_module = diffusers.schedulers.scheduling_dpmsolver_singlestep + has_clip_sample = False + has_steps_offset = False + elif args.sampler == "heun": + scheduler_cls = HeunDiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_heun_discrete + has_clip_sample = False + elif args.sampler == "dpm_2" or args.sampler == "k_dpm_2": + scheduler_cls = KDPM2DiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_discrete + has_clip_sample = False + elif args.sampler == "dpm_2_a" or args.sampler == "k_dpm_2_a": + scheduler_cls = KDPM2AncestralDiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_ancestral_discrete + scheduler_num_noises_per_step = 2 + has_clip_sample = False + + if args.v_parameterization: + sched_init_args["prediction_type"] = "v_prediction" + + # 警告を出さないようにする + if has_steps_offset: + sched_init_args["steps_offset"] = 1 + if has_clip_sample: + sched_init_args["clip_sample"] = False + + # samplerの乱数をあらかじめ指定するための処理 + + # replace randn + class NoiseManager: + def __init__(self): + self.sampler_noises = None + self.sampler_noise_index = 0 + + def reset_sampler_noises(self, noises): + self.sampler_noise_index = 0 + self.sampler_noises = noises + + def randn(self, shape, device=None, dtype=None, layout=None, generator=None): + # print("replacing", shape, len(self.sampler_noises), self.sampler_noise_index) + if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises): + noise = self.sampler_noises[self.sampler_noise_index] + if shape != noise.shape: + noise = None + else: + noise = None + + if noise == None: + print(f"unexpected noise request: {self.sampler_noise_index}, {shape}") + noise = torch.randn(shape, dtype=dtype, device=device, generator=generator) + + self.sampler_noise_index += 1 + return noise + + class TorchRandReplacer: + def __init__(self, noise_manager): + self.noise_manager = noise_manager + + def __getattr__(self, item): + if item == "randn": + return self.noise_manager.randn + if hasattr(torch, item): + return getattr(torch, item) + raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item)) + + noise_manager = NoiseManager() + if scheduler_module is not None: + scheduler_module.torch = TorchRandReplacer(noise_manager) + + scheduler = scheduler_cls( + num_train_timesteps=SCHEDULER_TIMESTEPS, + beta_start=SCHEDULER_LINEAR_START, + beta_end=SCHEDULER_LINEAR_END, + beta_schedule=SCHEDLER_SCHEDULE, + **sched_init_args, + ) + + # ↓以下は結局PipeでFalseに設定されるので意味がなかった + # # clip_sample=Trueにする + # if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: + # print("set clip_sample to True") + # scheduler.config.clip_sample = True + + # deviceを決定する + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # "mps"を考量してない + + # custom pipelineをコピったやつを生成する + if args.vae_slices: + from library.slicing_vae import SlicingAutoencoderKL + + sli_vae = SlicingAutoencoderKL( + act_fn="silu", + block_out_channels=(128, 256, 512, 512), + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"], + in_channels=3, + latent_channels=4, + layers_per_block=2, + norm_num_groups=32, + out_channels=3, + sample_size=512, + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"], + num_slices=args.vae_slices, + ) + sli_vae.load_state_dict(vae.state_dict()) # vaeのパラメータをコピーする + vae = sli_vae + del sli_vae + + vae_dtype = dtype + if args.no_half_vae: + print("set vae_dtype to float32") + vae_dtype = torch.float32 + vae.to(vae_dtype).to(device) + vae.eval() + + for text_encoder in text_encoders: + text_encoder.to(dtype).to(device) + text_encoder.eval() + unet.to(dtype).to(device) + unet.eval() + + # networkを組み込む + if args.network_module: + networks = [] + network_default_muls = [] + network_pre_calc = args.network_pre_calc + + # merge関連の引数を統合する + if args.network_merge: + network_merge = len(args.network_module) # all networks are merged + elif args.network_merge_n_models: + network_merge = args.network_merge_n_models + else: + network_merge = 0 + print(f"network_merge: {network_merge}") + + for i, network_module in enumerate(args.network_module): + print("import network module:", network_module) + imported_module = importlib.import_module(network_module) + + network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i] + + net_kwargs = {} + if args.network_args and i < len(args.network_args): + network_args = args.network_args[i] + # TODO escape special chars + network_args = network_args.split(";") + for net_arg in network_args: + key, value = net_arg.split("=") + net_kwargs[key] = value + + if args.network_weights is None or len(args.network_weights) <= i: + raise ValueError("No weight. Weight is required.") + + network_weight = args.network_weights[i] + print("load network weights from:", network_weight) + + if model_util.is_safetensors(network_weight) and args.network_show_meta: + from safetensors.torch import safe_open + + with safe_open(network_weight, framework="pt") as f: + metadata = f.metadata() + if metadata is not None: + print(f"metadata for: {network_weight}: {metadata}") + + network, weights_sd = imported_module.create_network_from_weights( + network_mul, network_weight, vae, text_encoders, unet, for_inference=True, **net_kwargs + ) + if network is None: + return + + mergeable = network.is_mergeable() + if network_merge and not mergeable: + print("network is not mergiable. ignore merge option.") + + if not mergeable or i >= network_merge: + # not merging + network.apply_to(text_encoders, unet) + info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい + print(f"weights are loaded: {info}") + + if args.opt_channels_last: + network.to(memory_format=torch.channels_last) + network.to(dtype).to(device) + + if network_pre_calc: + print("backup original weights") + network.backup_weights() + + networks.append(network) + network_default_muls.append(network_mul) + else: + network.merge_to(text_encoders, unet, weights_sd, dtype, device) + + else: + networks = [] + + # upscalerの指定があれば取得する + upscaler = None + if args.highres_fix_upscaler: + print("import upscaler module:", args.highres_fix_upscaler) + imported_module = importlib.import_module(args.highres_fix_upscaler) + + us_kwargs = {} + if args.highres_fix_upscaler_args: + for net_arg in args.highres_fix_upscaler_args.split(";"): + key, value = net_arg.split("=") + us_kwargs[key] = value + + print("create upscaler") + upscaler = imported_module.create_upscaler(**us_kwargs) + upscaler.to(dtype).to(device) + + # ControlNetの処理 + control_nets: List[ControlNetInfo] = [] + if args.control_net_models: + for i, model in enumerate(args.control_net_models): + prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i] + weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i] + ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i] + + ctrl_unet, ctrl_net = original_control_net.load_control_net(args.v2, unet, model) + prep = original_control_net.load_preprocess(prep_type) + control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio)) + + control_net_lllites: List[Tuple[ControlNetLLLite, float]] = [] + if args.control_net_lllite_models: + for i, model_file in enumerate(args.control_net_lllite_models): + print(f"loading ControlNet-LLLite: {model_file}") + + from safetensors.torch import load_file + + state_dict = load_file(model_file) + mlp_dim = None + cond_emb_dim = None + for key, value in state_dict.items(): + if mlp_dim is None and "down.0.weight" in key: + mlp_dim = value.shape[0] + elif cond_emb_dim is None and "conditioning1.0" in key: + cond_emb_dim = value.shape[0] * 2 + if mlp_dim is not None and cond_emb_dim is not None: + break + assert mlp_dim is not None and cond_emb_dim is not None, f"invalid control net: {model_file}" + + multiplier = ( + 1.0 + if not args.control_net_multipliers or len(args.control_net_multipliers) <= i + else args.control_net_multipliers[i] + ) + ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i] + + control_net_lllite = ControlNetLLLite(unet, cond_emb_dim, mlp_dim, multiplier=multiplier) + control_net_lllite.apply_to() + control_net_lllite.load_state_dict(state_dict) + control_net_lllite.to(dtype).to(device) + control_net_lllite.set_batch_cond_only(False, False) + control_net_lllites.append((control_net_lllite, ratio)) + assert ( + len(control_nets) == 0 or len(control_net_lllites) == 0 + ), "ControlNet and ControlNet-LLLite cannot be used at the same time" + + if args.opt_channels_last: + print(f"set optimizing: channels last") + for text_encoder in text_encoders: + text_encoder.to(memory_format=torch.channels_last) + vae.to(memory_format=torch.channels_last) + unet.to(memory_format=torch.channels_last) + if networks: + for network in networks: + network.to(memory_format=torch.channels_last) + + for cn in control_nets: + cn.to(memory_format=torch.channels_last) + + for cn in control_net_lllites: + cn.to(memory_format=torch.channels_last) + + pipe = PipelineLike( + is_sdxl, + device, + vae, + text_encoders, + tokenizers, + unet, + scheduler, + args.clip_skip, + ) + pipe.set_control_nets(control_nets) + pipe.set_control_net_lllites(control_net_lllites) + print("pipeline is ready.") + + if args.diffusers_xformers: + pipe.enable_xformers_memory_efficient_attention() + + # Deep Shrink + if args.ds_depth_1 is not None: + unet.set_deep_shrink(args.ds_depth_1, args.ds_timesteps_1, args.ds_depth_2, args.ds_timesteps_2, args.ds_ratio) + + # Gradual Latent + if args.gradual_latent_timesteps is not None: + if args.gradual_latent_unsharp_params: + us_params = args.gradual_latent_unsharp_params.split(",") + us_ksize, us_sigma, us_strength = [float(v) for v in us_params[:3]] + us_target_x = True if len(us_params) <= 3 else bool(int(us_params[3])) + us_ksize = int(us_ksize) + else: + us_ksize, us_sigma, us_strength, us_target_x = None, None, None, None + + gradual_latent = GradualLatent( + args.gradual_latent_ratio, + args.gradual_latent_timesteps, + args.gradual_latent_every_n_steps, + args.gradual_latent_ratio_step, + args.gradual_latent_s_noise, + us_ksize, + us_sigma, + us_strength, + us_target_x, + ) + pipe.set_gradual_latent(gradual_latent) + + # Textual Inversionを処理する + if args.textual_inversion_embeddings: + token_ids_embeds1 = [] + token_ids_embeds2 = [] + for embeds_file in args.textual_inversion_embeddings: + if model_util.is_safetensors(embeds_file): + from safetensors.torch import load_file + + data = load_file(embeds_file) + else: + data = torch.load(embeds_file, map_location="cpu") + + if "string_to_param" in data: + data = data["string_to_param"] + if is_sdxl: + + embeds1 = data["clip_l"] # text encoder 1 + embeds2 = data["clip_g"] # text encoder 2 + else: + embeds1 = next(iter(data.values())) + embeds2 = None + + num_vectors_per_token = embeds1.size()[0] + token_string = os.path.splitext(os.path.basename(embeds_file))[0] + + token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)] + + # add new word to tokenizer, count is num_vectors_per_token + num_added_tokens1 = tokenizers[0].add_tokens(token_strings) + num_added_tokens2 = tokenizers[1].add_tokens(token_strings) if is_sdxl else 0 + assert num_added_tokens1 == num_vectors_per_token and ( + num_added_tokens2 == 0 or num_added_tokens2 == num_vectors_per_token + ), ( + f"tokenizer has same word to token string (filename): {embeds_file}" + + f" / 指定した名前(ファイル名)のトークンが既に存在します: {embeds_file}" + ) + + token_ids1 = tokenizers[0].convert_tokens_to_ids(token_strings) + token_ids2 = tokenizers[1].convert_tokens_to_ids(token_strings) if is_sdxl else None + print(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids1} and {token_ids2}") + assert ( + min(token_ids1) == token_ids1[0] and token_ids1[-1] == token_ids1[0] + len(token_ids1) - 1 + ), f"token ids1 is not ordered" + assert not is_sdxl or ( + min(token_ids2) == token_ids2[0] and token_ids2[-1] == token_ids2[0] + len(token_ids2) - 1 + ), f"token ids2 is not ordered" + assert len(tokenizers[0]) - 1 == token_ids1[-1], f"token ids 1 is not end of tokenize: {len(tokenizers[0])}" + assert ( + not is_sdxl or len(tokenizers[1]) - 1 == token_ids2[-1] + ), f"token ids 2 is not end of tokenize: {len(tokenizers[1])}" + + if num_vectors_per_token > 1: + pipe.add_token_replacement(0, token_ids1[0], token_ids1) # hoge -> hoge, hogea, hogeb, ... + if is_sdxl: + pipe.add_token_replacement(1, token_ids2[0], token_ids2) + + token_ids_embeds1.append((token_ids1, embeds1)) + if is_sdxl: + token_ids_embeds2.append((token_ids2, embeds2)) + + text_encoders[0].resize_token_embeddings(len(tokenizers[0])) + token_embeds1 = text_encoders[0].get_input_embeddings().weight.data + for token_ids, embeds in token_ids_embeds1: + for token_id, embed in zip(token_ids, embeds): + token_embeds1[token_id] = embed + + if is_sdxl: + text_encoders[1].resize_token_embeddings(len(tokenizers[1])) + token_embeds2 = text_encoders[1].get_input_embeddings().weight.data + for token_ids, embeds in token_ids_embeds2: + for token_id, embed in zip(token_ids, embeds): + token_embeds2[token_id] = embed + + # promptを取得する + if args.from_file is not None: + print(f"reading prompts from {args.from_file}") + with open(args.from_file, "r", encoding="utf-8") as f: + prompt_list = f.read().splitlines() + prompt_list = [d for d in prompt_list if len(d.strip()) > 0 and d[0] != "#"] + elif args.prompt is not None: + prompt_list = [args.prompt] + else: + prompt_list = [] + + if args.interactive: + args.n_iter = 1 + + # img2imgの前処理、画像の読み込みなど + def load_images(path): + if os.path.isfile(path): + paths = [path] + else: + paths = ( + glob.glob(os.path.join(path, "*.png")) + + glob.glob(os.path.join(path, "*.jpg")) + + glob.glob(os.path.join(path, "*.jpeg")) + + glob.glob(os.path.join(path, "*.webp")) + ) + paths.sort() + + images = [] + for p in paths: + image = Image.open(p) + if image.mode != "RGB": + print(f"convert image to RGB from {image.mode}: {p}") + image = image.convert("RGB") + images.append(image) + + return images + + def resize_images(imgs, size): + resized = [] + for img in imgs: + r_img = img.resize(size, Image.Resampling.LANCZOS) + if hasattr(img, "filename"): # filename属性がない場合があるらしい + r_img.filename = img.filename + resized.append(r_img) + return resized + + if args.image_path is not None: + print(f"load image for img2img: {args.image_path}") + init_images = load_images(args.image_path) + assert len(init_images) > 0, f"No image / 画像がありません: {args.image_path}" + print(f"loaded {len(init_images)} images for img2img") + + # CLIP Vision + if args.clip_vision_strength is not None: + print(f"load CLIP Vision model: {CLIP_VISION_MODEL}") + vision_model = CLIPVisionModelWithProjection.from_pretrained(CLIP_VISION_MODEL, projection_dim=1280) + vision_model.to(device, dtype) + processor = CLIPImageProcessor.from_pretrained(CLIP_VISION_MODEL) + + pipe.clip_vision_model = vision_model + pipe.clip_vision_processor = processor + pipe.clip_vision_strength = args.clip_vision_strength + print(f"CLIP Vision model loaded.") + + else: + init_images = None + + if args.mask_path is not None: + print(f"load mask for inpainting: {args.mask_path}") + mask_images = load_images(args.mask_path) + assert len(mask_images) > 0, f"No mask image / マスク画像がありません: {args.image_path}" + print(f"loaded {len(mask_images)} mask images for inpainting") + else: + mask_images = None + + # promptがないとき、画像のPngInfoから取得する + if init_images is not None and len(prompt_list) == 0 and not args.interactive: + print("get prompts from images' metadata") + for img in init_images: + if "prompt" in img.text: + prompt = img.text["prompt"] + if "negative-prompt" in img.text: + prompt += " --n " + img.text["negative-prompt"] + prompt_list.append(prompt) + + # プロンプトと画像を一致させるため指定回数だけ繰り返す(画像を増幅する) + l = [] + for im in init_images: + l.extend([im] * args.images_per_prompt) + init_images = l + + if mask_images is not None: + l = [] + for im in mask_images: + l.extend([im] * args.images_per_prompt) + mask_images = l + + # 画像サイズにオプション指定があるときはリサイズする + if args.W is not None and args.H is not None: + # highres fix を考慮に入れる + w, h = args.W, args.H + if highres_fix: + w = int(w * args.highres_fix_scale + 0.5) + h = int(h * args.highres_fix_scale + 0.5) + + if init_images is not None: + print(f"resize img2img source images to {w}*{h}") + init_images = resize_images(init_images, (w, h)) + if mask_images is not None: + print(f"resize img2img mask images to {w}*{h}") + mask_images = resize_images(mask_images, (w, h)) + + regional_network = False + if networks and mask_images: + # mask を領域情報として流用する、現在は一回のコマンド呼び出しで1枚だけ対応 + regional_network = True + print("use mask as region") + + size = None + for i, network in enumerate(networks): + if (i < 3 and args.network_regional_mask_max_color_codes is None) or i < args.network_regional_mask_max_color_codes: + np_mask = np.array(mask_images[0]) + + if args.network_regional_mask_max_color_codes: + # カラーコードでマスクを指定する + ch0 = (i + 1) & 1 + ch1 = ((i + 1) >> 1) & 1 + ch2 = ((i + 1) >> 2) & 1 + np_mask = np.all(np_mask == np.array([ch0, ch1, ch2]) * 255, axis=2) + np_mask = np_mask.astype(np.uint8) * 255 + else: + np_mask = np_mask[:, :, i] + size = np_mask.shape + else: + np_mask = np.full(size, 255, dtype=np.uint8) + mask = torch.from_numpy(np_mask.astype(np.float32) / 255.0) + network.set_region(i, i == len(networks) - 1, mask) + mask_images = None + + prev_image = None # for VGG16 guided + if args.guide_image_path is not None: + print(f"load image for ControlNet guidance: {args.guide_image_path}") + guide_images = [] + for p in args.guide_image_path: + guide_images.extend(load_images(p)) + + print(f"loaded {len(guide_images)} guide images for guidance") + if len(guide_images) == 0: + print( + f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}" + ) + guide_images = None + else: + guide_images = None + + # seed指定時はseedを決めておく + if args.seed is not None: + # dynamic promptを使うと足りなくなる→images_per_promptを適当に大きくしておいてもらう + random.seed(args.seed) + predefined_seeds = [random.randint(0, 0x7FFFFFFF) for _ in range(args.n_iter * len(prompt_list) * args.images_per_prompt)] + if len(predefined_seeds) == 1: + predefined_seeds[0] = args.seed + else: + predefined_seeds = None + + # デフォルト画像サイズを設定する:img2imgではこれらの値は無視される(またはW*Hにリサイズ済み) + if args.W is None: + args.W = 1024 if is_sdxl else 512 + if args.H is None: + args.H = 1024 if is_sdxl else 512 + + # 画像生成のループ + os.makedirs(args.outdir, exist_ok=True) + max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples + + for gen_iter in range(args.n_iter): + print(f"iteration {gen_iter+1}/{args.n_iter}") + iter_seed = random.randint(0, 0x7FFFFFFF) + + # shuffle prompt list + if args.shuffle_prompts: + random.shuffle(prompt_list) + + # バッチ処理の関数 + def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): + batch_size = len(batch) + + # highres_fixの処理 + if highres_fix and not highres_1st: + # 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す + is_1st_latent = upscaler.support_latents() if upscaler else args.highres_fix_latents_upscaling + + print("process 1st stage") + batch_1st = [] + for _, base, ext in batch: + + def scale_and_round(x): + if x is None: + return None + return int(x * args.highres_fix_scale + 0.5) + + width_1st = scale_and_round(ext.width) + height_1st = scale_and_round(ext.height) + width_1st = width_1st - width_1st % 32 + height_1st = height_1st - height_1st % 32 + + original_width_1st = scale_and_round(ext.original_width) + original_height_1st = scale_and_round(ext.original_height) + original_width_negative_1st = scale_and_round(ext.original_width_negative) + original_height_negative_1st = scale_and_round(ext.original_height_negative) + crop_left_1st = scale_and_round(ext.crop_left) + crop_top_1st = scale_and_round(ext.crop_top) + + strength_1st = ext.strength if args.highres_fix_strength is None else args.highres_fix_strength + + ext_1st = BatchDataExt( + width_1st, + height_1st, + original_width_1st, + original_height_1st, + original_width_negative_1st, + original_height_negative_1st, + crop_left_1st, + crop_top_1st, + args.highres_fix_steps, + ext.scale, + ext.negative_scale, + strength_1st, + ext.network_muls, + ext.num_sub_prompts, + ) + batch_1st.append(BatchData(is_1st_latent, base, ext_1st)) + + pipe.set_enable_control_net(True) # 1st stageではControlNetを有効にする + images_1st = process_batch(batch_1st, True, True) + + # 2nd stageのバッチを作成して以下処理する + print("process 2nd stage") + width_2nd, height_2nd = batch[0].ext.width, batch[0].ext.height + + if upscaler: + # upscalerを使って画像を拡大する + lowreso_imgs = None if is_1st_latent else images_1st + lowreso_latents = None if not is_1st_latent else images_1st + + # 戻り値はPIL.Image.Imageかtorch.Tensorのlatents + batch_size = len(images_1st) + vae_batch_size = ( + batch_size + if args.vae_batch_size is None + else (max(1, int(batch_size * args.vae_batch_size)) if args.vae_batch_size < 1 else args.vae_batch_size) + ) + vae_batch_size = int(vae_batch_size) + images_1st = upscaler.upscale( + vae, lowreso_imgs, lowreso_latents, dtype, width_2nd, height_2nd, batch_size, vae_batch_size + ) + + elif args.highres_fix_latents_upscaling: + # latentを拡大する + org_dtype = images_1st.dtype + if images_1st.dtype == torch.bfloat16: + images_1st = images_1st.to(torch.float) # interpolateがbf16をサポートしていない + images_1st = torch.nn.functional.interpolate( + images_1st, (batch[0].ext.height // 8, batch[0].ext.width // 8), mode="bilinear" + ) # , antialias=True) + images_1st = images_1st.to(org_dtype) + + else: + # 画像をLANCZOSで拡大する + images_1st = [image.resize((width_2nd, height_2nd), resample=PIL.Image.LANCZOS) for image in images_1st] + + batch_2nd = [] + for i, (bd, image) in enumerate(zip(batch, images_1st)): + bd_2nd = BatchData(False, BatchDataBase(*bd.base[0:3], bd.base.seed + 1, image, None, *bd.base[6:]), bd.ext) + batch_2nd.append(bd_2nd) + batch = batch_2nd + + if args.highres_fix_disable_control_net: + pipe.set_enable_control_net(False) # オプション指定時、2nd stageではControlNetを無効にする + + # このバッチの情報を取り出す + ( + return_latents, + (step_first, _, _, _, init_image, mask_image, _, guide_image, _), + ( + width, + height, + original_width, + original_height, + original_width_negative, + original_height_negative, + crop_left, + crop_top, + steps, + scale, + negative_scale, + strength, + network_muls, + num_sub_prompts, + ), + ) = batch[0] + noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR) + + prompts = [] + negative_prompts = [] + raw_prompts = [] + start_code = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) + noises = [ + torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) + for _ in range(steps * scheduler_num_noises_per_step) + ] + seeds = [] + clip_prompts = [] + + if init_image is not None: # img2img? + i2i_noises = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) + init_images = [] + + if mask_image is not None: + mask_images = [] + else: + mask_images = None + else: + i2i_noises = None + init_images = None + mask_images = None + + if guide_image is not None: # CLIP image guided? + guide_images = [] + else: + guide_images = None + + # バッチ内の位置に関わらず同じ乱数を使うためにここで乱数を生成しておく。あわせてimage/maskがbatch内で同一かチェックする + all_images_are_same = True + all_masks_are_same = True + all_guide_images_are_same = True + for i, ( + _, + (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt), + _, + ) in enumerate(batch): + prompts.append(prompt) + negative_prompts.append(negative_prompt) + seeds.append(seed) + clip_prompts.append(clip_prompt) + raw_prompts.append(raw_prompt) + + if init_image is not None: + init_images.append(init_image) + if i > 0 and all_images_are_same: + all_images_are_same = init_images[-2] is init_image + + if mask_image is not None: + mask_images.append(mask_image) + if i > 0 and all_masks_are_same: + all_masks_are_same = mask_images[-2] is mask_image + + if guide_image is not None: + if type(guide_image) is list: + guide_images.extend(guide_image) + all_guide_images_are_same = False + else: + guide_images.append(guide_image) + if i > 0 and all_guide_images_are_same: + all_guide_images_are_same = guide_images[-2] is guide_image + + # make start code + torch.manual_seed(seed) + start_code[i] = torch.randn(noise_shape, device=device, dtype=dtype) + + # make each noises + for j in range(steps * scheduler_num_noises_per_step): + noises[j][i] = torch.randn(noise_shape, device=device, dtype=dtype) + + if i2i_noises is not None: # img2img noise + i2i_noises[i] = torch.randn(noise_shape, device=device, dtype=dtype) + + noise_manager.reset_sampler_noises(noises) + + # すべての画像が同じなら1枚だけpipeに渡すことでpipe側で処理を高速化する + if init_images is not None and all_images_are_same: + init_images = init_images[0] + if mask_images is not None and all_masks_are_same: + mask_images = mask_images[0] + if guide_images is not None and all_guide_images_are_same: + guide_images = guide_images[0] + + # ControlNet使用時はguide imageをリサイズする + if control_nets or control_net_lllites: + # TODO resampleのメソッド + guide_images = guide_images if type(guide_images) == list else [guide_images] + guide_images = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in guide_images] + if len(guide_images) == 1: + guide_images = guide_images[0] + + # generate + if networks: + # 追加ネットワークの処理 + shared = {} + for n, m in zip(networks, network_muls if network_muls else network_default_muls): + n.set_multiplier(m) + if regional_network: + n.set_current_generation(batch_size, num_sub_prompts, width, height, shared) + + if not regional_network and network_pre_calc: + for n in networks: + n.restore_weights() + for n in networks: + n.pre_calculation() + print("pre-calculation... done") + + images = pipe( + prompts, + negative_prompts, + init_images, + mask_images, + height, + width, + original_height, + original_width, + original_height_negative, + original_width_negative, + crop_top, + crop_left, + steps, + scale, + negative_scale, + strength, + latents=start_code, + output_type="pil", + max_embeddings_multiples=max_embeddings_multiples, + img2img_noise=i2i_noises, + vae_batch_size=args.vae_batch_size, + return_latents=return_latents, + clip_prompts=clip_prompts, + clip_guide_images=guide_images, + ) + if highres_1st and not args.highres_fix_save_1st: # return images or latents + return images + + # save image + highres_prefix = ("0" if highres_1st else "1") if highres_fix else "" + ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) + for i, (image, prompt, negative_prompts, seed, clip_prompt, raw_prompt) in enumerate( + zip(images, prompts, negative_prompts, seeds, clip_prompts, raw_prompts) + ): + if highres_fix: + seed -= 1 # record original seed + metadata = PngInfo() + metadata.add_text("prompt", prompt) + metadata.add_text("seed", str(seed)) + metadata.add_text("sampler", args.sampler) + metadata.add_text("steps", str(steps)) + metadata.add_text("scale", str(scale)) + if negative_prompt is not None: + metadata.add_text("negative-prompt", negative_prompt) + if negative_scale is not None: + metadata.add_text("negative-scale", str(negative_scale)) + if clip_prompt is not None: + metadata.add_text("clip-prompt", clip_prompt) + if raw_prompt is not None: + metadata.add_text("raw-prompt", raw_prompt) + if is_sdxl: + metadata.add_text("original-height", str(original_height)) + metadata.add_text("original-width", str(original_width)) + metadata.add_text("original-height-negative", str(original_height_negative)) + metadata.add_text("original-width-negative", str(original_width_negative)) + metadata.add_text("crop-top", str(crop_top)) + metadata.add_text("crop-left", str(crop_left)) + + if args.use_original_file_name and init_images is not None: + if type(init_images) is list: + fln = os.path.splitext(os.path.basename(init_images[i % len(init_images)].filename))[0] + ".png" + else: + fln = os.path.splitext(os.path.basename(init_images.filename))[0] + ".png" + elif args.sequential_file_name: + fln = f"im_{highres_prefix}{step_first + i + 1:06d}.png" + else: + fln = f"im_{ts_str}_{highres_prefix}{i:03d}_{seed}.png" + + image.save(os.path.join(args.outdir, fln), pnginfo=metadata) + + if not args.no_preview and not highres_1st and args.interactive: + try: + import cv2 + + for prompt, image in zip(prompts, images): + cv2.imshow(prompt[:128], np.array(image)[:, :, ::-1]) # プロンプトが長いと死ぬ + cv2.waitKey() + cv2.destroyAllWindows() + except ImportError: + print( + "opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません" + ) + + return images + + # 画像生成のプロンプトが一周するまでのループ + prompt_index = 0 + global_step = 0 + batch_data = [] + while args.interactive or prompt_index < len(prompt_list): + if len(prompt_list) == 0: + # interactive + valid = False + while not valid: + print("\nType prompt:") + try: + raw_prompt = input() + except EOFError: + break + + valid = len(raw_prompt.strip().split(" --")[0].strip()) > 0 + if not valid: # EOF, end app + break + else: + raw_prompt = prompt_list[prompt_index] + + # sd-dynamic-prompts like variants: + # count is 1 (not dynamic) or images_per_prompt (no enumeration) or arbitrary (enumeration) + raw_prompts = handle_dynamic_prompt_variants(raw_prompt, args.images_per_prompt) + + # repeat prompt + for pi in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)): + raw_prompt = raw_prompts[pi] if len(raw_prompts) > 1 else raw_prompts[0] + + if pi == 0 or len(raw_prompts) > 1: + # parse prompt: if prompt is not changed, skip parsing + width = args.W + height = args.H + original_width = args.original_width + original_height = args.original_height + original_width_negative = args.original_width_negative + original_height_negative = args.original_height_negative + crop_top = args.crop_top + crop_left = args.crop_left + scale = args.scale + negative_scale = args.negative_scale + steps = args.steps + seed = None + seeds = None + strength = 0.8 if args.strength is None else args.strength + negative_prompt = "" + clip_prompt = None + network_muls = None + + # Deep Shrink + ds_depth_1 = None # means no override + ds_timesteps_1 = args.ds_timesteps_1 + ds_depth_2 = args.ds_depth_2 + ds_timesteps_2 = args.ds_timesteps_2 + ds_ratio = args.ds_ratio + + # Gradual Latent + gl_timesteps = None # means no override + gl_ratio = args.gradual_latent_ratio + gl_every_n_steps = args.gradual_latent_every_n_steps + gl_ratio_step = args.gradual_latent_ratio_step + gl_s_noise = args.gradual_latent_s_noise + gl_unsharp_params = args.gradual_latent_unsharp_params + + prompt_args = raw_prompt.strip().split(" --") + prompt = prompt_args[0] + print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") + + for parg in prompt_args[1:]: + try: + m = re.match(r"w (\d+)", parg, re.IGNORECASE) + if m: + width = int(m.group(1)) + print(f"width: {width}") + continue + + m = re.match(r"h (\d+)", parg, re.IGNORECASE) + if m: + height = int(m.group(1)) + print(f"height: {height}") + continue + + m = re.match(r"ow (\d+)", parg, re.IGNORECASE) + if m: + original_width = int(m.group(1)) + print(f"original width: {original_width}") + continue + + m = re.match(r"oh (\d+)", parg, re.IGNORECASE) + if m: + original_height = int(m.group(1)) + print(f"original height: {original_height}") + continue + + m = re.match(r"nw (\d+)", parg, re.IGNORECASE) + if m: + original_width_negative = int(m.group(1)) + print(f"original width negative: {original_width_negative}") + continue + + m = re.match(r"nh (\d+)", parg, re.IGNORECASE) + if m: + original_height_negative = int(m.group(1)) + print(f"original height negative: {original_height_negative}") + continue + + m = re.match(r"ct (\d+)", parg, re.IGNORECASE) + if m: + crop_top = int(m.group(1)) + print(f"crop top: {crop_top}") + continue + + m = re.match(r"cl (\d+)", parg, re.IGNORECASE) + if m: + crop_left = int(m.group(1)) + print(f"crop left: {crop_left}") + continue + + m = re.match(r"s (\d+)", parg, re.IGNORECASE) + if m: # steps + steps = max(1, min(1000, int(m.group(1)))) + print(f"steps: {steps}") + continue + + m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE) + if m: # seed + seeds = [int(d) for d in m.group(1).split(",")] + print(f"seeds: {seeds}") + continue + + m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) + if m: # scale + scale = float(m.group(1)) + print(f"scale: {scale}") + continue + + m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE) + if m: # negative scale + if m.group(1).lower() == "none": + negative_scale = None + else: + negative_scale = float(m.group(1)) + print(f"negative scale: {negative_scale}") + continue + + m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE) + if m: # strength + strength = float(m.group(1)) + print(f"strength: {strength}") + continue + + m = re.match(r"n (.+)", parg, re.IGNORECASE) + if m: # negative prompt + negative_prompt = m.group(1) + print(f"negative prompt: {negative_prompt}") + continue + + m = re.match(r"c (.+)", parg, re.IGNORECASE) + if m: # clip prompt + clip_prompt = m.group(1) + print(f"clip prompt: {clip_prompt}") + continue + + m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE) + if m: # network multiplies + network_muls = [float(v) for v in m.group(1).split(",")] + while len(network_muls) < len(networks): + network_muls.append(network_muls[-1]) + print(f"network mul: {network_muls}") + continue + + # Deep Shrink + m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink depth 1 + ds_depth_1 = int(m.group(1)) + print(f"deep shrink depth 1: {ds_depth_1}") + continue + + m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink timesteps 1 + ds_timesteps_1 = int(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + print(f"deep shrink timesteps 1: {ds_timesteps_1}") + continue + + m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink depth 2 + ds_depth_2 = int(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + print(f"deep shrink depth 2: {ds_depth_2}") + continue + + m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink timesteps 2 + ds_timesteps_2 = int(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + print(f"deep shrink timesteps 2: {ds_timesteps_2}") + continue + + m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink ratio + ds_ratio = float(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + print(f"deep shrink ratio: {ds_ratio}") + continue + + # Gradual Latent + m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent timesteps + gl_timesteps = int(m.group(1)) + print(f"gradual latent timesteps: {gl_timesteps}") + continue + + m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio + gl_ratio = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + print(f"gradual latent ratio: {ds_ratio}") + continue + + m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent every n steps + gl_every_n_steps = int(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + print(f"gradual latent every n steps: {gl_every_n_steps}") + continue + + m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio step + gl_ratio_step = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + print(f"gradual latent ratio step: {gl_ratio_step}") + continue + + m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent s noise + gl_s_noise = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + print(f"gradual latent s noise: {gl_s_noise}") + continue + + m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE) + if m: # gradual latent unsharp params + gl_unsharp_params = m.group(1) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + print(f"gradual latent unsharp params: {gl_unsharp_params}") + continue + + except ValueError as ex: + print(f"Exception in parsing / 解析エラー: {parg}") + print(ex) + + # override Deep Shrink + if ds_depth_1 is not None: + if ds_depth_1 < 0: + ds_depth_1 = args.ds_depth_1 or 3 + unet.set_deep_shrink(ds_depth_1, ds_timesteps_1, ds_depth_2, ds_timesteps_2, ds_ratio) + + # override Gradual Latent + if gl_timesteps is not None: + if gl_timesteps < 0: + gl_timesteps = args.gradual_latent_timesteps or 650 + if gl_unsharp_params is not None: + unsharp_params = gl_unsharp_params.split(",") + us_ksize, us_sigma, us_strength = [float(v) for v in unsharp_params[:3]] + us_target_x = True if len(unsharp_params) < 4 else bool(int(unsharp_params[3])) + us_ksize = int(us_ksize) + else: + us_ksize, us_sigma, us_strength, us_target_x = None, None, None, None + gradual_latent = GradualLatent( + gl_ratio, + gl_timesteps, + gl_every_n_steps, + gl_ratio_step, + gl_s_noise, + us_ksize, + us_sigma, + us_strength, + us_target_x, + ) + pipe.set_gradual_latent(gradual_latent) + + # prepare seed + if seeds is not None: # given in prompt + # 数が足りないなら前のをそのまま使う + if len(seeds) > 0: + seed = seeds.pop(0) + else: + if predefined_seeds is not None: + if len(predefined_seeds) > 0: + seed = predefined_seeds.pop(0) + else: + print("predefined seeds are exhausted") + seed = None + elif args.iter_same_seed: + seeds = iter_seed + else: + seed = None # 前のを消す + + if seed is None: + seed = random.randint(0, 0x7FFFFFFF) + if args.interactive: + print(f"seed: {seed}") + + # prepare init image, guide image and mask + init_image = mask_image = guide_image = None + + # 同一イメージを使うとき、本当はlatentに変換しておくと無駄がないが面倒なのでとりあえず毎回処理する + if init_images is not None: + init_image = init_images[global_step % len(init_images)] + + # img2imgの場合は、基本的に元画像のサイズで生成する。highres fixの場合はargs.W, args.Hとscaleに従いリサイズ済みなので無視する + # 32単位に丸めたやつにresizeされるので踏襲する + if not highres_fix: + width, height = init_image.size + width = width - width % 32 + height = height - height % 32 + if width != init_image.size[0] or height != init_image.size[1]: + print( + f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます" + ) + + if mask_images is not None: + mask_image = mask_images[global_step % len(mask_images)] + + if guide_images is not None: + if control_nets or control_net_lllites: # 複数件の場合あり + c = max(len(control_nets), len(control_net_lllites)) + p = global_step % (len(guide_images) // c) + guide_image = guide_images[p * c : p * c + c] + else: + guide_image = guide_images[global_step % len(guide_images)] + + if regional_network: + num_sub_prompts = len(prompt.split(" AND ")) + assert ( + len(networks) <= num_sub_prompts + ), "Number of networks must be less than or equal to number of sub prompts." + else: + num_sub_prompts = None + + b1 = BatchData( + False, + BatchDataBase( + global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt + ), + BatchDataExt( + width, + height, + original_width, + original_height, + original_width_negative, + original_height_negative, + crop_left, + crop_top, + steps, + scale, + negative_scale, + strength, + tuple(network_muls) if network_muls else None, + num_sub_prompts, + ), + ) + if len(batch_data) > 0 and batch_data[-1].ext != b1.ext: # バッチ分割必要? + process_batch(batch_data, highres_fix) + batch_data.clear() + + batch_data.append(b1) + if len(batch_data) == args.batch_size: + prev_image = process_batch(batch_data, highres_fix)[0] + batch_data.clear() + + global_step += 1 + + prompt_index += 1 + + if len(batch_data) > 0: + process_batch(batch_data, highres_fix) + batch_data.clear() + + print("done!") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + parser.add_argument( + "--sdxl", action="store_true", help="load Stable Diffusion XL model / Stable Diffusion XLのモデルを読み込む" + ) + parser.add_argument( + "--v1", action="store_true", help="load Stable Diffusion v1.x model / Stable Diffusion 1.xのモデルを読み込む" + ) + parser.add_argument( + "--v2", action="store_true", help="load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む" + ) + parser.add_argument( + "--v_parameterization", action="store_true", help="enable v-parameterization training / v-parameterization学習を有効にする" + ) + + parser.add_argument("--prompt", type=str, default=None, help="prompt / プロンプト") + parser.add_argument( + "--from_file", + type=str, + default=None, + help="if specified, load prompts from this file / 指定時はプロンプトをファイルから読み込む", + ) + parser.add_argument( + "--interactive", + action="store_true", + help="interactive mode (generates one image) / 対話モード(生成される画像は1枚になります)", + ) + parser.add_argument( + "--no_preview", action="store_true", help="do not show generated image in interactive mode / 対話モードで画像を表示しない" + ) + parser.add_argument( + "--image_path", type=str, default=None, help="image to inpaint or to generate from / img2imgまたはinpaintを行う元画像" + ) + parser.add_argument("--mask_path", type=str, default=None, help="mask in inpainting / inpaint時のマスク") + parser.add_argument("--strength", type=float, default=None, help="img2img strength / img2img時のstrength") + parser.add_argument("--images_per_prompt", type=int, default=1, help="number of images per prompt / プロンプトあたりの出力枚数") + parser.add_argument("--outdir", type=str, default="outputs", help="dir to write results to / 生成画像の出力先") + parser.add_argument( + "--sequential_file_name", action="store_true", help="sequential output file name / 生成画像のファイル名を連番にする" + ) + parser.add_argument( + "--use_original_file_name", + action="store_true", + help="prepend original file name in img2img / img2imgで元画像のファイル名を生成画像のファイル名の先頭に付ける", + ) + # parser.add_argument("--ddim_eta", type=float, default=0.0, help="ddim eta (eta=0.0 corresponds to deterministic sampling", ) + parser.add_argument("--n_iter", type=int, default=1, help="sample this often / 繰り返し回数") + parser.add_argument("--H", type=int, default=None, help="image height, in pixel space / 生成画像高さ") + parser.add_argument("--W", type=int, default=None, help="image width, in pixel space / 生成画像幅") + parser.add_argument( + "--original_height", + type=int, + default=None, + help="original height for SDXL conditioning / SDXLの条件付けに用いるoriginal heightの値", + ) + parser.add_argument( + "--original_width", + type=int, + default=None, + help="original width for SDXL conditioning / SDXLの条件付けに用いるoriginal widthの値", + ) + parser.add_argument( + "--original_height_negative", + type=int, + default=None, + help="original height for SDXL unconditioning / SDXLのネガティブ条件付けに用いるoriginal heightの値", + ) + parser.add_argument( + "--original_width_negative", + type=int, + default=None, + help="original width for SDXL unconditioning / SDXLのネガティブ条件付けに用いるoriginal widthの値", + ) + parser.add_argument( + "--crop_top", type=int, default=None, help="crop top for SDXL conditioning / SDXLの条件付けに用いるcrop topの値" + ) + parser.add_argument( + "--crop_left", type=int, default=None, help="crop left for SDXL conditioning / SDXLの条件付けに用いるcrop leftの値" + ) + parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサイズ") + parser.add_argument( + "--vae_batch_size", + type=float, + default=None, + help="batch size for VAE, < 1.0 for ratio / VAE処理時のバッチサイズ、1未満の値の場合は通常バッチサイズの比率", + ) + parser.add_argument( + "--vae_slices", + type=int, + default=None, + help="number of slices to split image into for VAE to reduce VRAM usage, None for no splitting (default), slower if specified. 16 or 32 recommended / VAE処理時にVRAM使用量削減のため画像を分割するスライス数、Noneの場合は分割しない(デフォルト)、指定すると遅くなる。16か32程度を推奨", + ) + parser.add_argument( + "--no_half_vae", action="store_true", help="do not use fp16/bf16 precision for VAE / VAE処理時にfp16/bf16を使わない" + ) + parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数") + parser.add_argument( + "--sampler", + type=str, + default="ddim", + choices=[ + "ddim", + "pndm", + "lms", + "euler", + "euler_a", + "heun", + "dpm_2", + "dpm_2_a", + "dpmsolver", + "dpmsolver++", + "dpmsingle", + "k_lms", + "k_euler", + "k_euler_a", + "k_dpm_2", + "k_dpm_2_a", + ], + help=f"sampler (scheduler) type / サンプラー(スケジューラ)の種類", + ) + parser.add_argument( + "--scale", + type=float, + default=7.5, + help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty)) / guidance scale", + ) + parser.add_argument( + "--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ" + ) + parser.add_argument( + "--vae", + type=str, + default=None, + help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ", + ) + parser.add_argument( + "--tokenizer_cache_dir", + type=str, + default=None, + help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)", + ) + # parser.add_argument("--replace_clip_l14_336", action='store_true', + # help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える") + parser.add_argument( + "--seed", + type=int, + default=None, + help="seed, or seed of seeds in multiple generation / 1枚生成時のseed、または複数枚生成時の乱数seedを決めるためのseed", + ) + parser.add_argument( + "--iter_same_seed", + action="store_true", + help="use same seed for all prompts in iteration if no seed specified / 乱数seedの指定がないとき繰り返し内はすべて同じseedを使う(プロンプト間の差異の比較用)", + ) + parser.add_argument( + "--shuffle_prompts", + action="store_true", + help="shuffle prompts in iteration / 繰り返し内のプロンプトをシャッフルする", + ) + parser.add_argument("--fp16", action="store_true", help="use fp16 / fp16を指定し省メモリ化する") + parser.add_argument("--bf16", action="store_true", help="use bfloat16 / bfloat16を指定し省メモリ化する") + parser.add_argument("--xformers", action="store_true", help="use xformers / xformersを使用し高速化する") + parser.add_argument("--sdpa", action="store_true", help="use sdpa in PyTorch 2 / sdpa") + parser.add_argument( + "--diffusers_xformers", + action="store_true", + help="use xformers by diffusers (Hypernetworks doesn't work) / Diffusersでxformersを使用する(Hypernetwork利用不可)", + ) + parser.add_argument( + "--opt_channels_last", + action="store_true", + help="set channels last option to model / モデルにchannels lastを指定し最適化する", + ) + parser.add_argument( + "--network_module", + type=str, + default=None, + nargs="*", + help="additional network module to use / 追加ネットワークを使う時そのモジュール名", + ) + parser.add_argument( + "--network_weights", type=str, default=None, nargs="*", help="additional network weights to load / 追加ネットワークの重み" + ) + parser.add_argument( + "--network_mul", type=float, default=None, nargs="*", help="additional network multiplier / 追加ネットワークの効果の倍率" + ) + parser.add_argument( + "--network_args", + type=str, + default=None, + nargs="*", + help="additional arguments for network (key=value) / ネットワークへの追加の引数", + ) + parser.add_argument( + "--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する" + ) + parser.add_argument( + "--network_merge_n_models", + type=int, + default=None, + help="merge this number of networks / この数だけネットワークをマージする", + ) + parser.add_argument( + "--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする" + ) + parser.add_argument( + "--network_pre_calc", + action="store_true", + help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する", + ) + parser.add_argument( + "--network_regional_mask_max_color_codes", + type=int, + default=None, + help="max color codes for regional mask (default is None, mask by channel) / regional maskの最大色数(デフォルトはNoneでチャンネルごとのマスク)", + ) + parser.add_argument( + "--textual_inversion_embeddings", + type=str, + default=None, + nargs="*", + help="Embeddings files of Textual Inversion / Textual Inversionのembeddings", + ) + parser.add_argument( + "--clip_skip", + type=int, + default=None, + help="layer number from bottom to use in CLIP, default is 1 for SD1/2, 2 for SDXL " + + "/ CLIPの後ろからn層目の出力を使う(デフォルトはSD1/2の場合1、SDXLの場合2)", + ) + parser.add_argument( + "--max_embeddings_multiples", + type=int, + default=None, + help="max embedding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる", + ) + parser.add_argument( + "--guide_image_path", type=str, default=None, nargs="*", help="image to ControlNet / ControlNetでガイドに使う画像" + ) + parser.add_argument( + "--highres_fix_scale", + type=float, + default=None, + help="enable highres fix, reso scale for 1st stage / highres fixを有効にして最初の解像度をこのscaleにする", + ) + parser.add_argument( + "--highres_fix_steps", + type=int, + default=28, + help="1st stage steps for highres fix / highres fixの最初のステージのステップ数", + ) + parser.add_argument( + "--highres_fix_strength", + type=float, + default=None, + help="1st stage img2img strength for highres fix / highres fixの最初のステージのimg2img時のstrength、省略時はstrengthと同じ", + ) + parser.add_argument( + "--highres_fix_save_1st", + action="store_true", + help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する", + ) + parser.add_argument( + "--highres_fix_latents_upscaling", + action="store_true", + help="use latents upscaling for highres fix / highres fixでlatentで拡大する", + ) + parser.add_argument( + "--highres_fix_upscaler", + type=str, + default=None, + help="upscaler module for highres fix / highres fixで使うupscalerのモジュール名", + ) + parser.add_argument( + "--highres_fix_upscaler_args", + type=str, + default=None, + help="additional arguments for upscaler (key=value) / upscalerへの追加の引数", + ) + parser.add_argument( + "--highres_fix_disable_control_net", + action="store_true", + help="disable ControlNet for highres fix / highres fixでControlNetを使わない", + ) + + parser.add_argument( + "--negative_scale", + type=float, + default=None, + help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する", + ) + + parser.add_argument( + "--control_net_lllite_models", + type=str, + default=None, + nargs="*", + help="ControlNet models to use / 使用するControlNetのモデル名", + ) + parser.add_argument( + "--control_net_models", type=str, default=None, nargs="*", help="ControlNet models to use / 使用するControlNetのモデル名" + ) + parser.add_argument( + "--control_net_preps", + type=str, + default=None, + nargs="*", + help="ControlNet preprocess to use / 使用するControlNetのプリプロセス名", + ) + parser.add_argument( + "--control_net_multipliers", type=float, default=None, nargs="*", help="ControlNet multiplier / ControlNetの適用率" + ) + parser.add_argument( + "--control_net_ratios", + type=float, + default=None, + nargs="*", + help="ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率", + ) + parser.add_argument( + "--clip_vision_strength", + type=float, + default=None, + help="enable CLIP Vision Conditioning for img2img with this strength / img2imgでCLIP Vision Conditioningを有効にしてこのstrengthで処理する", + ) + + # Deep Shrink + parser.add_argument( + "--ds_depth_1", + type=int, + default=None, + help="Enable Deep Shrink with this depth 1, valid values are 0 to 8 / Deep Shrinkをこのdepthで有効にする", + ) + parser.add_argument( + "--ds_timesteps_1", + type=int, + default=650, + help="Apply Deep Shrink depth 1 until this timesteps / Deep Shrink depth 1を適用するtimesteps", + ) + parser.add_argument("--ds_depth_2", type=int, default=None, help="Deep Shrink depth 2 / Deep Shrinkのdepth 2") + parser.add_argument( + "--ds_timesteps_2", + type=int, + default=650, + help="Apply Deep Shrink depth 2 until this timesteps / Deep Shrink depth 2を適用するtimesteps", + ) + parser.add_argument( + "--ds_ratio", type=float, default=0.5, help="Deep Shrink ratio for downsampling / Deep Shrinkのdownsampling比率" + ) + + # gradual latent + parser.add_argument( + "--gradual_latent_timesteps", + type=int, + default=None, + help="enable Gradual Latent hires fix and apply upscaling from this timesteps / Gradual Latent hires fixをこのtimestepsで有効にし、このtimestepsからアップスケーリングを適用する", + ) + parser.add_argument( + "--gradual_latent_ratio", + type=float, + default=0.5, + help=" this size ratio, 0.5 means 1/2 / Gradual Latent hires fixをこのサイズ比率で有効にする、0.5は1/2を意味する", + ) + parser.add_argument( + "--gradual_latent_ratio_step", + type=float, + default=0.125, + help="step to increase ratio for Gradual Latent / Gradual Latentのratioをどのくらいずつ上げるか", + ) + parser.add_argument( + "--gradual_latent_every_n_steps", + type=int, + default=3, + help="steps to increase size of latents every this steps for Gradual Latent / Gradual Latentでlatentsのサイズをこのステップごとに上げる", + ) + parser.add_argument( + "--gradual_latent_s_noise", + type=float, + default=1.0, + help="s_noise for Gradual Latent / Gradual Latentのs_noise", + ) + parser.add_argument( + "--gradual_latent_unsharp_params", + type=str, + default=None, + help="unsharp mask parameters for Gradual Latent: ksize, sigma, strength, target-x (1 means True). `3,0.5,0.5,1` or `3,1.0,1.0,0` is recommended /" + + " Gradual Latentのunsharp maskのパラメータ: ksize, sigma, strength, target-x. `3,0.5,0.5,1` または `3,1.0,1.0,0` が推奨", + ) + + # # parser.add_argument( + # "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像" + # ) + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + main(args) From bf2de5620c7a7f283878f6090d8feab96abda44e Mon Sep 17 00:00:00 2001 From: mgz <49577754+mgz-dev@users.noreply.github.com> Date: Sat, 3 Feb 2024 20:09:37 -0600 Subject: [PATCH 044/103] fix formatting in resize_lora.py --- networks/resize_lora.py | 429 ++++++++++++++++++++-------------------- 1 file changed, 216 insertions(+), 213 deletions(-) diff --git a/networks/resize_lora.py b/networks/resize_lora.py index 03fc545e7..3c866f1eb 100644 --- a/networks/resize_lora.py +++ b/networks/resize_lora.py @@ -14,68 +14,68 @@ # Model save and load functions def load_state_dict(file_name, dtype): - if model_util.is_safetensors(file_name): - sd = load_file(file_name) - with safe_open(file_name, framework="pt") as f: - metadata = f.metadata() - else: - sd = torch.load(file_name, map_location='cpu') - metadata = None + if model_util.is_safetensors(file_name): + sd = load_file(file_name) + with safe_open(file_name, framework="pt") as f: + metadata = f.metadata() + else: + sd = torch.load(file_name, map_location='cpu') + metadata = None - for key in list(sd.keys()): - if type(sd[key]) == torch.Tensor: - sd[key] = sd[key].to(dtype) + for key in list(sd.keys()): + if type(sd[key]) == torch.Tensor: + sd[key] = sd[key].to(dtype) - return sd, metadata + return sd, metadata def save_to_file(file_name, model, state_dict, dtype, metadata): - if dtype is not None: - for key in list(state_dict.keys()): - if type(state_dict[key]) == torch.Tensor: - state_dict[key] = state_dict[key].to(dtype) + if dtype is not None: + for key in list(state_dict.keys()): + if type(state_dict[key]) == torch.Tensor: + state_dict[key] = state_dict[key].to(dtype) - if model_util.is_safetensors(file_name): - save_file(model, file_name, metadata) - else: - torch.save(model, file_name) + if model_util.is_safetensors(file_name): + save_file(model, file_name, metadata) + else: + torch.save(model, file_name) # Indexing functions def index_sv_cumulative(S, target): - original_sum = float(torch.sum(S)) - cumulative_sums = torch.cumsum(S, dim=0)/original_sum - index = int(torch.searchsorted(cumulative_sums, target)) + 1 - index = max(1, min(index, len(S)-1)) + original_sum = float(torch.sum(S)) + cumulative_sums = torch.cumsum(S, dim=0)/original_sum + index = int(torch.searchsorted(cumulative_sums, target)) + 1 + index = max(1, min(index, len(S)-1)) - return index + return index def index_sv_fro(S, target): - S_squared = S.pow(2) - s_fro_sq = float(torch.sum(S_squared)) - sum_S_squared = torch.cumsum(S_squared, dim=0)/s_fro_sq - index = int(torch.searchsorted(sum_S_squared, target**2)) + 1 - index = max(1, min(index, len(S)-1)) + S_squared = S.pow(2) + s_fro_sq = float(torch.sum(S_squared)) + sum_S_squared = torch.cumsum(S_squared, dim=0)/s_fro_sq + index = int(torch.searchsorted(sum_S_squared, target**2)) + 1 + index = max(1, min(index, len(S)-1)) - return index + return index def index_sv_ratio(S, target): - max_sv = S[0] - min_sv = max_sv/target - index = int(torch.sum(S > min_sv).item()) - index = max(1, min(index, len(S)-1)) + max_sv = S[0] + min_sv = max_sv/target + index = int(torch.sum(S > min_sv).item()) + index = max(1, min(index, len(S)-1)) - return index + return index # Modified from Kohaku-blueleaf's extract/merge functions def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1): out_size, in_size, kernel_size, _ = weight.size() U, S, Vh = torch.linalg.svd(weight.reshape(out_size, -1).to(device)) - + param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale) lora_rank = param_dict["new_rank"] @@ -92,17 +92,17 @@ def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1): out_size, in_size = weight.size() - + U, S, Vh = torch.linalg.svd(weight.to(device)) - + param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale) lora_rank = param_dict["new_rank"] - + U = U[:, :lora_rank] S = S[:lora_rank] U = U @ torch.diag(S) Vh = Vh[:lora_rank, :] - + param_dict["lora_down"] = Vh.reshape(lora_rank, in_size).cpu() param_dict["lora_up"] = U.reshape(out_size, lora_rank).cpu() del U, S, Vh, weight @@ -113,7 +113,7 @@ def merge_conv(lora_down, lora_up, device): in_rank, in_size, kernel_size, k_ = lora_down.shape out_size, out_rank, _, _ = lora_up.shape assert in_rank == out_rank and kernel_size == k_, f"rank {in_rank} {out_rank} or kernel {kernel_size} {k_} mismatch" - + lora_down = lora_down.to(device) lora_up = lora_up.to(device) @@ -127,31 +127,31 @@ def merge_linear(lora_down, lora_up, device): in_rank, in_size = lora_down.shape out_size, out_rank = lora_up.shape assert in_rank == out_rank, f"rank {in_rank} {out_rank} mismatch" - + lora_down = lora_down.to(device) lora_up = lora_up.to(device) - + weight = lora_up @ lora_down del lora_up, lora_down return weight - + # Calculate new rank def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1): param_dict = {} - if dynamic_method=="sv_ratio": + if dynamic_method == "sv_ratio": # Calculate new dim and alpha based off ratio new_rank = index_sv_ratio(S, dynamic_param) + 1 new_alpha = float(scale*new_rank) - elif dynamic_method=="sv_cumulative": + elif dynamic_method == "sv_cumulative": # Calculate new dim and alpha based off cumulative sum new_rank = index_sv_cumulative(S, dynamic_param) + 1 new_alpha = float(scale*new_rank) - elif dynamic_method=="sv_fro": + elif dynamic_method == "sv_fro": # Calculate new dim and alpha based off sqrt sum of squares new_rank = index_sv_fro(S, dynamic_param) + 1 new_alpha = float(scale*new_rank) @@ -159,19 +159,17 @@ def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1): new_rank = rank new_alpha = float(scale*new_rank) - - if S[0] <= MIN_SV: # Zero matrix, set dim to 1 + if S[0] <= MIN_SV: # Zero matrix, set dim to 1 new_rank = 1 new_alpha = float(scale*new_rank) - elif new_rank > rank: # cap max rank at rank + elif new_rank > rank: # cap max rank at rank new_rank = rank new_alpha = float(scale*new_rank) - # Calculate resize info s_sum = torch.sum(torch.abs(S)) s_rank = torch.sum(torch.abs(S[:new_rank])) - + S_squared = S.pow(2) s_fro = torch.sqrt(torch.sum(S_squared)) s_red_fro = torch.sqrt(torch.sum(S_squared[:new_rank])) @@ -187,176 +185,181 @@ def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1): def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dynamic_param, verbose): - network_alpha = None - network_dim = None - verbose_str = "\n" - fro_list = [] - - # Extract loaded lora dim and alpha - for key, value in lora_sd.items(): - if network_alpha is None and 'alpha' in key: - network_alpha = value - if network_dim is None and 'lora_down' in key and len(value.size()) == 2: - network_dim = value.size()[0] - if network_alpha is not None and network_dim is not None: - break - if network_alpha is None: - network_alpha = network_dim - - scale = network_alpha/network_dim - - if dynamic_method: - print(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}") - - lora_down_weight = None - lora_up_weight = None - - o_lora_sd = lora_sd.copy() - block_down_name = None - block_up_name = None - - with torch.no_grad(): - for key, value in tqdm(lora_sd.items()): - weight_name = None - if 'lora_down' in key: - block_down_name = key.rsplit('.lora_down', 1)[0] - weight_name = key.rsplit(".", 1)[-1] - lora_down_weight = value - else: - continue - - # find corresponding lora_up and alpha - block_up_name = block_down_name - lora_up_weight = lora_sd.get(block_up_name + '.lora_up.' + weight_name, None) - lora_alpha = lora_sd.get(block_down_name + '.alpha', None) - - weights_loaded = (lora_down_weight is not None and lora_up_weight is not None) - - if weights_loaded: - - conv2d = (len(lora_down_weight.size()) == 4) - if lora_alpha is None: - scale = 1.0 - else: - scale = lora_alpha/lora_down_weight.size()[0] - - if conv2d: - full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device) - param_dict = extract_conv(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale) - else: - full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device) - param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale) - - if verbose: - max_ratio = param_dict['max_ratio'] - sum_retained = param_dict['sum_retained'] - fro_retained = param_dict['fro_retained'] - if not np.isnan(fro_retained): - fro_list.append(float(fro_retained)) - - verbose_str+=f"{block_down_name:75} | " - verbose_str+=f"sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}" - - if verbose and dynamic_method: - verbose_str+=f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n" - else: - verbose_str+=f"\n" - - new_alpha = param_dict['new_alpha'] - o_lora_sd[block_down_name + "." + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous() - o_lora_sd[block_up_name + "." + "lora_up.weight"] = param_dict["lora_up"].to(save_dtype).contiguous() - o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(param_dict['new_alpha']).to(save_dtype) - - block_down_name = None - block_up_name = None - lora_down_weight = None - lora_up_weight = None - weights_loaded = False - del param_dict - - if verbose: - print(verbose_str) - - print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}") - print("resizing complete") - return o_lora_sd, network_dim, new_alpha + network_alpha = None + network_dim = None + verbose_str = "\n" + fro_list = [] + + # Extract loaded lora dim and alpha + for key, value in lora_sd.items(): + if network_alpha is None and 'alpha' in key: + network_alpha = value + if network_dim is None and 'lora_down' in key and len(value.size()) == 2: + network_dim = value.size()[0] + if network_alpha is not None and network_dim is not None: + break + if network_alpha is None: + network_alpha = network_dim + + scale = network_alpha/network_dim + + if dynamic_method: + print(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}") + + lora_down_weight = None + lora_up_weight = None + + o_lora_sd = lora_sd.copy() + block_down_name = None + block_up_name = None + + with torch.no_grad(): + for key, value in tqdm(lora_sd.items()): + weight_name = None + if 'lora_down' in key: + block_down_name = key.rsplit('.lora_down', 1)[0] + weight_name = key.rsplit(".", 1)[-1] + lora_down_weight = value + else: + continue + + # find corresponding lora_up and alpha + block_up_name = block_down_name + lora_up_weight = lora_sd.get(block_up_name + '.lora_up.' + weight_name, None) + lora_alpha = lora_sd.get(block_down_name + '.alpha', None) + + weights_loaded = (lora_down_weight is not None and lora_up_weight is not None) + + if weights_loaded: + + conv2d = (len(lora_down_weight.size()) == 4) + if lora_alpha is None: + scale = 1.0 + else: + scale = lora_alpha/lora_down_weight.size()[0] + + if conv2d: + full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device) + param_dict = extract_conv(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale) + else: + full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device) + param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale) + + if verbose: + max_ratio = param_dict['max_ratio'] + sum_retained = param_dict['sum_retained'] + fro_retained = param_dict['fro_retained'] + if not np.isnan(fro_retained): + fro_list.append(float(fro_retained)) + + verbose_str += f"{block_down_name:75} | " + verbose_str += f"sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}" + + if verbose and dynamic_method: + verbose_str += f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n" + else: + verbose_str += "\n" + + new_alpha = param_dict['new_alpha'] + o_lora_sd[block_down_name + "." + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous() + o_lora_sd[block_up_name + "." + "lora_up.weight"] = param_dict["lora_up"].to(save_dtype).contiguous() + o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(param_dict['new_alpha']).to(save_dtype) + + block_down_name = None + block_up_name = None + lora_down_weight = None + lora_up_weight = None + weights_loaded = False + del param_dict + + if verbose: + print(verbose_str) + + print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}") + print("resizing complete") + return o_lora_sd, network_dim, new_alpha def resize(args): - if args.save_to is None or not (args.save_to.endswith('.ckpt') or args.save_to.endswith('.pt') or args.save_to.endswith('.pth') or args.save_to.endswith('.safetensors')): - raise Exception("The --save_to argument must be specified and must be a .ckpt , .pt, .pth or .safetensors file.") - - - def str_to_dtype(p): - if p == 'float': - return torch.float - if p == 'fp16': - return torch.float16 - if p == 'bf16': - return torch.bfloat16 - return None - - if args.dynamic_method and not args.dynamic_param: - raise Exception("If using dynamic_method, then dynamic_param is required") - - merge_dtype = str_to_dtype('float') # matmul method above only seems to work in float32 - save_dtype = str_to_dtype(args.save_precision) - if save_dtype is None: - save_dtype = merge_dtype - - print("loading Model...") - lora_sd, metadata = load_state_dict(args.model, merge_dtype) - - print("Resizing Lora...") - state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose) - - # update metadata - if metadata is None: - metadata = {} - - comment = metadata.get("ss_training_comment", "") - - if not args.dynamic_method: - metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}" - metadata["ss_network_dim"] = str(args.new_rank) - metadata["ss_network_alpha"] = str(new_alpha) - else: - metadata["ss_training_comment"] = f"Dynamic resize with {args.dynamic_method}: {args.dynamic_param} from {old_dim}; {comment}" - metadata["ss_network_dim"] = 'Dynamic' - metadata["ss_network_alpha"] = 'Dynamic' + if ( + args.save_to is None or + not (args.save_to.endswith('.ckpt') or + args.save_to.endswith('.pt') or + args.save_to.endswith('.pth') or + args.save_to.endswith('.safetensors')) + ): + raise Exception("The --save_to argument must be specified and must be a .ckpt , .pt, .pth or .safetensors file.") + + def str_to_dtype(p): + if p == 'float': + return torch.float + if p == 'fp16': + return torch.float16 + if p == 'bf16': + return torch.bfloat16 + return None + + if args.dynamic_method and not args.dynamic_param: + raise Exception("If using dynamic_method, then dynamic_param is required") + + merge_dtype = str_to_dtype('float') # matmul method above only seems to work in float32 + save_dtype = str_to_dtype(args.save_precision) + if save_dtype is None: + save_dtype = merge_dtype + + print("loading Model...") + lora_sd, metadata = load_state_dict(args.model, merge_dtype) + + print("Resizing Lora...") + state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose) + + # update metadata + if metadata is None: + metadata = {} + + comment = metadata.get("ss_training_comment", "") + + if not args.dynamic_method: + metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}" + metadata["ss_network_dim"] = str(args.new_rank) + metadata["ss_network_alpha"] = str(new_alpha) + else: + metadata["ss_training_comment"] = f"Dynamic resize with {args.dynamic_method}: {args.dynamic_param} from {old_dim}; {comment}" + metadata["ss_network_dim"] = 'Dynamic' + metadata["ss_network_alpha"] = 'Dynamic' - model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) - metadata["sshs_model_hash"] = model_hash - metadata["sshs_legacy_hash"] = legacy_hash + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash - print(f"saving model to: {args.save_to}") - save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata) + print(f"saving model to: {args.save_to}") + save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata) def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - - parser.add_argument("--save_precision", type=str, default=None, - choices=[None, "float", "fp16", "bf16"], help="precision in saving, float if omitted / 保存時の精度、未指定時はfloat") - parser.add_argument("--new_rank", type=int, default=4, - help="Specify rank of output LoRA / 出力するLoRAのrank (dim)") - parser.add_argument("--save_to", type=str, default=None, - help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors") - parser.add_argument("--model", type=str, default=None, - help="LoRA model to resize at to new rank: ckpt or safetensors file / 読み込むLoRAモデル、ckptまたはsafetensors") - parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") - parser.add_argument("--verbose", action="store_true", - help="Display verbose resizing information / rank変更時の詳細情報を出力する") - parser.add_argument("--dynamic_method", type=str, default=None, choices=[None, "sv_ratio", "sv_fro", "sv_cumulative"], - help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank") - parser.add_argument("--dynamic_param", type=float, default=None, - help="Specify target for dynamic reduction") - - return parser + parser = argparse.ArgumentParser() + + parser.add_argument("--save_precision", type=str, default=None, + choices=[None, "float", "fp16", "bf16"], help="precision in saving, float if omitted / 保存時の精度、未指定時はfloat") + parser.add_argument("--new_rank", type=int, default=4, + help="Specify rank of output LoRA / 出力するLoRAのrank (dim)") + parser.add_argument("--save_to", type=str, default=None, + help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors") + parser.add_argument("--model", type=str, default=None, + help="LoRA model to resize at to new rank: ckpt or safetensors file / 読み込むLoRAモデル、ckptまたはsafetensors") + parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") + parser.add_argument("--verbose", action="store_true", + help="Display verbose resizing information / rank変更時の詳細情報を出力する") + parser.add_argument("--dynamic_method", type=str, default=None, choices=[None, "sv_ratio", "sv_fro", "sv_cumulative"], + help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank") + parser.add_argument("--dynamic_param", type=float, default=None, + help="Specify target for dynamic reduction") + + return parser if __name__ == '__main__': - parser = setup_parser() + parser = setup_parser() - args = parser.parse_args() - resize(args) + args = parser.parse_args() + resize(args) From 1492bcbfa2ba0aa4de26d829b7db1397348f1785 Mon Sep 17 00:00:00 2001 From: mgz <49577754+mgz-dev@users.noreply.github.com> Date: Sat, 3 Feb 2024 23:18:55 -0600 Subject: [PATCH 045/103] add --new_conv_rank option update script to also take a separate conv rank value --- networks/resize_lora.py | 44 +++++++++++++++++++++-------------------- 1 file changed, 23 insertions(+), 21 deletions(-) diff --git a/networks/resize_lora.py b/networks/resize_lora.py index 3c866f1eb..c60862687 100644 --- a/networks/resize_lora.py +++ b/networks/resize_lora.py @@ -2,11 +2,12 @@ # This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py # Thanks to cloneofsimo +import os import argparse import torch -from safetensors.torch import load_file, save_file, safe_open +from safetensors.torch import load_file, save_file from tqdm import tqdm -from library import train_util, model_util +from library import train_util import numpy as np MIN_SV = 1e-6 @@ -14,32 +15,29 @@ # Model save and load functions def load_state_dict(file_name, dtype): - if model_util.is_safetensors(file_name): + if os.path.splitext(file_name)[1] == ".safetensors": sd = load_file(file_name) - with safe_open(file_name, framework="pt") as f: - metadata = f.metadata() + metadata = train_util.load_metadata_from_safetensors(file_name) else: - sd = torch.load(file_name, map_location='cpu') - metadata = None + sd = torch.load(file_name, map_location="cpu") + metadata = {} for key in list(sd.keys()): - if type(sd[key]) == torch.Tensor: + if isinstance(sd[key], torch.Tensor): sd[key] = sd[key].to(dtype) return sd, metadata - -def save_to_file(file_name, model, state_dict, dtype, metadata): +def save_to_file(file_name, state_dict, dtype, metadata): if dtype is not None: for key in list(state_dict.keys()): - if type(state_dict[key]) == torch.Tensor: + if isinstance(state_dict[key], torch.Tensor): state_dict[key] = state_dict[key].to(dtype) - if model_util.is_safetensors(file_name): - save_file(model, file_name, metadata) + if os.path.splitext(file_name)[1] == ".safetensors": + save_file(state_dict, file_name, metadata=metadata) else: - torch.save(model, file_name) - + torch.save(state_dict, file_name) # Indexing functions @@ -54,8 +52,8 @@ def index_sv_cumulative(S, target): def index_sv_fro(S, target): S_squared = S.pow(2) - s_fro_sq = float(torch.sum(S_squared)) - sum_S_squared = torch.cumsum(S_squared, dim=0)/s_fro_sq + S_fro_sq = float(torch.sum(S_squared)) + sum_S_squared = torch.cumsum(S_squared, dim=0)/S_fro_sq index = int(torch.searchsorted(sum_S_squared, target**2)) + 1 index = max(1, min(index, len(S)-1)) @@ -184,7 +182,7 @@ def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1): return param_dict -def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dynamic_param, verbose): +def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dynamic_method, dynamic_param, verbose): network_alpha = None network_dim = None verbose_str = "\n" @@ -240,7 +238,7 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn if conv2d: full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device) - param_dict = extract_conv(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale) + param_dict = extract_conv(full_weight_matrix, new_conv_rank, dynamic_method, dynamic_param, device, scale) else: full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device) param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale) @@ -290,6 +288,8 @@ def resize(args): ): raise Exception("The --save_to argument must be specified and must be a .ckpt , .pt, .pth or .safetensors file.") + args.new_conv_rank = args.new_conv_rank if args.new_conv_rank is not None else args.new_rank + def str_to_dtype(p): if p == 'float': return torch.float @@ -311,7 +311,7 @@ def str_to_dtype(p): lora_sd, metadata = load_state_dict(args.model, merge_dtype) print("Resizing Lora...") - state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose) + state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, args.new_conv_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose) # update metadata if metadata is None: @@ -333,7 +333,7 @@ def str_to_dtype(p): metadata["sshs_legacy_hash"] = legacy_hash print(f"saving model to: {args.save_to}") - save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata) + save_to_file(args.save_to, state_dict, save_dtype, metadata) def setup_parser() -> argparse.ArgumentParser: @@ -343,6 +343,8 @@ def setup_parser() -> argparse.ArgumentParser: choices=[None, "float", "fp16", "bf16"], help="precision in saving, float if omitted / 保存時の精度、未指定時はfloat") parser.add_argument("--new_rank", type=int, default=4, help="Specify rank of output LoRA / 出力するLoRAのrank (dim)") + parser.add_argument("--new_conv_rank", type=int, default=None, + help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ") parser.add_argument("--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors") parser.add_argument("--model", type=str, default=None, From e793d7780d779855f23210d1c88368fd9286666e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 4 Feb 2024 17:31:01 +0900 Subject: [PATCH 046/103] reduce peak VRAM in sample gen --- library/train_util.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/library/train_util.py b/library/train_util.py index 1377997c9..32198774a 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4820,6 +4820,10 @@ def sample_image_inference(accelerator: Accelerator, args: argparse.Namespace, p controlnet=controlnet, controlnet_image=controlnet_image, ) + + with torch.cuda.device(torch.cuda.current_device()): + torch.cuda.empty_cache() + image = pipeline.latents_to_image(latents)[0] # adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list From 5f6bf29e5290f074059179116af5c991c8f523b7 Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Sun, 4 Feb 2024 16:14:34 +0700 Subject: [PATCH 047/103] Replace print with logger if they are logs (#905) * Add get_my_logger() * Use logger instead of print * Fix log level * Removed line-breaks for readability * Use setup_logging() * Add rich to requirements.txt * Make simple * Use logger instead of print --------- Co-authored-by: Kohya S <52813779+kohya-ss@users.noreply.github.com> --- fine_tune.py | 14 +- finetune/blip/blip.py | 6 +- finetune/clean_captions_and_tags.py | 42 +-- finetune/make_captions.py | 22 +- finetune/make_captions_by_git.py | 23 +- finetune/merge_captions_to_metadata.py | 20 +- finetune/merge_dd_tags_to_metadata.py | 20 +- finetune/prepare_buckets_latents.py | 26 +- finetune/tag_images_by_wd14_tagger.py | 31 ++- gen_img_diffusers.py | 167 ++++++------ library/config_util.py | 44 +-- library/custom_train_functions.py | 17 +- library/huggingface_util.py | 17 +- library/ipex/hijacks.py | 2 +- library/lpw_stable_diffusion.py | 3 +- library/model_util.py | 30 ++- library/original_unet.py | 12 +- library/sai_model_spec.py | 6 +- library/sdxl_model_util.py | 27 +- library/sdxl_original_unet.py | 33 +-- library/sdxl_train_util.py | 38 +-- library/slicing_vae.py | 25 +- library/train_util.py | 329 ++++++++++++----------- library/utils.py | 19 +- networks/check_lora_weights.py | 13 +- networks/control_net_lllite.py | 35 +-- networks/control_net_lllite_for_train.py | 43 +-- networks/dylora.py | 27 +- networks/extract_lora_from_dylora.py | 17 +- networks/extract_lora_from_models.py | 25 +- networks/lora.py | 105 ++++---- networks/lora_diffusers.py | 69 ++--- networks/lora_fa.py | 103 +++---- networks/lora_interrogator.py | 24 +- networks/merge_lora.py | 37 +-- networks/merge_lora_old.py | 27 +- networks/oft.py | 19 +- networks/resize_lora.py | 18 +- networks/sdxl_merge_lora.py | 35 +-- networks/svd_merge_lora.py | 23 +- requirements.txt | 2 + sdxl_gen_img.py | 167 ++++++------ sdxl_minimal_inference.py | 14 +- sdxl_train.py | 20 +- sdxl_train_control_net_lllite.py | 23 +- sdxl_train_control_net_lllite_old.py | 23 +- sdxl_train_network.py | 11 +- tools/cache_latents.py | 19 +- tools/cache_text_encoder_outputs.py | 19 +- tools/canny.py | 6 +- tools/convert_diffusers20_original_sd.py | 17 +- tools/detect_face_rotate.py | 22 +- tools/latent_upscaler.py | 21 +- tools/merge_models.py | 27 +- tools/original_control_net.py | 21 +- tools/resize_images_to_resolution.py | 9 +- tools/show_metadata.py | 8 +- train_controlnet.py | 17 +- train_db.py | 16 +- train_network.py | 21 +- train_textual_inversion.py | 10 +- train_textual_inversion_XTI.py | 70 ++--- 62 files changed, 1195 insertions(+), 961 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 982dc8aec..a75c46735 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -18,6 +18,10 @@ from accelerate.utils import set_seed from diffusers import DDPMScheduler +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) import library.train_util as train_util import library.config_util as config_util from library.config_util import ( @@ -49,11 +53,11 @@ def train(args): if args.dataset_class is None: blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, False, True)) if args.dataset_config is not None: - print(f"Load dataset config from {args.dataset_config}") + logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) ignored = ["train_data_dir", "in_json"] if any(getattr(args, attr) is not None for attr in ignored): - print( + logger.warning( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) @@ -86,7 +90,7 @@ def train(args): train_util.debug_dataset(train_dataset_group) return if len(train_dataset_group) == 0: - print( + logger.error( "No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。" ) return @@ -97,7 +101,7 @@ def train(args): ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" # acceleratorを準備する - print("prepare accelerator") + logger.info("prepare accelerator") accelerator = train_util.prepare_accelerator(args) # mixed precisionに対応した型を用意しておき適宜castする @@ -461,7 +465,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): train_util.save_sd_model_on_train_end( args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae ) - print("model saved.") + logger.info("model saved.") def setup_parser() -> argparse.ArgumentParser: diff --git a/finetune/blip/blip.py b/finetune/blip/blip.py index 7851fb08b..7d192cb26 100644 --- a/finetune/blip/blip.py +++ b/finetune/blip/blip.py @@ -21,6 +21,10 @@ import os from urllib.parse import urlparse from timm.models.hub import download_cached_file +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) class BLIP_Base(nn.Module): def __init__(self, @@ -235,6 +239,6 @@ def load_checkpoint(model,url_or_filename): del state_dict[key] msg = model.load_state_dict(state_dict,strict=False) - print('load checkpoint from %s'%url_or_filename) + logger.info('load checkpoint from %s'%url_or_filename) return model,msg diff --git a/finetune/clean_captions_and_tags.py b/finetune/clean_captions_and_tags.py index 68839eccc..5aeb17425 100644 --- a/finetune/clean_captions_and_tags.py +++ b/finetune/clean_captions_and_tags.py @@ -8,6 +8,10 @@ import re from tqdm import tqdm +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) PATTERN_HAIR_LENGTH = re.compile(r', (long|short|medium) hair, ') PATTERN_HAIR_CUT = re.compile(r', (bob|hime) cut, ') @@ -36,13 +40,13 @@ def clean_tags(image_key, tags): tokens = tags.split(", rating") if len(tokens) == 1: # WD14 taggerのときはこちらになるのでメッセージは出さない - # print("no rating:") - # print(f"{image_key} {tags}") + # logger.info("no rating:") + # logger.info(f"{image_key} {tags}") pass else: if len(tokens) > 2: - print("multiple ratings:") - print(f"{image_key} {tags}") + logger.info("multiple ratings:") + logger.info(f"{image_key} {tags}") tags = tokens[0] tags = ", " + tags.replace(", ", ", , ") + ", " # カンマ付きで検索をするための身も蓋もない対策 @@ -124,43 +128,43 @@ def clean_caption(caption): def main(args): if os.path.exists(args.in_json): - print(f"loading existing metadata: {args.in_json}") + logger.info(f"loading existing metadata: {args.in_json}") with open(args.in_json, "rt", encoding='utf-8') as f: metadata = json.load(f) else: - print("no metadata / メタデータファイルがありません") + logger.error("no metadata / メタデータファイルがありません") return - print("cleaning captions and tags.") + logger.info("cleaning captions and tags.") image_keys = list(metadata.keys()) for image_key in tqdm(image_keys): tags = metadata[image_key].get('tags') if tags is None: - print(f"image does not have tags / メタデータにタグがありません: {image_key}") + logger.error(f"image does not have tags / メタデータにタグがありません: {image_key}") else: org = tags tags = clean_tags(image_key, tags) metadata[image_key]['tags'] = tags if args.debug and org != tags: - print("FROM: " + org) - print("TO: " + tags) + logger.info("FROM: " + org) + logger.info("TO: " + tags) caption = metadata[image_key].get('caption') if caption is None: - print(f"image does not have caption / メタデータにキャプションがありません: {image_key}") + logger.error(f"image does not have caption / メタデータにキャプションがありません: {image_key}") else: org = caption caption = clean_caption(caption) metadata[image_key]['caption'] = caption if args.debug and org != caption: - print("FROM: " + org) - print("TO: " + caption) + logger.info("FROM: " + org) + logger.info("TO: " + caption) # metadataを書き出して終わり - print(f"writing metadata: {args.out_json}") + logger.info(f"writing metadata: {args.out_json}") with open(args.out_json, "wt", encoding='utf-8') as f: json.dump(metadata, f, indent=2) - print("done!") + logger.info("done!") def setup_parser() -> argparse.ArgumentParser: @@ -178,10 +182,10 @@ def setup_parser() -> argparse.ArgumentParser: args, unknown = parser.parse_known_args() if len(unknown) == 1: - print("WARNING: train_data_dir argument is removed. This script will not work with three arguments in future. Please specify two arguments: in_json and out_json.") - print("All captions and tags in the metadata are processed.") - print("警告: train_data_dir引数は不要になりました。将来的には三つの引数を指定すると動かなくなる予定です。読み込み元のメタデータと書き出し先の二つの引数だけ指定してください。") - print("メタデータ内のすべてのキャプションとタグが処理されます。") + logger.warning("WARNING: train_data_dir argument is removed. This script will not work with three arguments in future. Please specify two arguments: in_json and out_json.") + logger.warning("All captions and tags in the metadata are processed.") + logger.warning("警告: train_data_dir引数は不要になりました。将来的には三つの引数を指定すると動かなくなる予定です。読み込み元のメタデータと書き出し先の二つの引数だけ指定してください。") + logger.warning("メタデータ内のすべてのキャプションとタグが処理されます。") args.in_json = args.out_json args.out_json = unknown[0] elif len(unknown) > 0: diff --git a/finetune/make_captions.py b/finetune/make_captions.py index 074576bc2..5d6d48047 100644 --- a/finetune/make_captions.py +++ b/finetune/make_captions.py @@ -15,6 +15,10 @@ sys.path.append(os.path.dirname(__file__)) from blip.blip import blip_decoder, is_url import library.train_util as train_util +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -47,7 +51,7 @@ def __getitem__(self, idx): # convert to tensor temporarily so dataloader will accept it tensor = IMAGE_TRANSFORM(image) except Exception as e: - print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") + logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") return None return (tensor, img_path) @@ -74,21 +78,21 @@ def main(args): args.train_data_dir = os.path.abspath(args.train_data_dir) # convert to absolute path cwd = os.getcwd() - print("Current Working Directory is: ", cwd) + logger.info(f"Current Working Directory is: {cwd}") os.chdir("finetune") if not is_url(args.caption_weights) and not os.path.isfile(args.caption_weights): args.caption_weights = os.path.join("..", args.caption_weights) - print(f"load images from {args.train_data_dir}") + logger.info(f"load images from {args.train_data_dir}") train_data_dir_path = Path(args.train_data_dir) image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) - print(f"found {len(image_paths)} images.") + logger.info(f"found {len(image_paths)} images.") - print(f"loading BLIP caption: {args.caption_weights}") + logger.info(f"loading BLIP caption: {args.caption_weights}") model = blip_decoder(pretrained=args.caption_weights, image_size=IMAGE_SIZE, vit="large", med_config="./blip/med_config.json") model.eval() model = model.to(DEVICE) - print("BLIP loaded") + logger.info("BLIP loaded") # captioningする def run_batch(path_imgs): @@ -108,7 +112,7 @@ def run_batch(path_imgs): with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f: f.write(caption + "\n") if args.debug: - print(image_path, caption) + logger.info(f'{image_path} {caption}') # 読み込みの高速化のためにDataLoaderを使うオプション if args.max_data_loader_n_workers is not None: @@ -138,7 +142,7 @@ def run_batch(path_imgs): raw_image = raw_image.convert("RGB") img_tensor = IMAGE_TRANSFORM(raw_image) except Exception as e: - print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") + logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") continue b_imgs.append((image_path, img_tensor)) @@ -148,7 +152,7 @@ def run_batch(path_imgs): if len(b_imgs) > 0: run_batch(b_imgs) - print("done!") + logger.info("done!") def setup_parser() -> argparse.ArgumentParser: diff --git a/finetune/make_captions_by_git.py b/finetune/make_captions_by_git.py index b3c5cc423..e50ed8ca8 100644 --- a/finetune/make_captions_by_git.py +++ b/finetune/make_captions_by_git.py @@ -10,7 +10,10 @@ from transformers.generation.utils import GenerationMixin import library.train_util as train_util - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -35,8 +38,8 @@ def remove_words(captions, debug): for pat in PATTERN_REPLACE: cap = pat.sub("", cap) if debug and cap != caption: - print(caption) - print(cap) + logger.info(caption) + logger.info(cap) removed_caps.append(cap) return removed_caps @@ -70,16 +73,16 @@ def _prepare_input_ids_for_generation_patch(self, bos_token_id, encoder_outputs) GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch """ - print(f"load images from {args.train_data_dir}") + logger.info(f"load images from {args.train_data_dir}") train_data_dir_path = Path(args.train_data_dir) image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) - print(f"found {len(image_paths)} images.") + logger.info(f"found {len(image_paths)} images.") # できればcacheに依存せず明示的にダウンロードしたい - print(f"loading GIT: {args.model_id}") + logger.info(f"loading GIT: {args.model_id}") git_processor = AutoProcessor.from_pretrained(args.model_id) git_model = AutoModelForCausalLM.from_pretrained(args.model_id).to(DEVICE) - print("GIT loaded") + logger.info("GIT loaded") # captioningする def run_batch(path_imgs): @@ -97,7 +100,7 @@ def run_batch(path_imgs): with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f: f.write(caption + "\n") if args.debug: - print(image_path, caption) + logger.info(f"{image_path} {caption}") # 読み込みの高速化のためにDataLoaderを使うオプション if args.max_data_loader_n_workers is not None: @@ -126,7 +129,7 @@ def run_batch(path_imgs): if image.mode != "RGB": image = image.convert("RGB") except Exception as e: - print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") + logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") continue b_imgs.append((image_path, image)) @@ -137,7 +140,7 @@ def run_batch(path_imgs): if len(b_imgs) > 0: run_batch(b_imgs) - print("done!") + logger.info("done!") def setup_parser() -> argparse.ArgumentParser: diff --git a/finetune/merge_captions_to_metadata.py b/finetune/merge_captions_to_metadata.py index 241f6f902..60765b863 100644 --- a/finetune/merge_captions_to_metadata.py +++ b/finetune/merge_captions_to_metadata.py @@ -5,26 +5,30 @@ from tqdm import tqdm import library.train_util as train_util import os +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def main(args): assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください" train_data_dir_path = Path(args.train_data_dir) image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) - print(f"found {len(image_paths)} images.") + logger.info(f"found {len(image_paths)} images.") if args.in_json is None and Path(args.out_json).is_file(): args.in_json = args.out_json if args.in_json is not None: - print(f"loading existing metadata: {args.in_json}") + logger.info(f"loading existing metadata: {args.in_json}") metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8')) - print("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます") + logger.warning("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます") else: - print("new metadata will be created / 新しいメタデータファイルが作成されます") + logger.info("new metadata will be created / 新しいメタデータファイルが作成されます") metadata = {} - print("merge caption texts to metadata json.") + logger.info("merge caption texts to metadata json.") for image_path in tqdm(image_paths): caption_path = image_path.with_suffix(args.caption_extension) caption = caption_path.read_text(encoding='utf-8').strip() @@ -38,12 +42,12 @@ def main(args): metadata[image_key]['caption'] = caption if args.debug: - print(image_key, caption) + logger.info(f"{image_key} {caption}") # metadataを書き出して終わり - print(f"writing metadata: {args.out_json}") + logger.info(f"writing metadata: {args.out_json}") Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8') - print("done!") + logger.info("done!") def setup_parser() -> argparse.ArgumentParser: diff --git a/finetune/merge_dd_tags_to_metadata.py b/finetune/merge_dd_tags_to_metadata.py index db1bff6da..9ef8f14b0 100644 --- a/finetune/merge_dd_tags_to_metadata.py +++ b/finetune/merge_dd_tags_to_metadata.py @@ -5,26 +5,30 @@ from tqdm import tqdm import library.train_util as train_util import os +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def main(args): assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください" train_data_dir_path = Path(args.train_data_dir) image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) - print(f"found {len(image_paths)} images.") + logger.info(f"found {len(image_paths)} images.") if args.in_json is None and Path(args.out_json).is_file(): args.in_json = args.out_json if args.in_json is not None: - print(f"loading existing metadata: {args.in_json}") + logger.info(f"loading existing metadata: {args.in_json}") metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8')) - print("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます") + logger.warning("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます") else: - print("new metadata will be created / 新しいメタデータファイルが作成されます") + logger.info("new metadata will be created / 新しいメタデータファイルが作成されます") metadata = {} - print("merge tags to metadata json.") + logger.info("merge tags to metadata json.") for image_path in tqdm(image_paths): tags_path = image_path.with_suffix(args.caption_extension) tags = tags_path.read_text(encoding='utf-8').strip() @@ -38,13 +42,13 @@ def main(args): metadata[image_key]['tags'] = tags if args.debug: - print(image_key, tags) + logger.info(f"{image_key} {tags}") # metadataを書き出して終わり - print(f"writing metadata: {args.out_json}") + logger.info(f"writing metadata: {args.out_json}") Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8') - print("done!") + logger.info("done!") def setup_parser() -> argparse.ArgumentParser: diff --git a/finetune/prepare_buckets_latents.py b/finetune/prepare_buckets_latents.py index 1bccb1d3b..4292ac133 100644 --- a/finetune/prepare_buckets_latents.py +++ b/finetune/prepare_buckets_latents.py @@ -13,6 +13,10 @@ import library.model_util as model_util import library.train_util as train_util +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -51,22 +55,22 @@ def get_npz_filename(data_dir, image_key, is_full_path, recursive): def main(args): # assert args.bucket_reso_steps % 8 == 0, f"bucket_reso_steps must be divisible by 8 / bucket_reso_stepは8で割り切れる必要があります" if args.bucket_reso_steps % 8 > 0: - print(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります") + logger.warning(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります") if args.bucket_reso_steps % 32 > 0: - print( + logger.warning( f"WARNING: bucket_reso_steps is not divisible by 32. It is not working with SDXL / bucket_reso_stepsが32で割り切れません。SDXLでは動作しません" ) train_data_dir_path = Path(args.train_data_dir) image_paths: List[str] = [str(p) for p in train_util.glob_images_pathlib(train_data_dir_path, args.recursive)] - print(f"found {len(image_paths)} images.") + logger.info(f"found {len(image_paths)} images.") if os.path.exists(args.in_json): - print(f"loading existing metadata: {args.in_json}") + logger.info(f"loading existing metadata: {args.in_json}") with open(args.in_json, "rt", encoding="utf-8") as f: metadata = json.load(f) else: - print(f"no metadata / メタデータファイルがありません: {args.in_json}") + logger.error(f"no metadata / メタデータファイルがありません: {args.in_json}") return weight_dtype = torch.float32 @@ -89,7 +93,7 @@ def main(args): if not args.bucket_no_upscale: bucket_manager.make_buckets() else: - print( + logger.warning( "min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます" ) @@ -130,7 +134,7 @@ def process_batch(is_last): if image.mode != "RGB": image = image.convert("RGB") except Exception as e: - print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") + logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") continue image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0] @@ -183,15 +187,15 @@ def process_batch(is_last): for i, reso in enumerate(bucket_manager.resos): count = bucket_counts.get(reso, 0) if count > 0: - print(f"bucket {i} {reso}: {count}") + logger.info(f"bucket {i} {reso}: {count}") img_ar_errors = np.array(img_ar_errors) - print(f"mean ar error: {np.mean(img_ar_errors)}") + logger.info(f"mean ar error: {np.mean(img_ar_errors)}") # metadataを書き出して終わり - print(f"writing metadata: {args.out_json}") + logger.info(f"writing metadata: {args.out_json}") with open(args.out_json, "wt", encoding="utf-8") as f: json.dump(metadata, f, indent=2) - print("done!") + logger.info("done!") def setup_parser() -> argparse.ArgumentParser: diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index fbf328e83..7a751b177 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -11,6 +11,10 @@ from tqdm import tqdm import library.train_util as train_util +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) # from wd14 tagger IMAGE_SIZE = 448 @@ -58,7 +62,7 @@ def __getitem__(self, idx): image = preprocess_image(image) tensor = torch.tensor(image) except Exception as e: - print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") + logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") return None return (tensor, img_path) @@ -79,7 +83,7 @@ def main(args): # depreacatedの警告が出るけどなくなったらその時 # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22 if not os.path.exists(args.model_dir) or args.force_download: - print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}") + logger.info(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}") files = FILES if args.onnx: files += FILES_ONNX @@ -95,7 +99,7 @@ def main(args): force_filename=file, ) else: - print("using existing wd14 tagger model") + logger.info("using existing wd14 tagger model") # 画像を読み込む if args.onnx: @@ -103,8 +107,8 @@ def main(args): import onnxruntime as ort onnx_path = f"{args.model_dir}/model.onnx" - print("Running wd14 tagger with onnx") - print(f"loading onnx model: {onnx_path}") + logger.info("Running wd14 tagger with onnx") + logger.info(f"loading onnx model: {onnx_path}") if not os.path.exists(onnx_path): raise Exception( @@ -121,7 +125,7 @@ def main(args): if args.batch_size != batch_size and type(batch_size) != str: # some rebatch model may use 'N' as dynamic axes - print( + logger.warning( f"Batch size {args.batch_size} doesn't match onnx model batch size {batch_size}, use model batch size {batch_size}" ) args.batch_size = batch_size @@ -156,7 +160,7 @@ def main(args): train_data_dir_path = Path(args.train_data_dir) image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) - print(f"found {len(image_paths)} images.") + logger.info(f"found {len(image_paths)} images.") tag_freq = {} @@ -237,7 +241,10 @@ def run_batch(path_imgs): with open(caption_file, "wt", encoding="utf-8") as f: f.write(tag_text + "\n") if args.debug: - print(f"\n{image_path}:\n Character tags: {character_tag_text}\n General tags: {general_tag_text}") + logger.info("") + logger.info(f"{image_path}:") + logger.info(f"\tCharacter tags: {character_tag_text}") + logger.info(f"\tGeneral tags: {general_tag_text}") # 読み込みの高速化のためにDataLoaderを使うオプション if args.max_data_loader_n_workers is not None: @@ -269,7 +276,7 @@ def run_batch(path_imgs): image = image.convert("RGB") image = preprocess_image(image) except Exception as e: - print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") + logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") continue b_imgs.append((image_path, image)) @@ -284,11 +291,11 @@ def run_batch(path_imgs): if args.frequency_tags: sorted_tags = sorted(tag_freq.items(), key=lambda x: x[1], reverse=True) - print("\nTag frequencies:") + logger.info("Tag frequencies:") for tag, freq in sorted_tags: - print(f"{tag}: {freq}") + logger.info(f"{tag}: {freq}") - print("done!") + logger.info("done!") def setup_parser() -> argparse.ArgumentParser: diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index a207ad5a1..bf09cdd3b 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -104,6 +104,10 @@ from library.original_unet import FlashAttentionFunction from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) # scheduler: SCHEDULER_LINEAR_START = 0.00085 @@ -139,12 +143,12 @@ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa): if mem_eff_attn: - print("Enable memory efficient attention for U-Net") + logger.info("Enable memory efficient attention for U-Net") # これはDiffusersのU-Netではなく自前のU-Netなので置き換えなくても良い unet.set_use_memory_efficient_attention(False, True) elif xformers: - print("Enable xformers for U-Net") + logger.info("Enable xformers for U-Net") try: import xformers.ops except ImportError: @@ -152,7 +156,7 @@ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditio unet.set_use_memory_efficient_attention(True, False) elif sdpa: - print("Enable SDPA for U-Net") + logger.info("Enable SDPA for U-Net") unet.set_use_memory_efficient_attention(False, False) unet.set_use_sdpa(True) @@ -168,7 +172,7 @@ def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xform def replace_vae_attn_to_memory_efficient(): - print("VAE Attention.forward has been replaced to FlashAttention (not xformers)") + logger.info("VAE Attention.forward has been replaced to FlashAttention (not xformers)") flash_func = FlashAttentionFunction def forward_flash_attn(self, hidden_states, **kwargs): @@ -224,7 +228,7 @@ def forward_flash_attn_0_14(self, hidden_states, **kwargs): def replace_vae_attn_to_xformers(): - print("VAE: Attention.forward has been replaced to xformers") + logger.info("VAE: Attention.forward has been replaced to xformers") import xformers.ops def forward_xformers(self, hidden_states, **kwargs): @@ -280,7 +284,7 @@ def forward_xformers_0_14(self, hidden_states, **kwargs): def replace_vae_attn_to_sdpa(): - print("VAE: Attention.forward has been replaced to sdpa") + logger.info("VAE: Attention.forward has been replaced to sdpa") def forward_sdpa(self, hidden_states, **kwargs): residual = hidden_states @@ -684,7 +688,7 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 if not do_classifier_free_guidance and negative_scale is not None: - print(f"negative_scale is ignored if guidance scalle <= 1.0") + logger.warning(f"negative_scale is ignored if guidance scalle <= 1.0") negative_scale = None # get unconditional embeddings for classifier free guidance @@ -766,11 +770,11 @@ def __call__( clip_text_input = prompt_tokens if clip_text_input.shape[1] > self.tokenizer.model_max_length: # TODO 75文字を超えたら警告を出す? - print("trim text input", clip_text_input.shape) + logger.info(f"trim text input {clip_text_input.shape}") clip_text_input = torch.cat( [clip_text_input[:, : self.tokenizer.model_max_length - 1], clip_text_input[:, -1].unsqueeze(1)], dim=1 ) - print("trimmed", clip_text_input.shape) + logger.info(f"trimmed {clip_text_input.shape}") for i, clip_prompt in enumerate(clip_prompts): if clip_prompt is not None: # clip_promptがあれば上書きする @@ -1699,7 +1703,7 @@ def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length: if word.strip() == "BREAK": # pad until next multiple of tokenizer's max token length pad_len = pipe.tokenizer.model_max_length - (len(text_token) % pipe.tokenizer.model_max_length) - print(f"BREAK pad_len: {pad_len}") + logger.info(f"BREAK pad_len: {pad_len}") for i in range(pad_len): # v2のときEOSをつけるべきかどうかわからないぜ # if i == 0: @@ -1729,7 +1733,7 @@ def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length: tokens.append(text_token) weights.append(text_weight) if truncated: - print("warning: Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") + logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") return tokens, weights @@ -2041,7 +2045,7 @@ def handle_dynamic_prompt_variants(prompt, repeat_count): elif len(count_range) == 2: count_range = [int(count_range[0]), int(count_range[1])] else: - print(f"invalid count range: {count_range}") + logger.warning(f"invalid count range: {count_range}") count_range = [1, 1] if count_range[0] > count_range[1]: count_range = [count_range[1], count_range[0]] @@ -2111,7 +2115,7 @@ def replacer(): # def load_clip_l14_336(dtype): -# print(f"loading CLIP: {CLIP_ID_L14_336}") +# logger.info(f"loading CLIP: {CLIP_ID_L14_336}") # text_encoder = CLIPTextModel.from_pretrained(CLIP_ID_L14_336, torch_dtype=dtype) # return text_encoder @@ -2158,9 +2162,9 @@ def main(args): # assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません" if args.v_parameterization and not args.v2: - print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません") + logger.warning("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません") if args.v2 and args.clip_skip is not None: - print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") + logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") # モデルを読み込む if not os.path.isfile(args.ckpt): # ファイルがないならパターンで探し、一つだけ該当すればそれを使う @@ -2170,10 +2174,10 @@ def main(args): use_stable_diffusion_format = os.path.isfile(args.ckpt) if use_stable_diffusion_format: - print("load StableDiffusion checkpoint") + logger.info("load StableDiffusion checkpoint") text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.ckpt) else: - print("load Diffusers pretrained models") + logger.info("load Diffusers pretrained models") loading_pipe = StableDiffusionPipeline.from_pretrained(args.ckpt, safety_checker=None, torch_dtype=dtype) text_encoder = loading_pipe.text_encoder vae = loading_pipe.vae @@ -2196,21 +2200,21 @@ def main(args): # VAEを読み込む if args.vae is not None: vae = model_util.load_vae(args.vae, dtype) - print("additional VAE loaded") + logger.info("additional VAE loaded") # # 置換するCLIPを読み込む # if args.replace_clip_l14_336: # text_encoder = load_clip_l14_336(dtype) - # print(f"large clip {CLIP_ID_L14_336} is loaded") + # logger.info(f"large clip {CLIP_ID_L14_336} is loaded") if args.clip_guidance_scale > 0.0 or args.clip_image_guidance_scale: - print("prepare clip model") + logger.info("prepare clip model") clip_model = CLIPModel.from_pretrained(CLIP_MODEL_PATH, torch_dtype=dtype) else: clip_model = None if args.vgg16_guidance_scale > 0.0: - print("prepare resnet model") + logger.info("prepare resnet model") vgg16_model = torchvision.models.vgg16(torchvision.models.VGG16_Weights.IMAGENET1K_V1) else: vgg16_model = None @@ -2222,7 +2226,7 @@ def main(args): replace_vae_modules(vae, mem_eff, args.xformers, args.sdpa) # tokenizerを読み込む - print("loading tokenizer") + logger.info("loading tokenizer") if use_stable_diffusion_format: tokenizer = train_util.load_tokenizer(args) @@ -2281,7 +2285,7 @@ def reset_sampler_noises(self, noises): self.sampler_noises = noises def randn(self, shape, device=None, dtype=None, layout=None, generator=None): - # print("replacing", shape, len(self.sampler_noises), self.sampler_noise_index) + # logger.info(f"replacing {shape} {len(self.sampler_noises)} {self.sampler_noise_index}") if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises): noise = self.sampler_noises[self.sampler_noise_index] if shape != noise.shape: @@ -2290,7 +2294,7 @@ def randn(self, shape, device=None, dtype=None, layout=None, generator=None): noise = None if noise == None: - print(f"unexpected noise request: {self.sampler_noise_index}, {shape}") + logger.warning(f"unexpected noise request: {self.sampler_noise_index}, {shape}") noise = torch.randn(shape, dtype=dtype, device=device, generator=generator) self.sampler_noise_index += 1 @@ -2321,7 +2325,7 @@ def __getattr__(self, item): # clip_sample=Trueにする if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: - print("set clip_sample to True") + logger.info("set clip_sample to True") scheduler.config.clip_sample = True # deviceを決定する @@ -2378,7 +2382,7 @@ def __getattr__(self, item): network_merge = 0 for i, network_module in enumerate(args.network_module): - print("import network module:", network_module) + logger.info(f"import network module: {network_module}") imported_module = importlib.import_module(network_module) network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i] @@ -2396,7 +2400,7 @@ def __getattr__(self, item): raise ValueError("No weight. Weight is required.") network_weight = args.network_weights[i] - print("load network weights from:", network_weight) + logger.info(f"load network weights from: {network_weight}") if model_util.is_safetensors(network_weight) and args.network_show_meta: from safetensors.torch import safe_open @@ -2404,7 +2408,7 @@ def __getattr__(self, item): with safe_open(network_weight, framework="pt") as f: metadata = f.metadata() if metadata is not None: - print(f"metadata for: {network_weight}: {metadata}") + logger.info(f"metadata for: {network_weight}: {metadata}") network, weights_sd = imported_module.create_network_from_weights( network_mul, network_weight, vae, text_encoder, unet, for_inference=True, **net_kwargs @@ -2414,20 +2418,20 @@ def __getattr__(self, item): mergeable = network.is_mergeable() if network_merge and not mergeable: - print("network is not mergiable. ignore merge option.") + logger.warning("network is not mergiable. ignore merge option.") if not mergeable or i >= network_merge: # not merging network.apply_to(text_encoder, unet) info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい - print(f"weights are loaded: {info}") + logger.info(f"weights are loaded: {info}") if args.opt_channels_last: network.to(memory_format=torch.channels_last) network.to(dtype).to(device) if network_pre_calc: - print("backup original weights") + logger.info("backup original weights") network.backup_weights() networks.append(network) @@ -2441,7 +2445,7 @@ def __getattr__(self, item): # upscalerの指定があれば取得する upscaler = None if args.highres_fix_upscaler: - print("import upscaler module:", args.highres_fix_upscaler) + logger.info(f"import upscaler module {args.highres_fix_upscaler}") imported_module = importlib.import_module(args.highres_fix_upscaler) us_kwargs = {} @@ -2450,7 +2454,7 @@ def __getattr__(self, item): key, value = net_arg.split("=") us_kwargs[key] = value - print("create upscaler") + logger.info("create upscaler") upscaler = imported_module.create_upscaler(**us_kwargs) upscaler.to(dtype).to(device) @@ -2467,7 +2471,7 @@ def __getattr__(self, item): control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio)) if args.opt_channels_last: - print(f"set optimizing: channels last") + logger.info(f"set optimizing: channels last") text_encoder.to(memory_format=torch.channels_last) vae.to(memory_format=torch.channels_last) unet.to(memory_format=torch.channels_last) @@ -2499,7 +2503,7 @@ def __getattr__(self, item): args.vgg16_guidance_layer, ) pipe.set_control_nets(control_nets) - print("pipeline is ready.") + logger.info("pipeline is ready.") if args.diffusers_xformers: pipe.enable_xformers_memory_efficient_attention() @@ -2542,7 +2546,7 @@ def __getattr__(self, item): ), f"tokenizer has same word to token string (filename). please rename the file / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}" token_ids = tokenizer.convert_tokens_to_ids(token_strings) - print(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids}") + logger.info(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids}") assert ( min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1 ), f"token ids is not ordered" @@ -2601,7 +2605,7 @@ def __getattr__(self, item): ), f"tokenizer has same word to token string (filename). please rename the file / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}" token_ids = tokenizer.convert_tokens_to_ids(token_strings) - print(f"XTI embeddings `{token_string}` loaded. Tokens are added: {token_ids}") + logger.info(f"XTI embeddings `{token_string}` loaded. Tokens are added: {token_ids}") # if num_vectors_per_token > 1: pipe.add_token_replacement(token_ids[0], token_ids) @@ -2626,7 +2630,7 @@ def __getattr__(self, item): # promptを取得する if args.from_file is not None: - print(f"reading prompts from {args.from_file}") + logger.info(f"reading prompts from {args.from_file}") with open(args.from_file, "r", encoding="utf-8") as f: prompt_list = f.read().splitlines() prompt_list = [d for d in prompt_list if len(d.strip()) > 0] @@ -2655,7 +2659,7 @@ def load_images(path): for p in paths: image = Image.open(p) if image.mode != "RGB": - print(f"convert image to RGB from {image.mode}: {p}") + logger.info(f"convert image to RGB from {image.mode}: {p}") image = image.convert("RGB") images.append(image) @@ -2671,24 +2675,24 @@ def resize_images(imgs, size): return resized if args.image_path is not None: - print(f"load image for img2img: {args.image_path}") + logger.info(f"load image for img2img: {args.image_path}") init_images = load_images(args.image_path) assert len(init_images) > 0, f"No image / 画像がありません: {args.image_path}" - print(f"loaded {len(init_images)} images for img2img") + logger.info(f"loaded {len(init_images)} images for img2img") else: init_images = None if args.mask_path is not None: - print(f"load mask for inpainting: {args.mask_path}") + logger.info(f"load mask for inpainting: {args.mask_path}") mask_images = load_images(args.mask_path) assert len(mask_images) > 0, f"No mask image / マスク画像がありません: {args.image_path}" - print(f"loaded {len(mask_images)} mask images for inpainting") + logger.info(f"loaded {len(mask_images)} mask images for inpainting") else: mask_images = None # promptがないとき、画像のPngInfoから取得する if init_images is not None and len(prompt_list) == 0 and not args.interactive: - print("get prompts from images' meta data") + logger.info("get prompts from images' meta data") for img in init_images: if "prompt" in img.text: prompt = img.text["prompt"] @@ -2717,17 +2721,17 @@ def resize_images(imgs, size): h = int(h * args.highres_fix_scale + 0.5) if init_images is not None: - print(f"resize img2img source images to {w}*{h}") + logger.info(f"resize img2img source images to {w}*{h}") init_images = resize_images(init_images, (w, h)) if mask_images is not None: - print(f"resize img2img mask images to {w}*{h}") + logger.info(f"resize img2img mask images to {w}*{h}") mask_images = resize_images(mask_images, (w, h)) regional_network = False if networks and mask_images: # mask を領域情報として流用する、現在は一回のコマンド呼び出しで1枚だけ対応 regional_network = True - print("use mask as region") + logger.info("use mask as region") size = None for i, network in enumerate(networks): @@ -2752,14 +2756,14 @@ def resize_images(imgs, size): prev_image = None # for VGG16 guided if args.guide_image_path is not None: - print(f"load image for CLIP/VGG16/ControlNet guidance: {args.guide_image_path}") + logger.info(f"load image for CLIP/VGG16/ControlNet guidance: {args.guide_image_path}") guide_images = [] for p in args.guide_image_path: guide_images.extend(load_images(p)) - print(f"loaded {len(guide_images)} guide images for guidance") + logger.info(f"loaded {len(guide_images)} guide images for guidance") if len(guide_images) == 0: - print(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}") + logger.info(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}") guide_images = None else: guide_images = None @@ -2785,7 +2789,7 @@ def resize_images(imgs, size): max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples for gen_iter in range(args.n_iter): - print(f"iteration {gen_iter+1}/{args.n_iter}") + logger.info(f"iteration {gen_iter+1}/{args.n_iter}") iter_seed = random.randint(0, 0x7FFFFFFF) # shuffle prompt list @@ -2801,7 +2805,7 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): # 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す is_1st_latent = upscaler.support_latents() if upscaler else args.highres_fix_latents_upscaling - print("process 1st stage") + logger.info("process 1st stage") batch_1st = [] for _, base, ext in batch: width_1st = int(ext.width * args.highres_fix_scale + 0.5) @@ -2827,7 +2831,7 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): images_1st = process_batch(batch_1st, True, True) # 2nd stageのバッチを作成して以下処理する - print("process 2nd stage") + logger.info("process 2nd stage") width_2nd, height_2nd = batch[0].ext.width, batch[0].ext.height if upscaler: @@ -2978,7 +2982,7 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): n.restore_weights() for n in networks: n.pre_calculation() - print("pre-calculation... done") + logger.info("pre-calculation... done") images = pipe( prompts, @@ -3045,7 +3049,7 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): cv2.waitKey() cv2.destroyAllWindows() except ImportError: - print("opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません") + logger.info("opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません") return images @@ -3058,7 +3062,8 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): # interactive valid = False while not valid: - print("\nType prompt:") + logger.info("") + logger.info("Type prompt:") try: raw_prompt = input() except EOFError: @@ -3101,38 +3106,38 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): prompt_args = raw_prompt.strip().split(" --") prompt = prompt_args[0] - print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") + logger.info(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") for parg in prompt_args[1:]: try: m = re.match(r"w (\d+)", parg, re.IGNORECASE) if m: width = int(m.group(1)) - print(f"width: {width}") + logger.info(f"width: {width}") continue m = re.match(r"h (\d+)", parg, re.IGNORECASE) if m: height = int(m.group(1)) - print(f"height: {height}") + logger.info(f"height: {height}") continue m = re.match(r"s (\d+)", parg, re.IGNORECASE) if m: # steps steps = max(1, min(1000, int(m.group(1)))) - print(f"steps: {steps}") + logger.info(f"steps: {steps}") continue m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE) if m: # seed seeds = [int(d) for d in m.group(1).split(",")] - print(f"seeds: {seeds}") + logger.info(f"seeds: {seeds}") continue m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) if m: # scale scale = float(m.group(1)) - print(f"scale: {scale}") + logger.info(f"scale: {scale}") continue m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE) @@ -3141,25 +3146,25 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): negative_scale = None else: negative_scale = float(m.group(1)) - print(f"negative scale: {negative_scale}") + logger.info(f"negative scale: {negative_scale}") continue m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE) if m: # strength strength = float(m.group(1)) - print(f"strength: {strength}") + logger.info(f"strength: {strength}") continue m = re.match(r"n (.+)", parg, re.IGNORECASE) if m: # negative prompt negative_prompt = m.group(1) - print(f"negative prompt: {negative_prompt}") + logger.info(f"negative prompt: {negative_prompt}") continue m = re.match(r"c (.+)", parg, re.IGNORECASE) if m: # clip prompt clip_prompt = m.group(1) - print(f"clip prompt: {clip_prompt}") + logger.info(f"clip prompt: {clip_prompt}") continue m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE) @@ -3167,47 +3172,47 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): network_muls = [float(v) for v in m.group(1).split(",")] while len(network_muls) < len(networks): network_muls.append(network_muls[-1]) - print(f"network mul: {network_muls}") + logger.info(f"network mul: {network_muls}") continue # Deep Shrink m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE) if m: # deep shrink depth 1 ds_depth_1 = int(m.group(1)) - print(f"deep shrink depth 1: {ds_depth_1}") + logger.info(f"deep shrink depth 1: {ds_depth_1}") continue m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE) if m: # deep shrink timesteps 1 ds_timesteps_1 = int(m.group(1)) ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - print(f"deep shrink timesteps 1: {ds_timesteps_1}") + logger.info(f"deep shrink timesteps 1: {ds_timesteps_1}") continue m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE) if m: # deep shrink depth 2 ds_depth_2 = int(m.group(1)) ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - print(f"deep shrink depth 2: {ds_depth_2}") + logger.info(f"deep shrink depth 2: {ds_depth_2}") continue m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE) if m: # deep shrink timesteps 2 ds_timesteps_2 = int(m.group(1)) ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - print(f"deep shrink timesteps 2: {ds_timesteps_2}") + logger.info(f"deep shrink timesteps 2: {ds_timesteps_2}") continue m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE) if m: # deep shrink ratio ds_ratio = float(m.group(1)) - ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - print(f"deep shrink ratio: {ds_ratio}") + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + logger.info(f"deep shrink ratio: {ds_ratio}") continue except ValueError as ex: - print(f"Exception in parsing / 解析エラー: {parg}") - print(ex) + logger.info(f"Exception in parsing / 解析エラー: {parg}") + logger.info(ex) # override Deep Shrink if ds_depth_1 is not None: @@ -3225,7 +3230,7 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): if len(predefined_seeds) > 0: seed = predefined_seeds.pop(0) else: - print("predefined seeds are exhausted") + logger.info("predefined seeds are exhausted") seed = None elif args.iter_same_seed: seed = iter_seed @@ -3235,7 +3240,7 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): if seed is None: seed = random.randint(0, 0x7FFFFFFF) if args.interactive: - print(f"seed: {seed}") + logger.info(f"seed: {seed}") # prepare init image, guide image and mask init_image = mask_image = guide_image = None @@ -3251,7 +3256,7 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): width = width - width % 32 height = height - height % 32 if width != init_image.size[0] or height != init_image.size[1]: - print( + logger.info( f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます" ) @@ -3267,9 +3272,9 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): guide_image = guide_images[global_step % len(guide_images)] elif args.clip_image_guidance_scale > 0 or args.vgg16_guidance_scale > 0: if prev_image is None: - print("Generate 1st image without guide image.") + logger.info("Generate 1st image without guide image.") else: - print("Use previous image as guide image.") + logger.info("Use previous image as guide image.") guide_image = prev_image if regional_network: @@ -3311,7 +3316,7 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): process_batch(batch_data, highres_fix) batch_data.clear() - print("done!") + logger.info("done!") def setup_parser() -> argparse.ArgumentParser: diff --git a/library/config_util.py b/library/config_util.py index a98c2b90d..fc4b36175 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -40,7 +40,10 @@ ControlNetDataset, DatasetGroup, ) - +from .utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def add_config_arguments(parser: argparse.ArgumentParser): parser.add_argument("--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル") @@ -345,7 +348,7 @@ def sanitize_user_config(self, user_config: dict) -> dict: return self.user_config_validator(user_config) except MultipleInvalid: # TODO: エラー発生時のメッセージをわかりやすくする - print("Invalid user config / ユーザ設定の形式が正しくないようです") + logger.error("Invalid user config / ユーザ設定の形式が正しくないようです") raise # NOTE: In nature, argument parser result is not needed to be sanitize @@ -355,7 +358,7 @@ def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> return self.argparse_config_validator(argparse_namespace) except MultipleInvalid: # XXX: this should be a bug - print("Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。") + logger.error("Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。") raise # NOTE: value would be overwritten by latter dict if there is already the same key @@ -538,13 +541,13 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu " ", ) - print(info) + logger.info(f'{info}') # make buckets first because it determines the length of dataset # and set the same seed for all datasets - seed = random.randint(0, 2**31) # actual seed is seed + epoch_no + seed = random.randint(0, 2**31) # actual seed is seed + epoch_no for i, dataset in enumerate(datasets): - print(f"[Dataset {i}]") + logger.info(f"[Dataset {i}]") dataset.make_buckets() dataset.set_seed(seed) @@ -557,7 +560,7 @@ def extract_dreambooth_params(name: str) -> Tuple[int, str]: try: n_repeats = int(tokens[0]) except ValueError as e: - print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {name}") + logger.warning(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {name}") return 0, "" caption_by_folder = "_".join(tokens[1:]) return n_repeats, caption_by_folder @@ -629,17 +632,13 @@ def load_user_config(file: str) -> dict: with open(file, "r") as f: config = json.load(f) except Exception: - print( - f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}" - ) + logger.error(f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}") raise elif file.name.lower().endswith(".toml"): try: config = toml.load(file) except Exception: - print( - f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}" - ) + logger.error(f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}") raise else: raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}") @@ -665,23 +664,26 @@ def load_user_config(file: str) -> dict: argparse_namespace = parser.parse_args(remain) train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning) - print("[argparse_namespace]") - print(vars(argparse_namespace)) + logger.info("[argparse_namespace]") + logger.info(f'{vars(argparse_namespace)}') user_config = load_user_config(config_args.dataset_config) - print("\n[user_config]") - print(user_config) + logger.info("") + logger.info("[user_config]") + logger.info(f'{user_config}') sanitizer = ConfigSanitizer( config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout ) sanitized_user_config = sanitizer.sanitize_user_config(user_config) - print("\n[sanitized_user_config]") - print(sanitized_user_config) + logger.info("") + logger.info("[sanitized_user_config]") + logger.info(f'{sanitized_user_config}') blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace) - print("\n[blueprint]") - print(blueprint) + logger.info("") + logger.info("[blueprint]") + logger.info(f'{blueprint}') diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index e0a026dae..a56474622 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -3,7 +3,10 @@ import random import re from typing import List, Optional, Union - +from .utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def prepare_scheduler_for_custom_training(noise_scheduler, device): if hasattr(noise_scheduler, "all_snr"): @@ -21,7 +24,7 @@ def prepare_scheduler_for_custom_training(noise_scheduler, device): def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler): # fix beta: zero terminal SNR - print(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891") + logger.info(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891") def enforce_zero_terminal_snr(betas): # Convert betas to alphas_bar_sqrt @@ -49,8 +52,8 @@ def enforce_zero_terminal_snr(betas): alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, dim=0) - # print("original:", noise_scheduler.betas) - # print("fixed:", betas) + # logger.info(f"original: {noise_scheduler.betas}") + # logger.info(f"fixed: {betas}") noise_scheduler.betas = betas noise_scheduler.alphas = alphas @@ -79,13 +82,13 @@ def get_snr_scale(timesteps, noise_scheduler): snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000 scale = snr_t / (snr_t + 1) # # show debug info - # print(f"timesteps: {timesteps}, snr_t: {snr_t}, scale: {scale}") + # logger.info(f"timesteps: {timesteps}, snr_t: {snr_t}, scale: {scale}") return scale def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_loss): scale = get_snr_scale(timesteps, noise_scheduler) - # print(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}") + # logger.info(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}") loss = loss + loss / scale * v_pred_like_loss return loss @@ -268,7 +271,7 @@ def get_prompts_with_weights(tokenizer, prompt: List[str], max_length: int): tokens.append(text_token) weights.append(text_weight) if truncated: - print("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") + logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") return tokens, weights diff --git a/library/huggingface_util.py b/library/huggingface_util.py index 376fdb1e6..57b19d982 100644 --- a/library/huggingface_util.py +++ b/library/huggingface_util.py @@ -4,7 +4,10 @@ import argparse import os from library.utils import fire_in_thread - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def exists_repo(repo_id: str, repo_type: str, revision: str = "main", token: str = None): api = HfApi( @@ -33,9 +36,9 @@ def upload( try: api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private) except Exception as e: # とりあえずRepositoryNotFoundErrorは確認したが他にあると困るので - print("===========================================") - print(f"failed to create HuggingFace repo / HuggingFaceのリポジトリの作成に失敗しました : {e}") - print("===========================================") + logger.error("===========================================") + logger.error(f"failed to create HuggingFace repo / HuggingFaceのリポジトリの作成に失敗しました : {e}") + logger.error("===========================================") is_folder = (type(src) == str and os.path.isdir(src)) or (isinstance(src, Path) and src.is_dir()) @@ -56,9 +59,9 @@ def uploader(): path_in_repo=path_in_repo, ) except Exception as e: # RuntimeErrorを確認済みだが他にあると困るので - print("===========================================") - print(f"failed to upload to HuggingFace / HuggingFaceへのアップロードに失敗しました : {e}") - print("===========================================") + logger.error("===========================================") + logger.error(f"failed to upload to HuggingFace / HuggingFaceへのアップロードに失敗しました : {e}") + logger.error("===========================================") if args.async_upload and not force_sync_upload: fire_in_thread(uploader) diff --git a/library/ipex/hijacks.py b/library/ipex/hijacks.py index 7b2d26d4a..b1b9ccf0e 100644 --- a/library/ipex/hijacks.py +++ b/library/ipex/hijacks.py @@ -12,7 +12,7 @@ class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument if isinstance(device_ids, list) and len(device_ids) > 1: - print("IPEX backend doesn't support DataParallel on multiple XPU devices") + logger.error("IPEX backend doesn't support DataParallel on multiple XPU devices") return module.to("xpu") def return_null_context(*args, **kwargs): # pylint: disable=unused-argument diff --git a/library/lpw_stable_diffusion.py b/library/lpw_stable_diffusion.py index 3963e9b15..5717233d4 100644 --- a/library/lpw_stable_diffusion.py +++ b/library/lpw_stable_diffusion.py @@ -17,7 +17,6 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker from diffusers.utils import logging - try: from diffusers.utils import PIL_INTERPOLATION except ImportError: @@ -626,7 +625,7 @@ def check_inputs(self, prompt, height, width, strength, callback_steps): raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") if height % 8 != 0 or width % 8 != 0: - print(height, width) + logger.info(f'{height} {width}') raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( diff --git a/library/model_util.py b/library/model_util.py index 4361b4994..42c4960f5 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -13,6 +13,10 @@ from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel from safetensors.torch import load_file, save_file from library.original_unet import UNet2DConditionModel +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) # DiffUsers版StableDiffusionのモデルパラメータ NUM_TRAIN_TIMESTEPS = 1000 @@ -944,7 +948,7 @@ def convert_vae_state_dict(vae_state_dict): for k, v in new_state_dict.items(): for weight_name in weights_to_convert: if f"mid.attn_1.{weight_name}.weight" in k: - # print(f"Reshaping {k} for SD format: shape {v.shape} -> {v.shape} x 1 x 1") + # logger.info(f"Reshaping {k} for SD format: shape {v.shape} -> {v.shape} x 1 x 1") new_state_dict[k] = reshape_weight_for_sd(v) return new_state_dict @@ -1002,7 +1006,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt unet = UNet2DConditionModel(**unet_config).to(device) info = unet.load_state_dict(converted_unet_checkpoint) - print("loading u-net:", info) + logger.info(f"loading u-net: {info}") # Convert the VAE model. vae_config = create_vae_diffusers_config() @@ -1010,7 +1014,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt vae = AutoencoderKL(**vae_config).to(device) info = vae.load_state_dict(converted_vae_checkpoint) - print("loading vae:", info) + logger.info(f"loading vae: {info}") # convert text_model if v2: @@ -1044,7 +1048,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt # logging.set_verbosity_error() # don't show annoying warning # text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device) # logging.set_verbosity_warning() - # print(f"config: {text_model.config}") + # logger.info(f"config: {text_model.config}") cfg = CLIPTextConfig( vocab_size=49408, hidden_size=768, @@ -1067,7 +1071,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt ) text_model = CLIPTextModel._from_config(cfg) info = text_model.load_state_dict(converted_text_encoder_checkpoint) - print("loading text encoder:", info) + logger.info(f"loading text encoder: {info}") return text_model, vae, unet @@ -1142,7 +1146,7 @@ def convert_key(key): # 最後の層などを捏造するか if make_dummy_weights: - print("make dummy weights for resblock.23, text_projection and logit scale.") + logger.info("make dummy weights for resblock.23, text_projection and logit scale.") keys = list(new_sd.keys()) for key in keys: if key.startswith("transformer.resblocks.22."): @@ -1261,14 +1265,14 @@ def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_mod def load_vae(vae_id, dtype): - print(f"load VAE: {vae_id}") + logger.info(f"load VAE: {vae_id}") if os.path.isdir(vae_id) or not os.path.isfile(vae_id): # Diffusers local/remote try: vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype) except EnvironmentError as e: - print(f"exception occurs in loading vae: {e}") - print("retry with subfolder='vae'") + logger.error(f"exception occurs in loading vae: {e}") + logger.error("retry with subfolder='vae'") vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype) return vae @@ -1340,13 +1344,13 @@ def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64) if __name__ == "__main__": resos = make_bucket_resolutions((512, 768)) - print(len(resos)) - print(resos) + logger.info(f"{len(resos)}") + logger.info(f"{resos}") aspect_ratios = [w / h for w, h in resos] - print(aspect_ratios) + logger.info(f"{aspect_ratios}") ars = set() for ar in aspect_ratios: if ar in ars: - print("error! duplicate ar:", ar) + logger.error(f"error! duplicate ar: {ar}") ars.add(ar) diff --git a/library/original_unet.py b/library/original_unet.py index 030c5c9ec..e944ff22b 100644 --- a/library/original_unet.py +++ b/library/original_unet.py @@ -113,6 +113,10 @@ from torch import nn from torch.nn import functional as F from einops import rearrange +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) BLOCK_OUT_CHANNELS: Tuple[int] = (320, 640, 1280, 1280) TIMESTEP_INPUT_DIM = BLOCK_OUT_CHANNELS[0] @@ -1380,7 +1384,7 @@ def __init__( ): super().__init__() assert sample_size is not None, "sample_size must be specified" - print( + logger.info( f"UNet2DConditionModel: {sample_size}, {attention_head_dim}, {cross_attention_dim}, {use_linear_projection}, {upcast_attention}" ) @@ -1514,7 +1518,7 @@ def set_use_sdpa(self, sdpa: bool) -> None: def set_gradient_checkpointing(self, value=False): modules = self.down_blocks + [self.mid_block] + self.up_blocks for module in modules: - print(module.__class__.__name__, module.gradient_checkpointing, "->", value) + logger.info(f"{module.__class__.__name__} {module.gradient_checkpointing} -> {value}") module.gradient_checkpointing = value # endregion @@ -1709,14 +1713,14 @@ def __call__(self, *args, **kwargs): def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5): if ds_depth_1 is None: - print("Deep Shrink is disabled.") + logger.info("Deep Shrink is disabled.") self.ds_depth_1 = None self.ds_timesteps_1 = None self.ds_depth_2 = None self.ds_timesteps_2 = None self.ds_ratio = None else: - print( + logger.info( f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]" ) self.ds_depth_1 = ds_depth_1 diff --git a/library/sai_model_spec.py b/library/sai_model_spec.py index 472686ba4..a63bd82ec 100644 --- a/library/sai_model_spec.py +++ b/library/sai_model_spec.py @@ -5,6 +5,10 @@ import os from typing import List, Optional, Tuple, Union import safetensors +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) r""" # Metadata Example @@ -231,7 +235,7 @@ def build_metadata( # # assert all values are filled # assert all([v is not None for v in metadata.values()]), metadata if not all([v is not None for v in metadata.values()]): - print(f"Internal error: some metadata values are None: {metadata}") + logger.error(f"Internal error: some metadata values are None: {metadata}") return metadata diff --git a/library/sdxl_model_util.py b/library/sdxl_model_util.py index 08b90c393..f03f1bae5 100644 --- a/library/sdxl_model_util.py +++ b/library/sdxl_model_util.py @@ -7,7 +7,10 @@ from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel from library import model_util from library import sdxl_original_unet - +from .utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) VAE_SCALE_FACTOR = 0.13025 MODEL_VERSION_SDXL_BASE_V1_0 = "sdxl_base_v1-0" @@ -131,7 +134,7 @@ def convert_key(key): # temporary workaround for text_projection.weight.weight for Playground-v2 if "text_projection.weight.weight" in new_sd: - print(f"convert_sdxl_text_encoder_2_checkpoint: convert text_projection.weight.weight to text_projection.weight") + logger.info("convert_sdxl_text_encoder_2_checkpoint: convert text_projection.weight.weight to text_projection.weight") new_sd["text_projection.weight"] = new_sd["text_projection.weight.weight"] del new_sd["text_projection.weight.weight"] @@ -186,20 +189,20 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty checkpoint = None # U-Net - print("building U-Net") + logger.info("building U-Net") with init_empty_weights(): unet = sdxl_original_unet.SdxlUNet2DConditionModel() - print("loading U-Net from checkpoint") + logger.info("loading U-Net from checkpoint") unet_sd = {} for k in list(state_dict.keys()): if k.startswith("model.diffusion_model."): unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k) info = _load_state_dict_on_device(unet, unet_sd, device=map_location, dtype=dtype) - print("U-Net: ", info) + logger.info(f"U-Net: {info}") # Text Encoders - print("building text encoders") + logger.info("building text encoders") # Text Encoder 1 is same to Stability AI's SDXL text_model1_cfg = CLIPTextConfig( @@ -252,7 +255,7 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty with init_empty_weights(): text_model2 = CLIPTextModelWithProjection(text_model2_cfg) - print("loading text encoders from checkpoint") + logger.info("loading text encoders from checkpoint") te1_sd = {} te2_sd = {} for k in list(state_dict.keys()): @@ -266,22 +269,22 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty te1_sd.pop("text_model.embeddings.position_ids") info1 = _load_state_dict_on_device(text_model1, te1_sd, device=map_location) # remain fp32 - print("text encoder 1:", info1) + logger.info(f"text encoder 1: {info1}") converted_sd, logit_scale = convert_sdxl_text_encoder_2_checkpoint(te2_sd, max_length=77) info2 = _load_state_dict_on_device(text_model2, converted_sd, device=map_location) # remain fp32 - print("text encoder 2:", info2) + logger.info(f"text encoder 2: {info2}") # prepare vae - print("building VAE") + logger.info("building VAE") vae_config = model_util.create_vae_diffusers_config() with init_empty_weights(): vae = AutoencoderKL(**vae_config) - print("loading VAE from checkpoint") + logger.info("loading VAE from checkpoint") converted_vae_checkpoint = model_util.convert_ldm_vae_checkpoint(state_dict, vae_config) info = _load_state_dict_on_device(vae, converted_vae_checkpoint, device=map_location, dtype=dtype) - print("VAE:", info) + logger.info(f"VAE: {info}") ckpt_info = (epoch, global_step) if epoch is not None else None return text_model1, text_model2, vae, unet, logit_scale, ckpt_info diff --git a/library/sdxl_original_unet.py b/library/sdxl_original_unet.py index babda8ec5..673cf9f65 100644 --- a/library/sdxl_original_unet.py +++ b/library/sdxl_original_unet.py @@ -30,7 +30,10 @@ from torch import nn from torch.nn import functional as F from einops import rearrange - +from .utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) IN_CHANNELS: int = 4 OUT_CHANNELS: int = 4 @@ -332,7 +335,7 @@ def forward_body(self, x, emb): def forward(self, x, emb): if self.training and self.gradient_checkpointing: - # print("ResnetBlock2D: gradient_checkpointing") + # logger.info("ResnetBlock2D: gradient_checkpointing") def create_custom_forward(func): def custom_forward(*inputs): @@ -366,7 +369,7 @@ def forward_body(self, hidden_states): def forward(self, hidden_states): if self.training and self.gradient_checkpointing: - # print("Downsample2D: gradient_checkpointing") + # logger.info("Downsample2D: gradient_checkpointing") def create_custom_forward(func): def custom_forward(*inputs): @@ -653,7 +656,7 @@ def forward_body(self, hidden_states, context=None, timestep=None): def forward(self, hidden_states, context=None, timestep=None): if self.training and self.gradient_checkpointing: - # print("BasicTransformerBlock: checkpointing") + # logger.info("BasicTransformerBlock: checkpointing") def create_custom_forward(func): def custom_forward(*inputs): @@ -796,7 +799,7 @@ def forward_body(self, hidden_states, output_size=None): def forward(self, hidden_states, output_size=None): if self.training and self.gradient_checkpointing: - # print("Upsample2D: gradient_checkpointing") + # logger.info("Upsample2D: gradient_checkpointing") def create_custom_forward(func): def custom_forward(*inputs): @@ -1046,7 +1049,7 @@ def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool) -> N for block in blocks: for module in block: if hasattr(module, "set_use_memory_efficient_attention"): - # print(module.__class__.__name__) + # logger.info(module.__class__.__name__) module.set_use_memory_efficient_attention(xformers, mem_eff) def set_use_sdpa(self, sdpa: bool) -> None: @@ -1061,7 +1064,7 @@ def set_gradient_checkpointing(self, value=False): for block in blocks: for module in block.modules(): if hasattr(module, "gradient_checkpointing"): - # print(module.__class__.__name__, module.gradient_checkpointing, "->", value) + # logger.info(f{module.__class__.__name__} {module.gradient_checkpointing} -> {value}") module.gradient_checkpointing = value # endregion @@ -1083,7 +1086,7 @@ def forward(self, x, timesteps=None, context=None, y=None, **kwargs): def call_module(module, h, emb, context): x = h for layer in module: - # print(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None) + # logger.info(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None) if isinstance(layer, ResnetBlock2D): x = layer(x, emb) elif isinstance(layer, Transformer2DModel): @@ -1135,14 +1138,14 @@ def __call__(self, *args, **kwargs): def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5): if ds_depth_1 is None: - print("Deep Shrink is disabled.") + logger.info("Deep Shrink is disabled.") self.ds_depth_1 = None self.ds_timesteps_1 = None self.ds_depth_2 = None self.ds_timesteps_2 = None self.ds_ratio = None else: - print( + logger.info( f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]" ) self.ds_depth_1 = ds_depth_1 @@ -1229,7 +1232,7 @@ def call_module(module, h, emb, context): if __name__ == "__main__": import time - print("create unet") + logger.info("create unet") unet = SdxlUNet2DConditionModel() unet.to("cuda") @@ -1238,7 +1241,7 @@ def call_module(module, h, emb, context): unet.train() # 使用メモリ量確認用の疑似学習ループ - print("preparing optimizer") + logger.info("preparing optimizer") # optimizer = torch.optim.SGD(unet.parameters(), lr=1e-3, nesterov=True, momentum=0.9) # not working @@ -1253,12 +1256,12 @@ def call_module(module, h, emb, context): scaler = torch.cuda.amp.GradScaler(enabled=True) - print("start training") + logger.info("start training") steps = 10 batch_size = 1 for step in range(steps): - print(f"step {step}") + logger.info(f"step {step}") if step == 1: time_start = time.perf_counter() @@ -1278,4 +1281,4 @@ def call_module(module, h, emb, context): optimizer.zero_grad(set_to_none=True) time_end = time.perf_counter() - print(f"elapsed time: {time_end - time_start} [sec] for last {steps - 1} steps") + logger.info(f"elapsed time: {time_end - time_start} [sec] for last {steps - 1} steps") diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index 5ad748d15..aedab6c76 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -9,6 +9,10 @@ from transformers import CLIPTokenizer from library import model_util, sdxl_model_util, train_util, sdxl_original_unet from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline +from .utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) TOKENIZER1_PATH = "openai/clip-vit-large-patch14" TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" @@ -21,7 +25,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype): model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16 for pi in range(accelerator.state.num_processes): if pi == accelerator.state.local_process_index: - print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}") + logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}") ( load_stable_diffusion_format, @@ -62,7 +66,7 @@ def _load_target_model( load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers if load_stable_diffusion_format: - print(f"load StableDiffusion checkpoint: {name_or_path}") + logger.info(f"load StableDiffusion checkpoint: {name_or_path}") ( text_encoder1, text_encoder2, @@ -76,7 +80,7 @@ def _load_target_model( from diffusers import StableDiffusionXLPipeline variant = "fp16" if weight_dtype == torch.float16 else None - print(f"load Diffusers pretrained models: {name_or_path}, variant={variant}") + logger.info(f"load Diffusers pretrained models: {name_or_path}, variant={variant}") try: try: pipe = StableDiffusionXLPipeline.from_pretrained( @@ -84,12 +88,12 @@ def _load_target_model( ) except EnvironmentError as ex: if variant is not None: - print("try to load fp32 model") + logger.info("try to load fp32 model") pipe = StableDiffusionXLPipeline.from_pretrained(name_or_path, variant=None, tokenizer=None) else: raise ex except EnvironmentError as ex: - print( + logger.error( f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}" ) raise ex @@ -112,7 +116,7 @@ def _load_target_model( with init_empty_weights(): unet = sdxl_original_unet.SdxlUNet2DConditionModel() # overwrite unet sdxl_model_util._load_state_dict_on_device(unet, state_dict, device=device, dtype=model_dtype) - print("U-Net converted to original U-Net") + logger.info("U-Net converted to original U-Net") logit_scale = None ckpt_info = None @@ -120,13 +124,13 @@ def _load_target_model( # VAEを読み込む if vae_path is not None: vae = model_util.load_vae(vae_path, weight_dtype) - print("additional VAE loaded") + logger.info("additional VAE loaded") return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info def load_tokenizers(args: argparse.Namespace): - print("prepare tokenizers") + logger.info("prepare tokenizers") original_paths = [TOKENIZER1_PATH, TOKENIZER2_PATH] tokeniers = [] @@ -135,14 +139,14 @@ def load_tokenizers(args: argparse.Namespace): if args.tokenizer_cache_dir: local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_")) if os.path.exists(local_tokenizer_path): - print(f"load tokenizer from cache: {local_tokenizer_path}") + logger.info(f"load tokenizer from cache: {local_tokenizer_path}") tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path) if tokenizer is None: tokenizer = CLIPTokenizer.from_pretrained(original_path) if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path): - print(f"save Tokenizer to cache: {local_tokenizer_path}") + logger.info(f"save Tokenizer to cache: {local_tokenizer_path}") tokenizer.save_pretrained(local_tokenizer_path) if i == 1: @@ -151,7 +155,7 @@ def load_tokenizers(args: argparse.Namespace): tokeniers.append(tokenizer) if hasattr(args, "max_token_length") and args.max_token_length is not None: - print(f"update token length: {args.max_token_length}") + logger.info(f"update token length: {args.max_token_length}") return tokeniers @@ -332,23 +336,23 @@ def add_sdxl_training_arguments(parser: argparse.ArgumentParser): def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True): assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません" if args.v_parameterization: - print("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります") + logger.warning("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります") if args.clip_skip is not None: - print("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません") + logger.warning("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません") # if args.multires_noise_iterations: - # print( + # logger.info( # f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET}, but noise_offset is disabled due to multires_noise_iterations / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されていますが、multires_noise_iterationsが有効になっているためnoise_offsetは無効になります" # ) # else: # if args.noise_offset is None: # args.noise_offset = DEFAULT_NOISE_OFFSET # elif args.noise_offset != DEFAULT_NOISE_OFFSET: - # print( + # logger.info( # f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET} / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されています" # ) - # print(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました") + # logger.info(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました") assert ( not hasattr(args, "weighted_captions") or not args.weighted_captions @@ -357,7 +361,7 @@ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCachin if supportTextEncoderCaching: if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: args.cache_text_encoder_outputs = True - print( + logger.warning( "cache_text_encoder_outputs is enabled because cache_text_encoder_outputs_to_disk is enabled / " + "cache_text_encoder_outputs_to_diskが有効になっているためcache_text_encoder_outputsが有効になりました" ) diff --git a/library/slicing_vae.py b/library/slicing_vae.py index 5c4e056d3..ea7653429 100644 --- a/library/slicing_vae.py +++ b/library/slicing_vae.py @@ -26,7 +26,10 @@ from diffusers.models.unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block from diffusers.models.vae import DecoderOutput, DiagonalGaussianDistribution from diffusers.models.autoencoder_kl import AutoencoderKLOutput - +from .utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def slice_h(x, num_slices): # slice with pad 1 both sides: to eliminate side effect of padding of conv2d @@ -89,7 +92,7 @@ def resblock_forward(_self, num_slices, input_tensor, temb, **kwargs): # sliced_tensor = torch.chunk(x, num_div, dim=1) # sliced_weight = torch.chunk(norm.weight, num_div, dim=0) # sliced_bias = torch.chunk(norm.bias, num_div, dim=0) - # print(sliced_tensor[0].shape, num_div, sliced_weight[0].shape, sliced_bias[0].shape) + # logger.info(sliced_tensor[0].shape, num_div, sliced_weight[0].shape, sliced_bias[0].shape) # normed_tensor = [] # for i in range(num_div): # n = torch.group_norm(sliced_tensor[i], norm.num_groups, sliced_weight[i], sliced_bias[i], norm.eps) @@ -243,7 +246,7 @@ def forward(*args, **kwargs): self.num_slices = num_slices div = num_slices / (2 ** (len(self.down_blocks) - 1)) # 深い層はそこまで分割しなくていいので適宜減らす - # print(f"initial divisor: {div}") + # logger.info(f"initial divisor: {div}") if div >= 2: div = int(div) for resnet in self.mid_block.resnets: @@ -253,11 +256,11 @@ def forward(*args, **kwargs): for i, down_block in enumerate(self.down_blocks[::-1]): if div >= 2: div = int(div) - # print(f"down block: {i} divisor: {div}") + # logger.info(f"down block: {i} divisor: {div}") for resnet in down_block.resnets: resnet.forward = wrapper(resblock_forward, resnet, div) if down_block.downsamplers is not None: - # print("has downsample") + # logger.info("has downsample") for downsample in down_block.downsamplers: downsample.forward = wrapper(self.downsample_forward, downsample, div * 2) div *= 2 @@ -307,7 +310,7 @@ def forward(self, x): def downsample_forward(self, _self, num_slices, hidden_states): assert hidden_states.shape[1] == _self.channels assert _self.use_conv and _self.padding == 0 - print("downsample forward", num_slices, hidden_states.shape) + logger.info(f"downsample forward {num_slices} {hidden_states.shape}") org_device = hidden_states.device cpu_device = torch.device("cpu") @@ -350,7 +353,7 @@ def downsample_forward(self, _self, num_slices, hidden_states): hidden_states = torch.cat([hidden_states, x], dim=2) hidden_states = hidden_states.to(org_device) - # print("downsample forward done", hidden_states.shape) + # logger.info(f"downsample forward done {hidden_states.shape}") return hidden_states @@ -426,7 +429,7 @@ def forward(*args, **kwargs): self.num_slices = num_slices div = num_slices / (2 ** (len(self.up_blocks) - 1)) - print(f"initial divisor: {div}") + logger.info(f"initial divisor: {div}") if div >= 2: div = int(div) for resnet in self.mid_block.resnets: @@ -436,11 +439,11 @@ def forward(*args, **kwargs): for i, up_block in enumerate(self.up_blocks): if div >= 2: div = int(div) - # print(f"up block: {i} divisor: {div}") + # logger.info(f"up block: {i} divisor: {div}") for resnet in up_block.resnets: resnet.forward = wrapper(resblock_forward, resnet, div) if up_block.upsamplers is not None: - # print("has upsample") + # logger.info("has upsample") for upsample in up_block.upsamplers: upsample.forward = wrapper(self.upsample_forward, upsample, div * 2) div *= 2 @@ -528,7 +531,7 @@ def upsample_forward(self, _self, num_slices, hidden_states, output_size=None): del x hidden_states = torch.cat(sliced, dim=2) - # print("us hidden_states", hidden_states.shape) + # logger.info(f"us hidden_states {hidden_states.shape}") del sliced hidden_states = hidden_states.to(org_device) diff --git a/library/train_util.py b/library/train_util.py index ba428e508..2031debe8 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -6,6 +6,7 @@ import datetime import importlib import json +import logging import pathlib import re import shutil @@ -64,7 +65,10 @@ import library.model_util as model_util import library.huggingface_util as huggingface_util import library.sai_model_spec as sai_model_spec - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) # from library.attention_processors import FlashAttnProcessor # from library.hypernetwork import replace_attentions_for_hypernetwork from library.original_unet import UNet2DConditionModel @@ -211,7 +215,7 @@ def add_if_new_reso(self, reso): self.reso_to_id[reso] = bucket_id self.resos.append(reso) self.buckets.append([]) - # print(reso, bucket_id, len(self.buckets)) + # logger.info(reso, bucket_id, len(self.buckets)) def round_to_steps(self, x): x = int(x + 0.5) @@ -237,7 +241,7 @@ def select_bucket(self, image_width, image_height): scale = reso[0] / image_width resized_size = (int(image_width * scale + 0.5), int(image_height * scale + 0.5)) - # print("use predef", image_width, image_height, reso, resized_size) + # logger.info(f"use predef, {image_width}, {image_height}, {reso}, {resized_size}") else: # 縮小のみを行う if image_width * image_height > self.max_area: @@ -256,21 +260,21 @@ def select_bucket(self, image_width, image_height): b_width_in_hr = self.round_to_steps(b_height_rounded * aspect_ratio) ar_height_rounded = b_width_in_hr / b_height_rounded - # print(b_width_rounded, b_height_in_wr, ar_width_rounded) - # print(b_width_in_hr, b_height_rounded, ar_height_rounded) + # logger.info(b_width_rounded, b_height_in_wr, ar_width_rounded) + # logger.info(b_width_in_hr, b_height_rounded, ar_height_rounded) if abs(ar_width_rounded - aspect_ratio) < abs(ar_height_rounded - aspect_ratio): resized_size = (b_width_rounded, int(b_width_rounded / aspect_ratio + 0.5)) else: resized_size = (int(b_height_rounded * aspect_ratio + 0.5), b_height_rounded) - # print(resized_size) + # logger.info(resized_size) else: resized_size = (image_width, image_height) # リサイズは不要 # 画像のサイズ未満をbucketのサイズとする(paddingせずにcroppingする) bucket_width = resized_size[0] - resized_size[0] % self.reso_steps bucket_height = resized_size[1] - resized_size[1] % self.reso_steps - # print("use arbitrary", image_width, image_height, resized_size, bucket_width, bucket_height) + # logger.info(f"use arbitrary {image_width}, {image_height}, {resized_size}, {bucket_width}, {bucket_height}") reso = (bucket_width, bucket_height) @@ -779,15 +783,15 @@ def make_buckets(self): bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る) min_size and max_size are ignored when enable_bucket is False """ - print("loading image sizes.") + logger.info("loading image sizes.") for info in tqdm(self.image_data.values()): if info.image_size is None: info.image_size = self.get_image_size(info.absolute_path) if self.enable_bucket: - print("make buckets") + logger.info("make buckets") else: - print("prepare dataset") + logger.info("prepare dataset") # bucketを作成し、画像をbucketに振り分ける if self.enable_bucket: @@ -802,7 +806,7 @@ def make_buckets(self): if not self.bucket_no_upscale: self.bucket_manager.make_buckets() else: - print( + logger.warning( "min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます" ) @@ -813,7 +817,7 @@ def make_buckets(self): image_width, image_height ) - # print(image_info.image_key, image_info.bucket_reso) + # logger.info(image_info.image_key, image_info.bucket_reso) img_ar_errors.append(abs(ar_error)) self.bucket_manager.sort() @@ -831,17 +835,17 @@ def make_buckets(self): # bucket情報を表示、格納する if self.enable_bucket: self.bucket_info = {"buckets": {}} - print("number of images (including repeats) / 各bucketの画像枚数(繰り返し回数を含む)") + logger.info("number of images (including repeats) / 各bucketの画像枚数(繰り返し回数を含む)") for i, (reso, bucket) in enumerate(zip(self.bucket_manager.resos, self.bucket_manager.buckets)): count = len(bucket) if count > 0: self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(bucket)} - print(f"bucket {i}: resolution {reso}, count: {len(bucket)}") + logger.info(f"bucket {i}: resolution {reso}, count: {len(bucket)}") img_ar_errors = np.array(img_ar_errors) mean_img_ar_error = np.mean(np.abs(img_ar_errors)) self.bucket_info["mean_img_ar_error"] = mean_img_ar_error - print(f"mean ar error (without repeats): {mean_img_ar_error}") + logger.info(f"mean ar error (without repeats): {mean_img_ar_error}") # データ参照用indexを作る。このindexはdatasetのshuffleに用いられる self.buckets_indices: List(BucketBatchIndex) = [] @@ -861,7 +865,7 @@ def make_buckets(self): # num_of_image_types = len(set(bucket)) # bucket_batch_size = min(self.batch_size, num_of_image_types) # batch_count = int(math.ceil(len(bucket) / bucket_batch_size)) - # # print(bucket_index, num_of_image_types, bucket_batch_size, batch_count) + # # logger.info(bucket_index, num_of_image_types, bucket_batch_size, batch_count) # for batch_index in range(batch_count): # self.buckets_indices.append(BucketBatchIndex(bucket_index, bucket_batch_size, batch_index)) # ↑ここまで @@ -900,7 +904,7 @@ def is_text_encoder_output_cacheable(self): def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): # マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと - print("caching latents.") + logger.info("caching latents.") image_infos = list(self.image_data.values()) @@ -910,7 +914,7 @@ def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_proc # split by resolution batches = [] batch = [] - print("checking cache validity...") + logger.info("checking cache validity...") for info in tqdm(image_infos): subset = self.image_to_subset[info.image_key] @@ -947,7 +951,7 @@ def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_proc return # iterate batches: batch doesn't have image, image will be loaded in cache_batch_latents and discarded - print("caching latents...") + logger.info("caching latents...") for batch in tqdm(batches, smoothing=1, total=len(batches)): cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.random_crop) @@ -961,10 +965,10 @@ def cache_text_encoder_outputs( # latentsのキャッシュと同様に、ディスクへのキャッシュに対応する # またマルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと - print("caching text encoder outputs.") + logger.info("caching text encoder outputs.") image_infos = list(self.image_data.values()) - print("checking cache existence...") + logger.info("checking cache existence...") image_infos_to_cache = [] for info in tqdm(image_infos): # subset = self.image_to_subset[info.image_key] @@ -1005,7 +1009,7 @@ def cache_text_encoder_outputs( batches.append(batch) # iterate batches: call text encoder and cache outputs for memory or disk - print("caching text encoder outputs...") + logger.info("caching text encoder outputs...") for batch in tqdm(batches): infos, input_ids1, input_ids2 = zip(*batch) input_ids1 = torch.stack(input_ids1, dim=0) @@ -1404,7 +1408,7 @@ def read_caption(img_path, caption_extension): try: lines = f.readlines() except UnicodeDecodeError as e: - print(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}") + logger.error(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}") raise e assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}" caption = lines[0].strip() @@ -1413,11 +1417,11 @@ def read_caption(img_path, caption_extension): def load_dreambooth_dir(subset: DreamBoothSubset): if not os.path.isdir(subset.image_dir): - print(f"not directory: {subset.image_dir}") + logger.warning(f"not directory: {subset.image_dir}") return [], [] img_paths = glob_images(subset.image_dir, "*") - print(f"found directory {subset.image_dir} contains {len(img_paths)} image files") + logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files") # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う captions = [] @@ -1425,7 +1429,7 @@ def load_dreambooth_dir(subset: DreamBoothSubset): for img_path in img_paths: cap_for_img = read_caption(img_path, subset.caption_extension) if cap_for_img is None and subset.class_tokens is None: - print( + logger.warning( f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}" ) captions.append("") @@ -1444,36 +1448,36 @@ def load_dreambooth_dir(subset: DreamBoothSubset): number_of_missing_captions_to_show = 5 remaining_missing_captions = number_of_missing_captions - number_of_missing_captions_to_show - print( + logger.warning( f"No caption file found for {number_of_missing_captions} images. Training will continue without captions for these images. If class token exists, it will be used. / {number_of_missing_captions}枚の画像にキャプションファイルが見つかりませんでした。これらの画像についてはキャプションなしで学習を続行します。class tokenが存在する場合はそれを使います。" ) for i, missing_caption in enumerate(missing_captions): if i >= number_of_missing_captions_to_show: - print(missing_caption + f"... and {remaining_missing_captions} more") + logger.warning(missing_caption + f"... and {remaining_missing_captions} more") break - print(missing_caption) + logger.warning(missing_caption) return img_paths, captions - print("prepare images.") + logger.info("prepare images.") num_train_images = 0 num_reg_images = 0 reg_infos: List[ImageInfo] = [] for subset in subsets: if subset.num_repeats < 1: - print( + logger.warning( f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}" ) continue if subset in self.subsets: - print( + logger.warning( f"ignore duplicated subset with image_dir='{subset.image_dir}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します" ) continue img_paths, captions = load_dreambooth_dir(subset) if len(img_paths) < 1: - print(f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します") + logger.warning(f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します") continue if subset.is_reg: @@ -1491,15 +1495,15 @@ def load_dreambooth_dir(subset: DreamBoothSubset): subset.img_count = len(img_paths) self.subsets.append(subset) - print(f"{num_train_images} train images with repeating.") + logger.info(f"{num_train_images} train images with repeating.") self.num_train_images = num_train_images - print(f"{num_reg_images} reg images.") + logger.info(f"{num_reg_images} reg images.") if num_train_images < num_reg_images: - print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります") + logger.warning("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります") if num_reg_images == 0: - print("no regularization images / 正則化画像が見つかりませんでした") + logger.warning("no regularization images / 正則化画像が見つかりませんでした") else: # num_repeatsを計算する:どうせ大した数ではないのでループで処理する n = 0 @@ -1544,27 +1548,27 @@ def __init__( for subset in subsets: if subset.num_repeats < 1: - print( + logger.warning( f"ignore subset with metadata_file='{subset.metadata_file}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}" ) continue if subset in self.subsets: - print( + logger.warning( f"ignore duplicated subset with metadata_file='{subset.metadata_file}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します" ) continue # メタデータを読み込む if os.path.exists(subset.metadata_file): - print(f"loading existing metadata: {subset.metadata_file}") + logger.info(f"loading existing metadata: {subset.metadata_file}") with open(subset.metadata_file, "rt", encoding="utf-8") as f: metadata = json.load(f) else: raise ValueError(f"no metadata / メタデータファイルがありません: {subset.metadata_file}") if len(metadata) < 1: - print(f"ignore subset with '{subset.metadata_file}': no image entries found / 画像に関するデータが見つからないためサブセットを無視します") + logger.warning(f"ignore subset with '{subset.metadata_file}': no image entries found / 画像に関するデータが見つからないためサブセットを無視します") continue tags_list = [] @@ -1642,14 +1646,14 @@ def __init__( if not npz_any: use_npz_latents = False - print(f"npz file does not exist. ignore npz files / npzファイルが見つからないためnpzファイルを無視します") + logger.warning(f"npz file does not exist. ignore npz files / npzファイルが見つからないためnpzファイルを無視します") elif not npz_all: use_npz_latents = False - print(f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します") + logger.warning(f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します") if flip_aug_in_subset: - print("maybe no flipped files / 反転されたnpzファイルがないのかもしれません") + logger.warning("maybe no flipped files / 反転されたnpzファイルがないのかもしれません") # else: - # print("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません") + # logger.info("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません") # check min/max bucket size sizes = set() @@ -1665,7 +1669,7 @@ def __init__( if sizes is None: if use_npz_latents: use_npz_latents = False - print(f"npz files exist, but no bucket info in metadata. ignore npz files / メタデータにbucket情報がないためnpzファイルを無視します") + logger.warning(f"npz files exist, but no bucket info in metadata. ignore npz files / メタデータにbucket情報がないためnpzファイルを無視します") assert ( resolution is not None @@ -1679,8 +1683,8 @@ def __init__( self.bucket_no_upscale = bucket_no_upscale else: if not enable_bucket: - print("metadata has bucket info, enable bucketing / メタデータにbucket情報があるためbucketを有効にします") - print("using bucket info in metadata / メタデータ内のbucket情報を使います") + logger.info("metadata has bucket info, enable bucketing / メタデータにbucket情報があるためbucketを有効にします") + logger.info("using bucket info in metadata / メタデータ内のbucket情報を使います") self.enable_bucket = True assert ( @@ -1803,7 +1807,7 @@ def __init__( assert subset is not None, "internal error: subset not found" if not os.path.isdir(subset.conditioning_data_dir): - print(f"not directory: {subset.conditioning_data_dir}") + logger.warning(f"not directory: {subset.conditioning_data_dir}") continue img_basename = os.path.basename(info.absolute_path) @@ -1923,14 +1927,14 @@ def enable_XTI(self, *args, **kwargs): def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): for i, dataset in enumerate(self.datasets): - print(f"[Dataset {i}]") + logger.info(f"[Dataset {i}]") dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process) def cache_text_encoder_outputs( self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True ): for i, dataset in enumerate(self.datasets): - print(f"[Dataset {i}]") + logger.info(f"[Dataset {i}]") dataset.cache_text_encoder_outputs(tokenizers, text_encoders, device, weight_dtype, cache_to_disk, is_main_process) def set_caching_mode(self, caching_mode): @@ -2014,12 +2018,13 @@ def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, fli def debug_dataset(train_dataset, show_input_ids=False): - print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}") - print("`S` for next step, `E` for next epoch no. , Escape for exit. / Sキーで次のステップ、Eキーで次のエポック、Escキーで中断、終了します") + logger.info(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}") + logger.info("`S` for next step, `E` for next epoch no. , Escape for exit. / Sキーで次のステップ、Eキーで次のエポック、Escキーで中断、終了します") epoch = 1 while True: - print(f"\nepoch: {epoch}") + logger.info(f"") + logger.info(f"epoch: {epoch}") steps = (epoch - 1) * len(train_dataset) + 1 indices = list(range(len(train_dataset))) @@ -2029,11 +2034,11 @@ def debug_dataset(train_dataset, show_input_ids=False): for i, idx in enumerate(indices): train_dataset.set_current_epoch(epoch) train_dataset.set_current_step(steps) - print(f"steps: {steps} ({i + 1}/{len(train_dataset)})") + logger.info(f"steps: {steps} ({i + 1}/{len(train_dataset)})") example = train_dataset[idx] if example["latents"] is not None: - print(f"sample has latents from npz file: {example['latents'].size()}") + logger.info(f"sample has latents from npz file: {example['latents'].size()}") for j, (ik, cap, lw, iid, orgsz, crptl, trgsz, flpdz) in enumerate( zip( example["image_keys"], @@ -2046,26 +2051,26 @@ def debug_dataset(train_dataset, show_input_ids=False): example["flippeds"], ) ): - print( + logger.info( f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}", original size: {orgsz}, crop top left: {crptl}, target size: {trgsz}, flipped: {flpdz}' ) if "network_multipliers" in example: print(f"network multiplier: {example['network_multipliers'][j]}") if show_input_ids: - print(f"input ids: {iid}") + logger.info(f"input ids: {iid}") if "input_ids2" in example: - print(f"input ids2: {example['input_ids2'][j]}") + logger.info(f"input ids2: {example['input_ids2'][j]}") if example["images"] is not None: im = example["images"][j] - print(f"image size: {im.size()}") + logger.info(f"image size: {im.size()}") im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8) im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c im = im[:, :, ::-1] # RGB -> BGR (OpenCV) if "conditioning_images" in example: cond_img = example["conditioning_images"][j] - print(f"conditioning image size: {cond_img.size()}") + logger.info(f"conditioning image size: {cond_img.size()}") cond_img = ((cond_img.numpy() + 1.0) * 127.5).astype(np.uint8) cond_img = np.transpose(cond_img, (1, 2, 0)) cond_img = cond_img[:, :, ::-1] @@ -2213,12 +2218,12 @@ def trim_and_resize_if_required( if image_width > reso[0]: trim_size = image_width - reso[0] p = trim_size // 2 if not random_crop else random.randint(0, trim_size) - # print("w", trim_size, p) + # logger.info(f"w {trim_size} {p}") image = image[:, p : p + reso[0]] if image_height > reso[1]: trim_size = image_height - reso[1] p = trim_size // 2 if not random_crop else random.randint(0, trim_size) - # print("h", trim_size, p) + # logger.info(f"h {trim_size} {p}) image = image[p : p + reso[1]] # random cropの場合のcropされた値をどうcrop left/topに反映するべきか全くアイデアがない @@ -2456,7 +2461,7 @@ def get_git_revision_hash() -> str: # def replace_unet_cross_attn_to_xformers(): -# print("CrossAttention.forward has been replaced to enable xformers.") +# logger.info("CrossAttention.forward has been replaced to enable xformers.") # try: # import xformers.ops # except ImportError: @@ -2499,10 +2504,10 @@ def get_git_revision_hash() -> str: # diffusers.models.attention.CrossAttention.forward = forward_xformers def replace_unet_modules(unet: UNet2DConditionModel, mem_eff_attn, xformers, sdpa): if mem_eff_attn: - print("Enable memory efficient attention for U-Net") + logger.info("Enable memory efficient attention for U-Net") unet.set_use_memory_efficient_attention(False, True) elif xformers: - print("Enable xformers for U-Net") + logger.info("Enable xformers for U-Net") try: import xformers.ops except ImportError: @@ -2510,7 +2515,7 @@ def replace_unet_modules(unet: UNet2DConditionModel, mem_eff_attn, xformers, sdp unet.set_use_memory_efficient_attention(True, False) elif sdpa: - print("Enable SDPA for U-Net") + logger.info("Enable SDPA for U-Net") unet.set_use_sdpa(True) @@ -2521,17 +2526,17 @@ def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xform replace_vae_attn_to_memory_efficient() elif xformers: # とりあえずDiffusersのxformersを使う。AttentionがあるのはMidBlockのみ - print("Use Diffusers xformers for VAE") + logger.info("Use Diffusers xformers for VAE") vae.encoder.mid_block.attentions[0].set_use_memory_efficient_attention_xformers(True) vae.decoder.mid_block.attentions[0].set_use_memory_efficient_attention_xformers(True) def replace_vae_attn_to_memory_efficient(): - print("AttentionBlock.forward has been replaced to FlashAttention (not xformers)") + logger.info("AttentionBlock.forward has been replaced to FlashAttention (not xformers)") flash_func = FlashAttentionFunction def forward_flash_attn(self, hidden_states): - print("forward_flash_attn") + logger.info("forward_flash_attn") q_bucket_size = 512 k_bucket_size = 1024 @@ -3127,13 +3132,13 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: def verify_training_args(args: argparse.Namespace): if args.v_parameterization and not args.v2: - print("v_parameterization should be with v2 not v1 or sdxl / v1やsdxlでv_parameterizationを使用することは想定されていません") + logger.warning("v_parameterization should be with v2 not v1 or sdxl / v1やsdxlでv_parameterizationを使用することは想定されていません") if args.v2 and args.clip_skip is not None: - print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") + logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") if args.cache_latents_to_disk and not args.cache_latents: args.cache_latents = True - print( + logger.warning( "cache_latents_to_disk is enabled, so cache_latents is also enabled / cache_latents_to_diskが有効なため、cache_latentsを有効にします" ) @@ -3164,7 +3169,7 @@ def verify_training_args(args: argparse.Namespace): ) if args.zero_terminal_snr and not args.v_parameterization: - print( + logger.warning( f"zero_terminal_snr is enabled, but v_parameterization is not enabled. training will be unexpected" + " / zero_terminal_snrが有効ですが、v_parameterizationが有効ではありません。学習結果は想定外になる可能性があります" ) @@ -3334,7 +3339,7 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar if args.output_config: # check if config file exists if os.path.exists(config_path): - print(f"Config file already exists. Aborting... / 出力先の設定ファイルが既に存在します: {config_path}") + logger.error(f"Config file already exists. Aborting... / 出力先の設定ファイルが既に存在します: {config_path}") exit(1) # convert args to dictionary @@ -3362,14 +3367,14 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar with open(config_path, "w") as f: toml.dump(args_dict, f) - print(f"Saved config file / 設定ファイルを保存しました: {config_path}") + logger.info(f"Saved config file / 設定ファイルを保存しました: {config_path}") exit(0) if not os.path.exists(config_path): - print(f"{config_path} not found.") + logger.info(f"{config_path} not found.") exit(1) - print(f"Loading settings from {config_path}...") + logger.info(f"Loading settings from {config_path}...") with open(config_path, "r") as f: config_dict = toml.load(f) @@ -3388,7 +3393,7 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar config_args = argparse.Namespace(**ignore_nesting_dict) args = parser.parse_args(namespace=config_args) args.config_file = os.path.splitext(args.config_file)[0] - print(args.config_file) + logger.info(args.config_file) return args @@ -3403,11 +3408,11 @@ def resume_from_local_or_hf_if_specified(accelerator, args): return if not args.resume_from_huggingface: - print(f"resume training from local state: {args.resume}") + logger.info(f"resume training from local state: {args.resume}") accelerator.load_state(args.resume) return - print(f"resume training from huggingface state: {args.resume}") + logger.info(f"resume training from huggingface state: {args.resume}") repo_id = args.resume.split("/")[0] + "/" + args.resume.split("/")[1] path_in_repo = "/".join(args.resume.split("/")[2:]) revision = None @@ -3419,7 +3424,7 @@ def resume_from_local_or_hf_if_specified(accelerator, args): repo_type = "model" else: path_in_repo, revision, repo_type = divided - print(f"Downloading state from huggingface: {repo_id}/{path_in_repo}@{revision}") + logger.info(f"Downloading state from huggingface: {repo_id}/{path_in_repo}@{revision}") list_files = huggingface_util.list_dir( repo_id=repo_id, @@ -3491,7 +3496,7 @@ def get_optimizer(args, trainable_params): # value = tuple(value) optimizer_kwargs[key] = value - # print("optkwargs:", optimizer_kwargs) + # logger.info(f"optkwargs {optimizer}_{kwargs}") lr = args.learning_rate optimizer = None @@ -3501,7 +3506,7 @@ def get_optimizer(args, trainable_params): import lion_pytorch except ImportError: raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです") - print(f"use Lion optimizer | {optimizer_kwargs}") + logger.info(f"use Lion optimizer | {optimizer_kwargs}") optimizer_class = lion_pytorch.Lion optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) @@ -3512,14 +3517,14 @@ def get_optimizer(args, trainable_params): raise ImportError("No bitsandbytes / bitsandbytesがインストールされていないようです") if optimizer_type == "AdamW8bit".lower(): - print(f"use 8-bit AdamW optimizer | {optimizer_kwargs}") + logger.info(f"use 8-bit AdamW optimizer | {optimizer_kwargs}") optimizer_class = bnb.optim.AdamW8bit optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) elif optimizer_type == "SGDNesterov8bit".lower(): - print(f"use 8-bit SGD with Nesterov optimizer | {optimizer_kwargs}") + logger.info(f"use 8-bit SGD with Nesterov optimizer | {optimizer_kwargs}") if "momentum" not in optimizer_kwargs: - print( + logger.warning( f"8-bit SGD with Nesterov must be with momentum, set momentum to 0.9 / 8-bit SGD with Nesterovはmomentum指定が必須のため0.9に設定します" ) optimizer_kwargs["momentum"] = 0.9 @@ -3528,7 +3533,7 @@ def get_optimizer(args, trainable_params): optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs) elif optimizer_type == "Lion8bit".lower(): - print(f"use 8-bit Lion optimizer | {optimizer_kwargs}") + logger.info(f"use 8-bit Lion optimizer | {optimizer_kwargs}") try: optimizer_class = bnb.optim.Lion8bit except AttributeError: @@ -3536,7 +3541,7 @@ def get_optimizer(args, trainable_params): "No Lion8bit. The version of bitsandbytes installed seems to be old. Please install 0.38.0 or later. / Lion8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.38.0以上をインストールしてください" ) elif optimizer_type == "PagedAdamW8bit".lower(): - print(f"use 8-bit PagedAdamW optimizer | {optimizer_kwargs}") + logger.info(f"use 8-bit PagedAdamW optimizer | {optimizer_kwargs}") try: optimizer_class = bnb.optim.PagedAdamW8bit except AttributeError: @@ -3544,7 +3549,7 @@ def get_optimizer(args, trainable_params): "No PagedAdamW8bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedAdamW8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" ) elif optimizer_type == "PagedLion8bit".lower(): - print(f"use 8-bit Paged Lion optimizer | {optimizer_kwargs}") + logger.info(f"use 8-bit Paged Lion optimizer | {optimizer_kwargs}") try: optimizer_class = bnb.optim.PagedLion8bit except AttributeError: @@ -3555,7 +3560,7 @@ def get_optimizer(args, trainable_params): optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) elif optimizer_type == "PagedAdamW".lower(): - print(f"use PagedAdamW optimizer | {optimizer_kwargs}") + logger.info(f"use PagedAdamW optimizer | {optimizer_kwargs}") try: import bitsandbytes as bnb except ImportError: @@ -3569,7 +3574,7 @@ def get_optimizer(args, trainable_params): optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) elif optimizer_type == "PagedAdamW32bit".lower(): - print(f"use 32-bit PagedAdamW optimizer | {optimizer_kwargs}") + logger.info(f"use 32-bit PagedAdamW optimizer | {optimizer_kwargs}") try: import bitsandbytes as bnb except ImportError: @@ -3583,16 +3588,16 @@ def get_optimizer(args, trainable_params): optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) elif optimizer_type == "SGDNesterov".lower(): - print(f"use SGD with Nesterov optimizer | {optimizer_kwargs}") + logger.info(f"use SGD with Nesterov optimizer | {optimizer_kwargs}") if "momentum" not in optimizer_kwargs: - print(f"SGD with Nesterov must be with momentum, set momentum to 0.9 / SGD with Nesterovはmomentum指定が必須のため0.9に設定します") + logger.info(f"SGD with Nesterov must be with momentum, set momentum to 0.9 / SGD with Nesterovはmomentum指定が必須のため0.9に設定します") optimizer_kwargs["momentum"] = 0.9 optimizer_class = torch.optim.SGD optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs) elif optimizer_type.startswith("DAdapt".lower()) or optimizer_type == "Prodigy".lower(): - # check lr and lr_count, and print warning + # check lr and lr_count, and logger.info warning actual_lr = lr lr_count = 1 if type(trainable_params) == list and type(trainable_params[0]) == dict: @@ -3603,12 +3608,12 @@ def get_optimizer(args, trainable_params): lr_count = len(lrs) if actual_lr <= 0.1: - print( + logger.warning( f"learning rate is too low. If using D-Adaptation or Prodigy, set learning rate around 1.0 / 学習率が低すぎるようです。D-AdaptationまたはProdigyの使用時は1.0前後の値を指定してください: lr={actual_lr}" ) - print("recommend option: lr=1.0 / 推奨は1.0です") + logger.warning("recommend option: lr=1.0 / 推奨は1.0です") if lr_count > 1: - print( + logger.warning( f"when multiple learning rates are specified with dadaptation (e.g. for Text Encoder and U-Net), only the first one will take effect / D-AdaptationまたはProdigyで複数の学習率を指定した場合(Text EncoderとU-Netなど)、最初の学習率のみが有効になります: lr={actual_lr}" ) @@ -3624,25 +3629,25 @@ def get_optimizer(args, trainable_params): # set optimizer if optimizer_type == "DAdaptation".lower() or optimizer_type == "DAdaptAdamPreprint".lower(): optimizer_class = experimental.DAdaptAdamPreprint - print(f"use D-Adaptation AdamPreprint optimizer | {optimizer_kwargs}") + logger.info(f"use D-Adaptation AdamPreprint optimizer | {optimizer_kwargs}") elif optimizer_type == "DAdaptAdaGrad".lower(): optimizer_class = dadaptation.DAdaptAdaGrad - print(f"use D-Adaptation AdaGrad optimizer | {optimizer_kwargs}") + logger.info(f"use D-Adaptation AdaGrad optimizer | {optimizer_kwargs}") elif optimizer_type == "DAdaptAdam".lower(): optimizer_class = dadaptation.DAdaptAdam - print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}") + logger.info(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}") elif optimizer_type == "DAdaptAdan".lower(): optimizer_class = dadaptation.DAdaptAdan - print(f"use D-Adaptation Adan optimizer | {optimizer_kwargs}") + logger.info(f"use D-Adaptation Adan optimizer | {optimizer_kwargs}") elif optimizer_type == "DAdaptAdanIP".lower(): optimizer_class = experimental.DAdaptAdanIP - print(f"use D-Adaptation AdanIP optimizer | {optimizer_kwargs}") + logger.info(f"use D-Adaptation AdanIP optimizer | {optimizer_kwargs}") elif optimizer_type == "DAdaptLion".lower(): optimizer_class = dadaptation.DAdaptLion - print(f"use D-Adaptation Lion optimizer | {optimizer_kwargs}") + logger.info(f"use D-Adaptation Lion optimizer | {optimizer_kwargs}") elif optimizer_type == "DAdaptSGD".lower(): optimizer_class = dadaptation.DAdaptSGD - print(f"use D-Adaptation SGD optimizer | {optimizer_kwargs}") + logger.info(f"use D-Adaptation SGD optimizer | {optimizer_kwargs}") else: raise ValueError(f"Unknown optimizer type: {optimizer_type}") @@ -3655,7 +3660,7 @@ def get_optimizer(args, trainable_params): except ImportError: raise ImportError("No Prodigy / Prodigy がインストールされていないようです") - print(f"use Prodigy optimizer | {optimizer_kwargs}") + logger.info(f"use Prodigy optimizer | {optimizer_kwargs}") optimizer_class = prodigyopt.Prodigy optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) @@ -3664,14 +3669,14 @@ def get_optimizer(args, trainable_params): if "relative_step" not in optimizer_kwargs: optimizer_kwargs["relative_step"] = True # default if not optimizer_kwargs["relative_step"] and optimizer_kwargs.get("warmup_init", False): - print(f"set relative_step to True because warmup_init is True / warmup_initがTrueのためrelative_stepをTrueにします") + logger.info(f"set relative_step to True because warmup_init is True / warmup_initがTrueのためrelative_stepをTrueにします") optimizer_kwargs["relative_step"] = True - print(f"use Adafactor optimizer | {optimizer_kwargs}") + logger.info(f"use Adafactor optimizer | {optimizer_kwargs}") if optimizer_kwargs["relative_step"]: - print(f"relative_step is true / relative_stepがtrueです") + logger.info(f"relative_step is true / relative_stepがtrueです") if lr != 0.0: - print(f"learning rate is used as initial_lr / 指定したlearning rateはinitial_lrとして使用されます") + logger.warning(f"learning rate is used as initial_lr / 指定したlearning rateはinitial_lrとして使用されます") args.learning_rate = None # trainable_paramsがgroupだった時の処理:lrを削除する @@ -3683,37 +3688,37 @@ def get_optimizer(args, trainable_params): if has_group_lr: # 一応argsを無効にしておく TODO 依存関係が逆転してるのであまり望ましくない - print(f"unet_lr and text_encoder_lr are ignored / unet_lrとtext_encoder_lrは無視されます") + logger.warning(f"unet_lr and text_encoder_lr are ignored / unet_lrとtext_encoder_lrは無視されます") args.unet_lr = None args.text_encoder_lr = None if args.lr_scheduler != "adafactor": - print(f"use adafactor_scheduler / スケジューラにadafactor_schedulerを使用します") + logger.info(f"use adafactor_scheduler / スケジューラにadafactor_schedulerを使用します") args.lr_scheduler = f"adafactor:{lr}" # ちょっと微妙だけど lr = None else: if args.max_grad_norm != 0.0: - print( + logger.warning( f"because max_grad_norm is set, clip_grad_norm is enabled. consider set to 0 / max_grad_normが設定されているためclip_grad_normが有効になります。0に設定して無効にしたほうがいいかもしれません" ) if args.lr_scheduler != "constant_with_warmup": - print(f"constant_with_warmup will be good / スケジューラはconstant_with_warmupが良いかもしれません") + logger.warning(f"constant_with_warmup will be good / スケジューラはconstant_with_warmupが良いかもしれません") if optimizer_kwargs.get("clip_threshold", 1.0) != 1.0: - print(f"clip_threshold=1.0 will be good / clip_thresholdは1.0が良いかもしれません") + logger.warning(f"clip_threshold=1.0 will be good / clip_thresholdは1.0が良いかもしれません") optimizer_class = transformers.optimization.Adafactor optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) elif optimizer_type == "AdamW".lower(): - print(f"use AdamW optimizer | {optimizer_kwargs}") + logger.info(f"use AdamW optimizer | {optimizer_kwargs}") optimizer_class = torch.optim.AdamW optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) if optimizer is None: # 任意のoptimizerを使う optimizer_type = args.optimizer_type # lowerでないやつ(微妙) - print(f"use {optimizer_type} | {optimizer_kwargs}") + logger.info(f"use {optimizer_type} | {optimizer_kwargs}") if "." not in optimizer_type: optimizer_module = torch.optim else: @@ -3759,7 +3764,7 @@ def wrap_check_needless_num_warmup_steps(return_vals): # using any lr_scheduler from other library if args.lr_scheduler_type: lr_scheduler_type = args.lr_scheduler_type - print(f"use {lr_scheduler_type} | {lr_scheduler_kwargs} as lr_scheduler") + logger.info(f"use {lr_scheduler_type} | {lr_scheduler_kwargs} as lr_scheduler") if "." not in lr_scheduler_type: # default to use torch.optim lr_scheduler_module = torch.optim.lr_scheduler else: @@ -3775,7 +3780,7 @@ def wrap_check_needless_num_warmup_steps(return_vals): type(optimizer) == transformers.optimization.Adafactor ), f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください" initial_lr = float(name.split(":")[1]) - # print("adafactor scheduler init lr", initial_lr) + # logger.info(f"adafactor scheduler init lr {initial_lr}") return wrap_check_needless_num_warmup_steps(transformers.optimization.AdafactorSchedule(optimizer, initial_lr)) name = SchedulerType(name) @@ -3840,20 +3845,20 @@ def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool): if support_metadata: if args.in_json is not None and (args.color_aug or args.random_crop): - print( + logger.warning( f"latents in npz is ignored when color_aug or random_crop is True / color_augまたはrandom_cropを有効にした場合、npzファイルのlatentsは無視されます" ) def load_tokenizer(args: argparse.Namespace): - print("prepare tokenizer") + logger.info("prepare tokenizer") original_path = V2_STABLE_DIFFUSION_PATH if args.v2 else TOKENIZER_PATH tokenizer: CLIPTokenizer = None if args.tokenizer_cache_dir: local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_")) if os.path.exists(local_tokenizer_path): - print(f"load tokenizer from cache: {local_tokenizer_path}") + logger.info(f"load tokenizer from cache: {local_tokenizer_path}") tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path) # same for v1 and v2 if tokenizer is None: @@ -3863,10 +3868,10 @@ def load_tokenizer(args: argparse.Namespace): tokenizer = CLIPTokenizer.from_pretrained(original_path) if hasattr(args, "max_token_length") and args.max_token_length is not None: - print(f"update token length: {args.max_token_length}") + logger.info(f"update token length: {args.max_token_length}") if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path): - print(f"save Tokenizer to cache: {local_tokenizer_path}") + logger.info(f"save Tokenizer to cache: {local_tokenizer_path}") tokenizer.save_pretrained(local_tokenizer_path) return tokenizer @@ -3946,17 +3951,17 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", une name_or_path = os.path.realpath(name_or_path) if os.path.islink(name_or_path) else name_or_path load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers if load_stable_diffusion_format: - print(f"load StableDiffusion checkpoint: {name_or_path}") + logger.info(f"load StableDiffusion checkpoint: {name_or_path}") text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint( args.v2, name_or_path, device, unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2 ) else: # Diffusers model is loaded to CPU - print(f"load Diffusers pretrained models: {name_or_path}") + logger.info(f"load Diffusers pretrained models: {name_or_path}") try: pipe = StableDiffusionPipeline.from_pretrained(name_or_path, tokenizer=None, safety_checker=None) except EnvironmentError as ex: - print( + logger.error( f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}" ) raise ex @@ -3967,7 +3972,7 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", une # Diffusers U-Net to original U-Net # TODO *.ckpt/*.safetensorsのv2と同じ形式にここで変換すると良さそう - # print(f"unet config: {unet.config}") + # logger.info(f"unet config: {unet.config}") original_unet = UNet2DConditionModel( unet.config.sample_size, unet.config.attention_head_dim, @@ -3977,12 +3982,12 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", une ) original_unet.load_state_dict(unet.state_dict()) unet = original_unet - print("U-Net converted to original U-Net") + logger.info("U-Net converted to original U-Net") # VAEを読み込む if args.vae is not None: vae = model_util.load_vae(args.vae, weight_dtype) - print("additional VAE loaded") + logger.info("additional VAE loaded") return text_encoder, vae, unet, load_stable_diffusion_format @@ -3991,7 +3996,7 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio # load models for each process for pi in range(accelerator.state.num_processes): if pi == accelerator.state.local_process_index: - print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}") + logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}") text_encoder, vae, unet, load_stable_diffusion_format = _load_target_model( args, @@ -4300,7 +4305,8 @@ def save_sd_model_on_epoch_end_or_stepwise_common( ckpt_name = get_step_ckpt_name(args, ext, global_step) ckpt_file = os.path.join(args.output_dir, ckpt_name) - print(f"\nsaving checkpoint: {ckpt_file}") + logger.info("") + logger.info(f"saving checkpoint: {ckpt_file}") sd_saver(ckpt_file, epoch_no, global_step) if args.huggingface_repo_id is not None: @@ -4315,7 +4321,7 @@ def save_sd_model_on_epoch_end_or_stepwise_common( remove_ckpt_file = os.path.join(args.output_dir, remove_ckpt_name) if os.path.exists(remove_ckpt_file): - print(f"removing old checkpoint: {remove_ckpt_file}") + logger.info(f"removing old checkpoint: {remove_ckpt_file}") os.remove(remove_ckpt_file) else: @@ -4324,7 +4330,8 @@ def save_sd_model_on_epoch_end_or_stepwise_common( else: out_dir = os.path.join(args.output_dir, STEP_DIFFUSERS_DIR_NAME.format(model_name, global_step)) - print(f"\nsaving model: {out_dir}") + logger.info("") + logger.info(f"saving model: {out_dir}") diffusers_saver(out_dir) if args.huggingface_repo_id is not None: @@ -4338,7 +4345,7 @@ def save_sd_model_on_epoch_end_or_stepwise_common( remove_out_dir = os.path.join(args.output_dir, STEP_DIFFUSERS_DIR_NAME.format(model_name, remove_no)) if os.path.exists(remove_out_dir): - print(f"removing old model: {remove_out_dir}") + logger.info(f"removing old model: {remove_out_dir}") shutil.rmtree(remove_out_dir) if args.save_state: @@ -4351,13 +4358,14 @@ def save_sd_model_on_epoch_end_or_stepwise_common( def save_and_remove_state_on_epoch_end(args: argparse.Namespace, accelerator, epoch_no): model_name = default_if_none(args.output_name, DEFAULT_EPOCH_NAME) - print(f"\nsaving state at epoch {epoch_no}") + logger.info("") + logger.info(f"saving state at epoch {epoch_no}") os.makedirs(args.output_dir, exist_ok=True) state_dir = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no)) accelerator.save_state(state_dir) if args.save_state_to_huggingface: - print("uploading state to huggingface.") + logger.info("uploading state to huggingface.") huggingface_util.upload(args, state_dir, "/" + EPOCH_STATE_NAME.format(model_name, epoch_no)) last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs @@ -4365,20 +4373,21 @@ def save_and_remove_state_on_epoch_end(args: argparse.Namespace, accelerator, ep remove_epoch_no = epoch_no - args.save_every_n_epochs * last_n_epochs state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, remove_epoch_no)) if os.path.exists(state_dir_old): - print(f"removing old state: {state_dir_old}") + logger.info(f"removing old state: {state_dir_old}") shutil.rmtree(state_dir_old) def save_and_remove_state_stepwise(args: argparse.Namespace, accelerator, step_no): model_name = default_if_none(args.output_name, DEFAULT_STEP_NAME) - print(f"\nsaving state at step {step_no}") + logger.info("") + logger.info(f"saving state at step {step_no}") os.makedirs(args.output_dir, exist_ok=True) state_dir = os.path.join(args.output_dir, STEP_STATE_NAME.format(model_name, step_no)) accelerator.save_state(state_dir) if args.save_state_to_huggingface: - print("uploading state to huggingface.") + logger.info("uploading state to huggingface.") huggingface_util.upload(args, state_dir, "/" + STEP_STATE_NAME.format(model_name, step_no)) last_n_steps = args.save_last_n_steps_state if args.save_last_n_steps_state else args.save_last_n_steps @@ -4390,21 +4399,22 @@ def save_and_remove_state_stepwise(args: argparse.Namespace, accelerator, step_n if remove_step_no > 0: state_dir_old = os.path.join(args.output_dir, STEP_STATE_NAME.format(model_name, remove_step_no)) if os.path.exists(state_dir_old): - print(f"removing old state: {state_dir_old}") + logger.info(f"removing old state: {state_dir_old}") shutil.rmtree(state_dir_old) def save_state_on_train_end(args: argparse.Namespace, accelerator): model_name = default_if_none(args.output_name, DEFAULT_LAST_OUTPUT_NAME) - print("\nsaving last state.") + logger.info("") + logger.info("saving last state.") os.makedirs(args.output_dir, exist_ok=True) state_dir = os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)) accelerator.save_state(state_dir) if args.save_state_to_huggingface: - print("uploading last state to huggingface.") + logger.info("uploading last state to huggingface.") huggingface_util.upload(args, state_dir, "/" + LAST_STATE_NAME.format(model_name)) @@ -4453,7 +4463,7 @@ def save_sd_model_on_train_end_common( ckpt_name = model_name + (".safetensors" if use_safetensors else ".ckpt") ckpt_file = os.path.join(args.output_dir, ckpt_name) - print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}") + logger.info(f"save trained model as StableDiffusion checkpoint to {ckpt_file}") sd_saver(ckpt_file, epoch, global_step) if args.huggingface_repo_id is not None: @@ -4462,7 +4472,7 @@ def save_sd_model_on_train_end_common( out_dir = os.path.join(args.output_dir, model_name) os.makedirs(out_dir, exist_ok=True) - print(f"save trained model as Diffusers to {out_dir}") + logger.info(f"save trained model as Diffusers to {out_dir}") diffusers_saver(out_dir) if args.huggingface_repo_id is not None: @@ -4572,7 +4582,7 @@ def get_my_scheduler( # clip_sample=Trueにする if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: - # print("set clip_sample to True") + # logger.info("set clip_sample to True") scheduler.config.clip_sample = True return scheduler @@ -4631,8 +4641,8 @@ def line_to_prompt_dict(line: str) -> dict: continue except ValueError as ex: - print(f"Exception in parsing / 解析エラー: {parg}") - print(ex) + logger.error(f"Exception in parsing / 解析エラー: {parg}") + logger.error(ex) return prompt_dict @@ -4668,9 +4678,10 @@ def sample_images_common( if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch return - print(f"\ngenerating sample images at step / サンプル画像生成 ステップ: {steps}") + logger.info("") + logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}") if not os.path.isfile(args.sample_prompts): - print(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") + logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") return org_vae_device = vae.device # CPUにいるはず @@ -4770,15 +4781,15 @@ def sample_images_common( height = max(64, height - height % 8) # round to divisible by 8 width = max(64, width - width % 8) # round to divisible by 8 - print(f"prompt: {prompt}") - print(f"negative_prompt: {negative_prompt}") - print(f"height: {height}") - print(f"width: {width}") - print(f"sample_steps: {sample_steps}") - print(f"scale: {scale}") - print(f"sample_sampler: {sampler_name}") + logger.info(f"prompt: {prompt}") + logger.info(f"negative_prompt: {negative_prompt}") + logger.info(f"height: {height}") + logger.info(f"width: {width}") + logger.info(f"sample_steps: {sample_steps}") + logger.info(f"scale: {scale}") + logger.info(f"sample_sampler: {sampler_name}") if seed is not None: - print(f"seed: {seed}") + logger.info(f"seed: {seed}") with accelerator.autocast(): latents = pipeline( prompt=prompt, @@ -4844,7 +4855,7 @@ def __getitem__(self, idx): # convert to tensor temporarily so dataloader will accept it tensor_pil = transforms.functional.pil_to_tensor(image) except Exception as e: - print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") + logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") return None return (tensor_pil, img_path) diff --git a/library/utils.py b/library/utils.py index 7d801a676..ebcaba308 100644 --- a/library/utils.py +++ b/library/utils.py @@ -1,6 +1,23 @@ +import logging import threading from typing import * def fire_in_thread(f, *args, **kwargs): - threading.Thread(target=f, args=args, kwargs=kwargs).start() \ No newline at end of file + threading.Thread(target=f, args=args, kwargs=kwargs).start() + + +def setup_logging(log_level=logging.INFO): + if logging.root.handlers: # Already configured + return + + from rich.logging import RichHandler + + handler = RichHandler() + formatter = logging.Formatter( + fmt="%(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + handler.setFormatter(formatter) + logging.root.setLevel(log_level) + logging.root.addHandler(handler) diff --git a/networks/check_lora_weights.py b/networks/check_lora_weights.py index 51f581b29..6ec60a89b 100644 --- a/networks/check_lora_weights.py +++ b/networks/check_lora_weights.py @@ -2,10 +2,13 @@ import os import torch from safetensors.torch import load_file - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def main(file): - print(f"loading: {file}") + logger.info(f"loading: {file}") if os.path.splitext(file)[1] == ".safetensors": sd = load_file(file) else: @@ -17,16 +20,16 @@ def main(file): for key in keys: if "lora_up" in key or "lora_down" in key: values.append((key, sd[key])) - print(f"number of LoRA modules: {len(values)}") + logger.info(f"number of LoRA modules: {len(values)}") if args.show_all_keys: for key in [k for k in keys if k not in values]: values.append((key, sd[key])) - print(f"number of all modules: {len(values)}") + logger.info(f"number of all modules: {len(values)}") for key, value in values: value = value.to(torch.float32) - print(f"{key},{str(tuple(value.size())).replace(', ', '-')},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}") + logger.info(f"{key},{str(tuple(value.size())).replace(', ', '-')},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}") def setup_parser() -> argparse.ArgumentParser: diff --git a/networks/control_net_lllite.py b/networks/control_net_lllite.py index 4ebfef7a4..c9377bee8 100644 --- a/networks/control_net_lllite.py +++ b/networks/control_net_lllite.py @@ -2,7 +2,10 @@ from typing import Optional, List, Type import torch from library import sdxl_original_unet - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) # input_blocksに適用するかどうか / if True, input_blocks are not applied SKIP_INPUT_BLOCKS = False @@ -125,7 +128,7 @@ def set_cond_image(self, cond_image): return # timestepごとに呼ばれないので、あらかじめ計算しておく / it is not called for each timestep, so calculate it in advance - # print(f"C {self.lllite_name}, cond_image.shape={cond_image.shape}") + # logger.info(f"C {self.lllite_name}, cond_image.shape={cond_image.shape}") cx = self.conditioning1(cond_image) if not self.is_conv2d: # reshape / b,c,h,w -> b,h*w,c @@ -155,7 +158,7 @@ def forward(self, x): cx = cx.repeat(2, 1, 1, 1) if self.is_conv2d else cx.repeat(2, 1, 1) if self.use_zeros_for_batch_uncond: cx[0::2] = 0.0 # uncond is zero - # print(f"C {self.lllite_name}, x.shape={x.shape}, cx.shape={cx.shape}") + # logger.info(f"C {self.lllite_name}, x.shape={x.shape}, cx.shape={cx.shape}") # downで入力の次元数を削減し、conditioning image embeddingと結合する # 加算ではなくchannel方向に結合することで、うまいこと混ぜてくれることを期待している @@ -286,7 +289,7 @@ def create_modules( # create module instances self.unet_modules: List[LLLiteModule] = create_modules(unet, target_modules, LLLiteModule) - print(f"create ControlNet LLLite for U-Net: {len(self.unet_modules)} modules.") + logger.info(f"create ControlNet LLLite for U-Net: {len(self.unet_modules)} modules.") def forward(self, x): return x # dummy @@ -319,7 +322,7 @@ def load_weights(self, file): return info def apply_to(self): - print("applying LLLite for U-Net...") + logger.info("applying LLLite for U-Net...") for module in self.unet_modules: module.apply_to() self.add_module(module.lllite_name, module) @@ -374,19 +377,19 @@ def save_weights(self, file, dtype, metadata): # sdxl_original_unet.USE_REENTRANT = False # test shape etc - print("create unet") + logger.info("create unet") unet = sdxl_original_unet.SdxlUNet2DConditionModel() unet.to("cuda").to(torch.float16) - print("create ControlNet-LLLite") + logger.info("create ControlNet-LLLite") control_net = ControlNetLLLite(unet, 32, 64) control_net.apply_to() control_net.to("cuda") - print(control_net) + logger.info(control_net) - # print number of parameters - print("number of parameters", sum(p.numel() for p in control_net.parameters() if p.requires_grad)) + # logger.info number of parameters + logger.info(f"number of parameters {sum(p.numel() for p in control_net.parameters() if p.requires_grad)}") input() @@ -398,12 +401,12 @@ def save_weights(self, file, dtype, metadata): # # visualize # import torchviz - # print("run visualize") + # logger.info("run visualize") # controlnet.set_control(conditioning_image) # output = unet(x, t, ctx, y) - # print("make_dot") + # logger.info("make_dot") # image = torchviz.make_dot(output, params=dict(controlnet.named_parameters())) - # print("render") + # logger.info("render") # image.format = "svg" # "png" # image.render("NeuralNet") # すごく時間がかかるので注意 / be careful because it takes a long time # input() @@ -414,12 +417,12 @@ def save_weights(self, file, dtype, metadata): scaler = torch.cuda.amp.GradScaler(enabled=True) - print("start training") + logger.info("start training") steps = 10 sample_param = [p for p in control_net.named_parameters() if "up" in p[0]][0] for step in range(steps): - print(f"step {step}") + logger.info(f"step {step}") batch_size = 1 conditioning_image = torch.rand(batch_size, 3, 1024, 1024).cuda() * 2.0 - 1.0 @@ -439,7 +442,7 @@ def save_weights(self, file, dtype, metadata): scaler.step(optimizer) scaler.update() optimizer.zero_grad(set_to_none=True) - print(sample_param) + logger.info(f"{sample_param}") # from safetensors.torch import save_file diff --git a/networks/control_net_lllite_for_train.py b/networks/control_net_lllite_for_train.py index 026880015..65b3520cf 100644 --- a/networks/control_net_lllite_for_train.py +++ b/networks/control_net_lllite_for_train.py @@ -6,7 +6,10 @@ from typing import Optional, List, Type import torch from library import sdxl_original_unet - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) # input_blocksに適用するかどうか / if True, input_blocks are not applied SKIP_INPUT_BLOCKS = False @@ -270,7 +273,7 @@ def apply_to_modules( # create module instances self.lllite_modules = apply_to_modules(self, target_modules) - print(f"enable ControlNet LLLite for U-Net: {len(self.lllite_modules)} modules.") + logger.info(f"enable ControlNet LLLite for U-Net: {len(self.lllite_modules)} modules.") # def prepare_optimizer_params(self): def prepare_params(self): @@ -281,8 +284,8 @@ def prepare_params(self): train_params.append(p) else: non_train_params.append(p) - print(f"count of trainable parameters: {len(train_params)}") - print(f"count of non-trainable parameters: {len(non_train_params)}") + logger.info(f"count of trainable parameters: {len(train_params)}") + logger.info(f"count of non-trainable parameters: {len(non_train_params)}") for p in non_train_params: p.requires_grad_(False) @@ -388,7 +391,7 @@ def load_lllite_weights(self, file, non_lllite_unet_sd=None): matches = pattern.findall(module_name) if matches is not None: for m in matches: - print(module_name, m) + logger.info(f"{module_name} {m}") module_name = module_name.replace(m, m.replace("_", "@")) module_name = module_name.replace("_", ".") module_name = module_name.replace("@", "_") @@ -407,7 +410,7 @@ def forward(self, x, timesteps=None, context=None, y=None, cond_image=None, **kw def replace_unet_linear_and_conv2d(): - print("replace torch.nn.Linear and torch.nn.Conv2d to LLLiteLinear and LLLiteConv2d in U-Net") + logger.info("replace torch.nn.Linear and torch.nn.Conv2d to LLLiteLinear and LLLiteConv2d in U-Net") sdxl_original_unet.torch.nn.Linear = LLLiteLinear sdxl_original_unet.torch.nn.Conv2d = LLLiteConv2d @@ -419,10 +422,10 @@ def replace_unet_linear_and_conv2d(): replace_unet_linear_and_conv2d() # test shape etc - print("create unet") + logger.info("create unet") unet = SdxlUNet2DConditionModelControlNetLLLite() - print("enable ControlNet-LLLite") + logger.info("enable ControlNet-LLLite") unet.apply_lllite(32, 64, None, False, 1.0) unet.to("cuda") # .to(torch.float16) @@ -439,14 +442,14 @@ def replace_unet_linear_and_conv2d(): # unet_sd[converted_key] = model_sd[key] # info = unet.load_lllite_weights("r:/lllite_from_unet.safetensors", unet_sd) - # print(info) + # logger.info(info) - # print(unet) + # logger.info(unet) - # print number of parameters + # logger.info number of parameters params = unet.prepare_params() - print("number of parameters", sum(p.numel() for p in params)) - # print("type any key to continue") + logger.info(f"number of parameters {sum(p.numel() for p in params)}") + # logger.info("type any key to continue") # input() unet.set_use_memory_efficient_attention(True, False) @@ -455,12 +458,12 @@ def replace_unet_linear_and_conv2d(): # # visualize # import torchviz - # print("run visualize") + # logger.info("run visualize") # controlnet.set_control(conditioning_image) # output = unet(x, t, ctx, y) - # print("make_dot") + # logger.info("make_dot") # image = torchviz.make_dot(output, params=dict(controlnet.named_parameters())) - # print("render") + # logger.info("render") # image.format = "svg" # "png" # image.render("NeuralNet") # すごく時間がかかるので注意 / be careful because it takes a long time # input() @@ -471,13 +474,13 @@ def replace_unet_linear_and_conv2d(): scaler = torch.cuda.amp.GradScaler(enabled=True) - print("start training") + logger.info("start training") steps = 10 batch_size = 1 sample_param = [p for p in unet.named_parameters() if ".lllite_up." in p[0]][0] for step in range(steps): - print(f"step {step}") + logger.info(f"step {step}") conditioning_image = torch.rand(batch_size, 3, 1024, 1024).cuda() * 2.0 - 1.0 x = torch.randn(batch_size, 4, 128, 128).cuda() @@ -494,9 +497,9 @@ def replace_unet_linear_and_conv2d(): scaler.step(optimizer) scaler.update() optimizer.zero_grad(set_to_none=True) - print(sample_param) + logger.info(sample_param) # from safetensors.torch import save_file - # print("save weights") + # logger.info("save weights") # unet.save_lllite_weights("r:/lllite_from_unet.safetensors", torch.float16, None) diff --git a/networks/dylora.py b/networks/dylora.py index e5a55d198..262699014 100644 --- a/networks/dylora.py +++ b/networks/dylora.py @@ -15,7 +15,10 @@ from typing import List, Tuple, Union import torch from torch import nn - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) class DyLoRAModule(torch.nn.Module): """ @@ -223,7 +226,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh elif "lora_down" in key: dim = value.size()[0] modules_dim[lora_name] = dim - # print(lora_name, value.size(), dim) + # logger.info(f"{lora_name} {value.size()} {dim}") # support old LoRA without alpha for key in modules_dim.keys(): @@ -267,11 +270,11 @@ def __init__( self.apply_to_conv = apply_to_conv if modules_dim is not None: - print(f"create LoRA network from weights") + logger.info("create LoRA network from weights") else: - print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}, unit: {unit}") + logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}, unit: {unit}") if self.apply_to_conv: - print(f"apply LoRA to Conv2d with kernel size (3,3).") + logger.info("apply LoRA to Conv2d with kernel size (3,3).") # create module instances def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules) -> List[DyLoRAModule]: @@ -308,7 +311,7 @@ def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules return loras self.text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) - print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") + logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights target_modules = DyLoRANetwork.UNET_TARGET_REPLACE_MODULE @@ -316,7 +319,7 @@ def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules target_modules += DyLoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 self.unet_loras = create_modules(True, unet, target_modules) - print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") + logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") def set_multiplier(self, multiplier): self.multiplier = multiplier @@ -336,12 +339,12 @@ def load_weights(self, file): def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True): if apply_text_encoder: - print("enable LoRA for text encoder") + logger.info("enable LoRA for text encoder") else: self.text_encoder_loras = [] if apply_unet: - print("enable LoRA for U-Net") + logger.info("enable LoRA for U-Net") else: self.unet_loras = [] @@ -359,12 +362,12 @@ def merge_to(self, text_encoder, unet, weights_sd, dtype, device): apply_unet = True if apply_text_encoder: - print("enable LoRA for text encoder") + logger.info("enable LoRA for text encoder") else: self.text_encoder_loras = [] if apply_unet: - print("enable LoRA for U-Net") + logger.info("enable LoRA for U-Net") else: self.unet_loras = [] @@ -375,7 +378,7 @@ def merge_to(self, text_encoder, unet, weights_sd, dtype, device): sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key] lora.merge_to(sd_for_lora, dtype, device) - print(f"weights are merged") + logger.info(f"weights are merged") """ def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): diff --git a/networks/extract_lora_from_dylora.py b/networks/extract_lora_from_dylora.py index 0abee9836..1184cd8a5 100644 --- a/networks/extract_lora_from_dylora.py +++ b/networks/extract_lora_from_dylora.py @@ -10,7 +10,10 @@ from tqdm import tqdm from library import train_util, model_util import numpy as np - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def load_state_dict(file_name): if model_util.is_safetensors(file_name): @@ -40,13 +43,13 @@ def split_lora_model(lora_sd, unit): rank = value.size()[0] if rank > max_rank: max_rank = rank - print(f"Max rank: {max_rank}") + logger.info(f"Max rank: {max_rank}") rank = unit split_models = [] new_alpha = None while rank < max_rank: - print(f"Splitting rank {rank}") + logger.info(f"Splitting rank {rank}") new_sd = {} for key, value in lora_sd.items(): if "lora_down" in key: @@ -57,7 +60,7 @@ def split_lora_model(lora_sd, unit): # なぜかscaleするとおかしくなる…… # this_rank = lora_sd[key.replace("alpha", "lora_down.weight")].size()[0] # scale = math.sqrt(this_rank / rank) # rank is > unit - # print(key, value.size(), this_rank, rank, value, scale) + # logger.info(key, value.size(), this_rank, rank, value, scale) # new_alpha = value * scale # always same # new_sd[key] = new_alpha new_sd[key] = value @@ -69,10 +72,10 @@ def split_lora_model(lora_sd, unit): def split(args): - print("loading Model...") + logger.info("loading Model...") lora_sd, metadata = load_state_dict(args.model) - print("Splitting Model...") + logger.info("Splitting Model...") original_rank, split_models = split_lora_model(lora_sd, args.unit) comment = metadata.get("ss_training_comment", "") @@ -94,7 +97,7 @@ def split(args): filename, ext = os.path.splitext(args.save_to) model_file_name = filename + f"-{new_rank:04d}{ext}" - print(f"saving model to: {model_file_name}") + logger.info(f"saving model to: {model_file_name}") save_to_file(model_file_name, state_dict, new_metadata) diff --git a/networks/extract_lora_from_models.py b/networks/extract_lora_from_models.py index b9027adba..43c1d0058 100644 --- a/networks/extract_lora_from_models.py +++ b/networks/extract_lora_from_models.py @@ -11,7 +11,10 @@ from tqdm import tqdm from library import sai_model_spec, model_util, sdxl_model_util import lora - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) # CLAMP_QUANTILE = 0.99 # MIN_DIFF = 1e-1 @@ -66,14 +69,14 @@ def str_to_dtype(p): # load models if not sdxl: - print(f"loading original SD model : {model_org}") + logger.info(f"loading original SD model : {model_org}") text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_org) text_encoders_o = [text_encoder_o] if load_dtype is not None: text_encoder_o = text_encoder_o.to(load_dtype) unet_o = unet_o.to(load_dtype) - print(f"loading tuned SD model : {model_tuned}") + logger.info(f"loading tuned SD model : {model_tuned}") text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_tuned) text_encoders_t = [text_encoder_t] if load_dtype is not None: @@ -85,7 +88,7 @@ def str_to_dtype(p): device_org = load_original_model_to if load_original_model_to else "cpu" device_tuned = load_tuned_model_to if load_tuned_model_to else "cpu" - print(f"loading original SDXL model : {model_org}") + logger.info(f"loading original SDXL model : {model_org}") text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint( sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, device_org ) @@ -95,7 +98,7 @@ def str_to_dtype(p): text_encoder_o2 = text_encoder_o2.to(load_dtype) unet_o = unet_o.to(load_dtype) - print(f"loading original SDXL model : {model_tuned}") + logger.info(f"loading original SDXL model : {model_tuned}") text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint( sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, device_tuned ) @@ -135,7 +138,7 @@ def str_to_dtype(p): # Text Encoder might be same if not text_encoder_different and torch.max(torch.abs(diff)) > min_diff: text_encoder_different = True - print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {min_diff}") + logger.info(f"Text encoder is different. {torch.max(torch.abs(diff))} > {min_diff}") diffs[lora_name] = diff @@ -144,7 +147,7 @@ def str_to_dtype(p): del text_encoder if not text_encoder_different: - print("Text encoder is same. Extract U-Net only.") + logger.warning("Text encoder is same. Extract U-Net only.") lora_network_o.text_encoder_loras = [] diffs = {} # clear diffs @@ -166,7 +169,7 @@ def str_to_dtype(p): del unet_t # make LoRA with svd - print("calculating by svd") + logger.info("calculating by svd") lora_weights = {} with torch.no_grad(): for lora_name, mat in tqdm(list(diffs.items())): @@ -185,7 +188,7 @@ def str_to_dtype(p): if device: mat = mat.to(device) - # print(lora_name, mat.size(), mat.device, rank, in_dim, out_dim) + # logger.info(lora_name, mat.size(), mat.device, rank, in_dim, out_dim) rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim if conv2d: @@ -230,7 +233,7 @@ def str_to_dtype(p): lora_network_save.apply_to(text_encoders_o, unet_o) # create internal module references for state_dict info = lora_network_save.load_state_dict(lora_sd) - print(f"Loading extracted LoRA weights: {info}") + logger.info(f"Loading extracted LoRA weights: {info}") dir_name = os.path.dirname(save_to) if dir_name and not os.path.exists(dir_name): @@ -257,7 +260,7 @@ def str_to_dtype(p): metadata.update(sai_metadata) lora_network_save.save_weights(save_to, save_dtype, metadata) - print(f"LoRA weights are saved to: {save_to}") + logger.info(f"LoRA weights are saved to: {save_to}") def setup_parser() -> argparse.ArgumentParser: diff --git a/networks/lora.py b/networks/lora.py index 0c75cd428..eaf656ac8 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -11,7 +11,10 @@ import numpy as np import torch import re - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") @@ -46,7 +49,7 @@ def __init__( # if limit_rank: # self.lora_dim = min(lora_dim, in_dim, out_dim) # if self.lora_dim != lora_dim: - # print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}") + # logger.info(f"{lora_name} dim (rank) is changed to: {self.lora_dim}") # else: self.lora_dim = lora_dim @@ -177,7 +180,7 @@ def merge_to(self, sd, dtype, device): else: # conv2d 3x3 conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) - # print(conved.size(), weight.size(), module.stride, module.padding) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) weight = weight + self.multiplier * conved * self.scale # set weight to org_module @@ -216,7 +219,7 @@ def set_region(self, region): self.region_mask = None def default_forward(self, x): - # print("default_forward", self.lora_name, x.size()) + # logger.info(f"default_forward {self.lora_name} {x.size()}") return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale def forward(self, x): @@ -245,7 +248,7 @@ def get_mask_for_x(self, x): if mask is None: # raise ValueError(f"mask is None for resolution {area}") # emb_layers in SDXL doesn't have mask - # print(f"mask is None for resolution {area}, {x.size()}") + # logger.info(f"mask is None for resolution {area}, {x.size()}") mask_size = (1, x.size()[1]) if len(x.size()) == 2 else (1, *x.size()[1:-1], 1) return torch.ones(mask_size, dtype=x.dtype, device=x.device) / self.network.num_sub_prompts if len(x.size()) != 4: @@ -262,7 +265,7 @@ def regional_forward(self, x): # apply mask for LoRA result lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale mask = self.get_mask_for_x(lx) - # print("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size()) + # logger.info(f"regional {self.lora_name} {self.network.sub_prompt_index} {lx.size()} {mask.size()}") lx = lx * mask x = self.org_forward(x) @@ -291,7 +294,7 @@ def postp_to_q(self, x): if has_real_uncond: query[-self.network.batch_size :] = x[-self.network.batch_size :] - # print("postp_to_q", self.lora_name, x.size(), query.size(), self.network.num_sub_prompts) + # logger.info(f"postp_to_q {self.lora_name} {x.size()} {query.size()} {self.network.num_sub_prompts}") return query def sub_prompt_forward(self, x): @@ -306,7 +309,7 @@ def sub_prompt_forward(self, x): lx = x[emb_idx :: self.network.num_sub_prompts] lx = self.lora_up(self.lora_down(lx)) * self.multiplier * self.scale - # print("sub_prompt_forward", self.lora_name, x.size(), lx.size(), emb_idx) + # logger.info(f"sub_prompt_forward {self.lora_name} {x.size()} {lx.size()} {emb_idx}") x = self.org_forward(x) x[emb_idx :: self.network.num_sub_prompts] += lx @@ -314,7 +317,7 @@ def sub_prompt_forward(self, x): return x def to_out_forward(self, x): - # print("to_out_forward", self.lora_name, x.size(), self.network.is_last_network) + # logger.info(f"to_out_forward {self.lora_name} {x.size()} {self.network.is_last_network}") if self.network.is_last_network: masks = [None] * self.network.num_sub_prompts @@ -332,7 +335,7 @@ def to_out_forward(self, x): ) self.network.shared[self.lora_name] = (lx, masks) - # print("to_out_forward", lx.size(), lx1.size(), self.network.sub_prompt_index, self.network.num_sub_prompts) + # logger.info(f"to_out_forward {lx.size()} {lx1.size()} {self.network.sub_prompt_index} {self.network.num_sub_prompts}") lx[self.network.sub_prompt_index :: self.network.num_sub_prompts] += lx1 masks[self.network.sub_prompt_index] = self.get_mask_for_x(lx1) @@ -351,7 +354,7 @@ def to_out_forward(self, x): if has_real_uncond: out[-self.network.batch_size :] = x[-self.network.batch_size :] # real_uncond - # print("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts) + # logger.info(f"to_out_forward {self.lora_name} {self.network.sub_prompt_index} {self.network.num_sub_prompts}") # if num_sub_prompts > num of LoRAs, fill with zero for i in range(len(masks)): if masks[i] is None: @@ -374,7 +377,7 @@ def to_out_forward(self, x): x1 = x1 + lx1 out[self.network.batch_size + i] = x1 - # print("to_out_forward", x.size(), out.size(), has_real_uncond) + # logger.info(f"to_out_forward {x.size()} {out.size()} {has_real_uncond}") return out @@ -511,7 +514,7 @@ def parse_floats(s): len(block_dims) == num_total_blocks ), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください" else: - print(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります") + logger.warning(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります") block_dims = [network_dim] * num_total_blocks if block_alphas is not None: @@ -520,7 +523,7 @@ def parse_floats(s): len(block_alphas) == num_total_blocks ), f"block_alphas must have {num_total_blocks} elements / block_alphasは{num_total_blocks}個指定してください" else: - print( + logger.warning( f"block_alphas is not specified. all alphas are set to {network_alpha} / block_alphasが指定されていません。すべてのalphaは{network_alpha}になります" ) block_alphas = [network_alpha] * num_total_blocks @@ -540,13 +543,13 @@ def parse_floats(s): else: if conv_alpha is None: conv_alpha = 1.0 - print( + logger.warning( f"conv_block_alphas is not specified. all alphas are set to {conv_alpha} / conv_block_alphasが指定されていません。すべてのalphaは{conv_alpha}になります" ) conv_block_alphas = [conv_alpha] * num_total_blocks else: if conv_dim is not None: - print( + logger.warning( f"conv_dim/alpha for all blocks are set to {conv_dim} and {conv_alpha} / すべてのブロックのconv_dimとalphaは{conv_dim}および{conv_alpha}になります" ) conv_block_dims = [conv_dim] * num_total_blocks @@ -586,7 +589,7 @@ def get_list(name_with_suffix) -> List[float]: elif name == "zeros": return [0.0 + base_lr] * max_len else: - print( + logger.error( "Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros" % (name) ) @@ -598,14 +601,14 @@ def get_list(name_with_suffix) -> List[float]: up_lr_weight = get_list(up_lr_weight) if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len): - print("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len) - print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len) + logger.warning("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len) + logger.warning("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len) up_lr_weight = up_lr_weight[:max_len] down_lr_weight = down_lr_weight[:max_len] if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len): - print("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len) - print("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len) + logger.warning("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len) + logger.warning("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len) if down_lr_weight != None and len(down_lr_weight) < max_len: down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight)) @@ -613,24 +616,24 @@ def get_list(name_with_suffix) -> List[float]: up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight)) if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None): - print("apply block learning rate / 階層別学習率を適用します。") + logger.info("apply block learning rate / 階層別学習率を適用します。") if down_lr_weight != None: down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight] - print("down_lr_weight (shallower -> deeper, 浅い層->深い層):", down_lr_weight) + logger.info(f"down_lr_weight (shallower -> deeper, 浅い層->深い層): {down_lr_weight}") else: - print("down_lr_weight: all 1.0, すべて1.0") + logger.info("down_lr_weight: all 1.0, すべて1.0") if mid_lr_weight != None: mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0 - print("mid_lr_weight:", mid_lr_weight) + logger.info(f"mid_lr_weight: {mid_lr_weight}") else: - print("mid_lr_weight: 1.0") + logger.info("mid_lr_weight: 1.0") if up_lr_weight != None: up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight] - print("up_lr_weight (deeper -> shallower, 深い層->浅い層):", up_lr_weight) + logger.info(f"up_lr_weight (deeper -> shallower, 深い層->浅い層): {up_lr_weight}") else: - print("up_lr_weight: all 1.0, すべて1.0") + logger.info("up_lr_weight: all 1.0, すべて1.0") return down_lr_weight, mid_lr_weight, up_lr_weight @@ -711,7 +714,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh elif "lora_down" in key: dim = value.size()[0] modules_dim[lora_name] = dim - # print(lora_name, value.size(), dim) + # logger.info(lora_name, value.size(), dim) # support old LoRA without alpha for key in modules_dim.keys(): @@ -786,20 +789,20 @@ def __init__( self.module_dropout = module_dropout if modules_dim is not None: - print(f"create LoRA network from weights") + logger.info(f"create LoRA network from weights") elif block_dims is not None: - print(f"create LoRA network from block_dims") - print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") - print(f"block_dims: {block_dims}") - print(f"block_alphas: {block_alphas}") + logger.info(f"create LoRA network from block_dims") + logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") + logger.info(f"block_dims: {block_dims}") + logger.info(f"block_alphas: {block_alphas}") if conv_block_dims is not None: - print(f"conv_block_dims: {conv_block_dims}") - print(f"conv_block_alphas: {conv_block_alphas}") + logger.info(f"conv_block_dims: {conv_block_dims}") + logger.info(f"conv_block_alphas: {conv_block_alphas}") else: - print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") - print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") + logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") + logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") if self.conv_lora_dim is not None: - print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") + logger.info(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") # create module instances def create_modules( @@ -884,15 +887,15 @@ def create_modules( for i, text_encoder in enumerate(text_encoders): if len(text_encoders) > 1: index = i + 1 - print(f"create LoRA for Text Encoder {index}:") + logger.info(f"create LoRA for Text Encoder {index}:") else: index = None - print(f"create LoRA for Text Encoder:") + logger.info(f"create LoRA for Text Encoder:") text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) self.text_encoder_loras.extend(text_encoder_loras) skipped_te += skipped - print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") + logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE @@ -900,15 +903,15 @@ def create_modules( target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules) - print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") + logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") skipped = skipped_te + skipped_un if varbose and len(skipped) > 0: - print( + logger.warning( f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:" ) for name in skipped: - print(f"\t{name}") + logger.info(f"\t{name}") self.up_lr_weight: List[float] = None self.down_lr_weight: List[float] = None @@ -939,12 +942,12 @@ def load_weights(self, file): def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True): if apply_text_encoder: - print("enable LoRA for text encoder") + logger.info("enable LoRA for text encoder") else: self.text_encoder_loras = [] if apply_unet: - print("enable LoRA for U-Net") + logger.info("enable LoRA for U-Net") else: self.unet_loras = [] @@ -966,12 +969,12 @@ def merge_to(self, text_encoder, unet, weights_sd, dtype, device): apply_unet = True if apply_text_encoder: - print("enable LoRA for text encoder") + logger.info("enable LoRA for text encoder") else: self.text_encoder_loras = [] if apply_unet: - print("enable LoRA for U-Net") + logger.info("enable LoRA for U-Net") else: self.unet_loras = [] @@ -982,7 +985,7 @@ def merge_to(self, text_encoder, unet, weights_sd, dtype, device): sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key] lora.merge_to(sd_for_lora, dtype, device) - print(f"weights are merged") + logger.info(f"weights are merged") # 層別学習率用に層ごとの学習率に対する倍率を定義する 引数の順番が逆だがとりあえず気にしない def set_block_lr_weight( @@ -1128,7 +1131,7 @@ def set_current_generation(self, batch_size, num_sub_prompts, width, height, sha device = ref_weight.device def resize_add(mh, mw): - # print(mh, mw, mh * mw) + # logger.info(mh, mw, mh * mw) m = torch.nn.functional.interpolate(mask, (mh, mw), mode="bilinear") # doesn't work in bf16 m = m.to(device, dtype=dtype) mask_dic[mh * mw] = m diff --git a/networks/lora_diffusers.py b/networks/lora_diffusers.py index 47d75ac4d..31ca6feb3 100644 --- a/networks/lora_diffusers.py +++ b/networks/lora_diffusers.py @@ -10,7 +10,10 @@ from tqdm import tqdm from transformers import CLIPTextModel import torch - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def make_unet_conversion_map() -> Dict[str, str]: unet_conversion_map_layer = [] @@ -248,7 +251,7 @@ def create_network_from_weights( elif "lora_down" in key: dim = value.size()[0] modules_dim[lora_name] = dim - # print(lora_name, value.size(), dim) + # logger.info(f"{lora_name} {value.size()} {dim}") # support old LoRA without alpha for key in modules_dim.keys(): @@ -291,12 +294,12 @@ def __init__( super().__init__() self.multiplier = multiplier - print(f"create LoRA network from weights") + logger.info("create LoRA network from weights") # convert SDXL Stability AI's U-Net modules to Diffusers converted = self.convert_unet_modules(modules_dim, modules_alpha) if converted: - print(f"converted {converted} Stability AI's U-Net LoRA modules to Diffusers (SDXL)") + logger.info(f"converted {converted} Stability AI's U-Net LoRA modules to Diffusers (SDXL)") # create module instances def create_modules( @@ -331,7 +334,7 @@ def create_modules( lora_name = lora_name.replace(".", "_") if lora_name not in modules_dim: - # print(f"skipped {lora_name} (not found in modules_dim)") + # logger.info(f"skipped {lora_name} (not found in modules_dim)") skipped.append(lora_name) continue @@ -362,18 +365,18 @@ def create_modules( text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) self.text_encoder_loras.extend(text_encoder_loras) skipped_te += skipped - print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") + logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") if len(skipped_te) > 0: - print(f"skipped {len(skipped_te)} modules because of missing weight for text encoder.") + logger.warning(f"skipped {len(skipped_te)} modules because of missing weight for text encoder.") # extend U-Net target modules to include Conv2d 3x3 target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE + LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 self.unet_loras: List[LoRAModule] self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules) - print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") + logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") if len(skipped_un) > 0: - print(f"skipped {len(skipped_un)} modules because of missing weight for U-Net.") + logger.warning(f"skipped {len(skipped_un)} modules because of missing weight for U-Net.") # assertion names = set() @@ -420,11 +423,11 @@ def set_multiplier(self, multiplier): def apply_to(self, multiplier=1.0, apply_text_encoder=True, apply_unet=True): if apply_text_encoder: - print("enable LoRA for text encoder") + logger.info("enable LoRA for text encoder") for lora in self.text_encoder_loras: lora.apply_to(multiplier) if apply_unet: - print("enable LoRA for U-Net") + logger.info("enable LoRA for U-Net") for lora in self.unet_loras: lora.apply_to(multiplier) @@ -433,16 +436,16 @@ def unapply_to(self): lora.unapply_to() def merge_to(self, multiplier=1.0): - print("merge LoRA weights to original weights") + logger.info("merge LoRA weights to original weights") for lora in tqdm(self.text_encoder_loras + self.unet_loras): lora.merge_to(multiplier) - print(f"weights are merged") + logger.info(f"weights are merged") def restore_from(self, multiplier=1.0): - print("restore LoRA weights from original weights") + logger.info("restore LoRA weights from original weights") for lora in tqdm(self.text_encoder_loras + self.unet_loras): lora.restore_from(multiplier) - print(f"weights are restored") + logger.info(f"weights are restored") def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): # convert SDXL Stability AI's state dict to Diffusers' based state dict @@ -463,7 +466,7 @@ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): my_state_dict = self.state_dict() for key in state_dict.keys(): if state_dict[key].size() != my_state_dict[key].size(): - # print(f"convert {key} from {state_dict[key].size()} to {my_state_dict[key].size()}") + # logger.info(f"convert {key} from {state_dict[key].size()} to {my_state_dict[key].size()}") state_dict[key] = state_dict[key].view(my_state_dict[key].size()) return super().load_state_dict(state_dict, strict) @@ -490,7 +493,7 @@ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): image_prefix = args.model_id.replace("/", "_") + "_" # load Diffusers model - print(f"load model from {args.model_id}") + logger.info(f"load model from {args.model_id}") pipe: Union[StableDiffusionPipeline, StableDiffusionXLPipeline] if args.sdxl: # use_safetensors=True does not work with 0.18.2 @@ -503,7 +506,7 @@ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): text_encoders = [pipe.text_encoder, pipe.text_encoder_2] if args.sdxl else [pipe.text_encoder] # load LoRA weights - print(f"load LoRA weights from {args.lora_weights}") + logger.info(f"load LoRA weights from {args.lora_weights}") if os.path.splitext(args.lora_weights)[1] == ".safetensors": from safetensors.torch import load_file @@ -512,10 +515,10 @@ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): lora_sd = torch.load(args.lora_weights) # create by LoRA weights and load weights - print(f"create LoRA network") + logger.info(f"create LoRA network") lora_network: LoRANetwork = create_network_from_weights(text_encoders, pipe.unet, lora_sd, multiplier=1.0) - print(f"load LoRA network weights") + logger.info(f"load LoRA network weights") lora_network.load_state_dict(lora_sd) lora_network.to(device, dtype=pipe.unet.dtype) # required to apply_to. merge_to works without this @@ -544,34 +547,34 @@ def seed_everything(seed): random.seed(seed) # create image with original weights - print(f"create image with original weights") + logger.info(f"create image with original weights") seed_everything(args.seed) image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0] image.save(image_prefix + "original.png") # apply LoRA network to the model: slower than merge_to, but can be reverted easily - print(f"apply LoRA network to the model") + logger.info(f"apply LoRA network to the model") lora_network.apply_to(multiplier=1.0) - print(f"create image with applied LoRA") + logger.info(f"create image with applied LoRA") seed_everything(args.seed) image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0] image.save(image_prefix + "applied_lora.png") # unapply LoRA network to the model - print(f"unapply LoRA network to the model") + logger.info(f"unapply LoRA network to the model") lora_network.unapply_to() - print(f"create image with unapplied LoRA") + logger.info(f"create image with unapplied LoRA") seed_everything(args.seed) image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0] image.save(image_prefix + "unapplied_lora.png") # merge LoRA network to the model: faster than apply_to, but requires back-up of original weights (or unmerge_to) - print(f"merge LoRA network to the model") + logger.info(f"merge LoRA network to the model") lora_network.merge_to(multiplier=1.0) - print(f"create image with LoRA") + logger.info(f"create image with LoRA") seed_everything(args.seed) image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0] image.save(image_prefix + "merged_lora.png") @@ -579,31 +582,31 @@ def seed_everything(seed): # restore (unmerge) LoRA weights: numerically unstable # マージされた重みを元に戻す。計算誤差のため、元の重みと完全に一致しないことがあるかもしれない # 保存したstate_dictから元の重みを復元するのが確実 - print(f"restore (unmerge) LoRA weights") + logger.info(f"restore (unmerge) LoRA weights") lora_network.restore_from(multiplier=1.0) - print(f"create image without LoRA") + logger.info(f"create image without LoRA") seed_everything(args.seed) image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0] image.save(image_prefix + "unmerged_lora.png") # restore original weights - print(f"restore original weights") + logger.info(f"restore original weights") pipe.unet.load_state_dict(org_unet_sd) pipe.text_encoder.load_state_dict(org_text_encoder_sd) if args.sdxl: pipe.text_encoder_2.load_state_dict(org_text_encoder_2_sd) - print(f"create image with restored original weights") + logger.info(f"create image with restored original weights") seed_everything(args.seed) image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0] image.save(image_prefix + "restore_original.png") # use convenience function to merge LoRA weights - print(f"merge LoRA weights with convenience function") + logger.info(f"merge LoRA weights with convenience function") merge_lora_weights(pipe, lora_sd, multiplier=1.0) - print(f"create image with merged LoRA weights") + logger.info(f"create image with merged LoRA weights") seed_everything(args.seed) image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0] image.save(image_prefix + "convenience_merged_lora.png") diff --git a/networks/lora_fa.py b/networks/lora_fa.py index a357d7f7f..919222ce8 100644 --- a/networks/lora_fa.py +++ b/networks/lora_fa.py @@ -14,7 +14,10 @@ import numpy as np import torch import re - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") @@ -49,7 +52,7 @@ def __init__( # if limit_rank: # self.lora_dim = min(lora_dim, in_dim, out_dim) # if self.lora_dim != lora_dim: - # print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}") + # logger.info(f"{lora_name} dim (rank) is changed to: {self.lora_dim}") # else: self.lora_dim = lora_dim @@ -197,7 +200,7 @@ def merge_to(self, sd, dtype, device): else: # conv2d 3x3 conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) - # print(conved.size(), weight.size(), module.stride, module.padding) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) weight = weight + self.multiplier * conved * self.scale # set weight to org_module @@ -236,7 +239,7 @@ def set_region(self, region): self.region_mask = None def default_forward(self, x): - # print("default_forward", self.lora_name, x.size()) + # logger.info("default_forward", self.lora_name, x.size()) return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale def forward(self, x): @@ -278,7 +281,7 @@ def regional_forward(self, x): # apply mask for LoRA result lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale mask = self.get_mask_for_x(lx) - # print("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size()) + # logger.info("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size()) lx = lx * mask x = self.org_forward(x) @@ -307,7 +310,7 @@ def postp_to_q(self, x): if has_real_uncond: query[-self.network.batch_size :] = x[-self.network.batch_size :] - # print("postp_to_q", self.lora_name, x.size(), query.size(), self.network.num_sub_prompts) + # logger.info("postp_to_q", self.lora_name, x.size(), query.size(), self.network.num_sub_prompts) return query def sub_prompt_forward(self, x): @@ -322,7 +325,7 @@ def sub_prompt_forward(self, x): lx = x[emb_idx :: self.network.num_sub_prompts] lx = self.lora_up(self.lora_down(lx)) * self.multiplier * self.scale - # print("sub_prompt_forward", self.lora_name, x.size(), lx.size(), emb_idx) + # logger.info("sub_prompt_forward", self.lora_name, x.size(), lx.size(), emb_idx) x = self.org_forward(x) x[emb_idx :: self.network.num_sub_prompts] += lx @@ -330,7 +333,7 @@ def sub_prompt_forward(self, x): return x def to_out_forward(self, x): - # print("to_out_forward", self.lora_name, x.size(), self.network.is_last_network) + # logger.info("to_out_forward", self.lora_name, x.size(), self.network.is_last_network) if self.network.is_last_network: masks = [None] * self.network.num_sub_prompts @@ -348,7 +351,7 @@ def to_out_forward(self, x): ) self.network.shared[self.lora_name] = (lx, masks) - # print("to_out_forward", lx.size(), lx1.size(), self.network.sub_prompt_index, self.network.num_sub_prompts) + # logger.info("to_out_forward", lx.size(), lx1.size(), self.network.sub_prompt_index, self.network.num_sub_prompts) lx[self.network.sub_prompt_index :: self.network.num_sub_prompts] += lx1 masks[self.network.sub_prompt_index] = self.get_mask_for_x(lx1) @@ -367,7 +370,7 @@ def to_out_forward(self, x): if has_real_uncond: out[-self.network.batch_size :] = x[-self.network.batch_size :] # real_uncond - # print("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts) + # logger.info("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts) # for i in range(len(masks)): # if masks[i] is None: # masks[i] = torch.zeros_like(masks[-1]) @@ -389,7 +392,7 @@ def to_out_forward(self, x): x1 = x1 + lx1 out[self.network.batch_size + i] = x1 - # print("to_out_forward", x.size(), out.size(), has_real_uncond) + # logger.info("to_out_forward", x.size(), out.size(), has_real_uncond) return out @@ -526,7 +529,7 @@ def parse_floats(s): len(block_dims) == num_total_blocks ), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください" else: - print(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります") + logger.warning(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります") block_dims = [network_dim] * num_total_blocks if block_alphas is not None: @@ -535,7 +538,7 @@ def parse_floats(s): len(block_alphas) == num_total_blocks ), f"block_alphas must have {num_total_blocks} elements / block_alphasは{num_total_blocks}個指定してください" else: - print( + logger.warning( f"block_alphas is not specified. all alphas are set to {network_alpha} / block_alphasが指定されていません。すべてのalphaは{network_alpha}になります" ) block_alphas = [network_alpha] * num_total_blocks @@ -555,13 +558,13 @@ def parse_floats(s): else: if conv_alpha is None: conv_alpha = 1.0 - print( + logger.warning( f"conv_block_alphas is not specified. all alphas are set to {conv_alpha} / conv_block_alphasが指定されていません。すべてのalphaは{conv_alpha}になります" ) conv_block_alphas = [conv_alpha] * num_total_blocks else: if conv_dim is not None: - print( + logger.warning( f"conv_dim/alpha for all blocks are set to {conv_dim} and {conv_alpha} / すべてのブロックのconv_dimとalphaは{conv_dim}および{conv_alpha}になります" ) conv_block_dims = [conv_dim] * num_total_blocks @@ -601,7 +604,7 @@ def get_list(name_with_suffix) -> List[float]: elif name == "zeros": return [0.0 + base_lr] * max_len else: - print( + logger.error( "Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros" % (name) ) @@ -613,14 +616,14 @@ def get_list(name_with_suffix) -> List[float]: up_lr_weight = get_list(up_lr_weight) if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len): - print("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len) - print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len) + logger.warning("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len) + logger.warning("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len) up_lr_weight = up_lr_weight[:max_len] down_lr_weight = down_lr_weight[:max_len] if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len): - print("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len) - print("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len) + logger.warning("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len) + logger.warning("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len) if down_lr_weight != None and len(down_lr_weight) < max_len: down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight)) @@ -628,24 +631,24 @@ def get_list(name_with_suffix) -> List[float]: up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight)) if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None): - print("apply block learning rate / 階層別学習率を適用します。") + logger.info("apply block learning rate / 階層別学習率を適用します。") if down_lr_weight != None: down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight] - print("down_lr_weight (shallower -> deeper, 浅い層->深い層):", down_lr_weight) + logger.info(f"down_lr_weight (shallower -> deeper, 浅い層->深い層): {down_lr_weight}") else: - print("down_lr_weight: all 1.0, すべて1.0") + logger.info("down_lr_weight: all 1.0, すべて1.0") if mid_lr_weight != None: mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0 - print("mid_lr_weight:", mid_lr_weight) + logger.info(f"mid_lr_weight: {mid_lr_weight}") else: - print("mid_lr_weight: 1.0") + logger.info("mid_lr_weight: 1.0") if up_lr_weight != None: up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight] - print("up_lr_weight (deeper -> shallower, 深い層->浅い層):", up_lr_weight) + logger.info(f"up_lr_weight (deeper -> shallower, 深い層->浅い層): {up_lr_weight}") else: - print("up_lr_weight: all 1.0, すべて1.0") + logger.info("up_lr_weight: all 1.0, すべて1.0") return down_lr_weight, mid_lr_weight, up_lr_weight @@ -726,7 +729,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh elif "lora_down" in key: dim = value.size()[0] modules_dim[lora_name] = dim - # print(lora_name, value.size(), dim) + # logger.info(lora_name, value.size(), dim) # support old LoRA without alpha for key in modules_dim.keys(): @@ -801,20 +804,20 @@ def __init__( self.module_dropout = module_dropout if modules_dim is not None: - print(f"create LoRA network from weights") + logger.info(f"create LoRA network from weights") elif block_dims is not None: - print(f"create LoRA network from block_dims") - print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") - print(f"block_dims: {block_dims}") - print(f"block_alphas: {block_alphas}") + logger.info(f"create LoRA network from block_dims") + logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") + logger.info(f"block_dims: {block_dims}") + logger.info(f"block_alphas: {block_alphas}") if conv_block_dims is not None: - print(f"conv_block_dims: {conv_block_dims}") - print(f"conv_block_alphas: {conv_block_alphas}") + logger.info(f"conv_block_dims: {conv_block_dims}") + logger.info(f"conv_block_alphas: {conv_block_alphas}") else: - print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") - print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") + logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") + logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") if self.conv_lora_dim is not None: - print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") + logger.info(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") # create module instances def create_modules( @@ -899,15 +902,15 @@ def create_modules( for i, text_encoder in enumerate(text_encoders): if len(text_encoders) > 1: index = i + 1 - print(f"create LoRA for Text Encoder {index}:") + logger.info(f"create LoRA for Text Encoder {index}:") else: index = None - print(f"create LoRA for Text Encoder:") + logger.info(f"create LoRA for Text Encoder:") text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) self.text_encoder_loras.extend(text_encoder_loras) skipped_te += skipped - print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") + logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE @@ -915,15 +918,15 @@ def create_modules( target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules) - print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") + logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") skipped = skipped_te + skipped_un if varbose and len(skipped) > 0: - print( + logger.warning( f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:" ) for name in skipped: - print(f"\t{name}") + logger.info(f"\t{name}") self.up_lr_weight: List[float] = None self.down_lr_weight: List[float] = None @@ -954,12 +957,12 @@ def load_weights(self, file): def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True): if apply_text_encoder: - print("enable LoRA for text encoder") + logger.info("enable LoRA for text encoder") else: self.text_encoder_loras = [] if apply_unet: - print("enable LoRA for U-Net") + logger.info("enable LoRA for U-Net") else: self.unet_loras = [] @@ -981,12 +984,12 @@ def merge_to(self, text_encoder, unet, weights_sd, dtype, device): apply_unet = True if apply_text_encoder: - print("enable LoRA for text encoder") + logger.info("enable LoRA for text encoder") else: self.text_encoder_loras = [] if apply_unet: - print("enable LoRA for U-Net") + logger.info("enable LoRA for U-Net") else: self.unet_loras = [] @@ -997,7 +1000,7 @@ def merge_to(self, text_encoder, unet, weights_sd, dtype, device): sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key] lora.merge_to(sd_for_lora, dtype, device) - print(f"weights are merged") + logger.info(f"weights are merged") # 層別学習率用に層ごとの学習率に対する倍率を定義する 引数の順番が逆だがとりあえず気にしない def set_block_lr_weight( @@ -1144,7 +1147,7 @@ def set_current_generation(self, batch_size, num_sub_prompts, width, height, sha device = ref_weight.device def resize_add(mh, mw): - # print(mh, mw, mh * mw) + # logger.info(mh, mw, mh * mw) m = torch.nn.functional.interpolate(mask, (mh, mw), mode="bilinear") # doesn't work in bf16 m = m.to(device, dtype=dtype) mask_dic[mh * mw] = m diff --git a/networks/lora_interrogator.py b/networks/lora_interrogator.py index 0dc066fd1..d00aaaf08 100644 --- a/networks/lora_interrogator.py +++ b/networks/lora_interrogator.py @@ -9,6 +9,10 @@ import library.model_util as model_util import lora +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) TOKENIZER_PATH = "openai/clip-vit-large-patch14" V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う @@ -20,12 +24,12 @@ def interrogate(args): weights_dtype = torch.float16 # いろいろ準備する - print(f"loading SD model: {args.sd_model}") + logger.info(f"loading SD model: {args.sd_model}") args.pretrained_model_name_or_path = args.sd_model args.vae = None text_encoder, vae, unet, _ = train_util._load_target_model(args,weights_dtype, DEVICE) - print(f"loading LoRA: {args.model}") + logger.info(f"loading LoRA: {args.model}") network, weights_sd = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet) # text encoder向けの重みがあるかチェックする:本当はlora側でやるのがいい @@ -35,11 +39,11 @@ def interrogate(args): has_te_weight = True break if not has_te_weight: - print("This LoRA does not have modules for Text Encoder, cannot interrogate / このLoRAはText Encoder向けのモジュールがないため調査できません") + logger.error("This LoRA does not have modules for Text Encoder, cannot interrogate / このLoRAはText Encoder向けのモジュールがないため調査できません") return del vae - print("loading tokenizer") + logger.info("loading tokenizer") if args.v2: tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer") else: @@ -53,7 +57,7 @@ def interrogate(args): # トークンをひとつひとつ当たっていく token_id_start = 0 token_id_end = max(tokenizer.all_special_ids) - print(f"interrogate tokens are: {token_id_start} to {token_id_end}") + logger.info(f"interrogate tokens are: {token_id_start} to {token_id_end}") def get_all_embeddings(text_encoder): embs = [] @@ -79,24 +83,24 @@ def get_all_embeddings(text_encoder): embs.extend(encoder_hidden_states) return torch.stack(embs) - print("get original text encoder embeddings.") + logger.info("get original text encoder embeddings.") orig_embs = get_all_embeddings(text_encoder) network.apply_to(text_encoder, unet, True, len(network.unet_loras) > 0) info = network.load_state_dict(weights_sd, strict=False) - print(f"Loading LoRA weights: {info}") + logger.info(f"Loading LoRA weights: {info}") network.to(DEVICE, dtype=weights_dtype) network.eval() del unet - print("You can ignore warning messages start with '_IncompatibleKeys' (LoRA model does not have alpha because trained by older script) / '_IncompatibleKeys'の警告は無視して構いません(以前のスクリプトで学習されたLoRAモデルのためalphaの定義がありません)") - print("get text encoder embeddings with lora.") + logger.info("You can ignore warning messages start with '_IncompatibleKeys' (LoRA model does not have alpha because trained by older script) / '_IncompatibleKeys'の警告は無視して構いません(以前のスクリプトで学習されたLoRAモデルのためalphaの定義がありません)") + logger.info("get text encoder embeddings with lora.") lora_embs = get_all_embeddings(text_encoder) # 比べる:とりあえず単純に差分の絶対値で - print("comparing...") + logger.info("comparing...") diffs = {} for i, (orig_emb, lora_emb) in enumerate(zip(orig_embs, tqdm(lora_embs))): diff = torch.mean(torch.abs(orig_emb - lora_emb)) diff --git a/networks/merge_lora.py b/networks/merge_lora.py index 71492621e..fea8a3f32 100644 --- a/networks/merge_lora.py +++ b/networks/merge_lora.py @@ -7,7 +7,10 @@ from library import sai_model_spec, train_util import library.model_util as model_util import lora - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def load_state_dict(file_name, dtype): if os.path.splitext(file_name)[1] == ".safetensors": @@ -61,10 +64,10 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): name_to_module[lora_name] = child_module for model, ratio in zip(models, ratios): - print(f"loading: {model}") + logger.info(f"loading: {model}") lora_sd, _ = load_state_dict(model, merge_dtype) - print(f"merging...") + logger.info(f"merging...") for key in lora_sd.keys(): if "lora_down" in key: up_key = key.replace("lora_down", "lora_up") @@ -73,10 +76,10 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): # find original module for this lora module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight" if module_name not in name_to_module: - print(f"no module found for LoRA weight: {key}") + logger.info(f"no module found for LoRA weight: {key}") continue module = name_to_module[module_name] - # print(f"apply {key} to {module}") + # logger.info(f"apply {key} to {module}") down_weight = lora_sd[key] up_weight = lora_sd[up_key] @@ -104,7 +107,7 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): else: # conv2d 3x3 conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) - # print(conved.size(), weight.size(), module.stride, module.padding) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) weight = weight + ratio * conved * scale module.weight = torch.nn.Parameter(weight) @@ -118,7 +121,7 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): v2 = None base_model = None for model, ratio in zip(models, ratios): - print(f"loading: {model}") + logger.info(f"loading: {model}") lora_sd, lora_metadata = load_state_dict(model, merge_dtype) if lora_metadata is not None: @@ -151,10 +154,10 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): if lora_module_name not in base_alphas: base_alphas[lora_module_name] = alpha - print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}") + logger.info(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}") # merge - print(f"merging...") + logger.info(f"merging...") for key in lora_sd.keys(): if "alpha" in key: continue @@ -196,8 +199,8 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): merged_sd[key_down] = merged_sd[key_down][perm] merged_sd[key_up] = merged_sd[key_up][:,perm] - print("merged model") - print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") + logger.info("merged model") + logger.info(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") # check all dims are same dims_list = list(set(base_dims.values())) @@ -239,7 +242,7 @@ def str_to_dtype(p): save_dtype = merge_dtype if args.sd_model is not None: - print(f"loading SD model: {args.sd_model}") + logger.info(f"loading SD model: {args.sd_model}") text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model) @@ -264,18 +267,18 @@ def str_to_dtype(p): ) if args.v2: # TODO read sai modelspec - print( + logger.warning( "Cannot determine if model is for v-prediction, so save metadata as v-prediction / modelがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します" ) - print(f"saving SD model to: {args.save_to}") + logger.info(f"saving SD model to: {args.save_to}") model_util.save_stable_diffusion_checkpoint( args.v2, args.save_to, text_encoder, unet, args.sd_model, 0, 0, sai_metadata, save_dtype, vae ) else: state_dict, metadata, v2 = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle) - print(f"calculating hashes and creating metadata...") + logger.info(f"calculating hashes and creating metadata...") model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) metadata["sshs_model_hash"] = model_hash @@ -289,12 +292,12 @@ def str_to_dtype(p): ) if v2: # TODO read sai modelspec - print( + logger.warning( "Cannot determine if LoRA is for v-prediction, so save metadata as v-prediction / LoRAがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します" ) metadata.update(sai_metadata) - print(f"saving model to: {args.save_to}") + logger.info(f"saving model to: {args.save_to}") save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata) diff --git a/networks/merge_lora_old.py b/networks/merge_lora_old.py index ffd6b2b40..334d127b7 100644 --- a/networks/merge_lora_old.py +++ b/networks/merge_lora_old.py @@ -6,7 +6,10 @@ from safetensors.torch import load_file, save_file import library.model_util as model_util import lora - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def load_state_dict(file_name, dtype): if os.path.splitext(file_name)[1] == '.safetensors': @@ -54,10 +57,10 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): name_to_module[lora_name] = child_module for model, ratio in zip(models, ratios): - print(f"loading: {model}") + logger.info(f"loading: {model}") lora_sd = load_state_dict(model, merge_dtype) - print(f"merging...") + logger.info(f"merging...") for key in lora_sd.keys(): if "lora_down" in key: up_key = key.replace("lora_down", "lora_up") @@ -66,10 +69,10 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): # find original module for this lora module_name = '.'.join(key.split('.')[:-2]) # remove trailing ".lora_down.weight" if module_name not in name_to_module: - print(f"no module found for LoRA weight: {key}") + logger.info(f"no module found for LoRA weight: {key}") continue module = name_to_module[module_name] - # print(f"apply {key} to {module}") + # logger.info(f"apply {key} to {module}") down_weight = lora_sd[key] up_weight = lora_sd[up_key] @@ -96,10 +99,10 @@ def merge_lora_models(models, ratios, merge_dtype): alpha = None dim = None for model, ratio in zip(models, ratios): - print(f"loading: {model}") + logger.info(f"loading: {model}") lora_sd = load_state_dict(model, merge_dtype) - print(f"merging...") + logger.info(f"merging...") for key in lora_sd.keys(): if 'alpha' in key: if key in merged_sd: @@ -117,7 +120,7 @@ def merge_lora_models(models, ratios, merge_dtype): dim = lora_sd[key].size()[0] merged_sd[key] = lora_sd[key] * ratio - print(f"dim (rank): {dim}, alpha: {alpha}") + logger.info(f"dim (rank): {dim}, alpha: {alpha}") if alpha is None: alpha = dim @@ -142,19 +145,21 @@ def str_to_dtype(p): save_dtype = merge_dtype if args.sd_model is not None: - print(f"loading SD model: {args.sd_model}") + logger.info(f"loading SD model: {args.sd_model}") text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model) merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype) - print(f"\nsaving SD model to: {args.save_to}") + logger.info("") + logger.info(f"saving SD model to: {args.save_to}") model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet, args.sd_model, 0, 0, save_dtype, vae) else: state_dict, _, _ = merge_lora_models(args.models, args.ratios, merge_dtype) - print(f"\nsaving model to: {args.save_to}") + logger.info(f"") + logger.info(f"saving model to: {args.save_to}") save_to_file(args.save_to, state_dict, state_dict, save_dtype) diff --git a/networks/oft.py b/networks/oft.py index 1d088f877..461a98698 100644 --- a/networks/oft.py +++ b/networks/oft.py @@ -8,7 +8,10 @@ import numpy as np import torch import re - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") @@ -237,7 +240,7 @@ def __init__( self.dim = dim self.alpha = alpha - print( + logger.info( f"create OFT network. num blocks: {self.dim}, constraint: {self.alpha}, multiplier: {self.multiplier}, enable_conv: {enable_conv}" ) @@ -258,7 +261,7 @@ def create_modules( if is_linear or is_conv2d_1x1 or (is_conv2d and enable_conv): oft_name = prefix + "." + name + "." + child_name oft_name = oft_name.replace(".", "_") - # print(oft_name) + # logger.info(oft_name) oft = module_class( oft_name, @@ -279,7 +282,7 @@ def create_modules( target_modules += OFTNetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 self.unet_ofts: List[OFTModule] = create_modules(unet, target_modules) - print(f"create OFT for U-Net: {len(self.unet_ofts)} modules.") + logger.info(f"create OFT for U-Net: {len(self.unet_ofts)} modules.") # assertion names = set() @@ -316,7 +319,7 @@ def is_mergeable(self): # TODO refactor to common function with apply_to def merge_to(self, text_encoder, unet, weights_sd, dtype, device): - print("enable OFT for U-Net") + logger.info("enable OFT for U-Net") for oft in self.unet_ofts: sd_for_lora = {} @@ -326,7 +329,7 @@ def merge_to(self, text_encoder, unet, weights_sd, dtype, device): oft.load_state_dict(sd_for_lora, False) oft.merge_to() - print(f"weights are merged") + logger.info(f"weights are merged") # 二つのText Encoderに別々の学習率を設定できるようにするといいかも def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): @@ -338,11 +341,11 @@ def enumerate_params(ofts): for oft in ofts: params.extend(oft.parameters()) - # print num of params + # logger.info num of params num_params = 0 for p in params: num_params += p.numel() - print(f"OFT params: {num_params}") + logger.info(f"OFT params: {num_params}") return params param_data = {"params": enumerate_params(self.unet_ofts)} diff --git a/networks/resize_lora.py b/networks/resize_lora.py index 03fc545e7..c5932a893 100644 --- a/networks/resize_lora.py +++ b/networks/resize_lora.py @@ -8,6 +8,10 @@ from tqdm import tqdm from library import train_util, model_util import numpy as np +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) MIN_SV = 1e-6 @@ -206,7 +210,7 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn scale = network_alpha/network_dim if dynamic_method: - print(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}") + logger.info(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}") lora_down_weight = None lora_up_weight = None @@ -275,10 +279,10 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn del param_dict if verbose: - print(verbose_str) + logger.info(verbose_str) - print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}") - print("resizing complete") + logger.info(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}") + logger.info("resizing complete") return o_lora_sd, network_dim, new_alpha @@ -304,10 +308,10 @@ def str_to_dtype(p): if save_dtype is None: save_dtype = merge_dtype - print("loading Model...") + logger.info("loading Model...") lora_sd, metadata = load_state_dict(args.model, merge_dtype) - print("Resizing Lora...") + logger.info("Resizing Lora...") state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose) # update metadata @@ -329,7 +333,7 @@ def str_to_dtype(p): metadata["sshs_model_hash"] = model_hash metadata["sshs_legacy_hash"] = legacy_hash - print(f"saving model to: {args.save_to}") + logger.info(f"saving model to: {args.save_to}") save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata) diff --git a/networks/sdxl_merge_lora.py b/networks/sdxl_merge_lora.py index c513eb59f..3383a80de 100644 --- a/networks/sdxl_merge_lora.py +++ b/networks/sdxl_merge_lora.py @@ -8,7 +8,10 @@ from library import sai_model_spec, sdxl_model_util, train_util import library.model_util as model_util import lora - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def load_state_dict(file_name, dtype): if os.path.splitext(file_name)[1] == ".safetensors": @@ -66,10 +69,10 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ name_to_module[lora_name] = child_module for model, ratio in zip(models, ratios): - print(f"loading: {model}") + logger.info(f"loading: {model}") lora_sd, _ = load_state_dict(model, merge_dtype) - print(f"merging...") + logger.info(f"merging...") for key in tqdm(lora_sd.keys()): if "lora_down" in key: up_key = key.replace("lora_down", "lora_up") @@ -78,10 +81,10 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ # find original module for this lora module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight" if module_name not in name_to_module: - print(f"no module found for LoRA weight: {key}") + logger.info(f"no module found for LoRA weight: {key}") continue module = name_to_module[module_name] - # print(f"apply {key} to {module}") + # logger.info(f"apply {key} to {module}") down_weight = lora_sd[key] up_weight = lora_sd[up_key] @@ -92,7 +95,7 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ # W <- W + U * D weight = module.weight - # print(module_name, down_weight.size(), up_weight.size()) + # logger.info(module_name, down_weight.size(), up_weight.size()) if len(weight.size()) == 2: # linear weight = weight + ratio * (up_weight @ down_weight) * scale @@ -107,7 +110,7 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ else: # conv2d 3x3 conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) - # print(conved.size(), weight.size(), module.stride, module.padding) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) weight = weight + ratio * conved * scale module.weight = torch.nn.Parameter(weight) @@ -121,7 +124,7 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): v2 = None base_model = None for model, ratio in zip(models, ratios): - print(f"loading: {model}") + logger.info(f"loading: {model}") lora_sd, lora_metadata = load_state_dict(model, merge_dtype) if lora_metadata is not None: @@ -154,10 +157,10 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): if lora_module_name not in base_alphas: base_alphas[lora_module_name] = alpha - print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}") + logger.info(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}") # merge - print(f"merging...") + logger.info(f"merging...") for key in tqdm(lora_sd.keys()): if "alpha" in key: continue @@ -200,8 +203,8 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): merged_sd[key_down] = merged_sd[key_down][perm] merged_sd[key_up] = merged_sd[key_up][:,perm] - print("merged model") - print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") + logger.info("merged model") + logger.info(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") # check all dims are same dims_list = list(set(base_dims.values())) @@ -243,7 +246,7 @@ def str_to_dtype(p): save_dtype = merge_dtype if args.sd_model is not None: - print(f"loading SD model: {args.sd_model}") + logger.info(f"loading SD model: {args.sd_model}") ( text_model1, @@ -265,14 +268,14 @@ def str_to_dtype(p): None, False, False, True, False, False, time.time(), title=title, merged_from=merged_from ) - print(f"saving SD model to: {args.save_to}") + logger.info(f"saving SD model to: {args.save_to}") sdxl_model_util.save_stable_diffusion_checkpoint( args.save_to, text_model1, text_model2, unet, 0, 0, ckpt_info, vae, logit_scale, sai_metadata, save_dtype ) else: state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle) - print(f"calculating hashes and creating metadata...") + logger.info(f"calculating hashes and creating metadata...") model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) metadata["sshs_model_hash"] = model_hash @@ -286,7 +289,7 @@ def str_to_dtype(p): ) metadata.update(sai_metadata) - print(f"saving model to: {args.save_to}") + logger.info(f"saving model to: {args.save_to}") save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata) diff --git a/networks/svd_merge_lora.py b/networks/svd_merge_lora.py index ee2dae300..cb00a6000 100644 --- a/networks/svd_merge_lora.py +++ b/networks/svd_merge_lora.py @@ -5,7 +5,12 @@ from safetensors.torch import load_file, save_file from tqdm import tqdm from library import sai_model_spec, train_util - +import library.model_util as model_util +import lora +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) CLAMP_QUANTILE = 0.99 @@ -38,12 +43,12 @@ def save_to_file(file_name, state_dict, dtype, metadata): def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype): - print(f"new rank: {new_rank}, new conv rank: {new_conv_rank}") + logger.info(f"new rank: {new_rank}, new conv rank: {new_conv_rank}") merged_sd = {} v2 = None base_model = None for model, ratio in zip(models, ratios): - print(f"loading: {model}") + logger.info(f"loading: {model}") lora_sd, lora_metadata = load_state_dict(model, merge_dtype) if lora_metadata is not None: @@ -53,7 +58,7 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None) # merge - print(f"merging...") + logger.info(f"merging...") for key in tqdm(list(lora_sd.keys())): if "lora_down" not in key: continue @@ -70,7 +75,7 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty out_dim = up_weight.size()[0] conv2d = len(down_weight.size()) == 4 kernel_size = None if not conv2d else down_weight.size()[2:4] - # print(lora_module_name, network_dim, alpha, in_dim, out_dim, kernel_size) + # logger.info(lora_module_name, network_dim, alpha, in_dim, out_dim, kernel_size) # make original weight if not exist if lora_module_name not in merged_sd: @@ -107,7 +112,7 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty merged_sd[lora_module_name] = weight # extract from merged weights - print("extract new lora...") + logger.info("extract new lora...") merged_lora_sd = {} with torch.no_grad(): for lora_module_name, mat in tqdm(list(merged_sd.items())): @@ -185,7 +190,7 @@ def str_to_dtype(p): args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype ) - print(f"calculating hashes and creating metadata...") + logger.info(f"calculating hashes and creating metadata...") model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) metadata["sshs_model_hash"] = model_hash @@ -200,12 +205,12 @@ def str_to_dtype(p): ) if v2: # TODO read sai modelspec - print( + logger.warning( "Cannot determine if LoRA is for v-prediction, so save metadata as v-prediction / LoRAがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します" ) metadata.update(sai_metadata) - print(f"saving model to: {args.save_to}") + logger.info(f"saving model to: {args.save_to}") save_to_file(args.save_to, state_dict, save_dtype, metadata) diff --git a/requirements.txt b/requirements.txt index c09e62b7f..a399d8bd9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -29,5 +29,7 @@ huggingface-hub==0.20.1 # protobuf==3.20.3 # open clip for SDXL open-clip-torch==2.20.0 +# For logging +rich==13.7.0 # for kohya_ss library -e . diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index 0db9e340e..ae004463e 100755 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -55,6 +55,10 @@ from library.sdxl_original_unet import InferSdxlUNet2DConditionModel from library.original_unet import FlashAttentionFunction from networks.control_net_lllite import ControlNetLLLite +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) # scheduler: SCHEDULER_LINEAR_START = 0.00085 @@ -76,12 +80,12 @@ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa): if mem_eff_attn: - print("Enable memory efficient attention for U-Net") + logger.info("Enable memory efficient attention for U-Net") # これはDiffusersのU-Netではなく自前のU-Netなので置き換えなくても良い unet.set_use_memory_efficient_attention(False, True) elif xformers: - print("Enable xformers for U-Net") + logger.info("Enable xformers for U-Net") try: import xformers.ops except ImportError: @@ -89,7 +93,7 @@ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditio unet.set_use_memory_efficient_attention(True, False) elif sdpa: - print("Enable SDPA for U-Net") + logger.info("Enable SDPA for U-Net") unet.set_use_memory_efficient_attention(False, False) unet.set_use_sdpa(True) @@ -106,7 +110,7 @@ def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xform def replace_vae_attn_to_memory_efficient(): - print("VAE Attention.forward has been replaced to FlashAttention (not xformers)") + logger.info("VAE Attention.forward has been replaced to FlashAttention (not xformers)") flash_func = FlashAttentionFunction def forward_flash_attn(self, hidden_states, **kwargs): @@ -162,7 +166,7 @@ def forward_flash_attn_0_14(self, hidden_states, **kwargs): def replace_vae_attn_to_xformers(): - print("VAE: Attention.forward has been replaced to xformers") + logger.info("VAE: Attention.forward has been replaced to xformers") import xformers.ops def forward_xformers(self, hidden_states, **kwargs): @@ -218,7 +222,7 @@ def forward_xformers_0_14(self, hidden_states, **kwargs): def replace_vae_attn_to_sdpa(): - print("VAE: Attention.forward has been replaced to sdpa") + logger.info("VAE: Attention.forward has been replaced to sdpa") def forward_sdpa(self, hidden_states, **kwargs): residual = hidden_states @@ -352,7 +356,7 @@ def get_token_replacer(self, tokenizer): token_replacements = self.token_replacements_list[tokenizer_index] def replace_tokens(tokens): - # print("replace_tokens", tokens, "=>", token_replacements) + # logger.info("replace_tokens", tokens, "=>", token_replacements) if isinstance(tokens, torch.Tensor): tokens = tokens.tolist() @@ -444,7 +448,7 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 if not do_classifier_free_guidance and negative_scale is not None: - print(f"negative_scale is ignored if guidance scalle <= 1.0") + logger.info(f"negative_scale is ignored if guidance scalle <= 1.0") negative_scale = None # get unconditional embeddings for classifier free guidance @@ -548,7 +552,7 @@ def __call__( text_pool = text_pool[num_sub_prompts - 1 :: num_sub_prompts] # last subprompt if init_image is not None and self.clip_vision_model is not None: - print(f"encode by clip_vision_model and apply clip_vision_strength={self.clip_vision_strength}") + logger.info(f"encode by clip_vision_model and apply clip_vision_strength={self.clip_vision_strength}") vision_input = self.clip_vision_processor(init_image, return_tensors="pt", device=self.device) pixel_values = vision_input["pixel_values"].to(self.device, dtype=text_embeddings.dtype) @@ -715,7 +719,7 @@ def __call__( if not enabled or ratio >= 1.0: continue if ratio < i / len(timesteps): - print(f"ControlNet {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})") + logger.info(f"ControlNet {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})") control_net.set_cond_image(None) each_control_net_enabled[j] = False @@ -935,7 +939,7 @@ def get_prompts_with_weights(tokenizer: CLIPTokenizer, token_replacer, prompt: L if word.strip() == "BREAK": # pad until next multiple of tokenizer's max token length pad_len = tokenizer.model_max_length - (len(text_token) % tokenizer.model_max_length) - print(f"BREAK pad_len: {pad_len}") + logger.info(f"BREAK pad_len: {pad_len}") for i in range(pad_len): # v2のときEOSをつけるべきかどうかわからないぜ # if i == 0: @@ -965,7 +969,7 @@ def get_prompts_with_weights(tokenizer: CLIPTokenizer, token_replacer, prompt: L tokens.append(text_token) weights.append(text_weight) if truncated: - print("warning: Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") + logger.warning("warning: Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") return tokens, weights @@ -1238,7 +1242,7 @@ def handle_dynamic_prompt_variants(prompt, repeat_count): elif len(count_range) == 2: count_range = [int(count_range[0]), int(count_range[1])] else: - print(f"invalid count range: {count_range}") + logger.warning(f"invalid count range: {count_range}") count_range = [1, 1] if count_range[0] > count_range[1]: count_range = [count_range[1], count_range[0]] @@ -1308,7 +1312,7 @@ def replacer(): # def load_clip_l14_336(dtype): -# print(f"loading CLIP: {CLIP_ID_L14_336}") +# logger.info(f"loading CLIP: {CLIP_ID_L14_336}") # text_encoder = CLIPTextModel.from_pretrained(CLIP_ID_L14_336, torch_dtype=dtype) # return text_encoder @@ -1378,7 +1382,7 @@ def main(args): replace_vae_modules(vae, mem_eff, args.xformers, args.sdpa) # tokenizerを読み込む - print("loading tokenizer") + logger.info("loading tokenizer") tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) # schedulerを用意する @@ -1452,7 +1456,7 @@ def reset_sampler_noises(self, noises): self.sampler_noises = noises def randn(self, shape, device=None, dtype=None, layout=None, generator=None): - # print("replacing", shape, len(self.sampler_noises), self.sampler_noise_index) + # logger.info("replacing", shape, len(self.sampler_noises), self.sampler_noise_index) if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises): noise = self.sampler_noises[self.sampler_noise_index] if shape != noise.shape: @@ -1461,7 +1465,7 @@ def randn(self, shape, device=None, dtype=None, layout=None, generator=None): noise = None if noise == None: - print(f"unexpected noise request: {self.sampler_noise_index}, {shape}") + logger.warning(f"unexpected noise request: {self.sampler_noise_index}, {shape}") noise = torch.randn(shape, dtype=dtype, device=device, generator=generator) self.sampler_noise_index += 1 @@ -1493,7 +1497,7 @@ def __getattr__(self, item): # ↓以下は結局PipeでFalseに設定されるので意味がなかった # # clip_sample=Trueにする # if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: - # print("set clip_sample to True") + # logger.info("set clip_sample to True") # scheduler.config.clip_sample = True # deviceを決定する @@ -1522,7 +1526,7 @@ def __getattr__(self, item): vae_dtype = dtype if args.no_half_vae: - print("set vae_dtype to float32") + logger.info("set vae_dtype to float32") vae_dtype = torch.float32 vae.to(vae_dtype).to(device) vae.eval() @@ -1547,10 +1551,10 @@ def __getattr__(self, item): network_merge = args.network_merge_n_models else: network_merge = 0 - print(f"network_merge: {network_merge}") + logger.info(f"network_merge: {network_merge}") for i, network_module in enumerate(args.network_module): - print("import network module:", network_module) + logger.info(f"import network module: {network_module}") imported_module = importlib.import_module(network_module) network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i] @@ -1568,7 +1572,7 @@ def __getattr__(self, item): raise ValueError("No weight. Weight is required.") network_weight = args.network_weights[i] - print("load network weights from:", network_weight) + logger.info(f"load network weights from: {network_weight}") if model_util.is_safetensors(network_weight) and args.network_show_meta: from safetensors.torch import safe_open @@ -1576,7 +1580,7 @@ def __getattr__(self, item): with safe_open(network_weight, framework="pt") as f: metadata = f.metadata() if metadata is not None: - print(f"metadata for: {network_weight}: {metadata}") + logger.info(f"metadata for: {network_weight}: {metadata}") network, weights_sd = imported_module.create_network_from_weights( network_mul, network_weight, vae, [text_encoder1, text_encoder2], unet, for_inference=True, **net_kwargs @@ -1586,20 +1590,20 @@ def __getattr__(self, item): mergeable = network.is_mergeable() if network_merge and not mergeable: - print("network is not mergiable. ignore merge option.") + logger.warning("network is not mergiable. ignore merge option.") if not mergeable or i >= network_merge: # not merging network.apply_to([text_encoder1, text_encoder2], unet) info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい - print(f"weights are loaded: {info}") + logger.info(f"weights are loaded: {info}") if args.opt_channels_last: network.to(memory_format=torch.channels_last) network.to(dtype).to(device) if network_pre_calc: - print("backup original weights") + logger.info("backup original weights") network.backup_weights() networks.append(network) @@ -1613,7 +1617,7 @@ def __getattr__(self, item): # upscalerの指定があれば取得する upscaler = None if args.highres_fix_upscaler: - print("import upscaler module:", args.highres_fix_upscaler) + logger.info(f"import upscaler module: {args.highres_fix_upscaler}") imported_module = importlib.import_module(args.highres_fix_upscaler) us_kwargs = {} @@ -1622,7 +1626,7 @@ def __getattr__(self, item): key, value = net_arg.split("=") us_kwargs[key] = value - print("create upscaler") + logger.info("create upscaler") upscaler = imported_module.create_upscaler(**us_kwargs) upscaler.to(dtype).to(device) @@ -1639,7 +1643,7 @@ def __getattr__(self, item): # control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio)) if args.control_net_lllite_models: for i, model_file in enumerate(args.control_net_lllite_models): - print(f"loading ControlNet-LLLite: {model_file}") + logger.info(f"loading ControlNet-LLLite: {model_file}") from safetensors.torch import load_file @@ -1670,7 +1674,7 @@ def __getattr__(self, item): control_nets.append((control_net, ratio)) if args.opt_channels_last: - print(f"set optimizing: channels last") + logger.info(f"set optimizing: channels last") text_encoder1.to(memory_format=torch.channels_last) text_encoder2.to(memory_format=torch.channels_last) vae.to(memory_format=torch.channels_last) @@ -1694,7 +1698,7 @@ def __getattr__(self, item): args.clip_skip, ) pipe.set_control_nets(control_nets) - print("pipeline is ready.") + logger.info("pipeline is ready.") if args.diffusers_xformers: pipe.enable_xformers_memory_efficient_attention() @@ -1736,7 +1740,7 @@ def __getattr__(self, item): token_ids1 = tokenizer1.convert_tokens_to_ids(token_strings) token_ids2 = tokenizer2.convert_tokens_to_ids(token_strings) - print(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids1} and {token_ids2}") + logger.info(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids1} and {token_ids2}") assert ( min(token_ids1) == token_ids1[0] and token_ids1[-1] == token_ids1[0] + len(token_ids1) - 1 ), f"token ids1 is not ordered" @@ -1766,7 +1770,7 @@ def __getattr__(self, item): # promptを取得する if args.from_file is not None: - print(f"reading prompts from {args.from_file}") + logger.info(f"reading prompts from {args.from_file}") with open(args.from_file, "r", encoding="utf-8") as f: prompt_list = f.read().splitlines() prompt_list = [d for d in prompt_list if len(d.strip()) > 0] @@ -1795,7 +1799,7 @@ def load_images(path): for p in paths: image = Image.open(p) if image.mode != "RGB": - print(f"convert image to RGB from {image.mode}: {p}") + logger.info(f"convert image to RGB from {image.mode}: {p}") image = image.convert("RGB") images.append(image) @@ -1811,14 +1815,14 @@ def resize_images(imgs, size): return resized if args.image_path is not None: - print(f"load image for img2img: {args.image_path}") + logger.info(f"load image for img2img: {args.image_path}") init_images = load_images(args.image_path) assert len(init_images) > 0, f"No image / 画像がありません: {args.image_path}" - print(f"loaded {len(init_images)} images for img2img") + logger.info(f"loaded {len(init_images)} images for img2img") # CLIP Vision if args.clip_vision_strength is not None: - print(f"load CLIP Vision model: {CLIP_VISION_MODEL}") + logger.info(f"load CLIP Vision model: {CLIP_VISION_MODEL}") vision_model = CLIPVisionModelWithProjection.from_pretrained(CLIP_VISION_MODEL, projection_dim=1280) vision_model.to(device, dtype) processor = CLIPImageProcessor.from_pretrained(CLIP_VISION_MODEL) @@ -1826,22 +1830,22 @@ def resize_images(imgs, size): pipe.clip_vision_model = vision_model pipe.clip_vision_processor = processor pipe.clip_vision_strength = args.clip_vision_strength - print(f"CLIP Vision model loaded.") + logger.info(f"CLIP Vision model loaded.") else: init_images = None if args.mask_path is not None: - print(f"load mask for inpainting: {args.mask_path}") + logger.info(f"load mask for inpainting: {args.mask_path}") mask_images = load_images(args.mask_path) assert len(mask_images) > 0, f"No mask image / マスク画像がありません: {args.image_path}" - print(f"loaded {len(mask_images)} mask images for inpainting") + logger.info(f"loaded {len(mask_images)} mask images for inpainting") else: mask_images = None # promptがないとき、画像のPngInfoから取得する if init_images is not None and len(prompt_list) == 0 and not args.interactive: - print("get prompts from images' metadata") + logger.info("get prompts from images' metadata") for img in init_images: if "prompt" in img.text: prompt = img.text["prompt"] @@ -1870,17 +1874,17 @@ def resize_images(imgs, size): h = int(h * args.highres_fix_scale + 0.5) if init_images is not None: - print(f"resize img2img source images to {w}*{h}") + logger.info(f"resize img2img source images to {w}*{h}") init_images = resize_images(init_images, (w, h)) if mask_images is not None: - print(f"resize img2img mask images to {w}*{h}") + logger.info(f"resize img2img mask images to {w}*{h}") mask_images = resize_images(mask_images, (w, h)) regional_network = False if networks and mask_images: # mask を領域情報として流用する、現在は一回のコマンド呼び出しで1枚だけ対応 regional_network = True - print("use mask as region") + logger.info("use mask as region") size = None for i, network in enumerate(networks): @@ -1905,14 +1909,14 @@ def resize_images(imgs, size): prev_image = None # for VGG16 guided if args.guide_image_path is not None: - print(f"load image for ControlNet guidance: {args.guide_image_path}") + logger.info(f"load image for ControlNet guidance: {args.guide_image_path}") guide_images = [] for p in args.guide_image_path: guide_images.extend(load_images(p)) - print(f"loaded {len(guide_images)} guide images for guidance") + logger.info(f"loaded {len(guide_images)} guide images for guidance") if len(guide_images) == 0: - print(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}") + logger.warning(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}") guide_images = None else: guide_images = None @@ -1938,7 +1942,7 @@ def resize_images(imgs, size): max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples for gen_iter in range(args.n_iter): - print(f"iteration {gen_iter+1}/{args.n_iter}") + logger.info(f"iteration {gen_iter+1}/{args.n_iter}") iter_seed = random.randint(0, 0x7FFFFFFF) # バッチ処理の関数 @@ -1950,7 +1954,7 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): # 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す is_1st_latent = upscaler.support_latents() if upscaler else args.highres_fix_latents_upscaling - print("process 1st stage") + logger.info("process 1st stage") batch_1st = [] for _, base, ext in batch: @@ -1995,7 +1999,7 @@ def scale_and_round(x): images_1st = process_batch(batch_1st, True, True) # 2nd stageのバッチを作成して以下処理する - print("process 2nd stage") + logger.info("process 2nd stage") width_2nd, height_2nd = batch[0].ext.width, batch[0].ext.height if upscaler: @@ -2161,7 +2165,7 @@ def scale_and_round(x): n.restore_weights() for n in networks: n.pre_calculation() - print("pre-calculation... done") + logger.info("pre-calculation... done") images = pipe( prompts, @@ -2240,7 +2244,7 @@ def scale_and_round(x): cv2.waitKey() cv2.destroyAllWindows() except ImportError: - print("opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません") + logger.error("opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません") return images @@ -2253,7 +2257,8 @@ def scale_and_round(x): # interactive valid = False while not valid: - print("\nType prompt:") + logger.info("") + logger.info("Type prompt:") try: raw_prompt = input() except EOFError: @@ -2302,74 +2307,74 @@ def scale_and_round(x): prompt_args = raw_prompt.strip().split(" --") prompt = prompt_args[0] - print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") + logger.info(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") for parg in prompt_args[1:]: try: m = re.match(r"w (\d+)", parg, re.IGNORECASE) if m: width = int(m.group(1)) - print(f"width: {width}") + logger.info(f"width: {width}") continue m = re.match(r"h (\d+)", parg, re.IGNORECASE) if m: height = int(m.group(1)) - print(f"height: {height}") + logger.info(f"height: {height}") continue m = re.match(r"ow (\d+)", parg, re.IGNORECASE) if m: original_width = int(m.group(1)) - print(f"original width: {original_width}") + logger.info(f"original width: {original_width}") continue m = re.match(r"oh (\d+)", parg, re.IGNORECASE) if m: original_height = int(m.group(1)) - print(f"original height: {original_height}") + logger.info(f"original height: {original_height}") continue m = re.match(r"nw (\d+)", parg, re.IGNORECASE) if m: original_width_negative = int(m.group(1)) - print(f"original width negative: {original_width_negative}") + logger.info(f"original width negative: {original_width_negative}") continue m = re.match(r"nh (\d+)", parg, re.IGNORECASE) if m: original_height_negative = int(m.group(1)) - print(f"original height negative: {original_height_negative}") + logger.info(f"original height negative: {original_height_negative}") continue m = re.match(r"ct (\d+)", parg, re.IGNORECASE) if m: crop_top = int(m.group(1)) - print(f"crop top: {crop_top}") + logger.info(f"crop top: {crop_top}") continue m = re.match(r"cl (\d+)", parg, re.IGNORECASE) if m: crop_left = int(m.group(1)) - print(f"crop left: {crop_left}") + logger.info(f"crop left: {crop_left}") continue m = re.match(r"s (\d+)", parg, re.IGNORECASE) if m: # steps steps = max(1, min(1000, int(m.group(1)))) - print(f"steps: {steps}") + logger.info(f"steps: {steps}") continue m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE) if m: # seed seeds = [int(d) for d in m.group(1).split(",")] - print(f"seeds: {seeds}") + logger.info(f"seeds: {seeds}") continue m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) if m: # scale scale = float(m.group(1)) - print(f"scale: {scale}") + logger.info(f"scale: {scale}") continue m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE) @@ -2378,25 +2383,25 @@ def scale_and_round(x): negative_scale = None else: negative_scale = float(m.group(1)) - print(f"negative scale: {negative_scale}") + logger.info(f"negative scale: {negative_scale}") continue m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE) if m: # strength strength = float(m.group(1)) - print(f"strength: {strength}") + logger.info(f"strength: {strength}") continue m = re.match(r"n (.+)", parg, re.IGNORECASE) if m: # negative prompt negative_prompt = m.group(1) - print(f"negative prompt: {negative_prompt}") + logger.info(f"negative prompt: {negative_prompt}") continue m = re.match(r"c (.+)", parg, re.IGNORECASE) if m: # clip prompt clip_prompt = m.group(1) - print(f"clip prompt: {clip_prompt}") + logger.info(f"clip prompt: {clip_prompt}") continue m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE) @@ -2404,47 +2409,47 @@ def scale_and_round(x): network_muls = [float(v) for v in m.group(1).split(",")] while len(network_muls) < len(networks): network_muls.append(network_muls[-1]) - print(f"network mul: {network_muls}") + logger.info(f"network mul: {network_muls}") continue # Deep Shrink m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE) if m: # deep shrink depth 1 ds_depth_1 = int(m.group(1)) - print(f"deep shrink depth 1: {ds_depth_1}") + logger.info(f"deep shrink depth 1: {ds_depth_1}") continue m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE) if m: # deep shrink timesteps 1 ds_timesteps_1 = int(m.group(1)) ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - print(f"deep shrink timesteps 1: {ds_timesteps_1}") + logger.info(f"deep shrink timesteps 1: {ds_timesteps_1}") continue m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE) if m: # deep shrink depth 2 ds_depth_2 = int(m.group(1)) ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - print(f"deep shrink depth 2: {ds_depth_2}") + logger.info(f"deep shrink depth 2: {ds_depth_2}") continue m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE) if m: # deep shrink timesteps 2 ds_timesteps_2 = int(m.group(1)) ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - print(f"deep shrink timesteps 2: {ds_timesteps_2}") + logger.info(f"deep shrink timesteps 2: {ds_timesteps_2}") continue m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE) if m: # deep shrink ratio ds_ratio = float(m.group(1)) ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - print(f"deep shrink ratio: {ds_ratio}") + logger.info(f"deep shrink ratio: {ds_ratio}") continue except ValueError as ex: - print(f"Exception in parsing / 解析エラー: {parg}") - print(ex) + logger.error(f"Exception in parsing / 解析エラー: {parg}") + logger.error(f"{ex}") # override Deep Shrink if ds_depth_1 is not None: @@ -2462,7 +2467,7 @@ def scale_and_round(x): if len(predefined_seeds) > 0: seed = predefined_seeds.pop(0) else: - print("predefined seeds are exhausted") + logger.error("predefined seeds are exhausted") seed = None elif args.iter_same_seed: seeds = iter_seed @@ -2472,7 +2477,7 @@ def scale_and_round(x): if seed is None: seed = random.randint(0, 0x7FFFFFFF) if args.interactive: - print(f"seed: {seed}") + logger.info(f"seed: {seed}") # prepare init image, guide image and mask init_image = mask_image = guide_image = None @@ -2488,7 +2493,7 @@ def scale_and_round(x): width = width - width % 32 height = height - height % 32 if width != init_image.size[0] or height != init_image.size[1]: - print( + logger.warning( f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます" ) @@ -2548,7 +2553,7 @@ def scale_and_round(x): process_batch(batch_data, highres_fix) batch_data.clear() - print("done!") + logger.info("done!") def setup_parser() -> argparse.ArgumentParser: diff --git a/sdxl_minimal_inference.py b/sdxl_minimal_inference.py index 15a70678f..47bc3d986 100644 --- a/sdxl_minimal_inference.py +++ b/sdxl_minimal_inference.py @@ -23,6 +23,10 @@ from library import model_util, sdxl_model_util import networks.lora as lora +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) # scheduler: このあたりの設定はSD1/2と同じでいいらしい # scheduler: The settings around here seem to be the same as SD1/2 @@ -140,7 +144,7 @@ def get_timestep_embedding(x, outdim): vae_dtype = DTYPE if DTYPE == torch.float16: - print("use float32 for vae") + logger.info("use float32 for vae") vae_dtype = torch.float32 vae.to(DEVICE, dtype=vae_dtype) vae.eval() @@ -187,7 +191,7 @@ def generate_image(prompt, prompt2, negative_prompt, seed=None): emb1 = get_timestep_embedding(torch.FloatTensor([original_height, original_width]).unsqueeze(0), 256) emb2 = get_timestep_embedding(torch.FloatTensor([crop_top, crop_left]).unsqueeze(0), 256) emb3 = get_timestep_embedding(torch.FloatTensor([target_height, target_width]).unsqueeze(0), 256) - # print("emb1", emb1.shape) + # logger.info("emb1", emb1.shape) c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(DEVICE, dtype=DTYPE) uc_vector = c_vector.clone().to(DEVICE, dtype=DTYPE) # ちょっとここ正しいかどうかわからない I'm not sure if this is right @@ -217,7 +221,7 @@ def call_text_encoder(text, text2): enc_out = text_model2(tokens, output_hidden_states=True, return_dict=True) text_embedding2_penu = enc_out["hidden_states"][-2] - # print("hidden_states2", text_embedding2_penu.shape) + # logger.info("hidden_states2", text_embedding2_penu.shape) text_embedding2_pool = enc_out["text_embeds"] # do not support Textual Inversion # 連結して終了 concat and finish @@ -226,7 +230,7 @@ def call_text_encoder(text, text2): # cond c_ctx, c_ctx_pool = call_text_encoder(prompt, prompt2) - # print(c_ctx.shape, c_ctx_p.shape, c_vector.shape) + # logger.info(c_ctx.shape, c_ctx_p.shape, c_vector.shape) c_vector = torch.cat([c_ctx_pool, c_vector], dim=1) # uncond @@ -323,4 +327,4 @@ def call_text_encoder(text, text2): seed = int(seed) generate_image(prompt, prompt2, negative_prompt, seed) - print("Done!") + logger.info("Done!") diff --git a/sdxl_train.py b/sdxl_train.py index a3f6f3a17..8f1740af6 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -20,6 +20,10 @@ from library import sdxl_model_util import library.train_util as train_util +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) import library.config_util as config_util import library.sdxl_train_util as sdxl_train_util from library.config_util import ( @@ -117,18 +121,18 @@ def train(args): if args.dataset_class is None: blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) if args.dataset_config is not None: - print(f"Load dataset config from {args.dataset_config}") + logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) ignored = ["train_data_dir", "in_json"] if any(getattr(args, attr) is not None for attr in ignored): - print( + logger.warning( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) ) else: if use_dreambooth_method: - print("Using DreamBooth method.") + logger.info("Using DreamBooth method.") user_config = { "datasets": [ { @@ -139,7 +143,7 @@ def train(args): ] } else: - print("Training with captions.") + logger.info("Training with captions.") user_config = { "datasets": [ { @@ -169,7 +173,7 @@ def train(args): train_util.debug_dataset(train_dataset_group, True) return if len(train_dataset_group) == 0: - print( + logger.error( "No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。" ) return @@ -185,7 +189,7 @@ def train(args): ), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" # acceleratorを準備する - print("prepare accelerator") + logger.info("prepare accelerator") accelerator = train_util.prepare_accelerator(args) # mixed precisionに対応した型を用意しておき適宜castする @@ -537,7 +541,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): # assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 # assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 # assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 - # print("text encoder outputs verified") + # logger.info("text encoder outputs verified") # get size embeddings orig_size = batch["original_sizes_hw"] @@ -724,7 +728,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): logit_scale, ckpt_info, ) - print("model saved.") + logger.info("model saved.") def setup_parser() -> argparse.ArgumentParser: diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 7a88feb84..36dd72e64 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -45,7 +45,10 @@ apply_debiased_estimation, ) import networks.control_net_lllite_for_train as control_net_lllite_for_train - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) # TODO 他のスクリプトと共通化する def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): @@ -78,11 +81,11 @@ def train(args): # データセットを準備する blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True)) if use_user_config: - print(f"Load dataset config from {args.dataset_config}") + logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) ignored = ["train_data_dir", "conditioning_data_dir"] if any(getattr(args, attr) is not None for attr in ignored): - print( + logger.warning( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) @@ -114,7 +117,7 @@ def train(args): train_util.debug_dataset(train_dataset_group) return if len(train_dataset_group) == 0: - print( + logger.error( "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)" ) return @@ -124,7 +127,7 @@ def train(args): train_dataset_group.is_latent_cacheable() ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" else: - print("WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません") + logger.warning("WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません") if args.cache_text_encoder_outputs: assert ( @@ -132,7 +135,7 @@ def train(args): ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" # acceleratorを準備する - print("prepare accelerator") + logger.info("prepare accelerator") accelerator = train_util.prepare_accelerator(args) is_main_process = accelerator.is_main_process @@ -231,8 +234,8 @@ def train(args): accelerator.print("prepare optimizer, data loader etc.") trainable_params = list(unet.prepare_params()) - print(f"trainable params count: {len(trainable_params)}") - print(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}") + logger.info(f"trainable params count: {len(trainable_params)}") + logger.info(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}") _, _, optimizer = train_util.get_optimizer(args, trainable_params) @@ -324,7 +327,7 @@ def train(args): accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") accelerator.print(f" num epochs / epoch数: {num_train_epochs}") accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") - # print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + # logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") @@ -548,7 +551,7 @@ def remove_model(old_ckpt_name): ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) save_model(ckpt_name, unet, global_step, num_train_epochs, force_sync_upload=True) - print("model saved.") + logger.info("model saved.") def setup_parser() -> argparse.ArgumentParser: diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index b94bf5c1b..554b79cee 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -41,7 +41,10 @@ apply_debiased_estimation, ) import networks.control_net_lllite as control_net_lllite - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) # TODO 他のスクリプトと共通化する def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): @@ -74,11 +77,11 @@ def train(args): # データセットを準備する blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True)) if use_user_config: - print(f"Load dataset config from {args.dataset_config}") + logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) ignored = ["train_data_dir", "conditioning_data_dir"] if any(getattr(args, attr) is not None for attr in ignored): - print( + logger.warning( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) @@ -110,7 +113,7 @@ def train(args): train_util.debug_dataset(train_dataset_group) return if len(train_dataset_group) == 0: - print( + logger.error( "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)" ) return @@ -120,7 +123,7 @@ def train(args): train_dataset_group.is_latent_cacheable() ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" else: - print("WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません") + logger.warning("WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません") if args.cache_text_encoder_outputs: assert ( @@ -128,7 +131,7 @@ def train(args): ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" # acceleratorを準備する - print("prepare accelerator") + logger.info("prepare accelerator") accelerator = train_util.prepare_accelerator(args) is_main_process = accelerator.is_main_process @@ -199,8 +202,8 @@ def train(args): accelerator.print("prepare optimizer, data loader etc.") trainable_params = list(network.prepare_optimizer_params()) - print(f"trainable params count: {len(trainable_params)}") - print(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}") + logger.info(f"trainable params count: {len(trainable_params)}") + logger.info(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}") _, _, optimizer = train_util.get_optimizer(args, trainable_params) @@ -297,7 +300,7 @@ def train(args): accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") accelerator.print(f" num epochs / epoch数: {num_train_epochs}") accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") - # print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + # logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") @@ -516,7 +519,7 @@ def remove_model(old_ckpt_name): ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True) - print("model saved.") + logger.info("model saved.") def setup_parser() -> argparse.ArgumentParser: diff --git a/sdxl_train_network.py b/sdxl_train_network.py index 5d363280d..da006a5fa 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -7,7 +7,10 @@ from library import sdxl_model_util, sdxl_train_util, train_util import train_network - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) class SdxlNetworkTrainer(train_network.NetworkTrainer): def __init__(self): @@ -60,7 +63,7 @@ def cache_text_encoder_outputs_if_needed( if args.cache_text_encoder_outputs: if not args.lowram: # メモリ消費を減らす - print("move vae and unet to cpu to save memory") + logger.info("move vae and unet to cpu to save memory") org_vae_device = vae.device org_unet_device = unet.device vae.to("cpu") @@ -85,7 +88,7 @@ def cache_text_encoder_outputs_if_needed( torch.cuda.empty_cache() if not args.lowram: - print("move vae and unet back to original device") + logger.info("move vae and unet back to original device") vae.to(org_vae_device) unet.to(org_unet_device) else: @@ -143,7 +146,7 @@ def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, wei # assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 # assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 # assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 - # print("text encoder outputs verified") + # logger.info("text encoder outputs verified") return encoder_hidden_states1, encoder_hidden_states2, pool2 diff --git a/tools/cache_latents.py b/tools/cache_latents.py index 17916ef70..e25506e41 100644 --- a/tools/cache_latents.py +++ b/tools/cache_latents.py @@ -16,7 +16,10 @@ ConfigSanitizer, BlueprintGenerator, ) - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def cache_to_disk(args: argparse.Namespace) -> None: train_util.prepare_dataset_args(args, True) @@ -41,18 +44,18 @@ def cache_to_disk(args: argparse.Namespace) -> None: if args.dataset_class is None: blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) if args.dataset_config is not None: - print(f"Load dataset config from {args.dataset_config}") + logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) ignored = ["train_data_dir", "in_json"] if any(getattr(args, attr) is not None for attr in ignored): - print( + logger.warning( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) ) else: if use_dreambooth_method: - print("Using DreamBooth method.") + logger.info("Using DreamBooth method.") user_config = { "datasets": [ { @@ -63,7 +66,7 @@ def cache_to_disk(args: argparse.Namespace) -> None: ] } else: - print("Training with captions.") + logger.info("Training with captions.") user_config = { "datasets": [ { @@ -90,7 +93,7 @@ def cache_to_disk(args: argparse.Namespace) -> None: collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) # acceleratorを準備する - print("prepare accelerator") + logger.info("prepare accelerator") accelerator = train_util.prepare_accelerator(args) # mixed precisionに対応した型を用意しておき適宜castする @@ -98,7 +101,7 @@ def cache_to_disk(args: argparse.Namespace) -> None: vae_dtype = torch.float32 if args.no_half_vae else weight_dtype # モデルを読み込む - print("load model") + logger.info("load model") if args.sdxl: (_, _, _, vae, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype) else: @@ -152,7 +155,7 @@ def cache_to_disk(args: argparse.Namespace) -> None: if args.skip_existing: if train_util.is_disk_cached_latents_is_expected(image_info.bucket_reso, image_info.latents_npz, flip_aug): - print(f"Skipping {image_info.latents_npz} because it already exists.") + logger.warning(f"Skipping {image_info.latents_npz} because it already exists.") continue image_infos.append(image_info) diff --git a/tools/cache_text_encoder_outputs.py b/tools/cache_text_encoder_outputs.py index 7d9b13d68..46bffc4e0 100644 --- a/tools/cache_text_encoder_outputs.py +++ b/tools/cache_text_encoder_outputs.py @@ -16,7 +16,10 @@ ConfigSanitizer, BlueprintGenerator, ) - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def cache_to_disk(args: argparse.Namespace) -> None: train_util.prepare_dataset_args(args, True) @@ -48,18 +51,18 @@ def cache_to_disk(args: argparse.Namespace) -> None: if args.dataset_class is None: blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) if args.dataset_config is not None: - print(f"Load dataset config from {args.dataset_config}") + logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) ignored = ["train_data_dir", "in_json"] if any(getattr(args, attr) is not None for attr in ignored): - print( + logger.warning( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) ) else: if use_dreambooth_method: - print("Using DreamBooth method.") + logger.info("Using DreamBooth method.") user_config = { "datasets": [ { @@ -70,7 +73,7 @@ def cache_to_disk(args: argparse.Namespace) -> None: ] } else: - print("Training with captions.") + logger.info("Training with captions.") user_config = { "datasets": [ { @@ -95,14 +98,14 @@ def cache_to_disk(args: argparse.Namespace) -> None: collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) # acceleratorを準備する - print("prepare accelerator") + logger.info("prepare accelerator") accelerator = train_util.prepare_accelerator(args) # mixed precisionに対応した型を用意しておき適宜castする weight_dtype, _ = train_util.prepare_dtype(args) # モデルを読み込む - print("load model") + logger.info("load model") if args.sdxl: (_, text_encoder1, text_encoder2, _, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype) text_encoders = [text_encoder1, text_encoder2] @@ -147,7 +150,7 @@ def cache_to_disk(args: argparse.Namespace) -> None: if args.skip_existing: if os.path.exists(image_info.text_encoder_outputs_npz): - print(f"Skipping {image_info.text_encoder_outputs_npz} because it already exists.") + logger.warning(f"Skipping {image_info.text_encoder_outputs_npz} because it already exists.") continue image_info.input_ids1 = input_ids1 diff --git a/tools/canny.py b/tools/canny.py index 5e0806898..f2190975c 100644 --- a/tools/canny.py +++ b/tools/canny.py @@ -1,6 +1,10 @@ import argparse import cv2 +import logging +from library.utils import setup_logging +setup_logging() +logger = logging.getLogger(__name__) def canny(args): img = cv2.imread(args.input) @@ -10,7 +14,7 @@ def canny(args): # canny_img = 255 - canny_img cv2.imwrite(args.output, canny_img) - print("done!") + logger.info("done!") def setup_parser() -> argparse.ArgumentParser: diff --git a/tools/convert_diffusers20_original_sd.py b/tools/convert_diffusers20_original_sd.py index fe30996aa..572ee2f0c 100644 --- a/tools/convert_diffusers20_original_sd.py +++ b/tools/convert_diffusers20_original_sd.py @@ -6,7 +6,10 @@ from diffusers import StableDiffusionPipeline import library.model_util as model_util - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def convert(args): # 引数を確認する @@ -30,7 +33,7 @@ def convert(args): # モデルを読み込む msg = "checkpoint" if is_load_ckpt else ("Diffusers" + (" as fp16" if args.fp16 else "")) - print(f"loading {msg}: {args.model_to_load}") + logger.info(f"loading {msg}: {args.model_to_load}") if is_load_ckpt: v2_model = args.v2 @@ -48,13 +51,13 @@ def convert(args): if args.v1 == args.v2: # 自動判定する v2_model = unet.config.cross_attention_dim == 1024 - print("checking model version: model is " + ("v2" if v2_model else "v1")) + logger.info("checking model version: model is " + ("v2" if v2_model else "v1")) else: v2_model = not args.v1 # 変換して保存する msg = ("checkpoint" + ("" if save_dtype is None else f" in {save_dtype}")) if is_save_ckpt else "Diffusers" - print(f"converting and saving as {msg}: {args.model_to_save}") + logger.info(f"converting and saving as {msg}: {args.model_to_save}") if is_save_ckpt: original_model = args.model_to_load if is_load_ckpt else None @@ -70,15 +73,15 @@ def convert(args): save_dtype=save_dtype, vae=vae, ) - print(f"model saved. total converted state_dict keys: {key_count}") + logger.info(f"model saved. total converted state_dict keys: {key_count}") else: - print( + logger.info( f"copy scheduler/tokenizer config from: {args.reference_model if args.reference_model is not None else 'default model'}" ) model_util.save_diffusers_checkpoint( v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors ) - print("model saved.") + logger.info("model saved.") def setup_parser() -> argparse.ArgumentParser: diff --git a/tools/detect_face_rotate.py b/tools/detect_face_rotate.py index 68dec6cae..bbc643edc 100644 --- a/tools/detect_face_rotate.py +++ b/tools/detect_face_rotate.py @@ -15,6 +15,10 @@ from anime_face_detector import create_detector from tqdm import tqdm import numpy as np +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) KP_REYE = 11 KP_LEYE = 19 @@ -24,7 +28,7 @@ def detect_faces(detector, image, min_size): preds = detector(image) # bgr - # print(len(preds)) + # logger.info(len(preds)) faces = [] for pred in preds: @@ -78,7 +82,7 @@ def process(args): assert args.crop_ratio is None or args.resize_face_size is None, f"crop_ratio指定時はresize_face_sizeは指定できません" # アニメ顔検出モデルを読み込む - print("loading face detector.") + logger.info("loading face detector.") detector = create_detector('yolov3') # cropの引数を解析する @@ -97,7 +101,7 @@ def process(args): crop_h_ratio, crop_v_ratio = [float(t) for t in tokens] # 画像を処理する - print("processing.") + logger.info("processing.") output_extension = ".png" os.makedirs(args.dst_dir, exist_ok=True) @@ -111,7 +115,7 @@ def process(args): if len(image.shape) == 2: image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) if image.shape[2] == 4: - print(f"image has alpha. ignore / 画像の透明度が設定されているため無視します: {path}") + logger.warning(f"image has alpha. ignore / 画像の透明度が設定されているため無視します: {path}") image = image[:, :, :3].copy() # copyをしないと内部的に透明度情報が付いたままになるらしい h, w = image.shape[:2] @@ -144,11 +148,11 @@ def process(args): # 顔サイズを基準にリサイズする scale = args.resize_face_size / face_size if scale < cur_crop_width / w: - print( + logger.warning( f"image width too small in face size based resizing / 顔を基準にリサイズすると画像の幅がcrop sizeより小さい(顔が相対的に大きすぎる)ので顔サイズが変わります: {path}") scale = cur_crop_width / w if scale < cur_crop_height / h: - print( + logger.warning( f"image height too small in face size based resizing / 顔を基準にリサイズすると画像の高さがcrop sizeより小さい(顔が相対的に大きすぎる)ので顔サイズが変わります: {path}") scale = cur_crop_height / h elif crop_h_ratio is not None: @@ -157,10 +161,10 @@ def process(args): else: # 切り出しサイズ指定あり if w < cur_crop_width: - print(f"image width too small/ 画像の幅がcrop sizeより小さいので画質が劣化します: {path}") + logger.warning(f"image width too small/ 画像の幅がcrop sizeより小さいので画質が劣化します: {path}") scale = cur_crop_width / w if h < cur_crop_height: - print(f"image height too small/ 画像の高さがcrop sizeより小さいので画質が劣化します: {path}") + logger.warning(f"image height too small/ 画像の高さがcrop sizeより小さいので画質が劣化します: {path}") scale = cur_crop_height / h if args.resize_fit: scale = max(cur_crop_width / w, cur_crop_height / h) @@ -198,7 +202,7 @@ def process(args): face_img = face_img[y:y + cur_crop_height] # # debug - # print(path, cx, cy, angle) + # logger.info(path, cx, cy, angle) # crp = cv2.resize(image, (image.shape[1]//8, image.shape[0]//8)) # cv2.imshow("image", crp) # if cv2.waitKey() == 27: diff --git a/tools/latent_upscaler.py b/tools/latent_upscaler.py index ab1fa3390..4fa599e29 100644 --- a/tools/latent_upscaler.py +++ b/tools/latent_upscaler.py @@ -14,7 +14,10 @@ from torch import nn from tqdm import tqdm from PIL import Image - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) class ResidualBlock(nn.Module): def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=1): @@ -216,7 +219,7 @@ def upscale( upsampled_images = upsampled_images / 127.5 - 1.0 # convert upsample images to latents with batch size - # print("Encoding upsampled (LANCZOS4) images...") + # logger.info("Encoding upsampled (LANCZOS4) images...") upsampled_latents = [] for i in tqdm(range(0, upsampled_images.shape[0], vae_batch_size)): batch = upsampled_images[i : i + vae_batch_size].to(vae.device) @@ -227,7 +230,7 @@ def upscale( upsampled_latents = torch.cat(upsampled_latents, dim=0) # upscale (refine) latents with this model with batch size - print("Upscaling latents...") + logger.info("Upscaling latents...") upscaled_latents = [] for i in range(0, upsampled_latents.shape[0], batch_size): with torch.no_grad(): @@ -242,7 +245,7 @@ def create_upscaler(**kwargs): weights = kwargs["weights"] model = Upscaler() - print(f"Loading weights from {weights}...") + logger.info(f"Loading weights from {weights}...") if os.path.splitext(weights)[1] == ".safetensors": from safetensors.torch import load_file @@ -261,14 +264,14 @@ def upscale_images(args: argparse.Namespace): # load VAE with Diffusers assert args.vae_path is not None, "VAE path is required" - print(f"Loading VAE from {args.vae_path}...") + logger.info(f"Loading VAE from {args.vae_path}...") vae = AutoencoderKL.from_pretrained(args.vae_path, subfolder="vae") vae.to(DEVICE, dtype=us_dtype) # prepare model - print("Preparing model...") + logger.info("Preparing model...") upscaler: Upscaler = create_upscaler(weights=args.weights) - # print("Loading weights from", args.weights) + # logger.info("Loading weights from", args.weights) # upscaler.load_state_dict(torch.load(args.weights)) upscaler.eval() upscaler.to(DEVICE, dtype=us_dtype) @@ -303,14 +306,14 @@ def upscale_images(args: argparse.Namespace): image_debug.save(dest_file_name) # upscale - print("Upscaling...") + logger.info("Upscaling...") upscaled_latents = upscaler.upscale( vae, images, None, us_dtype, width * 2, height * 2, batch_size=args.batch_size, vae_batch_size=args.vae_batch_size ) upscaled_latents /= 0.18215 # decode with batch - print("Decoding...") + logger.info("Decoding...") upscaled_images = [] for i in tqdm(range(0, upscaled_latents.shape[0], args.vae_batch_size)): with torch.no_grad(): diff --git a/tools/merge_models.py b/tools/merge_models.py index 391bfe677..8f1fbf2f8 100644 --- a/tools/merge_models.py +++ b/tools/merge_models.py @@ -5,7 +5,10 @@ from safetensors import safe_open from safetensors.torch import load_file, save_file from tqdm import tqdm - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def is_unet_key(key): # VAE or TextEncoder, the last one is for SDXL @@ -45,10 +48,10 @@ def merge(args): # check if all models are safetensors for model in args.models: if not model.endswith("safetensors"): - print(f"Model {model} is not a safetensors model") + logger.info(f"Model {model} is not a safetensors model") exit() if not os.path.isfile(model): - print(f"Model {model} does not exist") + logger.info(f"Model {model} does not exist") exit() assert args.ratios is None or len(args.models) == len(args.ratios), "ratios must be the same length as models" @@ -65,7 +68,7 @@ def merge(args): if merged_sd is None: # load first model - print(f"Loading model {model}, ratio = {ratio}...") + logger.info(f"Loading model {model}, ratio = {ratio}...") merged_sd = {} with safe_open(model, framework="pt", device=args.device) as f: for key in tqdm(f.keys()): @@ -81,11 +84,11 @@ def merge(args): value = ratio * value.to(dtype) # first model's value * ratio merged_sd[key] = value - print(f"Model has {len(merged_sd)} keys " + ("(UNet only)" if args.unet_only else "")) + logger.info(f"Model has {len(merged_sd)} keys " + ("(UNet only)" if args.unet_only else "")) continue # load other models - print(f"Loading model {model}, ratio = {ratio}...") + logger.info(f"Loading model {model}, ratio = {ratio}...") with safe_open(model, framework="pt", device=args.device) as f: model_keys = f.keys() @@ -93,7 +96,7 @@ def merge(args): _, new_key = replace_text_encoder_key(key) if new_key not in merged_sd: if args.show_skipped and new_key not in first_model_keys: - print(f"Skip: {new_key}") + logger.info(f"Skip: {new_key}") continue value = f.get_tensor(key) @@ -104,7 +107,7 @@ def merge(args): for key in merged_sd.keys(): if key in model_keys: continue - print(f"Key {key} not in model {model}, use first model's value") + logger.warning(f"Key {key} not in model {model}, use first model's value") if key in supplementary_key_ratios: supplementary_key_ratios[key] += ratio else: @@ -112,7 +115,7 @@ def merge(args): # add supplementary keys' value (including VAE and TextEncoder) if len(supplementary_key_ratios) > 0: - print("add first model's value") + logger.info("add first model's value") with safe_open(args.models[0], framework="pt", device=args.device) as f: for key in tqdm(f.keys()): _, new_key = replace_text_encoder_key(key) @@ -120,7 +123,7 @@ def merge(args): continue if is_unet_key(new_key): # not VAE or TextEncoder - print(f"Key {new_key} not in all models, ratio = {supplementary_key_ratios[new_key]}") + logger.warning(f"Key {new_key} not in all models, ratio = {supplementary_key_ratios[new_key]}") value = f.get_tensor(key) # original key @@ -134,7 +137,7 @@ def merge(args): if not output_file.endswith(".safetensors"): output_file = output_file + ".safetensors" - print(f"Saving to {output_file}...") + logger.info(f"Saving to {output_file}...") # convert to save_dtype for k in merged_sd.keys(): @@ -142,7 +145,7 @@ def merge(args): save_file(merged_sd, output_file) - print("Done!") + logger.info("Done!") if __name__ == "__main__": diff --git a/tools/original_control_net.py b/tools/original_control_net.py index cd47bd76a..21392dc47 100644 --- a/tools/original_control_net.py +++ b/tools/original_control_net.py @@ -7,7 +7,10 @@ from library.original_unet import UNet2DConditionModel, SampleOutput import library.model_util as model_util - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) class ControlNetInfo(NamedTuple): unet: Any @@ -51,7 +54,7 @@ def load_control_net(v2, unet, model): # control sdからキー変換しつつU-Netに対応する部分のみ取り出し、DiffusersのU-Netに読み込む # state dictを読み込む - print(f"ControlNet: loading control SD model : {model}") + logger.info(f"ControlNet: loading control SD model : {model}") if model_util.is_safetensors(model): ctrl_sd_sd = load_file(model) @@ -61,7 +64,7 @@ def load_control_net(v2, unet, model): # 重みをU-Netに読み込めるようにする。ControlNetはSD版のstate dictなので、それを読み込む is_difference = "difference" in ctrl_sd_sd - print("ControlNet: loading difference:", is_difference) + logger.info(f"ControlNet: loading difference: {is_difference}") # ControlNetには存在しないキーがあるので、まず現在のU-NetでSD版の全keyを作っておく # またTransfer Controlの元weightとなる @@ -89,13 +92,13 @@ def load_control_net(v2, unet, model): # ControlNetのU-Netを作成する ctrl_unet = UNet2DConditionModel(**unet_config) info = ctrl_unet.load_state_dict(ctrl_unet_du_sd) - print("ControlNet: loading Control U-Net:", info) + logger.info(f"ControlNet: loading Control U-Net: {info}") # U-Net以外のControlNetを作成する # TODO support middle only ctrl_net = ControlNet() info = ctrl_net.load_state_dict(zero_conv_sd) - print("ControlNet: loading ControlNet:", info) + logger.info("ControlNet: loading ControlNet: {info}") ctrl_unet.to(unet.device, dtype=unet.dtype) ctrl_net.to(unet.device, dtype=unet.dtype) @@ -117,7 +120,7 @@ def canny(img): return canny - print("Unsupported prep type:", prep_type) + logger.info(f"Unsupported prep type: {prep_type}") return None @@ -174,7 +177,7 @@ def call_unet_and_control_net( cnet_idx = step % cnet_cnt cnet_info = control_nets[cnet_idx] - # print(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio) + # logger.info(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio) if cnet_info.ratio < current_ratio: return original_unet(sample, timestep, encoder_hidden_states) @@ -192,7 +195,7 @@ def call_unet_and_control_net( # ControlNet cnet_outs_list = [] for i, cnet_info in enumerate(control_nets): - # print(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio) + # logger.info(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio) if cnet_info.ratio < current_ratio: continue guided_hint = guided_hints[i] @@ -232,7 +235,7 @@ def unet_forward( upsample_size = None if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): - print("Forward upsample size to force interpolation output size.") + logger.info("Forward upsample size to force interpolation output size.") forward_upsample_size = True # 1. time diff --git a/tools/resize_images_to_resolution.py b/tools/resize_images_to_resolution.py index 2d3224c4e..b8069fc1d 100644 --- a/tools/resize_images_to_resolution.py +++ b/tools/resize_images_to_resolution.py @@ -6,7 +6,10 @@ import math from PIL import Image import numpy as np - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=2, interpolation=None, save_as_png=False, copy_associated_files=False): # Split the max_resolution string by "," and strip any whitespaces @@ -83,7 +86,7 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi image.save(os.path.join(dst_img_folder, new_filename), quality=100) proc = "Resized" if current_pixels > max_pixels else "Saved" - print(f"{proc} image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}") + logger.info(f"{proc} image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}") # If other files with same basename, copy them with resolution suffix if copy_associated_files: @@ -94,7 +97,7 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi continue for max_resolution in max_resolutions: new_asoc_file = base + '+' + max_resolution + ext - print(f"Copy {asoc_file} as {new_asoc_file}") + logger.info(f"Copy {asoc_file} as {new_asoc_file}") shutil.copy(os.path.join(src_img_folder, asoc_file), os.path.join(dst_img_folder, new_asoc_file)) diff --git a/tools/show_metadata.py b/tools/show_metadata.py index 92ca7b1c8..05bfbe0a4 100644 --- a/tools/show_metadata.py +++ b/tools/show_metadata.py @@ -1,6 +1,10 @@ import json import argparse from safetensors import safe_open +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, required=True) @@ -10,10 +14,10 @@ metadata = f.metadata() if metadata is None: - print("No metadata found") + logger.error("No metadata found") else: # metadata is json dict, but not pretty printed # sort by key and pretty print print(json.dumps(metadata, indent=4, sort_keys=True)) - \ No newline at end of file + diff --git a/train_controlnet.py b/train_controlnet.py index 7b0b2bbfe..fbeb7afdb 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -35,7 +35,10 @@ pyramid_noise_like, apply_noise_offset, ) - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) # TODO 他のスクリプトと共通化する def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): @@ -69,11 +72,11 @@ def train(args): # データセットを準備する blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True)) if use_user_config: - print(f"Load dataset config from {args.dataset_config}") + logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) ignored = ["train_data_dir", "conditioning_data_dir"] if any(getattr(args, attr) is not None for attr in ignored): - print( + logger.warning( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) @@ -103,7 +106,7 @@ def train(args): train_util.debug_dataset(train_dataset_group) return if len(train_dataset_group) == 0: - print( + logger.error( "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)" ) return @@ -114,7 +117,7 @@ def train(args): ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" # acceleratorを準備する - print("prepare accelerator") + logger.info("prepare accelerator") accelerator = train_util.prepare_accelerator(args) is_main_process = accelerator.is_main_process @@ -310,7 +313,7 @@ def train(args): accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") accelerator.print(f" num epochs / epoch数: {num_train_epochs}") accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") - # print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + # logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") @@ -567,7 +570,7 @@ def remove_model(old_ckpt_name): ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) save_model(ckpt_name, controlnet, force_sync_upload=True) - print("model saved.") + logger.info("model saved.") def setup_parser() -> argparse.ArgumentParser: diff --git a/train_db.py b/train_db.py index 888cad25e..984386665 100644 --- a/train_db.py +++ b/train_db.py @@ -35,6 +35,10 @@ scale_v_prediction_loss_like_noise_prediction, apply_debiased_estimation, ) +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) # perlin_noise, @@ -54,11 +58,11 @@ def train(args): if args.dataset_class is None: blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, False, True)) if args.dataset_config is not None: - print(f"Load dataset config from {args.dataset_config}") + logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) ignored = ["train_data_dir", "reg_data_dir"] if any(getattr(args, attr) is not None for attr in ignored): - print( + logger.warning( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) @@ -93,13 +97,13 @@ def train(args): ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" # acceleratorを準備する - print("prepare accelerator") + logger.info("prepare accelerator") if args.gradient_accumulation_steps > 1: - print( + logger.warning( f"gradient_accumulation_steps is {args.gradient_accumulation_steps}. accelerate does not support gradient_accumulation_steps when training multiple models (U-Net and Text Encoder), so something might be wrong" ) - print( + logger.warning( f"gradient_accumulation_stepsが{args.gradient_accumulation_steps}に設定されています。accelerateは複数モデル(U-NetおよびText Encoder)の学習時にgradient_accumulation_stepsをサポートしていないため結果は未知数です" ) @@ -449,7 +453,7 @@ def train(args): train_util.save_sd_model_on_train_end( args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae ) - print("model saved.") + logger.info("model saved.") def setup_parser() -> argparse.ArgumentParser: diff --git a/train_network.py b/train_network.py index 8d102ae8f..99c95aacd 100644 --- a/train_network.py +++ b/train_network.py @@ -41,7 +41,10 @@ add_v_prediction_like_loss, apply_debiased_estimation, ) - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) class NetworkTrainer: def __init__(self): @@ -153,18 +156,18 @@ def train(self, args): if args.dataset_class is None: blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) if use_user_config: - print(f"Loading dataset config from {args.dataset_config}") + logger.info(f"Loading dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) ignored = ["train_data_dir", "reg_data_dir", "in_json"] if any(getattr(args, attr) is not None for attr in ignored): - print( + logger.warning( "ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) ) else: if use_dreambooth_method: - print("Using DreamBooth method.") + logger.info("Using DreamBooth method.") user_config = { "datasets": [ { @@ -175,7 +178,7 @@ def train(self, args): ] } else: - print("Training with captions.") + logger.info("Training with captions.") user_config = { "datasets": [ { @@ -204,7 +207,7 @@ def train(self, args): train_util.debug_dataset(train_dataset_group) return if len(train_dataset_group) == 0: - print( + logger.error( "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)" ) return @@ -217,7 +220,7 @@ def train(self, args): self.assert_extra_args(args, train_dataset_group) # acceleratorを準備する - print("preparing accelerator") + logger.info("preparing accelerator") accelerator = train_util.prepare_accelerator(args) is_main_process = accelerator.is_main_process @@ -310,7 +313,7 @@ def train(self, args): if hasattr(network, "prepare_network"): network.prepare_network(args) if args.scale_weight_norms and not hasattr(network, "apply_max_norm_regularization"): - print( + logger.warning( "warning: scale_weight_norms is specified but the network does not support it / scale_weight_normsが指定されていますが、ネットワークが対応していません" ) args.scale_weight_norms = False @@ -938,7 +941,7 @@ def remove_model(old_ckpt_name): ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True) - print("model saved.") + logger.info("model saved.") def setup_parser() -> argparse.ArgumentParser: diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 441c1e00b..d3d10f176 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -32,6 +32,10 @@ add_v_prediction_like_loss, apply_debiased_estimation, ) +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) imagenet_templates_small = [ "a photo of a {}", @@ -178,7 +182,7 @@ def train(self, args): tokenizers = tokenizer_or_list if isinstance(tokenizer_or_list, list) else [tokenizer_or_list] # acceleratorを準備する - print("prepare accelerator") + logger.info("prepare accelerator") accelerator = train_util.prepare_accelerator(args) # mixed precisionに対応した型を用意しておき適宜castする @@ -288,7 +292,7 @@ def train(self, args): ] } else: - print("Train with captions.") + logger.info("Train with captions.") user_config = { "datasets": [ { @@ -736,7 +740,7 @@ def remove_model(old_ckpt_name): ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) save_model(ckpt_name, updated_embs_list, global_step, num_train_epochs, force_sync_upload=True) - print("model saved.") + logger.info("model saved.") def setup_parser() -> argparse.ArgumentParser: diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 7046a4808..45ced3cfa 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -36,6 +36,10 @@ ) import library.original_unet as original_unet from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) imagenet_templates_small = [ "a photo of a {}", @@ -99,7 +103,7 @@ def train(args): train_util.prepare_dataset_args(args, True) if args.sample_every_n_steps is not None or args.sample_every_n_epochs is not None: - print( + logger.warning( "sample_every_n_steps and sample_every_n_epochs are not supported in this script currently / sample_every_n_stepsとsample_every_n_epochsは現在このスクリプトではサポートされていません" ) assert ( @@ -114,7 +118,7 @@ def train(args): tokenizer = train_util.load_tokenizer(args) # acceleratorを準備する - print("prepare accelerator") + logger.info("prepare accelerator") accelerator = train_util.prepare_accelerator(args) # mixed precisionに対応した型を用意しておき適宜castする @@ -127,7 +131,7 @@ def train(args): if args.init_word is not None: init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False) if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token: - print( + logger.warning( f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / 初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: length {len(init_token_ids)}" ) else: @@ -141,7 +145,7 @@ def train(args): ), f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: {args.token_string}" token_ids = tokenizer.convert_tokens_to_ids(token_strings) - print(f"tokens are added: {token_ids}") + logger.info(f"tokens are added: {token_ids}") assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered" assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}" @@ -169,7 +173,7 @@ def train(args): tokenizer.add_tokens(token_strings_XTI) token_ids_XTI = tokenizer.convert_tokens_to_ids(token_strings_XTI) - print(f"tokens are added (XTI): {token_ids_XTI}") + logger.info(f"tokens are added (XTI): {token_ids_XTI}") # Resize the token embeddings as we are adding new special tokens to the tokenizer text_encoder.resize_token_embeddings(len(tokenizer)) @@ -178,7 +182,7 @@ def train(args): if init_token_ids is not None: for i, token_id in enumerate(token_ids_XTI): token_embeds[token_id] = token_embeds[init_token_ids[(i // 16) % len(init_token_ids)]] - # print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) + # logger.info(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) # load weights if args.weights is not None: @@ -186,22 +190,22 @@ def train(args): assert len(token_ids) == len( embeddings ), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}" - # print(token_ids, embeddings.size()) + # logger.info(token_ids, embeddings.size()) for token_id, embedding in zip(token_ids_XTI, embeddings): token_embeds[token_id] = embedding - # print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) - print(f"weighs loaded") + # logger.info(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) + logger.info(f"weighs loaded") - print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}") + logger.info(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}") # データセットを準備する blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, False)) if args.dataset_config is not None: - print(f"Load dataset config from {args.dataset_config}") + logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) ignored = ["train_data_dir", "reg_data_dir", "in_json"] if any(getattr(args, attr) is not None for attr in ignored): - print( + logger.info( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) @@ -209,14 +213,14 @@ def train(args): else: use_dreambooth_method = args.in_json is None if use_dreambooth_method: - print("Use DreamBooth method.") + logger.info("Use DreamBooth method.") user_config = { "datasets": [ {"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)} ] } else: - print("Train with captions.") + logger.info("Train with captions.") user_config = { "datasets": [ { @@ -240,7 +244,7 @@ def train(args): # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装 if use_template: - print(f"use template for training captions. is object: {args.use_object_template}") + logger.info(f"use template for training captions. is object: {args.use_object_template}") templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small replace_to = " ".join(token_strings) captions = [] @@ -264,7 +268,7 @@ def train(args): train_util.debug_dataset(train_dataset_group, show_input_ids=True) return if len(train_dataset_group) == 0: - print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください") + logger.error("No data found. Please verify arguments / 画像がありません。引数指定を確認してください") return if cache_latents: @@ -297,7 +301,7 @@ def train(args): text_encoder.gradient_checkpointing_enable() # 学習に必要なクラスを準備する - print("prepare optimizer, data loader etc.") + logger.info("prepare optimizer, data loader etc.") trainable_params = text_encoder.get_input_embeddings().parameters() _, _, optimizer = train_util.get_optimizer(args, trainable_params) @@ -318,7 +322,7 @@ def train(args): args.max_train_steps = args.max_train_epochs * math.ceil( len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps ) - print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + logger.info(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) @@ -332,7 +336,7 @@ def train(args): ) index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0] - # print(len(index_no_updates), torch.sum(index_no_updates)) + # logger.info(len(index_no_updates), torch.sum(index_no_updates)) orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone() # Freeze all parameters except for the token embeddings in text encoder @@ -370,15 +374,15 @@ def train(args): # 学習する total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - print("running training / 学習開始") - print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") - print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") - print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") - print(f" num epochs / epoch数: {num_train_epochs}") - print(f" batch size per device / バッチサイズ: {args.train_batch_size}") - print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") - print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") - print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + logger.info("running training / 学習開始") + logger.info(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + logger.info(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") + logger.info(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + logger.info(f" num epochs / epoch数: {num_train_epochs}") + logger.info(f" batch size per device / バッチサイズ: {args.train_batch_size}") + logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + logger.info(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + logger.info(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") global_step = 0 @@ -403,7 +407,8 @@ def save_model(ckpt_name, embs, steps, epoch_no, force_sync_upload=False): os.makedirs(args.output_dir, exist_ok=True) ckpt_file = os.path.join(args.output_dir, ckpt_name) - print(f"\nsaving checkpoint: {ckpt_file}") + logger.info("") + logger.info(f"saving checkpoint: {ckpt_file}") save_weights(ckpt_file, embs, save_dtype) if args.huggingface_repo_id is not None: huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) @@ -411,12 +416,13 @@ def save_model(ckpt_name, embs, steps, epoch_no, force_sync_upload=False): def remove_model(old_ckpt_name): old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) if os.path.exists(old_ckpt_file): - print(f"removing old checkpoint: {old_ckpt_file}") + logger.info(f"removing old checkpoint: {old_ckpt_file}") os.remove(old_ckpt_file) # training loop for epoch in range(num_train_epochs): - print(f"\nepoch {epoch+1}/{num_train_epochs}") + logger.info("") + logger.info(f"epoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 text_encoder.train() @@ -586,7 +592,7 @@ def remove_model(old_ckpt_name): ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) save_model(ckpt_name, updated_embs, global_step, num_train_epochs, force_sync_upload=True) - print("model saved.") + logger.info("model saved.") def save_weights(file, updated_embs, save_dtype): From 6279b33736f3e41b43cd271b12a66c3de76dffda Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 4 Feb 2024 18:28:54 +0900 Subject: [PATCH 048/103] fallback to basic logging if rich is not installed --- library/utils.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/library/utils.py b/library/utils.py index ebcaba308..1b5d7eb66 100644 --- a/library/utils.py +++ b/library/utils.py @@ -1,4 +1,5 @@ import logging +import sys import threading from typing import * @@ -11,9 +12,15 @@ def setup_logging(log_level=logging.INFO): if logging.root.handlers: # Already configured return - from rich.logging import RichHandler + try: + from rich.logging import RichHandler + + handler = RichHandler() + except ImportError: + print("rich is not installed, using basic logging") + handler = logging.StreamHandler(sys.stdout) # same as print + handler.propagate = False - handler = RichHandler() formatter = logging.Formatter( fmt="%(message)s", datefmt="%Y-%m-%d %H:%M:%S", From efd3b5897304d312b619c5d341f63d8b060df226 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 4 Feb 2024 20:44:10 +0900 Subject: [PATCH 049/103] Add logging arguments and update logging setup --- fine_tune.py | 16 +++++++++---- library/utils.py | 59 ++++++++++++++++++++++++++++++++++++++---------- 2 files changed, 59 insertions(+), 16 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index a75c46735..86f5b253f 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -18,9 +18,11 @@ from accelerate.utils import set_seed from diffusers import DDPMScheduler -from library.utils import setup_logging +from library.utils import setup_logging, add_logging_arguments + setup_logging() import logging + logger = logging.getLogger(__name__) import library.train_util as train_util import library.config_util as config_util @@ -41,6 +43,7 @@ def train(args): train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) + setup_logging(args, reset=True) cache_latents = args.cache_latents @@ -227,7 +230,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): args.max_train_steps = args.max_train_epochs * math.ceil( len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps ) - accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) @@ -291,7 +296,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): if accelerator.is_main_process: init_kwargs = {} if args.wandb_run_name: - init_kwargs['wandb'] = {'name': args.wandb_run_name} + init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) @@ -471,6 +476,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() + add_logging_arguments(parser) # モジュール指定がないのがちょっと気持ち悪い / bit weird that this does not have module prefix train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, False, True, True) train_util.add_training_arguments(parser, False) @@ -479,7 +485,9 @@ def setup_parser() -> argparse.ArgumentParser: config_util.add_config_arguments(parser) custom_train_functions.add_custom_train_arguments(parser) - parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する") + parser.add_argument( + "--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する" + ) parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") parser.add_argument( "--learning_rate_te", diff --git a/library/utils.py b/library/utils.py index 1b5d7eb66..af0563dfa 100644 --- a/library/utils.py +++ b/library/utils.py @@ -8,18 +8,53 @@ def fire_in_thread(f, *args, **kwargs): threading.Thread(target=f, args=args, kwargs=kwargs).start() -def setup_logging(log_level=logging.INFO): - if logging.root.handlers: # Already configured - return - - try: - from rich.logging import RichHandler - - handler = RichHandler() - except ImportError: - print("rich is not installed, using basic logging") - handler = logging.StreamHandler(sys.stdout) # same as print - handler.propagate = False +def add_logging_arguments(parser): + parser.add_argument( + "--console_log_level", + type=str, + default=None, + choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], + help="Set the logging level, default is INFO / ログレベルを設定する。デフォルトはINFO", + ) + parser.add_argument( + "--console_log_file", + type=str, + default=None, + help="Log to a file instead of stdout / 標準出力ではなくファイルにログを出力する", + ) + parser.add_argument("--console_log_simple", action="store_true", help="Simple log output / シンプルなログ出力") + + +def setup_logging(args=None, log_level=None, reset=False): + if logging.root.handlers: + if reset: + # remove all handlers + for handler in logging.root.handlers[:]: + logging.root.removeHandler(handler) + else: + return + + if log_level is None and args is not None: + log_level = args.console_log_level + if log_level is None: + log_level = "INFO" + log_level = getattr(logging, log_level) + + if args is not None and args.console_log_file: + handler = logging.FileHandler(args.console_log_file, mode="w") + else: + handler = None + if not args or not args.console_log_simple: + try: + from rich.logging import RichHandler + + handler = RichHandler() + except ImportError: + print("rich is not installed, using basic logging") + + if handler is None: + handler = logging.StreamHandler(sys.stdout) # same as print + handler.propagate = False formatter = logging.Formatter( fmt="%(message)s", From 74fe0453b204e5b2c718e821c692595ff2e76c35 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 8 Feb 2024 20:58:54 +0900 Subject: [PATCH 050/103] add comment for get_preferred_device --- library/device_utils.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/library/device_utils.py b/library/device_utils.py index 93371ca64..546fb386c 100644 --- a/library/device_utils.py +++ b/library/device_utils.py @@ -15,6 +15,7 @@ try: import intel_extension_for_pytorch as ipex # noqa + HAS_XPU = torch.xpu.is_available() except Exception: HAS_XPU = False @@ -32,6 +33,9 @@ def clean_memory(): @functools.lru_cache(maxsize=None) def get_preferred_device() -> torch.device: + r""" + Do not call this function from training scripts. Use accelerator.device instead. + """ if HAS_CUDA: device = torch.device("cuda") elif HAS_XPU: @@ -43,6 +47,7 @@ def get_preferred_device() -> torch.device: print(f"get_preferred_device() -> {device}") return device + def init_ipex(): """ Apply IPEX to CUDA hijacks using `library.ipex.ipex_init`. @@ -54,10 +59,11 @@ def init_ipex(): try: if HAS_XPU: from library.ipex import ipex_init + is_initialized, error_message = ipex_init() if not is_initialized: print("failed to initialize ipex:", error_message) else: return except Exception as e: - print("failed to initialize ipex:", e) \ No newline at end of file + print("failed to initialize ipex:", e) From 9b8ea12d34ecb2db7d4bbdad52569d1455df042b Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 8 Feb 2024 21:06:39 +0900 Subject: [PATCH 051/103] update log initialization without rich --- library/utils.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/library/utils.py b/library/utils.py index af0563dfa..0b7acf32d 100644 --- a/library/utils.py +++ b/library/utils.py @@ -34,12 +34,14 @@ def setup_logging(args=None, log_level=None, reset=False): else: return + # log_level can be set by the caller or by the args, the caller has priority. If not set, use INFO if log_level is None and args is not None: log_level = args.console_log_level if log_level is None: log_level = "INFO" log_level = getattr(logging, log_level) + msg_init = None if args is not None and args.console_log_file: handler = logging.FileHandler(args.console_log_file, mode="w") else: @@ -50,7 +52,8 @@ def setup_logging(args=None, log_level=None, reset=False): handler = RichHandler() except ImportError: - print("rich is not installed, using basic logging") + # print("rich is not installed, using basic logging") + msg_init = "rich is not installed, using basic logging" if handler is None: handler = logging.StreamHandler(sys.stdout) # same as print @@ -63,3 +66,7 @@ def setup_logging(args=None, log_level=None, reset=False): handler.setFormatter(formatter) logging.root.setLevel(log_level) logging.root.addHandler(handler) + + if msg_init is not None: + logger = logging.getLogger(__name__) + logger.info(msg_init) From 055f02e1e16f2b5d71d48e9aba7e751b825e6b46 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 8 Feb 2024 21:16:42 +0900 Subject: [PATCH 052/103] add logging args for training scripts --- fine_tune.py | 3 +- sdxl_train.py | 25 +++++++++++++---- sdxl_train_control_net_lllite.py | 28 ++++++++++++++----- sdxl_train_control_net_lllite_old.py | 26 +++++++++++++---- train_controlnet.py | 16 ++++++++--- train_db.py | 11 ++++++-- train_network.py | 42 ++++++++++++++++++++++------ train_textual_inversion.py | 13 +++++++-- train_textual_inversion_XTI.py | 27 +++++++++++++----- 9 files changed, 145 insertions(+), 46 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 86f5b253f..66f072884 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -24,6 +24,7 @@ import logging logger = logging.getLogger(__name__) + import library.train_util as train_util import library.config_util as config_util from library.config_util import ( @@ -476,7 +477,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() - add_logging_arguments(parser) # モジュール指定がないのがちょっと気持ち悪い / bit weird that this does not have module prefix + add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, False, True, True) train_util.add_training_arguments(parser, False) diff --git a/sdxl_train.py b/sdxl_train.py index 8f1740af6..3a09e4c84 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -20,10 +20,14 @@ from library import sdxl_model_util import library.train_util as train_util -from library.utils import setup_logging + +from library.utils import setup_logging, add_logging_arguments + setup_logging() import logging + logger = logging.getLogger(__name__) + import library.config_util as config_util import library.sdxl_train_util as sdxl_train_util from library.config_util import ( @@ -96,7 +100,9 @@ def train(args): train_util.prepare_dataset_args(args, True) sdxl_train_util.verify_sdxl_training_args(args) - assert not args.weighted_captions, "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" + assert ( + not args.weighted_captions + ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" assert ( not args.train_text_encoder or not args.cache_text_encoder_outputs ), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません" @@ -367,7 +373,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): args.max_train_steps = args.max_train_epochs * math.ceil( len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps ) - accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) @@ -437,7 +445,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}") accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") accelerator.print(f" num epochs / epoch数: {num_train_epochs}") - accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") + accelerator.print( + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) # accelerator.print( # f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}" # ) @@ -457,7 +467,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): if accelerator.is_main_process: init_kwargs = {} if args.wandb_run_name: - init_kwargs['wandb'] = {'name': args.wandb_run_name} + init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) @@ -734,6 +744,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() + add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, True, True, True) train_util.add_training_arguments(parser, False) @@ -756,7 +767,9 @@ def setup_parser() -> argparse.ArgumentParser: help="learning rate for text encoder 2 (BiG-G) / text encoder 2 (BiG-G)の学習率", ) - parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する") + parser.add_argument( + "--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する" + ) parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") parser.add_argument( "--no_half_vae", diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 36dd72e64..15ee2f85c 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -45,11 +45,14 @@ apply_debiased_estimation, ) import networks.control_net_lllite_for_train as control_net_lllite_for_train -from library.utils import setup_logging +from library.utils import setup_logging, add_logging_arguments + setup_logging() import logging + logger = logging.getLogger(__name__) + # TODO 他のスクリプトと共通化する def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): logs = { @@ -127,7 +130,9 @@ def train(args): train_dataset_group.is_latent_cacheable() ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" else: - logger.warning("WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません") + logger.warning( + "WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません" + ) if args.cache_text_encoder_outputs: assert ( @@ -257,7 +262,9 @@ def train(args): args.max_train_steps = args.max_train_epochs * math.ceil( len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps ) - accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) @@ -326,7 +333,9 @@ def train(args): accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") accelerator.print(f" num epochs / epoch数: {num_train_epochs}") - accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") + accelerator.print( + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) # logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") @@ -344,7 +353,7 @@ def train(args): if accelerator.is_main_process: init_kwargs = {} if args.wandb_run_name: - init_kwargs['wandb'] = {'name': args.wandb_run_name} + init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( @@ -557,6 +566,7 @@ def remove_model(old_ckpt_name): def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() + add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, False, True, True) train_util.add_training_arguments(parser, False) @@ -572,8 +582,12 @@ def setup_parser() -> argparse.ArgumentParser: choices=[None, "ckpt", "pt", "safetensors"], help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)", ) - parser.add_argument("--cond_emb_dim", type=int, default=None, help="conditioning embedding dimension / 条件付け埋め込みの次元数") - parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み") + parser.add_argument( + "--cond_emb_dim", type=int, default=None, help="conditioning embedding dimension / 条件付け埋め込みの次元数" + ) + parser.add_argument( + "--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み" + ) parser.add_argument("--network_dim", type=int, default=None, help="network dimensions (rank) / モジュールの次元数") parser.add_argument( "--network_dropout", diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 554b79cee..542cd7347 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -41,11 +41,14 @@ apply_debiased_estimation, ) import networks.control_net_lllite as control_net_lllite -from library.utils import setup_logging +from library.utils import setup_logging, add_logging_arguments + setup_logging() import logging + logger = logging.getLogger(__name__) + # TODO 他のスクリプトと共通化する def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): logs = { @@ -123,7 +126,9 @@ def train(args): train_dataset_group.is_latent_cacheable() ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" else: - logger.warning("WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません") + logger.warning( + "WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません" + ) if args.cache_text_encoder_outputs: assert ( @@ -225,7 +230,9 @@ def train(args): args.max_train_steps = args.max_train_epochs * math.ceil( len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps ) - accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) @@ -299,7 +306,9 @@ def train(args): accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") accelerator.print(f" num epochs / epoch数: {num_train_epochs}") - accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") + accelerator.print( + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) # logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") @@ -525,6 +534,7 @@ def remove_model(old_ckpt_name): def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() + add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, False, True, True) train_util.add_training_arguments(parser, False) @@ -540,8 +550,12 @@ def setup_parser() -> argparse.ArgumentParser: choices=[None, "ckpt", "pt", "safetensors"], help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)", ) - parser.add_argument("--cond_emb_dim", type=int, default=None, help="conditioning embedding dimension / 条件付け埋め込みの次元数") - parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み") + parser.add_argument( + "--cond_emb_dim", type=int, default=None, help="conditioning embedding dimension / 条件付け埋め込みの次元数" + ) + parser.add_argument( + "--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み" + ) parser.add_argument("--network_dim", type=int, default=None, help="network dimensions (rank) / モジュールの次元数") parser.add_argument( "--network_dropout", diff --git a/train_controlnet.py b/train_controlnet.py index fbeb7afdb..7489c48fe 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -35,11 +35,14 @@ pyramid_noise_like, apply_noise_offset, ) -from library.utils import setup_logging +from library.utils import setup_logging, add_logging_arguments + setup_logging() import logging + logger = logging.getLogger(__name__) + # TODO 他のスクリプトと共通化する def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): logs = { @@ -256,7 +259,9 @@ def train(args): args.max_train_steps = args.max_train_epochs * math.ceil( len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps ) - accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) @@ -312,7 +317,9 @@ def train(args): accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") accelerator.print(f" num epochs / epoch数: {num_train_epochs}") - accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") + accelerator.print( + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) # logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") @@ -335,7 +342,7 @@ def train(args): if accelerator.is_main_process: init_kwargs = {} if args.wandb_run_name: - init_kwargs['wandb'] = {'name': args.wandb_run_name} + init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( @@ -576,6 +583,7 @@ def remove_model(old_ckpt_name): def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() + add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, False, True, True) train_util.add_training_arguments(parser, False) diff --git a/train_db.py b/train_db.py index 984386665..09cb69009 100644 --- a/train_db.py +++ b/train_db.py @@ -35,9 +35,11 @@ scale_v_prediction_loss_like_noise_prediction, apply_debiased_estimation, ) -from library.utils import setup_logging +from library.utils import setup_logging, add_logging_arguments + setup_logging() import logging + logger = logging.getLogger(__name__) # perlin_noise, @@ -197,7 +199,9 @@ def train(args): args.max_train_steps = args.max_train_epochs * math.ceil( len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps ) - accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) @@ -268,7 +272,7 @@ def train(args): if accelerator.is_main_process: init_kwargs = {} if args.wandb_run_name: - init_kwargs['wandb'] = {'name': args.wandb_run_name} + init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) @@ -459,6 +463,7 @@ def train(args): def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() + add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, True, False, True) train_util.add_training_arguments(parser, True) diff --git a/train_network.py b/train_network.py index 99c95aacd..bd5e3693e 100644 --- a/train_network.py +++ b/train_network.py @@ -41,11 +41,14 @@ add_v_prediction_like_loss, apply_debiased_estimation, ) -from library.utils import setup_logging +from library.utils import setup_logging, add_logging_arguments + setup_logging() import logging + logger = logging.getLogger(__name__) + class NetworkTrainer: def __init__(self): self.vae_scale_factor = 0.18215 @@ -947,6 +950,7 @@ def remove_model(old_ckpt_name): def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() + add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, True, True, True) train_util.add_training_arguments(parser, True) @@ -954,7 +958,9 @@ def setup_parser() -> argparse.ArgumentParser: config_util.add_config_arguments(parser) custom_train_functions.add_custom_train_arguments(parser) - parser.add_argument("--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない") + parser.add_argument( + "--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない" + ) parser.add_argument( "--save_model_as", type=str, @@ -966,10 +972,17 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率") parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率") - parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み") - parser.add_argument("--network_module", type=str, default=None, help="network module to train / 学習対象のネットワークのモジュール") parser.add_argument( - "--network_dim", type=int, default=None, help="network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)" + "--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み" + ) + parser.add_argument( + "--network_module", type=str, default=None, help="network module to train / 学習対象のネットワークのモジュール" + ) + parser.add_argument( + "--network_dim", + type=int, + default=None, + help="network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)", ) parser.add_argument( "--network_alpha", @@ -984,14 +997,25 @@ def setup_parser() -> argparse.ArgumentParser: help="Drops neurons out of training every step (0 or None is default behavior (no dropout), 1 would drop all neurons) / 訓練時に毎ステップでニューロンをdropする(0またはNoneはdropoutなし、1は全ニューロンをdropout)", ) parser.add_argument( - "--network_args", type=str, default=None, nargs="*", help="additional arguments for network (key=value) / ネットワークへの追加の引数" + "--network_args", + type=str, + default=None, + nargs="*", + help="additional arguments for network (key=value) / ネットワークへの追加の引数", ) - parser.add_argument("--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する") parser.add_argument( - "--network_train_text_encoder_only", action="store_true", help="only training Text Encoder part / Text Encoder関連部分のみ学習する" + "--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する" ) parser.add_argument( - "--training_comment", type=str, default=None, help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列" + "--network_train_text_encoder_only", + action="store_true", + help="only training Text Encoder part / Text Encoder関連部分のみ学習する", + ) + parser.add_argument( + "--training_comment", + type=str, + default=None, + help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列", ) parser.add_argument( "--dim_from_weights", diff --git a/train_textual_inversion.py b/train_textual_inversion.py index d3d10f176..5900ad19d 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -32,9 +32,11 @@ add_v_prediction_like_loss, apply_debiased_estimation, ) -from library.utils import setup_logging +from library.utils import setup_logging, add_logging_arguments + setup_logging() import logging + logger = logging.getLogger(__name__) imagenet_templates_small = [ @@ -746,6 +748,7 @@ def remove_model(old_ckpt_name): def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() + add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, True, True, False) train_util.add_training_arguments(parser, True) @@ -761,7 +764,9 @@ def setup_parser() -> argparse.ArgumentParser: help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)", ) - parser.add_argument("--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み") + parser.add_argument( + "--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み" + ) parser.add_argument( "--num_vectors_per_token", type=int, default=1, help="number of vectors per token / トークンに割り当てるembeddingsの要素数" ) @@ -771,7 +776,9 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること", ) - parser.add_argument("--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可") + parser.add_argument( + "--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可" + ) parser.add_argument( "--use_object_template", action="store_true", diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 45ced3cfa..9b7e04c89 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -36,9 +36,11 @@ ) import library.original_unet as original_unet from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI -from library.utils import setup_logging +from library.utils import setup_logging, add_logging_arguments + setup_logging() import logging + logger = logging.getLogger(__name__) imagenet_templates_small = [ @@ -322,7 +324,9 @@ def train(args): args.max_train_steps = args.max_train_epochs * math.ceil( len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps ) - logger.info(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + logger.info( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) @@ -380,7 +384,9 @@ def train(args): logger.info(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") logger.info(f" num epochs / epoch数: {num_train_epochs}") logger.info(f" batch size per device / バッチサイズ: {args.train_batch_size}") - logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + logger.info( + f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}" + ) logger.info(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") logger.info(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") @@ -397,10 +403,12 @@ def train(args): if accelerator.is_main_process: init_kwargs = {} if args.wandb_run_name: - init_kwargs['wandb'] = {'name': args.wandb_run_name} + init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) + accelerator.init_trackers( + "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs + ) # function for saving/removing def save_model(ckpt_name, embs, steps, epoch_no, force_sync_upload=False): @@ -653,6 +661,7 @@ def load_weights(file): def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() + add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, True, True, False) train_util.add_training_arguments(parser, True) @@ -668,7 +677,9 @@ def setup_parser() -> argparse.ArgumentParser: help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)", ) - parser.add_argument("--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み") + parser.add_argument( + "--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み" + ) parser.add_argument( "--num_vectors_per_token", type=int, default=1, help="number of vectors per token / トークンに割り当てるembeddingsの要素数" ) @@ -678,7 +689,9 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること", ) - parser.add_argument("--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可") + parser.add_argument( + "--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可" + ) parser.add_argument( "--use_object_template", action="store_true", From 5d9e2873f644777e045865ce3af718b726dbcee5 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 8 Feb 2024 21:38:02 +0900 Subject: [PATCH 053/103] make rich to output to stderr instead of stdout --- library/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/library/utils.py b/library/utils.py index 0b7acf32d..8291a487d 100644 --- a/library/utils.py +++ b/library/utils.py @@ -49,8 +49,10 @@ def setup_logging(args=None, log_level=None, reset=False): if not args or not args.console_log_simple: try: from rich.logging import RichHandler + from rich.console import Console + from rich.logging import RichHandler - handler = RichHandler() + handler = RichHandler(console=Console(stderr=True)) except ImportError: # print("rich is not installed, using basic logging") msg_init = "rich is not installed, using basic logging" From 720259639399542ee408989b86ed7690aad0ca19 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 10 Feb 2024 09:59:12 +0900 Subject: [PATCH 054/103] log to print tag frequencies --- finetune/tag_images_by_wd14_tagger.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index 7a751b177..b56d921a3 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -291,9 +291,9 @@ def run_batch(path_imgs): if args.frequency_tags: sorted_tags = sorted(tag_freq.items(), key=lambda x: x[1], reverse=True) - logger.info("Tag frequencies:") + print("Tag frequencies:") for tag, freq in sorted_tags: - logger.info(f"{tag}: {freq}") + print(f"{tag}: {freq}") logger.info("done!") From e24d9606a2c890cf9f015239e04c82eecbda8bce Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 12 Feb 2024 11:10:52 +0900 Subject: [PATCH 055/103] add clean_memory_on_device and use it from training --- fine_tune.py | 4 ++-- library/device_utils.py | 15 +++++++++++++++ library/sdxl_train_util.py | 4 ++-- library/train_util.py | 24 +++++++++++++----------- sdxl_train.py | 6 +++--- sdxl_train_control_net_lllite.py | 6 +++--- sdxl_train_control_net_lllite_old.py | 6 +++--- sdxl_train_network.py | 6 +++--- train_controlnet.py | 6 +++--- train_db.py | 4 ++-- train_network.py | 4 ++-- train_textual_inversion.py | 4 ++-- train_textual_inversion_XTI.py | 4 ++-- 13 files changed, 55 insertions(+), 38 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index a6a5c1e25..72bae9727 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -10,7 +10,7 @@ from tqdm import tqdm import torch -from library.device_utils import init_ipex, clean_memory +from library.device_utils import init_ipex, clean_memory_on_device init_ipex() from accelerate.utils import set_seed @@ -156,7 +156,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") - clean_memory() + clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() diff --git a/library/device_utils.py b/library/device_utils.py index 546fb386c..8823c5d9a 100644 --- a/library/device_utils.py +++ b/library/device_utils.py @@ -31,6 +31,21 @@ def clean_memory(): torch.mps.empty_cache() +def clean_memory_on_device(device: torch.device): + r""" + Clean memory on the specified device, will be called from training scripts. + """ + gc.collect() + + # device may "cuda" or "cuda:0", so we need to check the type of device + if device.type == "cuda": + torch.cuda.empty_cache() + if device.type == "xpu": + torch.xpu.empty_cache() + if device.type == "mps": + torch.mps.empty_cache() + + @functools.lru_cache(maxsize=None) def get_preferred_device() -> torch.device: r""" diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index 0ecf4feb6..aa36c5445 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -4,7 +4,7 @@ from typing import Optional import torch -from library.device_utils import init_ipex, clean_memory +from library.device_utils import init_ipex, clean_memory_on_device init_ipex() from accelerate import init_empty_weights @@ -50,7 +50,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype): unet.to(accelerator.device) vae.to(accelerator.device) - clean_memory() + clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info diff --git a/library/train_util.py b/library/train_util.py index bf224b3e2..47a767bd9 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -32,7 +32,7 @@ from tqdm import tqdm import torch -from library.device_utils import init_ipex, clean_memory +from library.device_utils import init_ipex, clean_memory_on_device init_ipex() @@ -2285,7 +2285,7 @@ def cache_batch_latents( info.latents_flipped = flipped_latent if not HIGH_VRAM: - clean_memory() + clean_memory_on_device(vae.device) def cache_batch_text_encoder_outputs( @@ -4026,7 +4026,7 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio unet.to(accelerator.device) vae.to(accelerator.device) - clean_memory() + clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() return text_encoder, vae, unet, load_stable_diffusion_format @@ -4695,7 +4695,7 @@ def sample_images_common( distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here org_vae_device = vae.device # CPUにいるはず - vae.to(distributed_state.device) + vae.to(distributed_state.device) # distributed_state.device is same as accelerator.device # unwrap unet and text_encoder(s) unet = accelerator.unwrap_model(unet) @@ -4752,7 +4752,11 @@ def sample_images_common( # save random state to restore later rng_state = torch.get_rng_state() - cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None # TODO mps etc. support + cuda_rng_state = None + try: + cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None + except Exception: + pass if distributed_state.num_processes <= 1: # If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts. @@ -4774,8 +4778,10 @@ def sample_images_common( # clear pipeline and cache to reduce vram usage del pipeline - with torch.cuda.device(torch.cuda.current_device()): - torch.cuda.empty_cache() + # I'm not sure which of these is the correct way to clear the memory, but accelerator's device is used in the pipeline, so I'm using it here. + # with torch.cuda.device(torch.cuda.current_device()): + # torch.cuda.empty_cache() + clean_memory_on_device(accelerator.device) torch.set_rng_state(rng_state) if cuda_rng_state is not None: @@ -4870,10 +4876,6 @@ def sample_image_inference(accelerator: Accelerator, args: argparse.Namespace, p pass # endregion - # # clear pipeline and cache to reduce vram usage - # del pipeline - # torch.cuda.empty_cache() - # region 前処理用 diff --git a/sdxl_train.py b/sdxl_train.py index b0dcdbe9d..b7789d4b2 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -10,7 +10,7 @@ from tqdm import tqdm import torch -from library.device_utils import init_ipex, clean_memory +from library.device_utils import init_ipex, clean_memory_on_device init_ipex() from accelerate.utils import set_seed @@ -250,7 +250,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") - clean_memory() + clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() @@ -403,7 +403,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 text_encoder1.to("cpu", dtype=torch.float32) text_encoder2.to("cpu", dtype=torch.float32) - clean_memory() + clean_memory_on_device(accelerator.device) else: # make sure Text Encoders are on GPU text_encoder1.to(accelerator.device) diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 1b0696247..c37b847c4 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -14,7 +14,7 @@ from tqdm import tqdm import torch -from library.device_utils import init_ipex, clean_memory +from library.device_utils import init_ipex, clean_memory_on_device init_ipex() from torch.nn.parallel import DistributedDataParallel as DDP @@ -162,7 +162,7 @@ def train(args): accelerator.is_main_process, ) vae.to("cpu") - clean_memory() + clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() @@ -287,7 +287,7 @@ def train(args): # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 text_encoder1.to("cpu", dtype=torch.float32) text_encoder2.to("cpu", dtype=torch.float32) - clean_memory() + clean_memory_on_device(accelerator.device) else: # make sure Text Encoders are on GPU text_encoder1.to(accelerator.device) diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index b74e3b90b..35747c1e9 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -11,7 +11,7 @@ from tqdm import tqdm import torch -from library.device_utils import init_ipex, clean_memory +from library.device_utils import init_ipex, clean_memory_on_device init_ipex() from torch.nn.parallel import DistributedDataParallel as DDP @@ -161,7 +161,7 @@ def train(args): accelerator.is_main_process, ) vae.to("cpu") - clean_memory() + clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() @@ -260,7 +260,7 @@ def train(args): # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 text_encoder1.to("cpu", dtype=torch.float32) text_encoder2.to("cpu", dtype=torch.float32) - clean_memory() + clean_memory_on_device(accelerator.device) else: # make sure Text Encoders are on GPU text_encoder1.to(accelerator.device) diff --git a/sdxl_train_network.py b/sdxl_train_network.py index 205b526e8..fd04a5726 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -1,7 +1,7 @@ import argparse import torch -from library.device_utils import init_ipex, clean_memory +from library.device_utils import init_ipex, clean_memory_on_device init_ipex() from library import sdxl_model_util, sdxl_train_util, train_util @@ -64,7 +64,7 @@ def cache_text_encoder_outputs_if_needed( org_unet_device = unet.device vae.to("cpu") unet.to("cpu") - clean_memory() + clean_memory_on_device(accelerator.device) # When TE is not be trained, it will not be prepared so we need to use explicit autocast with accelerator.autocast(): @@ -79,7 +79,7 @@ def cache_text_encoder_outputs_if_needed( text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU text_encoders[1].to("cpu", dtype=torch.float32) - clean_memory() + clean_memory_on_device(accelerator.device) if not args.lowram: print("move vae and unet back to original device") diff --git a/train_controlnet.py b/train_controlnet.py index e7a06ae14..6ff2e781c 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -11,7 +11,7 @@ from tqdm import tqdm import torch -from library.device_utils import init_ipex, clean_memory +from library.device_utils import init_ipex, clean_memory_on_device init_ipex() from torch.nn.parallel import DistributedDataParallel as DDP @@ -217,8 +217,8 @@ def train(args): accelerator.is_main_process, ) vae.to("cpu") - clean_memory() - + clean_memory_on_device(accelerator.device) + accelerator.wait_for_everyone() if args.gradient_checkpointing: diff --git a/train_db.py b/train_db.py index f6795dcef..085ef2ce3 100644 --- a/train_db.py +++ b/train_db.py @@ -11,7 +11,7 @@ from tqdm import tqdm import torch -from library.device_utils import init_ipex, clean_memory +from library.device_utils import init_ipex, clean_memory_on_device init_ipex() from accelerate.utils import set_seed @@ -136,7 +136,7 @@ def train(args): with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") - clean_memory() + clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() diff --git a/train_network.py b/train_network.py index f8dd6ab23..046edb5bd 100644 --- a/train_network.py +++ b/train_network.py @@ -12,7 +12,7 @@ from tqdm import tqdm import torch -from library.device_utils import init_ipex, clean_memory +from library.device_utils import init_ipex, clean_memory_on_device init_ipex() from torch.nn.parallel import DistributedDataParallel as DDP @@ -265,7 +265,7 @@ def train(self, args): with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") - clean_memory() + clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() diff --git a/train_textual_inversion.py b/train_textual_inversion.py index d18837bd4..c34cfc969 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -7,7 +7,7 @@ from tqdm import tqdm import torch -from library.device_utils import init_ipex, clean_memory +from library.device_utils import init_ipex, clean_memory_on_device init_ipex() from accelerate.utils import set_seed @@ -361,7 +361,7 @@ def train(self, args): with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") - clean_memory() + clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 9d4b0aef1..b63200820 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -8,7 +8,7 @@ from tqdm import tqdm import torch -from library.device_utils import init_ipex, clean_memory +from library.device_utils import init_ipex, clean_memory_on_device init_ipex() from accelerate.utils import set_seed @@ -284,7 +284,7 @@ def train(args): with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") - clean_memory() + clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() From e579648ce9d86de22eb99d23dc4052fe01748373 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 12 Feb 2024 11:12:41 +0900 Subject: [PATCH 056/103] fix help for highvram arg --- library/train_util.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 47a767bd9..1c3b4bdfe 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3045,7 +3045,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument( "--highvram", action="store_true", - help="disable low VRAM optimization. e.g. do not clear CUDA cache after each latent caching (for machines which have bigger VRAM) / VRAMが少ない環境向け最適化を無効にする。たとえば各latentのキャッシュ後のCUDAキャッシュを行わない等(VRAMが多い環境向け)", + help="disable low VRAM optimization. e.g. do not clear CUDA cache after each latent caching (for machines which have bigger VRAM) " + + "/ VRAMが少ない環境向け最適化を無効にする。たとえば各latentのキャッシュ後のCUDAキャッシュクリアを行わない等(VRAMが多い環境向け)", ) parser.add_argument( From c7487191152d4120b4bfa92e2168c8174ca942fd Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 12 Feb 2024 12:59:45 +0900 Subject: [PATCH 057/103] fix indent --- gen_img_diffusers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index b346a5fb3..e4b40a337 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -3299,7 +3299,7 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE) if m: # deep shrink ratio ds_ratio = float(m.group(1)) - ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override logger.info(f"deep shrink ratio: {ds_ratio}") continue From d3745db7644b2bb89357e227c8de781b0060ec33 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 12 Feb 2024 13:15:21 +0900 Subject: [PATCH 058/103] add args for logging --- gen_img_diffusers.py | 115 ++++++++++++++++++++++++++++--------- sdxl_gen_img.py | 131 +++++++++++++++++++++++++++++++++---------- 2 files changed, 192 insertions(+), 54 deletions(-) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index e4b40a337..f368ed0fb 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -105,9 +105,11 @@ from library.utils import GradualLatent, EulerAncestralDiscreteSchedulerGL from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI -from library.utils import setup_logging +from library.utils import setup_logging, add_logging_arguments + setup_logging() import logging + logger = logging.getLogger(__name__) # scheduler: @@ -1592,7 +1594,9 @@ def cond_fn_vgg16_b1(self, latents, timestep, index, text_embeddings, noise_pred image_embeddings = self.vgg16_feat_model(image)["feat"] # バッチサイズが複数だと正しく動くかわからない - loss = ((image_embeddings - guide_embeddings) ** 2).mean() * guidance_scale # MSE style transferでコンテンツの損失はMSEなので + loss = ( + (image_embeddings - guide_embeddings) ** 2 + ).mean() * guidance_scale # MSE style transferでコンテンツの損失はMSEなので grads = -torch.autograd.grad(loss, latents)[0] if isinstance(self.scheduler, LMSDiscreteScheduler): @@ -2610,7 +2614,9 @@ def __getattr__(self, item): embeds = next(iter(data.values())) if type(embeds) != torch.Tensor: - raise ValueError(f"weight file does not contains Tensor / 重みファイルのデータがTensorではありません: {embeds_file}") + raise ValueError( + f"weight file does not contains Tensor / 重みファイルのデータがTensorではありません: {embeds_file}" + ) num_vectors_per_token = embeds.size()[0] token_string = os.path.splitext(os.path.basename(embeds_file))[0] @@ -2840,7 +2846,9 @@ def resize_images(imgs, size): logger.info(f"loaded {len(guide_images)} guide images for guidance") if len(guide_images) == 0: - logger.info(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}") + logger.info( + f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}" + ) guide_images = None else: guide_images = None @@ -3134,7 +3142,9 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): cv2.waitKey() cv2.destroyAllWindows() except ImportError: - logger.info("opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません") + logger.info( + "opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません" + ) return images @@ -3484,16 +3494,25 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() - parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む") + add_logging_arguments(parser) + + parser.add_argument( + "--v2", action="store_true", help="load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む" + ) parser.add_argument( "--v_parameterization", action="store_true", help="enable v-parameterization training / v-parameterization学習を有効にする" ) parser.add_argument("--prompt", type=str, default=None, help="prompt / プロンプト") parser.add_argument( - "--from_file", type=str, default=None, help="if specified, load prompts from this file / 指定時はプロンプトをファイルから読み込む" + "--from_file", + type=str, + default=None, + help="if specified, load prompts from this file / 指定時はプロンプトをファイルから読み込む", ) parser.add_argument( - "--interactive", action="store_true", help="interactive mode (generates one image) / 対話モード(生成される画像は1枚になります)" + "--interactive", + action="store_true", + help="interactive mode (generates one image) / 対話モード(生成される画像は1枚になります)", ) parser.add_argument( "--no_preview", action="store_true", help="do not show generated image in interactive mode / 対話モードで画像を表示しない" @@ -3505,7 +3524,9 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument("--strength", type=float, default=None, help="img2img strength / img2img時のstrength") parser.add_argument("--images_per_prompt", type=int, default=1, help="number of images per prompt / プロンプトあたりの出力枚数") parser.add_argument("--outdir", type=str, default="outputs", help="dir to write results to / 生成画像の出力先") - parser.add_argument("--sequential_file_name", action="store_true", help="sequential output file name / 生成画像のファイル名を連番にする") + parser.add_argument( + "--sequential_file_name", action="store_true", help="sequential output file name / 生成画像のファイル名を連番にする" + ) parser.add_argument( "--use_original_file_name", action="store_true", @@ -3559,9 +3580,14 @@ def setup_parser() -> argparse.ArgumentParser: default=7.5, help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty)) / guidance scale", ) - parser.add_argument("--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ") parser.add_argument( - "--vae", type=str, default=None, help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ" + "--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ" + ) + parser.add_argument( + "--vae", + type=str, + default=None, + help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ", ) parser.add_argument( "--tokenizer_cache_dir", @@ -3597,25 +3623,46 @@ def setup_parser() -> argparse.ArgumentParser: help="use xformers by diffusers (Hypernetworks doesn't work) / Diffusersでxformersを使用する(Hypernetwork利用不可)", ) parser.add_argument( - "--opt_channels_last", action="store_true", help="set channels last option to model / モデルにchannels lastを指定し最適化する" + "--opt_channels_last", + action="store_true", + help="set channels last option to model / モデルにchannels lastを指定し最適化する", ) parser.add_argument( - "--network_module", type=str, default=None, nargs="*", help="additional network module to use / 追加ネットワークを使う時そのモジュール名" + "--network_module", + type=str, + default=None, + nargs="*", + help="additional network module to use / 追加ネットワークを使う時そのモジュール名", ) parser.add_argument( "--network_weights", type=str, default=None, nargs="*", help="additional network weights to load / 追加ネットワークの重み" ) - parser.add_argument("--network_mul", type=float, default=None, nargs="*", help="additional network multiplier / 追加ネットワークの効果の倍率") parser.add_argument( - "--network_args", type=str, default=None, nargs="*", help="additional arguments for network (key=value) / ネットワークへの追加の引数" + "--network_mul", type=float, default=None, nargs="*", help="additional network multiplier / 追加ネットワークの効果の倍率" + ) + parser.add_argument( + "--network_args", + type=str, + default=None, + nargs="*", + help="additional arguments for network (key=value) / ネットワークへの追加の引数", ) - parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する") parser.add_argument( - "--network_merge_n_models", type=int, default=None, help="merge this number of networks / この数だけネットワークをマージする" + "--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する" ) - parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする") parser.add_argument( - "--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する" + "--network_merge_n_models", + type=int, + default=None, + help="merge this number of networks / この数だけネットワークをマージする", + ) + parser.add_argument( + "--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする" + ) + parser.add_argument( + "--network_pre_calc", + action="store_true", + help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する", ) parser.add_argument( "--network_regional_mask_max_color_codes", @@ -3637,7 +3684,9 @@ def setup_parser() -> argparse.ArgumentParser: nargs="*", help="Embeddings files of Extended Textual Inversion / Extended Textual Inversionのembeddings", ) - parser.add_argument("--clip_skip", type=int, default=None, help="layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う") + parser.add_argument( + "--clip_skip", type=int, default=None, help="layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う" + ) parser.add_argument( "--max_embeddings_multiples", type=int, @@ -3678,7 +3727,10 @@ def setup_parser() -> argparse.ArgumentParser: help="enable highres fix, reso scale for 1st stage / highres fixを有効にして最初の解像度をこのscaleにする", ) parser.add_argument( - "--highres_fix_steps", type=int, default=28, help="1st stage steps for highres fix / highres fixの最初のステージのステップ数" + "--highres_fix_steps", + type=int, + default=28, + help="1st stage steps for highres fix / highres fixの最初のステージのステップ数", ) parser.add_argument( "--highres_fix_strength", @@ -3687,7 +3739,9 @@ def setup_parser() -> argparse.ArgumentParser: help="1st stage img2img strength for highres fix / highres fixの最初のステージのimg2img時のstrength、省略時はstrengthと同じ", ) parser.add_argument( - "--highres_fix_save_1st", action="store_true", help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する" + "--highres_fix_save_1st", + action="store_true", + help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する", ) parser.add_argument( "--highres_fix_latents_upscaling", @@ -3695,7 +3749,10 @@ def setup_parser() -> argparse.ArgumentParser: help="use latents upscaling for highres fix / highres fixでlatentで拡大する", ) parser.add_argument( - "--highres_fix_upscaler", type=str, default=None, help="upscaler module for highres fix / highres fixで使うupscalerのモジュール名" + "--highres_fix_upscaler", + type=str, + default=None, + help="upscaler module for highres fix / highres fixで使うupscalerのモジュール名", ) parser.add_argument( "--highres_fix_upscaler_args", @@ -3710,14 +3767,21 @@ def setup_parser() -> argparse.ArgumentParser: ) parser.add_argument( - "--negative_scale", type=float, default=None, help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する" + "--negative_scale", + type=float, + default=None, + help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する", ) parser.add_argument( "--control_net_models", type=str, default=None, nargs="*", help="ControlNet models to use / 使用するControlNetのモデル名" ) parser.add_argument( - "--control_net_preps", type=str, default=None, nargs="*", help="ControlNet preprocess to use / 使用するControlNetのプリプロセス名" + "--control_net_preps", + type=str, + default=None, + nargs="*", + help="ControlNet preprocess to use / 使用するControlNetのプリプロセス名", ) parser.add_argument("--control_net_weights", type=float, default=None, nargs="*", help="ControlNet weights / ControlNetの重み") parser.add_argument( @@ -3801,4 +3865,5 @@ def setup_parser() -> argparse.ArgumentParser: parser = setup_parser() args = parser.parse_args() + setup_logging(args, reset=True) main(args) diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index 8d351a2d8..d8eb9a19c 100755 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -56,9 +56,11 @@ from library.original_unet import FlashAttentionFunction from networks.control_net_lllite import ControlNetLLLite from library.utils import GradualLatent, EulerAncestralDiscreteSchedulerGL -from library.utils import setup_logging +from library.utils import setup_logging, add_logging_arguments + setup_logging() import logging + logger = logging.getLogger(__name__) # scheduler: @@ -2061,7 +2063,9 @@ def resize_images(imgs, size): logger.info(f"loaded {len(guide_images)} guide images for guidance") if len(guide_images) == 0: - logger.warning(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}") + logger.warning( + f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}" + ) guide_images = None else: guide_images = None @@ -2351,7 +2355,7 @@ def scale_and_round(x): highres_prefix = ("0" if highres_1st else "1") if highres_fix else "" ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) for i, (image, prompt, negative_prompts, seed, clip_prompt, raw_prompt) in enumerate( - zip(images, prompts, negative_prompts, seeds, clip_prompts, raw_prompts ) + zip(images, prompts, negative_prompts, seeds, clip_prompts, raw_prompts) ): if highres_fix: seed -= 1 # record original seed @@ -2397,7 +2401,9 @@ def scale_and_round(x): cv2.waitKey() cv2.destroyAllWindows() except ImportError: - logger.error("opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません") + logger.error( + "opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません" + ) return images @@ -2787,7 +2793,9 @@ def scale_and_round(x): b1 = BatchData( False, - BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt), + BatchDataBase( + global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt + ), BatchDataExt( width, height, @@ -2828,12 +2836,19 @@ def scale_and_round(x): def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() + add_logging_arguments(parser) + parser.add_argument("--prompt", type=str, default=None, help="prompt / プロンプト") parser.add_argument( - "--from_file", type=str, default=None, help="if specified, load prompts from this file / 指定時はプロンプトをファイルから読み込む" + "--from_file", + type=str, + default=None, + help="if specified, load prompts from this file / 指定時はプロンプトをファイルから読み込む", ) parser.add_argument( - "--interactive", action="store_true", help="interactive mode (generates one image) / 対話モード(生成される画像は1枚になります)" + "--interactive", + action="store_true", + help="interactive mode (generates one image) / 対話モード(生成される画像は1枚になります)", ) parser.add_argument( "--no_preview", action="store_true", help="do not show generated image in interactive mode / 対話モードで画像を表示しない" @@ -2845,7 +2860,9 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument("--strength", type=float, default=None, help="img2img strength / img2img時のstrength") parser.add_argument("--images_per_prompt", type=int, default=1, help="number of images per prompt / プロンプトあたりの出力枚数") parser.add_argument("--outdir", type=str, default="outputs", help="dir to write results to / 生成画像の出力先") - parser.add_argument("--sequential_file_name", action="store_true", help="sequential output file name / 生成画像のファイル名を連番にする") + parser.add_argument( + "--sequential_file_name", action="store_true", help="sequential output file name / 生成画像のファイル名を連番にする" + ) parser.add_argument( "--use_original_file_name", action="store_true", @@ -2856,10 +2873,16 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument("--H", type=int, default=None, help="image height, in pixel space / 生成画像高さ") parser.add_argument("--W", type=int, default=None, help="image width, in pixel space / 生成画像幅") parser.add_argument( - "--original_height", type=int, default=None, help="original height for SDXL conditioning / SDXLの条件付けに用いるoriginal heightの値" + "--original_height", + type=int, + default=None, + help="original height for SDXL conditioning / SDXLの条件付けに用いるoriginal heightの値", ) parser.add_argument( - "--original_width", type=int, default=None, help="original width for SDXL conditioning / SDXLの条件付けに用いるoriginal widthの値" + "--original_width", + type=int, + default=None, + help="original width for SDXL conditioning / SDXLの条件付けに用いるoriginal widthの値", ) parser.add_argument( "--original_height_negative", @@ -2873,8 +2896,12 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="original width for SDXL unconditioning / SDXLのネガティブ条件付けに用いるoriginal widthの値", ) - parser.add_argument("--crop_top", type=int, default=None, help="crop top for SDXL conditioning / SDXLの条件付けに用いるcrop topの値") - parser.add_argument("--crop_left", type=int, default=None, help="crop left for SDXL conditioning / SDXLの条件付けに用いるcrop leftの値") + parser.add_argument( + "--crop_top", type=int, default=None, help="crop top for SDXL conditioning / SDXLの条件付けに用いるcrop topの値" + ) + parser.add_argument( + "--crop_left", type=int, default=None, help="crop left for SDXL conditioning / SDXLの条件付けに用いるcrop leftの値" + ) parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサイズ") parser.add_argument( "--vae_batch_size", @@ -2888,7 +2915,9 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="number of slices to split image into for VAE to reduce VRAM usage, None for no splitting (default), slower if specified. 16 or 32 recommended / VAE処理時にVRAM使用量削減のため画像を分割するスライス数、Noneの場合は分割しない(デフォルト)、指定すると遅くなる。16か32程度を推奨", ) - parser.add_argument("--no_half_vae", action="store_true", help="do not use fp16/bf16 precision for VAE / VAE処理時にfp16/bf16を使わない") + parser.add_argument( + "--no_half_vae", action="store_true", help="do not use fp16/bf16 precision for VAE / VAE処理時にfp16/bf16を使わない" + ) parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数") parser.add_argument( "--sampler", @@ -2920,9 +2949,14 @@ def setup_parser() -> argparse.ArgumentParser: default=7.5, help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty)) / guidance scale", ) - parser.add_argument("--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ") parser.add_argument( - "--vae", type=str, default=None, help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ" + "--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ" + ) + parser.add_argument( + "--vae", + type=str, + default=None, + help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ", ) parser.add_argument( "--tokenizer_cache_dir", @@ -2953,25 +2987,46 @@ def setup_parser() -> argparse.ArgumentParser: help="use xformers by diffusers (Hypernetworks doesn't work) / Diffusersでxformersを使用する(Hypernetwork利用不可)", ) parser.add_argument( - "--opt_channels_last", action="store_true", help="set channels last option to model / モデルにchannels lastを指定し最適化する" + "--opt_channels_last", + action="store_true", + help="set channels last option to model / モデルにchannels lastを指定し最適化する", ) parser.add_argument( - "--network_module", type=str, default=None, nargs="*", help="additional network module to use / 追加ネットワークを使う時そのモジュール名" + "--network_module", + type=str, + default=None, + nargs="*", + help="additional network module to use / 追加ネットワークを使う時そのモジュール名", ) parser.add_argument( "--network_weights", type=str, default=None, nargs="*", help="additional network weights to load / 追加ネットワークの重み" ) - parser.add_argument("--network_mul", type=float, default=None, nargs="*", help="additional network multiplier / 追加ネットワークの効果の倍率") parser.add_argument( - "--network_args", type=str, default=None, nargs="*", help="additional arguments for network (key=value) / ネットワークへの追加の引数" + "--network_mul", type=float, default=None, nargs="*", help="additional network multiplier / 追加ネットワークの効果の倍率" + ) + parser.add_argument( + "--network_args", + type=str, + default=None, + nargs="*", + help="additional arguments for network (key=value) / ネットワークへの追加の引数", + ) + parser.add_argument( + "--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する" + ) + parser.add_argument( + "--network_merge_n_models", + type=int, + default=None, + help="merge this number of networks / この数だけネットワークをマージする", ) - parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する") parser.add_argument( - "--network_merge_n_models", type=int, default=None, help="merge this number of networks / この数だけネットワークをマージする" + "--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする" ) - parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする") parser.add_argument( - "--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する" + "--network_pre_calc", + action="store_true", + help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する", ) parser.add_argument( "--network_regional_mask_max_color_codes", @@ -2986,7 +3041,9 @@ def setup_parser() -> argparse.ArgumentParser: nargs="*", help="Embeddings files of Textual Inversion / Textual Inversionのembeddings", ) - parser.add_argument("--clip_skip", type=int, default=None, help="layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う") + parser.add_argument( + "--clip_skip", type=int, default=None, help="layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う" + ) parser.add_argument( "--max_embeddings_multiples", type=int, @@ -3003,7 +3060,10 @@ def setup_parser() -> argparse.ArgumentParser: help="enable highres fix, reso scale for 1st stage / highres fixを有効にして最初の解像度をこのscaleにする", ) parser.add_argument( - "--highres_fix_steps", type=int, default=28, help="1st stage steps for highres fix / highres fixの最初のステージのステップ数" + "--highres_fix_steps", + type=int, + default=28, + help="1st stage steps for highres fix / highres fixの最初のステージのステップ数", ) parser.add_argument( "--highres_fix_strength", @@ -3012,7 +3072,9 @@ def setup_parser() -> argparse.ArgumentParser: help="1st stage img2img strength for highres fix / highres fixの最初のステージのimg2img時のstrength、省略時はstrengthと同じ", ) parser.add_argument( - "--highres_fix_save_1st", action="store_true", help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する" + "--highres_fix_save_1st", + action="store_true", + help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する", ) parser.add_argument( "--highres_fix_latents_upscaling", @@ -3020,7 +3082,10 @@ def setup_parser() -> argparse.ArgumentParser: help="use latents upscaling for highres fix / highres fixでlatentで拡大する", ) parser.add_argument( - "--highres_fix_upscaler", type=str, default=None, help="upscaler module for highres fix / highres fixで使うupscalerのモジュール名" + "--highres_fix_upscaler", + type=str, + default=None, + help="upscaler module for highres fix / highres fixで使うupscalerのモジュール名", ) parser.add_argument( "--highres_fix_upscaler_args", @@ -3035,11 +3100,18 @@ def setup_parser() -> argparse.ArgumentParser: ) parser.add_argument( - "--negative_scale", type=float, default=None, help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する" + "--negative_scale", + type=float, + default=None, + help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する", ) parser.add_argument( - "--control_net_lllite_models", type=str, default=None, nargs="*", help="ControlNet models to use / 使用するControlNetのモデル名" + "--control_net_lllite_models", + type=str, + default=None, + nargs="*", + help="ControlNet models to use / 使用するControlNetのモデル名", ) # parser.add_argument( # "--control_net_models", type=str, default=None, nargs="*", help="ControlNet models to use / 使用するControlNetのモデル名" @@ -3138,4 +3210,5 @@ def setup_parser() -> argparse.ArgumentParser: parser = setup_parser() args = parser.parse_args() + setup_logging(args, reset=True) main(args) From cbe9c5dc068cd81c5f7d53c7aeba601d5241e7f9 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 12 Feb 2024 14:17:27 +0900 Subject: [PATCH 059/103] supprt deep shink with regional lora, add prompter module --- gen_img.py | 168 ++++++++++++++++++++++++++++++++++++----------- networks/lora.py | 38 +++++++++-- 2 files changed, 161 insertions(+), 45 deletions(-) diff --git a/gen_img.py b/gen_img.py index f5a740c3e..a24220a0a 100644 --- a/gen_img.py +++ b/gen_img.py @@ -3,6 +3,8 @@ from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable import glob import importlib +import importlib.util +import sys import inspect import time import zipfile @@ -333,6 +335,10 @@ def __init__( self.scheduler = scheduler self.safety_checker = None + self.clip_vision_model: CLIPVisionModelWithProjection = None + self.clip_vision_processor: CLIPImageProcessor = None + self.clip_vision_strength = 0.0 + # Textual Inversion self.token_replacements_list = [] for _ in range(len(self.text_encoders)): @@ -419,6 +425,7 @@ def __call__( callback_steps: Optional[int] = 1, img2img_noise=None, clip_guide_images=None, + emb_normalize_mode: str = "original", **kwargs, ): # TODO support secondary prompt @@ -493,6 +500,7 @@ def __call__( clip_skip=self.clip_skip, token_replacer=token_replacer, device=self.device, + emb_normalize_mode=emb_normalize_mode, **kwargs, ) tes_text_embs.append(text_embeddings) @@ -508,6 +516,7 @@ def __call__( clip_skip=self.clip_skip, token_replacer=token_replacer, device=self.device, + emb_normalize_mode=emb_normalize_mode, **kwargs, ) tes_real_uncond_embs.append(real_uncond_embeddings) @@ -1099,7 +1108,7 @@ def get_unweighted_text_embeddings( # in sdxl, value of clip_skip is same for Text Encoder 1 and 2 enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True) text_embedding = enc_out["hidden_states"][-clip_skip] - if not is_sdxl: # SD 1.5 requires final_layer_norm + if not is_sdxl: # SD 1.5 requires final_layer_norm text_embedding = text_encoder.text_model.final_layer_norm(text_embedding) if pool is None: pool = enc_out.get("text_embeds", None) # use 1st chunk, if provided @@ -1122,7 +1131,7 @@ def get_unweighted_text_embeddings( else: enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True) text_embeddings = enc_out["hidden_states"][-clip_skip] - if not is_sdxl: # SD 1.5 requires final_layer_norm + if not is_sdxl: # SD 1.5 requires final_layer_norm text_embeddings = text_encoder.text_model.final_layer_norm(text_embeddings) pool = enc_out.get("text_embeds", None) # text encoder 1 doesn't return this if pool is not None: @@ -1143,6 +1152,7 @@ def get_weighted_text_embeddings( clip_skip: int = 1, token_replacer=None, device=None, + emb_normalize_mode: Optional[str] = "original", # "original", "abs", "none" **kwargs, ): max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 @@ -1239,16 +1249,34 @@ def get_weighted_text_embeddings( # assign weights to the prompts and normalize in the sense of mean # TODO: should we normalize by chunk or in a whole (current implementation)? # →全体でいいんじゃないかな + if (not skip_parsing) and (not skip_weighting): - previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) - text_embeddings *= prompt_weights.unsqueeze(-1) - current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) - text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) - if uncond_prompt is not None: - previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) - uncond_embeddings *= uncond_weights.unsqueeze(-1) - current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) - uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + if emb_normalize_mode == "abs": + previous_mean = text_embeddings.float().abs().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings *= prompt_weights.unsqueeze(-1) + current_mean = text_embeddings.float().abs().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + if uncond_prompt is not None: + previous_mean = uncond_embeddings.float().abs().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) + uncond_embeddings *= uncond_weights.unsqueeze(-1) + current_mean = uncond_embeddings.float().abs().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) + uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + + elif emb_normalize_mode == "none": + text_embeddings *= prompt_weights.unsqueeze(-1) + if uncond_prompt is not None: + uncond_embeddings *= uncond_weights.unsqueeze(-1) + + else: # "original" + previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings *= prompt_weights.unsqueeze(-1) + current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + if uncond_prompt is not None: + previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) + uncond_embeddings *= uncond_weights.unsqueeze(-1) + current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) + uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) if uncond_prompt is not None: return text_embeddings, text_pool, uncond_embeddings, uncond_pool, prompt_tokens @@ -1427,6 +1455,27 @@ class BatchData(NamedTuple): ext: BatchDataExt +class ListPrompter: + def __init__(self, prompts: List[str]): + self.prompts = prompts + self.index = 0 + + def shuffle(self): + random.shuffle(self.prompts) + + def __len__(self): + return len(self.prompts) + + def __call__(self, *args, **kwargs): + if self.index >= len(self.prompts): + self.index = 0 # reset + return None + + prompt = self.prompts[self.index] + self.index += 1 + return prompt + + def main(args): if args.fp16: dtype = torch.float16 @@ -1951,15 +2000,35 @@ def __getattr__(self, item): token_embeds2[token_id] = embed # promptを取得する + prompt_list = None if args.from_file is not None: print(f"reading prompts from {args.from_file}") with open(args.from_file, "r", encoding="utf-8") as f: prompt_list = f.read().splitlines() prompt_list = [d for d in prompt_list if len(d.strip()) > 0 and d[0] != "#"] + prompter = ListPrompter(prompt_list) + + elif args.from_module is not None: + + def load_module_from_path(module_name, file_path): + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Module '{module_name}' cannot be loaded from '{file_path}'") + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + + print(f"reading prompts from module: {args.from_module}") + prompt_module = load_module_from_path("prompt_module", args.from_module) + + prompter = prompt_module.get_prompter(args, pipe, networks) + elif args.prompt is not None: - prompt_list = [args.prompt] + prompter = ListPrompter([args.prompt]) + else: - prompt_list = [] + prompter = None # interactive mode if args.interactive: args.n_iter = 1 @@ -2026,14 +2095,16 @@ def resize_images(imgs, size): mask_images = None # promptがないとき、画像のPngInfoから取得する - if init_images is not None and len(prompt_list) == 0 and not args.interactive: + if init_images is not None and prompter is None and not args.interactive: print("get prompts from images' metadata") + prompt_list = [] for img in init_images: if "prompt" in img.text: prompt = img.text["prompt"] if "negative-prompt" in img.text: prompt += " --n " + img.text["negative-prompt"] prompt_list.append(prompt) + prompter = ListPrompter(prompt_list) # プロンプトと画像を一致させるため指定回数だけ繰り返す(画像を増幅する) l = [] @@ -2105,15 +2176,18 @@ def resize_images(imgs, size): else: guide_images = None - # seed指定時はseedを決めておく + # 新しい乱数生成器を作成する if args.seed is not None: - # dynamic promptを使うと足りなくなる→images_per_promptを適当に大きくしておいてもらう - random.seed(args.seed) - predefined_seeds = [random.randint(0, 0x7FFFFFFF) for _ in range(args.n_iter * len(prompt_list) * args.images_per_prompt)] - if len(predefined_seeds) == 1: - predefined_seeds[0] = args.seed + if prompt_list and len(prompt_list) == 1 and args.images_per_prompt == 1: + # 引数のseedをそのまま使う + def fixed_seed(*args, **kwargs): + return args.seed + + seed_random = SimpleNamespace(randint=fixed_seed) + else: + seed_random = random.Random(args.seed) else: - predefined_seeds = None + seed_random = random.Random() # デフォルト画像サイズを設定する:img2imgではこれらの値は無視される(またはW*Hにリサイズ済み) if args.W is None: @@ -2127,11 +2201,14 @@ def resize_images(imgs, size): for gen_iter in range(args.n_iter): print(f"iteration {gen_iter+1}/{args.n_iter}") - iter_seed = random.randint(0, 0x7FFFFFFF) + if args.iter_same_seed: + iter_seed = seed_random.randint(0, 2**32 - 1) + else: + iter_seed = None # shuffle prompt list if args.shuffle_prompts: - random.shuffle(prompt_list) + prompter.shuffle() # バッチ処理の関数 def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): @@ -2352,7 +2429,8 @@ def scale_and_round(x): for n, m in zip(networks, network_muls if network_muls else network_default_muls): n.set_multiplier(m) if regional_network: - n.set_current_generation(batch_size, num_sub_prompts, width, height, shared) + # TODO バッチから ds_ratio を取り出すべき + n.set_current_generation(batch_size, num_sub_prompts, width, height, shared, unet.ds_ratio) if not regional_network and network_pre_calc: for n in networks: @@ -2386,6 +2464,7 @@ def scale_and_round(x): return_latents=return_latents, clip_prompts=clip_prompts, clip_guide_images=guide_images, + emb_normalize_mode=args.emb_normalize_mode, ) if highres_1st and not args.highres_fix_save_1st: # return images or latents return images @@ -2451,8 +2530,8 @@ def scale_and_round(x): prompt_index = 0 global_step = 0 batch_data = [] - while args.interactive or prompt_index < len(prompt_list): - if len(prompt_list) == 0: + while True: + if args.interactive: # interactive valid = False while not valid: @@ -2466,7 +2545,9 @@ def scale_and_round(x): if not valid: # EOF, end app break else: - raw_prompt = prompt_list[prompt_index] + raw_prompt = prompter(args, pipe, seed_random, iter_seed, prompt_index, global_step) + if raw_prompt is None: + break # sd-dynamic-prompts like variants: # count is 1 (not dynamic) or images_per_prompt (no enumeration) or arbitrary (enumeration) @@ -2513,7 +2594,8 @@ def scale_and_round(x): prompt_args = raw_prompt.strip().split(" --") prompt = prompt_args[0] - print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") + length = len(prompter) if hasattr(prompter, "__len__") else 0 + print(f"prompt {prompt_index+1}/{length}: {prompt}") for parg in prompt_args[1:]: try: @@ -2731,23 +2813,17 @@ def scale_and_round(x): # prepare seed if seeds is not None: # given in prompt - # 数が足りないなら前のをそのまま使う + # num_images_per_promptが多い場合は足りなくなるので、足りない分は前のを使う if len(seeds) > 0: seed = seeds.pop(0) else: - if predefined_seeds is not None: - if len(predefined_seeds) > 0: - seed = predefined_seeds.pop(0) - else: - print("predefined seeds are exhausted") - seed = None - elif args.iter_same_seed: - seeds = iter_seed + if args.iter_same_seed: + seed = iter_seed else: seed = None # 前のを消す if seed is None: - seed = random.randint(0, 0x7FFFFFFF) + seed = seed_random.randint(0, 2**32 - 1) if args.interactive: print(f"seed: {seed}") @@ -2853,6 +2929,15 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="if specified, load prompts from this file / 指定時はプロンプトをファイルから読み込む", ) + parser.add_argument( + "--from_module", + type=str, + default=None, + help="if specified, load prompts from this module / 指定時はプロンプトをモジュールから読み込む", + ) + parser.add_argument( + "--prompter_module_args", type=str, default=None, help="args for prompter module / prompterモジュールの引数" + ) parser.add_argument( "--interactive", action="store_true", @@ -3067,6 +3152,13 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="max embedding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる", ) + parser.add_argument( + "--emb_normalize_mode", + type=str, + default="original", + choices=["original", "none", "abs"], + help="embedding normalization mode / embeddingの正規化モード", + ) parser.add_argument( "--guide_image_path", type=str, default=None, nargs="*", help="image to ControlNet / ControlNetでガイドに使う画像" ) diff --git a/networks/lora.py b/networks/lora.py index eaf656ac8..948b30b0e 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -12,8 +12,10 @@ import torch import re from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") @@ -248,7 +250,8 @@ def get_mask_for_x(self, x): if mask is None: # raise ValueError(f"mask is None for resolution {area}") # emb_layers in SDXL doesn't have mask - # logger.info(f"mask is None for resolution {area}, {x.size()}") + # if "emb" not in self.lora_name: + # print(f"mask is None for resolution {self.lora_name}, {area}, {x.size()}") mask_size = (1, x.size()[1]) if len(x.size()) == 2 else (1, *x.size()[1:-1], 1) return torch.ones(mask_size, dtype=x.dtype, device=x.device) / self.network.num_sub_prompts if len(x.size()) != 4: @@ -265,7 +268,9 @@ def regional_forward(self, x): # apply mask for LoRA result lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale mask = self.get_mask_for_x(lx) - # logger.info(f"regional {self.lora_name} {self.network.sub_prompt_index} {lx.size()} {mask.size()}") + # print("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size()) + # if mask.ndim > lx.ndim: # in some resolution, lx is 2d and mask is 3d (the reason is not checked) + # mask = mask.squeeze(-1) lx = lx * mask x = self.org_forward(x) @@ -514,7 +519,9 @@ def parse_floats(s): len(block_dims) == num_total_blocks ), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください" else: - logger.warning(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります") + logger.warning( + f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります" + ) block_dims = [network_dim] * num_total_blocks if block_alphas is not None: @@ -792,7 +799,9 @@ def __init__( logger.info(f"create LoRA network from weights") elif block_dims is not None: logger.info(f"create LoRA network from block_dims") - logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") + logger.info( + f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" + ) logger.info(f"block_dims: {block_dims}") logger.info(f"block_alphas: {block_alphas}") if conv_block_dims is not None: @@ -800,9 +809,13 @@ def __init__( logger.info(f"conv_block_alphas: {conv_block_alphas}") else: logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") - logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") + logger.info( + f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" + ) if self.conv_lora_dim is not None: - logger.info(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") + logger.info( + f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}" + ) # create module instances def create_modules( @@ -929,6 +942,10 @@ def set_multiplier(self, multiplier): for lora in self.text_encoder_loras + self.unet_loras: lora.multiplier = self.multiplier + def set_enabled(self, is_enabled): + for lora in self.text_encoder_loras + self.unet_loras: + lora.enabled = is_enabled + def load_weights(self, file): if os.path.splitext(file)[1] == ".safetensors": from safetensors.torch import load_file @@ -1116,7 +1133,7 @@ def set_region(self, sub_prompt_index, is_last_network, mask): for lora in self.text_encoder_loras + self.unet_loras: lora.set_network(self) - def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared): + def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared, ds_ratio=None): self.batch_size = batch_size self.num_sub_prompts = num_sub_prompts self.current_size = (height, width) @@ -1142,6 +1159,13 @@ def resize_add(mh, mw): resize_add(h, w) if h % 2 == 1 or w % 2 == 1: # add extra shape if h/w is not divisible by 2 resize_add(h + h % 2, w + w % 2) + + # deep shrink + if ds_ratio is not None: + hd = int(h * ds_ratio) + wd = int(w * ds_ratio) + resize_add(hd, wd) + h = (h + 1) // 2 w = (w + 1) // 2 From 93bed60762d7823a18c03229655478fb33d6a95a Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 12 Feb 2024 14:49:29 +0900 Subject: [PATCH 060/103] fix to work `--console_log_xxx` options --- library/utils.py | 2 +- sdxl_train.py | 1 + sdxl_train_control_net_lllite.py | 1 + sdxl_train_control_net_lllite_old.py | 1 + train_controlnet.py | 1 + train_db.py | 1 + train_network.py | 1 + train_textual_inversion.py | 1 + train_textual_inversion_XTI.py | 1 + 9 files changed, 9 insertions(+), 1 deletion(-) diff --git a/library/utils.py b/library/utils.py index a56c6d942..3037c055d 100644 --- a/library/utils.py +++ b/library/utils.py @@ -25,7 +25,7 @@ def add_logging_arguments(parser): "--console_log_file", type=str, default=None, - help="Log to a file instead of stdout / 標準出力ではなくファイルにログを出力する", + help="Log to a file instead of stderr / 標準エラー出力ではなくファイルにログを出力する", ) parser.add_argument("--console_log_simple", action="store_true", help="Simple log output / シンプルなログ出力") diff --git a/sdxl_train.py b/sdxl_train.py index 3a09e4c84..8ee4a8b6e 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -99,6 +99,7 @@ def train(args): train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) sdxl_train_util.verify_sdxl_training_args(args) + setup_logging(args, reset=True) assert ( not args.weighted_captions diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 15ee2f85c..dba7c0bbe 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -71,6 +71,7 @@ def train(args): train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) sdxl_train_util.verify_sdxl_training_args(args) + setup_logging(args, reset=True) cache_latents = args.cache_latents use_user_config = args.dataset_config is not None diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 542cd7347..d4159eec2 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -67,6 +67,7 @@ def train(args): train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) sdxl_train_util.verify_sdxl_training_args(args) + setup_logging(args, reset=True) cache_latents = args.cache_latents use_user_config = args.dataset_config is not None diff --git a/train_controlnet.py b/train_controlnet.py index 7489c48fe..6e0e53189 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -62,6 +62,7 @@ def train(args): # training_started_at = time.time() train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) + setup_logging(args, reset=True) cache_latents = args.cache_latents use_user_config = args.dataset_config is not None diff --git a/train_db.py b/train_db.py index 09cb69009..9bcf9d8ce 100644 --- a/train_db.py +++ b/train_db.py @@ -48,6 +48,7 @@ def train(args): train_util.verify_training_args(args) train_util.prepare_dataset_args(args, False) + setup_logging(args, reset=True) cache_latents = args.cache_latents diff --git a/train_network.py b/train_network.py index bd5e3693e..c481c7ac9 100644 --- a/train_network.py +++ b/train_network.py @@ -142,6 +142,7 @@ def train(self, args): training_started_at = time.time() train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) + setup_logging(args, reset=True) cache_latents = args.cache_latents use_dreambooth_method = args.in_json is None diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 5900ad19d..07a912ffe 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -174,6 +174,7 @@ def train(self, args): train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) + setup_logging(args, reset=True) cache_latents = args.cache_latents diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 9b7e04c89..d7799acd4 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -100,6 +100,7 @@ def train(args): if args.output_name is None: args.output_name = args.token_string use_template = args.use_object_template or args.use_style_template + setup_logging(args, reset=True) train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) From 71ebcc5e259ccc73eaf97ad720d54c667fc99825 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 12 Feb 2024 14:52:19 +0900 Subject: [PATCH 061/103] update readme and gradual latent doc --- README.md | 65 ++++++++++++++++++++------------------- docs/gen_img_README-ja.md | 33 ++++++++++++++++++++ 2 files changed, 66 insertions(+), 32 deletions(-) diff --git a/README.md b/README.md index 3b3b2c0e7..d95159525 100644 --- a/README.md +++ b/README.md @@ -1,35 +1,3 @@ -# Gradual Latent について - -latentのサイズを徐々に大きくしていくHires fixです。`sdxl_gen_img.py` に以下のオプションが追加されています。 - -- `--gradual_latent_timesteps` : latentのサイズを大きくし始めるタイムステップを指定します。デフォルトは None で、Gradual Latentを使用しません。750 くらいから始めてみてください。 -- `--gradual_latent_ratio` : latentの初期サイズを指定します。デフォルトは 0.5 で、デフォルトの latent サイズの半分のサイズから始めます。 -- `--gradual_latent_ratio_step`: latentのサイズを大きくする割合を指定します。デフォルトは 0.125 で、latentのサイズを 0.625, 0.75, 0.875, 1.0 と徐々に大きくします。 -- `--gradual_latent_ratio_every_n_steps`: latentのサイズを大きくする間隔を指定します。デフォルトは 3 で、3ステップごとに latent のサイズを大きくします。 - -それぞれのオプションは、プロンプトオプション、`--glt`、`--glr`、`--gls`、`--gle` でも指定できます。 - -サンプラーに手を加えているため、__サンプラーに `euler_a` を指定してください。__ 他のサンプラーでは動作しません。 - -`gen_img_diffusers.py` にも同様のオプションが追加されていますが、試した範囲ではどうやっても乱れた画像しか生成できませんでした。 - -# About Gradual Latent - -Gradual Latent is a Hires fix that gradually increases the size of the latent. `sdxl_gen_img.py` has the following options added. - -- `--gradual_latent_timesteps`: Specifies the timestep to start increasing the size of the latent. The default is None, which means Gradual Latent is not used. Please try around 750 at first. -- `--gradual_latent_ratio`: Specifies the initial size of the latent. The default is 0.5, which means it starts with half the default latent size. -- `--gradual_latent_ratio_step`: Specifies the ratio to increase the size of the latent. The default is 0.125, which means the latent size is gradually increased to 0.625, 0.75, 0.875, 1.0. -- `--gradual_latent_ratio_every_n_steps`: Specifies the interval to increase the size of the latent. The default is 3, which means the latent size is increased every 3 steps. - -Each option can also be specified with prompt options, `--glt`, `--glr`, `--gls`, `--gle`. - -__Please specify `euler_a` for the sampler.__ Because the source code of the sampler is modified. It will not work with other samplers. - -`gen_img_diffusers.py` also has the same options, but in the range I tried, it only generated distorted images no matter what I did. - ---- - __SDXL is now supported. The sdxl branch has been merged into the main branch. If you update the repository, please follow the upgrade instructions. Also, the version of accelerate has been updated, so please run accelerate config again.__ The documentation for SDXL training is [here](./README.md#sdxl-training). This repository contains training, generation and utility scripts for Stable Diffusion. @@ -281,6 +249,39 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum ## Change History +### Working in progress + +- The log output has been improved. PR [#905](https://github.com/kohya-ss/sd-scripts/pull/905) Thanks to shirayu! + - The log is formatted by default. The `rich` library is required. Please see [Upgrade](#upgrade) and update the library. + - If `rich` is not installed, the log output will be the same as before. + - The following options are available in each training script: + - `--console_log_simple` option can be used to switch to the previous log output. + - `--console_log_level` option can be used to specify the log level. The default is `INFO`. + - `--console_log_file` option can be used to output the log to a file. The default is `None` (output to the console). +- The sample image generation during multi-GPU training is now done with multiple GPUs. PR [#1061](https://github.com/kohya-ss/sd-scripts/pull/1061) Thanks to DKnight54! +- The IPEX support is improved. PR [#1086](https://github.com/kohya-ss/sd-scripts/pull/1086) Thanks to Disty0! +- Fixed a bug that `svd_merge_lora.py` crashes in some cases. PR [#1087](https://github.com/kohya-ss/sd-scripts/pull/1087) Thanks to mgz-dev! +- The common image generation script `gen_img.py` for SD 1/2 and SDXL is added. The basic functions are the same as the scripts for SD 1/2 and SDXL, but some new features are added. + - External scripts to generate prompts can be supported. It can be called with `--from_module` option. (The documentation will be added later) + - The normalization method after prompt weighting can be specified with `--emb_normalize_mode` option. `original` is the original method, `abs` is the normalization with the average of the absolute values, `none` is no normalization. +- Gradual Latent Hires fix is added to each generation script. See [here](./docs/gen_img_README-ja.md#about-gradual-latent) for details. + +- ログ出力が改善されました。 PR [#905](https://github.com/kohya-ss/sd-scripts/pull/905) shirayu 氏に感謝します。 + - デフォルトでログが成形されます。`rich` ライブラリが必要なため、[Upgrade](#upgrade) を参照し更新をお願いします。 + - `rich` がインストールされていない場合は、従来のログ出力になります。 + - 各学習スクリプトでは以下のオプションが有効です。 + - `--console_log_simple` オプションで従来のログ出力に切り替えられます。 + - `--console_log_level` でログレベルを指定できます。デフォルトは `INFO` です。 + - `--console_log_file` でログファイルを出力できます。デフォルトは `None`(コンソールに出力) です。 +- 複数 GPU 学習時に学習中のサンプル画像生成を複数 GPU で行うようになりました。 PR [#1061](https://github.com/kohya-ss/sd-scripts/pull/1061) DKnight54 氏に感謝します。 +- IPEX サポートが改善されました。 PR [#1086](https://github.com/kohya-ss/sd-scripts/pull/1086) Disty0 氏に感謝します。 +- `svd_merge_lora.py` が場合によってエラーになる不具合が修正されました。 PR [#1087](https://github.com/kohya-ss/sd-scripts/pull/1087) mgz-dev 氏に感謝します。 +- SD 1/2 および SDXL 共通の生成スクリプト `gen_img.py` を追加しました。基本的な機能は SD 1/2、SDXL 向けスクリプトと同じですが、いくつかの新機能が追加されています。 + - プロンプトを動的に生成する外部スクリプトをサポートしました。 `--from_module` で呼び出せます。(ドキュメントはのちほど追加します) + - プロンプト重みづけ後の正規化方法を `--emb_normalize_mode` で指定できます。`original` は元の方法、`abs` は絶対値の平均値で正規化、`none` は正規化を行いません。 +- Gradual Latent Hires fix を各生成スクリプトに追加しました。詳細は [こちら](./docs/gen_img_README-ja.md#about-gradual-latent)。 + + ### Jan 27, 2024 / 2024/1/27: v0.8.3 - Fixed a bug that the training crashes when `--fp8_base` is specified with `--save_state`. PR [#1079](https://github.com/kohya-ss/sd-scripts/pull/1079) Thanks to feffy380! diff --git a/docs/gen_img_README-ja.md b/docs/gen_img_README-ja.md index cf35f1df7..8f4442d00 100644 --- a/docs/gen_img_README-ja.md +++ b/docs/gen_img_README-ja.md @@ -452,3 +452,36 @@ python gen_img_diffusers.py --ckpt wd-v1-3-full-pruned-half.ckpt - `--network_show_meta` : 追加ネットワークのメタデータを表示します。 + +--- + +# About Gradual Latent + +Gradual Latent is a Hires fix that gradually increases the size of the latent. `gen_img.py`, `sdxl_gen_img.py`, and `gen_img_diffusers.py` have the following options. + +- `--gradual_latent_timesteps`: Specifies the timestep to start increasing the size of the latent. The default is None, which means Gradual Latent is not used. Please try around 750 at first. +- `--gradual_latent_ratio`: Specifies the initial size of the latent. The default is 0.5, which means it starts with half the default latent size. +- `--gradual_latent_ratio_step`: Specifies the ratio to increase the size of the latent. The default is 0.125, which means the latent size is gradually increased to 0.625, 0.75, 0.875, 1.0. +- `--gradual_latent_ratio_every_n_steps`: Specifies the interval to increase the size of the latent. The default is 3, which means the latent size is increased every 3 steps. + +Each option can also be specified with prompt options, `--glt`, `--glr`, `--gls`, `--gle`. + +__Please specify `euler_a` for the sampler.__ Because the source code of the sampler is modified. It will not work with other samplers. + +It is more effective with SD 1.5. It is quite subtle with SDXL. + +# Gradual Latent について + +latentのサイズを徐々に大きくしていくHires fixです。`gen_img.py` 、``sdxl_gen_img.py` 、`gen_img_diffusers.py` に以下のオプションが追加されています。 + +- `--gradual_latent_timesteps` : latentのサイズを大きくし始めるタイムステップを指定します。デフォルトは None で、Gradual Latentを使用しません。750 くらいから始めてみてください。 +- `--gradual_latent_ratio` : latentの初期サイズを指定します。デフォルトは 0.5 で、デフォルトの latent サイズの半分のサイズから始めます。 +- `--gradual_latent_ratio_step`: latentのサイズを大きくする割合を指定します。デフォルトは 0.125 で、latentのサイズを 0.625, 0.75, 0.875, 1.0 と徐々に大きくします。 +- `--gradual_latent_ratio_every_n_steps`: latentのサイズを大きくする間隔を指定します。デフォルトは 3 で、3ステップごとに latent のサイズを大きくします。 + +それぞれのオプションは、プロンプトオプション、`--glt`、`--glr`、`--gls`、`--gle` でも指定できます。 + +サンプラーに手を加えているため、__サンプラーに `euler_a` を指定してください。__ 他のサンプラーでは動作しません。 + +SD 1.5 のほうが効果があります。SDXL ではかなり微妙です。 + From 75e4a951d09928e88be419911eb185a5bc6f0de7 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 17 Feb 2024 12:04:12 +0900 Subject: [PATCH 062/103] update readme --- README.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/README.md b/README.md index d95159525..4a53c1542 100644 --- a/README.md +++ b/README.md @@ -259,6 +259,9 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum - `--console_log_level` option can be used to specify the log level. The default is `INFO`. - `--console_log_file` option can be used to output the log to a file. The default is `None` (output to the console). - The sample image generation during multi-GPU training is now done with multiple GPUs. PR [#1061](https://github.com/kohya-ss/sd-scripts/pull/1061) Thanks to DKnight54! +- The support for mps devices is improved. PR [#1054](https://github.com/kohya-ss/sd-scripts/pull/1054) Thanks to akx! If mps device exists instead of CUDA, the mps device is used automatically. +- An option `--highvram` to disable the optimization for environments with little VRAM is added to the training scripts. If you specify it when there is enough VRAM, the operation will be faster. + - Currently, only the cache part of latents is optimized. - The IPEX support is improved. PR [#1086](https://github.com/kohya-ss/sd-scripts/pull/1086) Thanks to Disty0! - Fixed a bug that `svd_merge_lora.py` crashes in some cases. PR [#1087](https://github.com/kohya-ss/sd-scripts/pull/1087) Thanks to mgz-dev! - The common image generation script `gen_img.py` for SD 1/2 and SDXL is added. The basic functions are the same as the scripts for SD 1/2 and SDXL, but some new features are added. @@ -274,6 +277,9 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum - `--console_log_level` でログレベルを指定できます。デフォルトは `INFO` です。 - `--console_log_file` でログファイルを出力できます。デフォルトは `None`(コンソールに出力) です。 - 複数 GPU 学習時に学習中のサンプル画像生成を複数 GPU で行うようになりました。 PR [#1061](https://github.com/kohya-ss/sd-scripts/pull/1061) DKnight54 氏に感謝します。 +- mps デバイスのサポートが改善されました。 PR [#1054](https://github.com/kohya-ss/sd-scripts/pull/1054) akx 氏に感謝します。CUDA ではなく mps が存在する場合には自動的に mps デバイスを使用します。 +- 学習スクリプトに VRAMが少ない環境向け最適化を無効にするオプション `--highvram` を追加しました。VRAM に余裕がある場合に指定すると動作が高速化されます。 + - 現在は latents のキャッシュ部分のみ高速化されます。 - IPEX サポートが改善されました。 PR [#1086](https://github.com/kohya-ss/sd-scripts/pull/1086) Disty0 氏に感謝します。 - `svd_merge_lora.py` が場合によってエラーになる不具合が修正されました。 PR [#1087](https://github.com/kohya-ss/sd-scripts/pull/1087) mgz-dev 氏に感謝します。 - SD 1/2 および SDXL 共通の生成スクリプト `gen_img.py` を追加しました。基本的な機能は SD 1/2、SDXL 向けスクリプトと同じですが、いくつかの新機能が追加されています。 From d1fb480887f30f3525f060f1f9b42957f822bb39 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 18 Feb 2024 09:13:24 +0900 Subject: [PATCH 063/103] format by black --- library/train_util.py | 232 +++++++++++++++++++++++++++++++----------- 1 file changed, 170 insertions(+), 62 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 60af6e893..d2b69edb5 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -70,8 +70,10 @@ import library.huggingface_util as huggingface_util import library.sai_model_spec as sai_model_spec from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) # from library.attention_processors import FlashAttnProcessor # from library.hypernetwork import replace_attentions_for_hypernetwork @@ -1483,7 +1485,9 @@ def load_dreambooth_dir(subset: DreamBoothSubset): img_paths, captions = load_dreambooth_dir(subset) if len(img_paths) < 1: - logger.warning(f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します") + logger.warning( + f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します" + ) continue if subset.is_reg: @@ -1574,7 +1578,9 @@ def __init__( raise ValueError(f"no metadata / メタデータファイルがありません: {subset.metadata_file}") if len(metadata) < 1: - logger.warning(f"ignore subset with '{subset.metadata_file}': no image entries found / 画像に関するデータが見つからないためサブセットを無視します") + logger.warning( + f"ignore subset with '{subset.metadata_file}': no image entries found / 画像に関するデータが見つからないためサブセットを無視します" + ) continue tags_list = [] @@ -1655,7 +1661,9 @@ def __init__( logger.warning(f"npz file does not exist. ignore npz files / npzファイルが見つからないためnpzファイルを無視します") elif not npz_all: use_npz_latents = False - logger.warning(f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します") + logger.warning( + f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します" + ) if flip_aug_in_subset: logger.warning("maybe no flipped files / 反転されたnpzファイルがないのかもしれません") # else: @@ -1675,7 +1683,9 @@ def __init__( if sizes is None: if use_npz_latents: use_npz_latents = False - logger.warning(f"npz files exist, but no bucket info in metadata. ignore npz files / メタデータにbucket情報がないためnpzファイルを無視します") + logger.warning( + f"npz files exist, but no bucket info in metadata. ignore npz files / メタデータにbucket情報がないためnpzファイルを無視します" + ) assert ( resolution is not None @@ -1871,7 +1881,9 @@ def __getitem__(self, index): assert ( cond_img.shape[0] == original_size_hw[0] and cond_img.shape[1] == original_size_hw[1] ), f"size of conditioning image is not match / 画像サイズが合いません: {image_info.absolute_path}" - cond_img = cv2.resize(cond_img, image_info.resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ + cond_img = cv2.resize( + cond_img, image_info.resized_size, interpolation=cv2.INTER_AREA + ) # INTER_AREAでやりたいのでcv2でリサイズ # TODO support random crop # 現在サポートしているcropはrandomではなく中央のみ @@ -2025,7 +2037,9 @@ def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, fli def debug_dataset(train_dataset, show_input_ids=False): logger.info(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}") - logger.info("`S` for next step, `E` for next epoch no. , Escape for exit. / Sキーで次のステップ、Eキーで次のエポック、Escキーで中断、終了します") + logger.info( + "`S` for next step, `E` for next epoch no. , Escape for exit. / Sキーで次のステップ、Eキーで次のエポック、Escキーで中断、終了します" + ) epoch = 1 while True: @@ -2686,7 +2700,9 @@ def get_sai_model_spec( def add_sd_models_arguments(parser: argparse.ArgumentParser): # for pretrained models - parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む") + parser.add_argument( + "--v2", action="store_true", help="load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む" + ) parser.add_argument( "--v_parameterization", action="store_true", help="enable v-parameterization training / v-parameterization学習を有効にする" ) @@ -2726,7 +2742,10 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率") parser.add_argument( - "--max_grad_norm", default=1.0, type=float, help="Max gradient norm, 0 for no clipping / 勾配正規化の最大norm、0でclippingを行わない" + "--max_grad_norm", + default=1.0, + type=float, + help="Max gradient norm, 0 for no clipping / 勾配正規化の最大norm、0でclippingを行わない", ) parser.add_argument( @@ -2773,13 +2792,23 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool): - parser.add_argument("--output_dir", type=str, default=None, help="directory to output trained model / 学習後のモデル出力先ディレクトリ") - parser.add_argument("--output_name", type=str, default=None, help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名") parser.add_argument( - "--huggingface_repo_id", type=str, default=None, help="huggingface repo name to upload / huggingfaceにアップロードするリポジトリ名" + "--output_dir", type=str, default=None, help="directory to output trained model / 学習後のモデル出力先ディレクトリ" ) parser.add_argument( - "--huggingface_repo_type", type=str, default=None, help="huggingface repo type to upload / huggingfaceにアップロードするリポジトリの種類" + "--output_name", type=str, default=None, help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名" + ) + parser.add_argument( + "--huggingface_repo_id", + type=str, + default=None, + help="huggingface repo name to upload / huggingfaceにアップロードするリポジトリ名", + ) + parser.add_argument( + "--huggingface_repo_type", + type=str, + default=None, + help="huggingface repo type to upload / huggingfaceにアップロードするリポジトリの種類", ) parser.add_argument( "--huggingface_path_in_repo", @@ -2815,10 +2844,16 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: help="precision in saving / 保存時に精度を変更して保存する", ) parser.add_argument( - "--save_every_n_epochs", type=int, default=None, help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する" + "--save_every_n_epochs", + type=int, + default=None, + help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する", ) parser.add_argument( - "--save_every_n_steps", type=int, default=None, help="save checkpoint every N steps / 学習中のモデルを指定ステップごとに保存する" + "--save_every_n_steps", + type=int, + default=None, + help="save checkpoint every N steps / 学習中のモデルを指定ステップごとに保存する", ) parser.add_argument( "--save_n_epoch_ratio", @@ -2870,7 +2905,9 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: action="store_true", help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う", ) - parser.add_argument("--torch_compile", action="store_true", help="use torch.compile (requires PyTorch 2.0) / torch.compile を使う") + parser.add_argument( + "--torch_compile", action="store_true", help="use torch.compile (requires PyTorch 2.0) / torch.compile を使う" + ) parser.add_argument( "--dynamo_backend", type=str, @@ -2888,7 +2925,10 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: help="use sdpa for CrossAttention (requires PyTorch 2.0) / CrossAttentionにsdpaを使う(PyTorch 2.0が必要)", ) parser.add_argument( - "--vae", type=str, default=None, help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ" + "--vae", + type=str, + default=None, + help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ", ) parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数") @@ -2920,7 +2960,11 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: help="Number of updates steps to accumulate before performing a backward/update pass / 学習時に逆伝播をする前に勾配を合計するステップ数", ) parser.add_argument( - "--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度" + "--mixed_precision", + type=str, + default="no", + choices=["no", "fp16", "bf16"], + help="use mixed precision / 混合精度を使う場合、その精度", ) parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する") parser.add_argument( @@ -2962,7 +3006,9 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: choices=["tensorboard", "wandb", "all"], help="what logging tool(s) to use (if 'all', TensorBoard and WandB are both used) / ログ出力に使用するツール (allを指定するとTensorBoardとWandBの両方が使用される)", ) - parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列") + parser.add_argument( + "--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列" + ) parser.add_argument( "--log_tracker_name", type=str, @@ -3050,14 +3096,19 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument( "--highvram", action="store_true", - help="disable low VRAM optimization. e.g. do not clear CUDA cache after each latent caching (for machines which have bigger VRAM) " + - "/ VRAMが少ない環境向け最適化を無効にする。たとえば各latentのキャッシュ後のCUDAキャッシュクリアを行わない等(VRAMが多い環境向け)", + help="disable low VRAM optimization. e.g. do not clear CUDA cache after each latent caching (for machines which have bigger VRAM) " + + "/ VRAMが少ない環境向け最適化を無効にする。たとえば各latentのキャッシュ後のCUDAキャッシュクリアを行わない等(VRAMが多い環境向け)", ) parser.add_argument( - "--sample_every_n_steps", type=int, default=None, help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する" + "--sample_every_n_steps", + type=int, + default=None, + help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する", + ) + parser.add_argument( + "--sample_at_first", action="store_true", help="generate sample images before training / 学習前にサンプル出力する" ) - parser.add_argument("--sample_at_first", action="store_true", help="generate sample images before training / 学習前にサンプル出力する") parser.add_argument( "--sample_every_n_epochs", type=int, @@ -3065,7 +3116,10 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: help="generate sample images every N epochs (overwrites n_steps) / 学習中のモデルで指定エポックごとにサンプル出力する(ステップ数指定を上書きします)", ) parser.add_argument( - "--sample_prompts", type=str, default=None, help="file for prompts to generate sample images / 学習中モデルのサンプル出力用プロンプトのファイル" + "--sample_prompts", + type=str, + default=None, + help="file for prompts to generate sample images / 学習中モデルのサンプル出力用プロンプトのファイル", ) parser.add_argument( "--sample_sampler", @@ -3152,7 +3206,9 @@ def verify_training_args(args: argparse.Namespace): HIGH_VRAM = True if args.v_parameterization and not args.v2: - logger.warning("v_parameterization should be with v2 not v1 or sdxl / v1やsdxlでv_parameterizationを使用することは想定されていません") + logger.warning( + "v_parameterization should be with v2 not v1 or sdxl / v1やsdxlでv_parameterizationを使用することは想定されていません" + ) if args.v2 and args.clip_skip is not None: logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") @@ -3199,8 +3255,12 @@ def add_dataset_arguments( parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool ): # dataset common - parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ") - parser.add_argument("--shuffle_caption", action="store_true", help="shuffle separated caption / 区切られたcaptionの各要素をshuffleする") + parser.add_argument( + "--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ" + ) + parser.add_argument( + "--shuffle_caption", action="store_true", help="shuffle separated caption / 区切られたcaptionの各要素をshuffleする" + ) parser.add_argument("--caption_separator", type=str, default=",", help="separator for caption / captionの区切り文字") parser.add_argument( "--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子" @@ -3236,8 +3296,12 @@ def add_dataset_arguments( default=None, help="suffix for caption text / captionのテキストの末尾に付ける文字列", ) - parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする") - parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする") + parser.add_argument( + "--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする" + ) + parser.add_argument( + "--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする" + ) parser.add_argument( "--face_crop_aug_range", type=str, @@ -3250,7 +3314,9 @@ def add_dataset_arguments( help="enable random crop (for style training in face-centered crop augmentation) / ランダムな切り出しを有効にする(顔を中心としたaugmentationを行うときに画風の学習用に指定する)", ) parser.add_argument( - "--debug_dataset", action="store_true", help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)" + "--debug_dataset", + action="store_true", + help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)", ) parser.add_argument( "--resolution", @@ -3263,14 +3329,18 @@ def add_dataset_arguments( action="store_true", help="cache latents to main memory to reduce VRAM usage (augmentations must be disabled) / VRAM削減のためにlatentをメインメモリにcacheする(augmentationは使用不可) ", ) - parser.add_argument("--vae_batch_size", type=int, default=1, help="batch size for caching latents / latentのcache時のバッチサイズ") + parser.add_argument( + "--vae_batch_size", type=int, default=1, help="batch size for caching latents / latentのcache時のバッチサイズ" + ) parser.add_argument( "--cache_latents_to_disk", action="store_true", help="cache latents to disk to reduce VRAM usage (augmentations must be disabled) / VRAM削減のためにlatentをディスクにcacheする(augmentationは使用不可)", ) parser.add_argument( - "--enable_bucket", action="store_true", help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする" + "--enable_bucket", + action="store_true", + help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする", ) parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度") parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最大解像度") @@ -3281,7 +3351,9 @@ def add_dataset_arguments( help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します", ) parser.add_argument( - "--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します" + "--bucket_no_upscale", + action="store_true", + help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します", ) parser.add_argument( @@ -3325,13 +3397,20 @@ def add_dataset_arguments( if support_dreambooth: # DreamBooth dataset - parser.add_argument("--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ") + parser.add_argument( + "--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ" + ) if support_caption: # caption dataset - parser.add_argument("--in_json", type=str, default=None, help="json metadata for dataset / データセットのmetadataのjsonファイル") parser.add_argument( - "--dataset_repeats", type=int, default=1, help="repeat dataset when training with captions / キャプションでの学習時にデータセットを繰り返す回数" + "--in_json", type=str, default=None, help="json metadata for dataset / データセットのmetadataのjsonファイル" + ) + parser.add_argument( + "--dataset_repeats", + type=int, + default=1, + help="repeat dataset when training with captions / キャプションでの学習時にデータセットを繰り返す回数", ) @@ -3469,7 +3548,9 @@ def task(): loop = asyncio.get_event_loop() results = loop.run_until_complete(asyncio.gather(*[download(filename=filename.rfilename) for filename in list_files])) if len(results) == 0: - raise ValueError("No files found in the specified repo id/path/revision / 指定されたリポジトリID/パス/リビジョンにファイルが見つかりませんでした") + raise ValueError( + "No files found in the specified repo id/path/revision / 指定されたリポジトリID/パス/リビジョンにファイルが見つかりませんでした" + ) dirname = os.path.dirname(results[0]) accelerator.load_state(dirname) @@ -3610,7 +3691,9 @@ def get_optimizer(args, trainable_params): elif optimizer_type == "SGDNesterov".lower(): logger.info(f"use SGD with Nesterov optimizer | {optimizer_kwargs}") if "momentum" not in optimizer_kwargs: - logger.info(f"SGD with Nesterov must be with momentum, set momentum to 0.9 / SGD with Nesterovはmomentum指定が必須のため0.9に設定します") + logger.info( + f"SGD with Nesterov must be with momentum, set momentum to 0.9 / SGD with Nesterovはmomentum指定が必須のため0.9に設定します" + ) optimizer_kwargs["momentum"] = 0.9 optimizer_class = torch.optim.SGD @@ -3689,7 +3772,9 @@ def get_optimizer(args, trainable_params): if "relative_step" not in optimizer_kwargs: optimizer_kwargs["relative_step"] = True # default if not optimizer_kwargs["relative_step"] and optimizer_kwargs.get("warmup_init", False): - logger.info(f"set relative_step to True because warmup_init is True / warmup_initがTrueのためrelative_stepをTrueにします") + logger.info( + f"set relative_step to True because warmup_init is True / warmup_initがTrueのためrelative_stepをTrueにします" + ) optimizer_kwargs["relative_step"] = True logger.info(f"use Adafactor optimizer | {optimizer_kwargs}") @@ -3913,7 +3998,9 @@ def prepare_accelerator(args: argparse.Namespace): log_with = args.log_with if log_with in ["tensorboard", "all"]: if logging_dir is None: - raise ValueError("logging_dir is required when log_with is tensorboard / Tensorboardを使う場合、logging_dirを指定してください") + raise ValueError( + "logging_dir is required when log_with is tensorboard / Tensorboardを使う場合、logging_dirを指定してください" + ) if log_with in ["wandb", "all"]: try: import wandb @@ -3932,9 +4019,13 @@ def prepare_accelerator(args: argparse.Namespace): kwargs_handlers = ( InitProcessGroupKwargs(timeout=datetime.timedelta(minutes=args.ddp_timeout)) if args.ddp_timeout else None, - DistributedDataParallelKwargs(gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, static_graph=args.ddp_static_graph) - if args.ddp_gradient_as_bucket_view or args.ddp_static_graph - else None, + ( + DistributedDataParallelKwargs( + gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, static_graph=args.ddp_static_graph + ) + if args.ddp_gradient_as_bucket_view or args.ddp_static_graph + else None + ), ) kwargs_handlers = list(filter(lambda x: x is not None, kwargs_handlers)) accelerator = Accelerator( @@ -4083,7 +4174,9 @@ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encod # v1: ... の三連を ... へ戻す states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # for i in range(1, args.max_token_length, tokenizer.model_max_length): - states_list.append(encoder_hidden_states[:, i : i + tokenizer.model_max_length - 2]) # の後から の前まで + states_list.append( + encoder_hidden_states[:, i : i + tokenizer.model_max_length - 2] + ) # の後から の前まで states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # encoder_hidden_states = torch.cat(states_list, dim=1) @@ -4666,6 +4759,7 @@ def line_to_prompt_dict(line: str) -> dict: return prompt_dict + def sample_images_common( pipe_class, accelerator: Accelerator, @@ -4704,10 +4798,10 @@ def sample_images_common( logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") return - distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here - + distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here + org_vae_device = vae.device # CPUにいるはず - vae.to(distributed_state.device) # distributed_state.device is same as accelerator.device + vae.to(distributed_state.device) # distributed_state.device is same as accelerator.device # unwrap unet and text_encoder(s) unet = accelerator.unwrap_model(unet) @@ -4769,23 +4863,27 @@ def sample_images_common( cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None except Exception: pass - - if distributed_state.num_processes <= 1: + + if distributed_state.num_processes <= 1: # If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts. with torch.no_grad(): for prompt_dict in prompts: - sample_image_inference(accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet) + sample_image_inference( + accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet + ) else: # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available) # prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical. - per_process_prompts = [] # list of lists + per_process_prompts = [] # list of lists for i in range(distributed_state.num_processes): - per_process_prompts.append(prompts[i::distributed_state.num_processes]) - + per_process_prompts.append(prompts[i :: distributed_state.num_processes]) + with torch.no_grad(): with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists: for prompt_dict in prompt_dict_lists[0]: - sample_image_inference(accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet) + sample_image_inference( + accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet + ) # clear pipeline and cache to reduce vram usage del pipeline @@ -4794,13 +4892,24 @@ def sample_images_common( # with torch.cuda.device(torch.cuda.current_device()): # torch.cuda.empty_cache() clean_memory_on_device(accelerator.device) - + torch.set_rng_state(rng_state) if cuda_rng_state is not None: torch.cuda.set_rng_state(cuda_rng_state) vae.to(org_vae_device) -def sample_image_inference(accelerator: Accelerator, args: argparse.Namespace, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=None): + +def sample_image_inference( + accelerator: Accelerator, + args: argparse.Namespace, + pipeline, + save_dir, + prompt_dict, + epoch, + steps, + prompt_replacement, + controlnet=None, +): assert isinstance(prompt_dict, dict) negative_prompt = prompt_dict.get("negative_prompt") sample_steps = prompt_dict.get("sample_steps", 30) @@ -4815,7 +4924,7 @@ def sample_image_inference(accelerator: Accelerator, args: argparse.Namespace, p if prompt_replacement is not None: prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) if negative_prompt is not None: - negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) + negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) if seed is not None: torch.manual_seed(seed) @@ -4857,10 +4966,10 @@ def sample_image_inference(accelerator: Accelerator, args: argparse.Namespace, p controlnet=controlnet, controlnet_image=controlnet_image, ) - + with torch.cuda.device(torch.cuda.current_device()): torch.cuda.empty_cache() - + image = pipeline.latents_to_image(latents)[0] # adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list @@ -4870,9 +4979,7 @@ def sample_image_inference(accelerator: Accelerator, args: argparse.Namespace, p num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}" seed_suffix = "" if seed is None else f"_{seed}" i: int = prompt_dict["enum"] - img_filename = ( - f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png" - ) + img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png" image.save(os.path.join(save_dir, img_filename)) # wandb有効時のみログを送信 @@ -4886,9 +4993,10 @@ def sample_image_inference(accelerator: Accelerator, args: argparse.Namespace, p wandb_tracker.log({f"sample_{i}": wandb.Image(image)}) except: # wandb 無効時 pass -# endregion +# endregion + # region 前処理用 From a6f1ed2e140eb4d4d37c0bb0502a7c0fd0621f5f Mon Sep 17 00:00:00 2001 From: tamlog06 Date: Sun, 18 Feb 2024 13:20:47 +0000 Subject: [PATCH 064/103] fix dylora create_modules error --- networks/dylora.py | 31 ++++++++++++++++++++++++++++--- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/networks/dylora.py b/networks/dylora.py index e5a55d198..64e39eaf7 100644 --- a/networks/dylora.py +++ b/networks/dylora.py @@ -12,7 +12,9 @@ import math import os import random -from typing import List, Tuple, Union +from typing import Dict, List, Optional, Tuple, Type, Union +from diffusers import AutoencoderKL +from transformers import CLIPTextModel import torch from torch import nn @@ -165,7 +167,15 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, miss super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) -def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs): +def create_network( + multiplier: float, + network_dim: Optional[int], + network_alpha: Optional[float], + vae: AutoencoderKL, + text_encoder: Union[CLIPTextModel, List[CLIPTextModel]], + unet, + **kwargs, +): if network_dim is None: network_dim = 4 # default if network_alpha is None: @@ -182,6 +192,7 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un conv_alpha = 1.0 else: conv_alpha = float(conv_alpha) + if unit is not None: unit = int(unit) else: @@ -306,8 +317,22 @@ def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules lora = module_class(lora_name, child_module, self.multiplier, dim, alpha, unit) loras.append(lora) return loras + + text_encoders = text_encoder if type(text_encoder) == list else [text_encoder] + + self.text_encoder_loras = [] + for i, text_encoder in enumerate(text_encoders): + if len(text_encoders) > 1: + index = i + 1 + print(f"create LoRA for Text Encoder {index}") + else: + index = None + print(f"create LoRA for Text Encoder") + + text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) + self.text_encoder_loras.extend(text_encoder_loras) - self.text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) + # self.text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights From f4132018c568db2a822f1adf2c55e2615128a485 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 24 Feb 2024 19:25:31 +0900 Subject: [PATCH 065/103] fix to work with cpu_count() == 1 closes #1134 --- fine_tune.py | 4 ++-- sdxl_train.py | 4 ++-- sdxl_train_control_net_lllite.py | 4 ++-- sdxl_train_control_net_lllite_old.py | 4 ++-- tools/cache_latents.py | 4 ++-- tools/cache_text_encoder_outputs.py | 4 ++-- train_controlnet.py | 4 ++-- train_db.py | 4 ++-- train_network.py | 4 ++-- train_textual_inversion.py | 4 ++-- train_textual_inversion_XTI.py | 4 ++-- 11 files changed, 22 insertions(+), 22 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 8df896b43..875a91951 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -211,8 +211,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): _, _, optimizer = train_util.get_optimizer(args, trainable_params=trainable_params) # dataloaderを準備する - # DataLoaderのプロセス数:0はメインプロセスになる - n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( train_dataset_group, batch_size=1, diff --git a/sdxl_train.py b/sdxl_train.py index aa161e8ac..e0df263d6 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -354,8 +354,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) # dataloaderを準備する - # DataLoaderのプロセス数:0はメインプロセスになる - n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( train_dataset_group, batch_size=1, diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index b11999bd6..1e5f92349 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -242,8 +242,8 @@ def train(args): _, _, optimizer = train_util.get_optimizer(args, trainable_params) # dataloaderを準備する - # DataLoaderのプロセス数:0はメインプロセスになる - n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( train_dataset_group, diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 89a1bc8e0..dac56eedd 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -210,8 +210,8 @@ def train(args): _, _, optimizer = train_util.get_optimizer(args, trainable_params) # dataloaderを準備する - # DataLoaderのプロセス数:0はメインプロセスになる - n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( train_dataset_group, diff --git a/tools/cache_latents.py b/tools/cache_latents.py index e25506e41..347db27f7 100644 --- a/tools/cache_latents.py +++ b/tools/cache_latents.py @@ -116,8 +116,8 @@ def cache_to_disk(args: argparse.Namespace) -> None: # dataloaderを準備する train_dataset_group.set_caching_mode("latents") - # DataLoaderのプロセス数:0はメインプロセスになる - n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( train_dataset_group, diff --git a/tools/cache_text_encoder_outputs.py b/tools/cache_text_encoder_outputs.py index 46bffc4e0..5f1d6d201 100644 --- a/tools/cache_text_encoder_outputs.py +++ b/tools/cache_text_encoder_outputs.py @@ -121,8 +121,8 @@ def cache_to_disk(args: argparse.Namespace) -> None: # dataloaderを準備する train_dataset_group.set_caching_mode("text") - # DataLoaderのプロセス数:0はメインプロセスになる - n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( train_dataset_group, diff --git a/train_controlnet.py b/train_controlnet.py index 8963a5d62..dc73a91c8 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -239,8 +239,8 @@ def train(args): _, _, optimizer = train_util.get_optimizer(args, trainable_params) # dataloaderを準備する - # DataLoaderのプロセス数:0はメインプロセスになる - n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( train_dataset_group, diff --git a/train_db.py b/train_db.py index c89caaf2c..8d36097a5 100644 --- a/train_db.py +++ b/train_db.py @@ -180,8 +180,8 @@ def train(args): _, _, optimizer = train_util.get_optimizer(args, trainable_params) # dataloaderを準備する - # DataLoaderのプロセス数:0はメインプロセスになる - n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( train_dataset_group, batch_size=1, diff --git a/train_network.py b/train_network.py index af15560ce..e0fa69458 100644 --- a/train_network.py +++ b/train_network.py @@ -349,8 +349,8 @@ def train(self, args): optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params) # dataloaderを準備する - # DataLoaderのプロセス数:0はメインプロセスになる - n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( train_dataset_group, diff --git a/train_textual_inversion.py b/train_textual_inversion.py index a78a37b2c..df1d8485a 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -385,8 +385,8 @@ def train(self, args): _, _, optimizer = train_util.get_optimizer(args, trainable_params) # dataloaderを準備する - # DataLoaderのプロセス数:0はメインプロセスになる - n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( train_dataset_group, batch_size=1, diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 3f9155978..695fad2a8 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -305,8 +305,8 @@ def train(args): _, _, optimizer = train_util.get_optimizer(args, trainable_params) # dataloaderを準備する - # DataLoaderのプロセス数:0はメインプロセスになる - n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( train_dataset_group, batch_size=1, From 24092e6f2153c69a8647e0856ec733dd9e9f6ec3 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 24 Feb 2024 19:51:51 +0900 Subject: [PATCH 066/103] update einops to 0.7.0 #1122 --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index a399d8bd9..279de350c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,7 @@ diffusers[torch]==0.25.0 ftfy==6.1.1 # albumentations==1.3.0 opencv-python==4.7.0.68 -einops==0.6.1 +einops==0.7.0 pytorch-lightning==1.9.0 # bitsandbytes==0.39.1 tensorboard==2.10.1 From fb9110bac1814fbfe0dc031efece96ad03dac7f1 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 24 Feb 2024 20:00:57 +0900 Subject: [PATCH 067/103] format by black --- networks/resize_lora.py | 486 +++++++++++++++++++++------------------- 1 file changed, 260 insertions(+), 226 deletions(-) diff --git a/networks/resize_lora.py b/networks/resize_lora.py index c5932a893..37eb3caa3 100644 --- a/networks/resize_lora.py +++ b/networks/resize_lora.py @@ -9,77 +9,81 @@ from library import train_util, model_util import numpy as np from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) MIN_SV = 1e-6 # Model save and load functions + def load_state_dict(file_name, dtype): - if model_util.is_safetensors(file_name): - sd = load_file(file_name) - with safe_open(file_name, framework="pt") as f: - metadata = f.metadata() - else: - sd = torch.load(file_name, map_location='cpu') - metadata = None + if model_util.is_safetensors(file_name): + sd = load_file(file_name) + with safe_open(file_name, framework="pt") as f: + metadata = f.metadata() + else: + sd = torch.load(file_name, map_location="cpu") + metadata = None - for key in list(sd.keys()): - if type(sd[key]) == torch.Tensor: - sd[key] = sd[key].to(dtype) + for key in list(sd.keys()): + if type(sd[key]) == torch.Tensor: + sd[key] = sd[key].to(dtype) - return sd, metadata + return sd, metadata def save_to_file(file_name, model, state_dict, dtype, metadata): - if dtype is not None: - for key in list(state_dict.keys()): - if type(state_dict[key]) == torch.Tensor: - state_dict[key] = state_dict[key].to(dtype) + if dtype is not None: + for key in list(state_dict.keys()): + if type(state_dict[key]) == torch.Tensor: + state_dict[key] = state_dict[key].to(dtype) - if model_util.is_safetensors(file_name): - save_file(model, file_name, metadata) - else: - torch.save(model, file_name) + if model_util.is_safetensors(file_name): + save_file(model, file_name, metadata) + else: + torch.save(model, file_name) # Indexing functions + def index_sv_cumulative(S, target): - original_sum = float(torch.sum(S)) - cumulative_sums = torch.cumsum(S, dim=0)/original_sum - index = int(torch.searchsorted(cumulative_sums, target)) + 1 - index = max(1, min(index, len(S)-1)) + original_sum = float(torch.sum(S)) + cumulative_sums = torch.cumsum(S, dim=0) / original_sum + index = int(torch.searchsorted(cumulative_sums, target)) + 1 + index = max(1, min(index, len(S) - 1)) - return index + return index def index_sv_fro(S, target): - S_squared = S.pow(2) - s_fro_sq = float(torch.sum(S_squared)) - sum_S_squared = torch.cumsum(S_squared, dim=0)/s_fro_sq - index = int(torch.searchsorted(sum_S_squared, target**2)) + 1 - index = max(1, min(index, len(S)-1)) + S_squared = S.pow(2) + s_fro_sq = float(torch.sum(S_squared)) + sum_S_squared = torch.cumsum(S_squared, dim=0) / s_fro_sq + index = int(torch.searchsorted(sum_S_squared, target**2)) + 1 + index = max(1, min(index, len(S) - 1)) - return index + return index def index_sv_ratio(S, target): - max_sv = S[0] - min_sv = max_sv/target - index = int(torch.sum(S > min_sv).item()) - index = max(1, min(index, len(S)-1)) + max_sv = S[0] + min_sv = max_sv / target + index = int(torch.sum(S > min_sv).item()) + index = max(1, min(index, len(S) - 1)) - return index + return index # Modified from Kohaku-blueleaf's extract/merge functions def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1): out_size, in_size, kernel_size, _ = weight.size() U, S, Vh = torch.linalg.svd(weight.reshape(out_size, -1).to(device)) - + param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale) lora_rank = param_dict["new_rank"] @@ -96,17 +100,17 @@ def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1): out_size, in_size = weight.size() - + U, S, Vh = torch.linalg.svd(weight.to(device)) - + param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale) lora_rank = param_dict["new_rank"] - + U = U[:, :lora_rank] S = S[:lora_rank] U = U @ torch.diag(S) Vh = Vh[:lora_rank, :] - + param_dict["lora_down"] = Vh.reshape(lora_rank, in_size).cpu() param_dict["lora_up"] = U.reshape(out_size, lora_rank).cpu() del U, S, Vh, weight @@ -117,7 +121,7 @@ def merge_conv(lora_down, lora_up, device): in_rank, in_size, kernel_size, k_ = lora_down.shape out_size, out_rank, _, _ = lora_up.shape assert in_rank == out_rank and kernel_size == k_, f"rank {in_rank} {out_rank} or kernel {kernel_size} {k_} mismatch" - + lora_down = lora_down.to(device) lora_up = lora_up.to(device) @@ -131,236 +135,266 @@ def merge_linear(lora_down, lora_up, device): in_rank, in_size = lora_down.shape out_size, out_rank = lora_up.shape assert in_rank == out_rank, f"rank {in_rank} {out_rank} mismatch" - + lora_down = lora_down.to(device) lora_up = lora_up.to(device) - + weight = lora_up @ lora_down del lora_up, lora_down return weight - + # Calculate new rank + def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1): param_dict = {} - if dynamic_method=="sv_ratio": + if dynamic_method == "sv_ratio": # Calculate new dim and alpha based off ratio new_rank = index_sv_ratio(S, dynamic_param) + 1 - new_alpha = float(scale*new_rank) + new_alpha = float(scale * new_rank) - elif dynamic_method=="sv_cumulative": + elif dynamic_method == "sv_cumulative": # Calculate new dim and alpha based off cumulative sum new_rank = index_sv_cumulative(S, dynamic_param) + 1 - new_alpha = float(scale*new_rank) + new_alpha = float(scale * new_rank) - elif dynamic_method=="sv_fro": + elif dynamic_method == "sv_fro": # Calculate new dim and alpha based off sqrt sum of squares new_rank = index_sv_fro(S, dynamic_param) + 1 - new_alpha = float(scale*new_rank) + new_alpha = float(scale * new_rank) else: new_rank = rank - new_alpha = float(scale*new_rank) + new_alpha = float(scale * new_rank) - - if S[0] <= MIN_SV: # Zero matrix, set dim to 1 + if S[0] <= MIN_SV: # Zero matrix, set dim to 1 new_rank = 1 - new_alpha = float(scale*new_rank) - elif new_rank > rank: # cap max rank at rank + new_alpha = float(scale * new_rank) + elif new_rank > rank: # cap max rank at rank new_rank = rank - new_alpha = float(scale*new_rank) - + new_alpha = float(scale * new_rank) # Calculate resize info s_sum = torch.sum(torch.abs(S)) s_rank = torch.sum(torch.abs(S[:new_rank])) - + S_squared = S.pow(2) s_fro = torch.sqrt(torch.sum(S_squared)) s_red_fro = torch.sqrt(torch.sum(S_squared[:new_rank])) - fro_percent = float(s_red_fro/s_fro) + fro_percent = float(s_red_fro / s_fro) param_dict["new_rank"] = new_rank param_dict["new_alpha"] = new_alpha - param_dict["sum_retained"] = (s_rank)/s_sum + param_dict["sum_retained"] = (s_rank) / s_sum param_dict["fro_retained"] = fro_percent - param_dict["max_ratio"] = S[0]/S[new_rank - 1] + param_dict["max_ratio"] = S[0] / S[new_rank - 1] return param_dict def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dynamic_param, verbose): - network_alpha = None - network_dim = None - verbose_str = "\n" - fro_list = [] - - # Extract loaded lora dim and alpha - for key, value in lora_sd.items(): - if network_alpha is None and 'alpha' in key: - network_alpha = value - if network_dim is None and 'lora_down' in key and len(value.size()) == 2: - network_dim = value.size()[0] - if network_alpha is not None and network_dim is not None: - break - if network_alpha is None: - network_alpha = network_dim - - scale = network_alpha/network_dim - - if dynamic_method: - logger.info(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}") - - lora_down_weight = None - lora_up_weight = None - - o_lora_sd = lora_sd.copy() - block_down_name = None - block_up_name = None - - with torch.no_grad(): - for key, value in tqdm(lora_sd.items()): - weight_name = None - if 'lora_down' in key: - block_down_name = key.rsplit('.lora_down', 1)[0] - weight_name = key.rsplit(".", 1)[-1] - lora_down_weight = value - else: - continue - - # find corresponding lora_up and alpha - block_up_name = block_down_name - lora_up_weight = lora_sd.get(block_up_name + '.lora_up.' + weight_name, None) - lora_alpha = lora_sd.get(block_down_name + '.alpha', None) - - weights_loaded = (lora_down_weight is not None and lora_up_weight is not None) - - if weights_loaded: - - conv2d = (len(lora_down_weight.size()) == 4) - if lora_alpha is None: - scale = 1.0 - else: - scale = lora_alpha/lora_down_weight.size()[0] - - if conv2d: - full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device) - param_dict = extract_conv(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale) - else: - full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device) - param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale) - - if verbose: - max_ratio = param_dict['max_ratio'] - sum_retained = param_dict['sum_retained'] - fro_retained = param_dict['fro_retained'] - if not np.isnan(fro_retained): - fro_list.append(float(fro_retained)) - - verbose_str+=f"{block_down_name:75} | " - verbose_str+=f"sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}" - - if verbose and dynamic_method: - verbose_str+=f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n" - else: - verbose_str+=f"\n" - - new_alpha = param_dict['new_alpha'] - o_lora_sd[block_down_name + "." + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous() - o_lora_sd[block_up_name + "." + "lora_up.weight"] = param_dict["lora_up"].to(save_dtype).contiguous() - o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(param_dict['new_alpha']).to(save_dtype) - - block_down_name = None - block_up_name = None - lora_down_weight = None - lora_up_weight = None - weights_loaded = False - del param_dict - - if verbose: - logger.info(verbose_str) - - logger.info(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}") - logger.info("resizing complete") - return o_lora_sd, network_dim, new_alpha + network_alpha = None + network_dim = None + verbose_str = "\n" + fro_list = [] + + # Extract loaded lora dim and alpha + for key, value in lora_sd.items(): + if network_alpha is None and "alpha" in key: + network_alpha = value + if network_dim is None and "lora_down" in key and len(value.size()) == 2: + network_dim = value.size()[0] + if network_alpha is not None and network_dim is not None: + break + if network_alpha is None: + network_alpha = network_dim + + scale = network_alpha / network_dim + + if dynamic_method: + logger.info( + f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}" + ) + + lora_down_weight = None + lora_up_weight = None + + o_lora_sd = lora_sd.copy() + block_down_name = None + block_up_name = None + + with torch.no_grad(): + for key, value in tqdm(lora_sd.items()): + weight_name = None + if "lora_down" in key: + block_down_name = key.rsplit(".lora_down", 1)[0] + weight_name = key.rsplit(".", 1)[-1] + lora_down_weight = value + else: + continue + + # find corresponding lora_up and alpha + block_up_name = block_down_name + lora_up_weight = lora_sd.get(block_up_name + ".lora_up." + weight_name, None) + lora_alpha = lora_sd.get(block_down_name + ".alpha", None) + + weights_loaded = lora_down_weight is not None and lora_up_weight is not None + + if weights_loaded: + + conv2d = len(lora_down_weight.size()) == 4 + if lora_alpha is None: + scale = 1.0 + else: + scale = lora_alpha / lora_down_weight.size()[0] + + if conv2d: + full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device) + param_dict = extract_conv(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale) + else: + full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device) + param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale) + + if verbose: + max_ratio = param_dict["max_ratio"] + sum_retained = param_dict["sum_retained"] + fro_retained = param_dict["fro_retained"] + if not np.isnan(fro_retained): + fro_list.append(float(fro_retained)) + + verbose_str += f"{block_down_name:75} | " + verbose_str += ( + f"sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}" + ) + + if verbose and dynamic_method: + verbose_str += f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n" + else: + verbose_str += f"\n" + + new_alpha = param_dict["new_alpha"] + o_lora_sd[block_down_name + "." + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous() + o_lora_sd[block_up_name + "." + "lora_up.weight"] = param_dict["lora_up"].to(save_dtype).contiguous() + o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(param_dict["new_alpha"]).to(save_dtype) + + block_down_name = None + block_up_name = None + lora_down_weight = None + lora_up_weight = None + weights_loaded = False + del param_dict + + if verbose: + logger.info(verbose_str) + + logger.info(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}") + logger.info("resizing complete") + return o_lora_sd, network_dim, new_alpha def resize(args): - if args.save_to is None or not (args.save_to.endswith('.ckpt') or args.save_to.endswith('.pt') or args.save_to.endswith('.pth') or args.save_to.endswith('.safetensors')): - raise Exception("The --save_to argument must be specified and must be a .ckpt , .pt, .pth or .safetensors file.") - - - def str_to_dtype(p): - if p == 'float': - return torch.float - if p == 'fp16': - return torch.float16 - if p == 'bf16': - return torch.bfloat16 - return None - - if args.dynamic_method and not args.dynamic_param: - raise Exception("If using dynamic_method, then dynamic_param is required") - - merge_dtype = str_to_dtype('float') # matmul method above only seems to work in float32 - save_dtype = str_to_dtype(args.save_precision) - if save_dtype is None: - save_dtype = merge_dtype - - logger.info("loading Model...") - lora_sd, metadata = load_state_dict(args.model, merge_dtype) - - logger.info("Resizing Lora...") - state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose) - - # update metadata - if metadata is None: - metadata = {} - - comment = metadata.get("ss_training_comment", "") - - if not args.dynamic_method: - metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}" - metadata["ss_network_dim"] = str(args.new_rank) - metadata["ss_network_alpha"] = str(new_alpha) - else: - metadata["ss_training_comment"] = f"Dynamic resize with {args.dynamic_method}: {args.dynamic_param} from {old_dim}; {comment}" - metadata["ss_network_dim"] = 'Dynamic' - metadata["ss_network_alpha"] = 'Dynamic' + if args.save_to is None or not ( + args.save_to.endswith(".ckpt") + or args.save_to.endswith(".pt") + or args.save_to.endswith(".pth") + or args.save_to.endswith(".safetensors") + ): + raise Exception("The --save_to argument must be specified and must be a .ckpt , .pt, .pth or .safetensors file.") + + def str_to_dtype(p): + if p == "float": + return torch.float + if p == "fp16": + return torch.float16 + if p == "bf16": + return torch.bfloat16 + return None + + if args.dynamic_method and not args.dynamic_param: + raise Exception("If using dynamic_method, then dynamic_param is required") + + merge_dtype = str_to_dtype("float") # matmul method above only seems to work in float32 + save_dtype = str_to_dtype(args.save_precision) + if save_dtype is None: + save_dtype = merge_dtype + + logger.info("loading Model...") + lora_sd, metadata = load_state_dict(args.model, merge_dtype) + + logger.info("Resizing Lora...") + state_dict, old_dim, new_alpha = resize_lora_model( + lora_sd, args.new_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose + ) + + # update metadata + if metadata is None: + metadata = {} + + comment = metadata.get("ss_training_comment", "") + + if not args.dynamic_method: + metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}" + metadata["ss_network_dim"] = str(args.new_rank) + metadata["ss_network_alpha"] = str(new_alpha) + else: + metadata["ss_training_comment"] = ( + f"Dynamic resize with {args.dynamic_method}: {args.dynamic_param} from {old_dim}; {comment}" + ) + metadata["ss_network_dim"] = "Dynamic" + metadata["ss_network_alpha"] = "Dynamic" - model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) - metadata["sshs_model_hash"] = model_hash - metadata["sshs_legacy_hash"] = legacy_hash + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash - logger.info(f"saving model to: {args.save_to}") - save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata) + logger.info(f"saving model to: {args.save_to}") + save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata) def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - - parser.add_argument("--save_precision", type=str, default=None, - choices=[None, "float", "fp16", "bf16"], help="precision in saving, float if omitted / 保存時の精度、未指定時はfloat") - parser.add_argument("--new_rank", type=int, default=4, - help="Specify rank of output LoRA / 出力するLoRAのrank (dim)") - parser.add_argument("--save_to", type=str, default=None, - help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors") - parser.add_argument("--model", type=str, default=None, - help="LoRA model to resize at to new rank: ckpt or safetensors file / 読み込むLoRAモデル、ckptまたはsafetensors") - parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") - parser.add_argument("--verbose", action="store_true", - help="Display verbose resizing information / rank変更時の詳細情報を出力する") - parser.add_argument("--dynamic_method", type=str, default=None, choices=[None, "sv_ratio", "sv_fro", "sv_cumulative"], - help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank") - parser.add_argument("--dynamic_param", type=float, default=None, - help="Specify target for dynamic reduction") - - return parser - - -if __name__ == '__main__': - parser = setup_parser() - - args = parser.parse_args() - resize(args) + parser = argparse.ArgumentParser() + + parser.add_argument( + "--save_precision", + type=str, + default=None, + choices=[None, "float", "fp16", "bf16"], + help="precision in saving, float if omitted / 保存時の精度、未指定時はfloat", + ) + parser.add_argument("--new_rank", type=int, default=4, help="Specify rank of output LoRA / 出力するLoRAのrank (dim)") + parser.add_argument( + "--save_to", + type=str, + default=None, + help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors", + ) + parser.add_argument( + "--model", + type=str, + default=None, + help="LoRA model to resize at to new rank: ckpt or safetensors file / 読み込むLoRAモデル、ckptまたはsafetensors", + ) + parser.add_argument( + "--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う" + ) + parser.add_argument( + "--verbose", action="store_true", help="Display verbose resizing information / rank変更時の詳細情報を出力する" + ) + parser.add_argument( + "--dynamic_method", + type=str, + default=None, + choices=[None, "sv_ratio", "sv_fro", "sv_cumulative"], + help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank", + ) + parser.add_argument("--dynamic_param", type=float, default=None, help="Specify target for dynamic reduction") + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + resize(args) From 52b379998944a9feac072110b05f77c406a2ec1d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 24 Feb 2024 20:49:41 +0900 Subject: [PATCH 068/103] fix format, add new conv rank to metadata comment --- networks/resize_lora.py | 168 ++++++++++++++++++++++++---------------- 1 file changed, 101 insertions(+), 67 deletions(-) diff --git a/networks/resize_lora.py b/networks/resize_lora.py index 5bf8b3c3a..d697baa4c 100644 --- a/networks/resize_lora.py +++ b/networks/resize_lora.py @@ -5,10 +5,12 @@ import os import argparse import torch -from safetensors.torch import load_file, save_file +from safetensors.torch import load_file, save_file, safe_open from tqdm import tqdm -from library import train_util import numpy as np + +from library import train_util +from library import model_util from library.utils import setup_logging setup_logging() @@ -36,16 +38,18 @@ def load_state_dict(file_name, dtype): return sd, metadata -def save_to_file(file_name, model, state_dict, dtype, metadata): + +def save_to_file(file_name, state_dict, dtype, metadata): if dtype is not None: for key in list(state_dict.keys()): if type(state_dict[key]) == torch.Tensor: state_dict[key] = state_dict[key].to(dtype) if model_util.is_safetensors(file_name): - save_file(model, file_name, metadata) + save_file(state_dict, file_name, metadata) else: - torch.save(model, file_name) + torch.save(state_dict, file_name) + # Indexing functions @@ -62,18 +66,18 @@ def index_sv_cumulative(S, target): def index_sv_fro(S, target): S_squared = S.pow(2) S_fro_sq = float(torch.sum(S_squared)) - sum_S_squared = torch.cumsum(S_squared, dim=0)/S_fro_sq + sum_S_squared = torch.cumsum(S_squared, dim=0) / S_fro_sq index = int(torch.searchsorted(sum_S_squared, target**2)) + 1 - index = max(1, min(index, len(S)-1)) + index = max(1, min(index, len(S) - 1)) return index def index_sv_ratio(S, target): max_sv = S[0] - min_sv = max_sv/target + min_sv = max_sv / target index = int(torch.sum(S > min_sv).item()) - index = max(1, min(index, len(S)-1)) + index = max(1, min(index, len(S) - 1)) return index @@ -169,10 +173,10 @@ def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1): if S[0] <= MIN_SV: # Zero matrix, set dim to 1 new_rank = 1 - new_alpha = float(scale*new_rank) + new_alpha = float(scale * new_rank) elif new_rank > rank: # cap max rank at rank new_rank = rank - new_alpha = float(scale*new_rank) + new_alpha = float(scale * new_rank) # Calculate resize info s_sum = torch.sum(torch.abs(S)) @@ -200,19 +204,21 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna # Extract loaded lora dim and alpha for key, value in lora_sd.items(): - if network_alpha is None and 'alpha' in key: + if network_alpha is None and "alpha" in key: network_alpha = value - if network_dim is None and 'lora_down' in key and len(value.size()) == 2: + if network_dim is None and "lora_down" in key and len(value.size()) == 2: network_dim = value.size()[0] if network_alpha is not None and network_dim is not None: break if network_alpha is None: network_alpha = network_dim - scale = network_alpha/network_dim + scale = network_alpha / network_dim if dynamic_method: - logger.info(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}") + logger.info( + f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}" + ) lora_down_weight = None lora_up_weight = None @@ -224,8 +230,8 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna with torch.no_grad(): for key, value in tqdm(lora_sd.items()): weight_name = None - if 'lora_down' in key: - block_down_name = key.rsplit('.lora_down', 1)[0] + if "lora_down" in key: + block_down_name = key.rsplit(".lora_down", 1)[0] weight_name = key.rsplit(".", 1)[-1] lora_down_weight = value else: @@ -233,18 +239,18 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna # find corresponding lora_up and alpha block_up_name = block_down_name - lora_up_weight = lora_sd.get(block_up_name + '.lora_up.' + weight_name, None) - lora_alpha = lora_sd.get(block_down_name + '.alpha', None) + lora_up_weight = lora_sd.get(block_up_name + ".lora_up." + weight_name, None) + lora_alpha = lora_sd.get(block_down_name + ".alpha", None) - weights_loaded = (lora_down_weight is not None and lora_up_weight is not None) + weights_loaded = lora_down_weight is not None and lora_up_weight is not None if weights_loaded: - conv2d = (len(lora_down_weight.size()) == 4) + conv2d = len(lora_down_weight.size()) == 4 if lora_alpha is None: scale = 1.0 else: - scale = lora_alpha/lora_down_weight.size()[0] + scale = lora_alpha / lora_down_weight.size()[0] if conv2d: full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device) @@ -254,24 +260,26 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale) if verbose: - max_ratio = param_dict['max_ratio'] - sum_retained = param_dict['sum_retained'] - fro_retained = param_dict['fro_retained'] + max_ratio = param_dict["max_ratio"] + sum_retained = param_dict["sum_retained"] + fro_retained = param_dict["fro_retained"] if not np.isnan(fro_retained): fro_list.append(float(fro_retained)) verbose_str += f"{block_down_name:75} | " - verbose_str += f"sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}" + verbose_str += ( + f"sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}" + ) if verbose and dynamic_method: verbose_str += f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n" else: verbose_str += "\n" - new_alpha = param_dict['new_alpha'] + new_alpha = param_dict["new_alpha"] o_lora_sd[block_down_name + "." + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous() o_lora_sd[block_up_name + "." + "lora_up.weight"] = param_dict["lora_up"].to(save_dtype).contiguous() - o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(param_dict['new_alpha']).to(save_dtype) + o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(param_dict["new_alpha"]).to(save_dtype) block_down_name = None block_up_name = None @@ -281,38 +289,36 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna del param_dict if verbose: - logger.info(verbose_str) - - logger.info(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}") + print(verbose_str) + print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}") logger.info("resizing complete") return o_lora_sd, network_dim, new_alpha def resize(args): - if ( - args.save_to is None or - not (args.save_to.endswith('.ckpt') or - args.save_to.endswith('.pt') or - args.save_to.endswith('.pth') or - args.save_to.endswith('.safetensors')) - ): + if args.save_to is None or not ( + args.save_to.endswith(".ckpt") + or args.save_to.endswith(".pt") + or args.save_to.endswith(".pth") + or args.save_to.endswith(".safetensors") + ): raise Exception("The --save_to argument must be specified and must be a .ckpt , .pt, .pth or .safetensors file.") args.new_conv_rank = args.new_conv_rank if args.new_conv_rank is not None else args.new_rank def str_to_dtype(p): - if p == 'float': + if p == "float": return torch.float - if p == 'fp16': + if p == "fp16": return torch.float16 - if p == 'bf16': + if p == "bf16": return torch.bfloat16 return None if args.dynamic_method and not args.dynamic_param: raise Exception("If using dynamic_method, then dynamic_param is required") - merge_dtype = str_to_dtype('float') # matmul method above only seems to work in float32 + merge_dtype = str_to_dtype("float") # matmul method above only seems to work in float32 save_dtype = str_to_dtype(args.save_precision) if save_dtype is None: save_dtype = merge_dtype @@ -321,7 +327,9 @@ def str_to_dtype(p): lora_sd, metadata = load_state_dict(args.model, merge_dtype) logger.info("Resizing Lora...") - state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, args.new_conv_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose) + state_dict, old_dim, new_alpha = resize_lora_model( + lora_sd, args.new_rank, args.new_conv_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose + ) # update metadata if metadata is None: @@ -330,47 +338,73 @@ def str_to_dtype(p): comment = metadata.get("ss_training_comment", "") if not args.dynamic_method: - metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}" + conv_desc = "" if args.new_rank == args.new_conv_rank else f" (conv: {args.new_conv_rank})" + metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}{conv_desc}; {comment}" metadata["ss_network_dim"] = str(args.new_rank) metadata["ss_network_alpha"] = str(new_alpha) else: - metadata["ss_training_comment"] = f"Dynamic resize with {args.dynamic_method}: {args.dynamic_param} from {old_dim}; {comment}" - metadata["ss_network_dim"] = 'Dynamic' - metadata["ss_network_alpha"] = 'Dynamic' + metadata["ss_training_comment"] = ( + f"Dynamic resize with {args.dynamic_method}: {args.dynamic_param} from {old_dim}; {comment}" + ) + metadata["ss_network_dim"] = "Dynamic" + metadata["ss_network_alpha"] = "Dynamic" model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) metadata["sshs_model_hash"] = model_hash metadata["sshs_legacy_hash"] = legacy_hash logger.info(f"saving model to: {args.save_to}") - save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata) + save_to_file(args.save_to, state_dict, save_dtype, metadata) def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() - parser.add_argument("--save_precision", type=str, default=None, - choices=[None, "float", "fp16", "bf16"], help="precision in saving, float if omitted / 保存時の精度、未指定時はfloat") - parser.add_argument("--new_rank", type=int, default=4, - help="Specify rank of output LoRA / 出力するLoRAのrank (dim)") - parser.add_argument("--new_conv_rank", type=int, default=None, - help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ") - parser.add_argument("--save_to", type=str, default=None, - help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors") - parser.add_argument("--model", type=str, default=None, - help="LoRA model to resize at to new rank: ckpt or safetensors file / 読み込むLoRAモデル、ckptまたはsafetensors") - parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") - parser.add_argument("--verbose", action="store_true", - help="Display verbose resizing information / rank変更時の詳細情報を出力する") - parser.add_argument("--dynamic_method", type=str, default=None, choices=[None, "sv_ratio", "sv_fro", "sv_cumulative"], - help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank") - parser.add_argument("--dynamic_param", type=float, default=None, - help="Specify target for dynamic reduction") + parser.add_argument( + "--save_precision", + type=str, + default=None, + choices=[None, "float", "fp16", "bf16"], + help="precision in saving, float if omitted / 保存時の精度、未指定時はfloat", + ) + parser.add_argument("--new_rank", type=int, default=4, help="Specify rank of output LoRA / 出力するLoRAのrank (dim)") + parser.add_argument( + "--new_conv_rank", + type=int, + default=None, + help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ", + ) + parser.add_argument( + "--save_to", + type=str, + default=None, + help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors", + ) + parser.add_argument( + "--model", + type=str, + default=None, + help="LoRA model to resize at to new rank: ckpt or safetensors file / 読み込むLoRAモデル、ckptまたはsafetensors", + ) + parser.add_argument( + "--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う" + ) + parser.add_argument( + "--verbose", action="store_true", help="Display verbose resizing information / rank変更時の詳細情報を出力する" + ) + parser.add_argument( + "--dynamic_method", + type=str, + default=None, + choices=[None, "sv_ratio", "sv_fro", "sv_cumulative"], + help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank", + ) + parser.add_argument("--dynamic_param", type=float, default=None, help="Specify target for dynamic reduction") return parser - -if __name__ == '__main__': + +if __name__ == "__main__": parser = setup_parser() args = parser.parse_args() From 8b7c14246ae4e48375ffde13e4fdce2c8cd29b17 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 24 Feb 2024 20:50:00 +0900 Subject: [PATCH 069/103] some log output to print --- networks/check_lora_weights.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/networks/check_lora_weights.py b/networks/check_lora_weights.py index 6ec60a89b..794659c94 100644 --- a/networks/check_lora_weights.py +++ b/networks/check_lora_weights.py @@ -20,16 +20,16 @@ def main(file): for key in keys: if "lora_up" in key or "lora_down" in key: values.append((key, sd[key])) - logger.info(f"number of LoRA modules: {len(values)}") + print(f"number of LoRA modules: {len(values)}") if args.show_all_keys: for key in [k for k in keys if k not in values]: values.append((key, sd[key])) - logger.info(f"number of all modules: {len(values)}") + print(f"number of all modules: {len(values)}") for key, value in values: value = value.to(torch.float32) - logger.info(f"{key},{str(tuple(value.size())).replace(', ', '-')},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}") + print(f"{key},{str(tuple(value.size())).replace(', ', '-')},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}") def setup_parser() -> argparse.ArgumentParser: From 81e8af651942c44e7ef2f0068b497de2d12cf0fa Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 24 Feb 2024 20:51:26 +0900 Subject: [PATCH 070/103] fix ipex init --- gen_img.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gen_img.py b/gen_img.py index a24220a0a..daf88d2a1 100644 --- a/gen_img.py +++ b/gen_img.py @@ -20,7 +20,7 @@ import numpy as np import torch -from library.ipex_interop import init_ipex +from library.device_utils import init_ipex, clean_memory, get_preferred_device init_ipex() @@ -338,7 +338,7 @@ def __init__( self.clip_vision_model: CLIPVisionModelWithProjection = None self.clip_vision_processor: CLIPImageProcessor = None self.clip_vision_strength = 0.0 - + # Textual Inversion self.token_replacements_list = [] for _ in range(len(self.text_encoders)): From a21218bdd5d451f431c35577f4cccf8c60862ef8 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 24 Feb 2024 21:09:59 +0900 Subject: [PATCH 071/103] update readme --- README.md | 45 +++++---------------------------------------- 1 file changed, 5 insertions(+), 40 deletions(-) diff --git a/README.md b/README.md index 4a53c1542..e635e5aed 100644 --- a/README.md +++ b/README.md @@ -249,7 +249,7 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum ## Change History -### Working in progress +### Feb 24, 2024 / 2024/2/24: v0.8.4 - The log output has been improved. PR [#905](https://github.com/kohya-ss/sd-scripts/pull/905) Thanks to shirayu! - The log is formatted by default. The `rich` library is required. Please see [Upgrade](#upgrade) and update the library. @@ -260,10 +260,12 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum - `--console_log_file` option can be used to output the log to a file. The default is `None` (output to the console). - The sample image generation during multi-GPU training is now done with multiple GPUs. PR [#1061](https://github.com/kohya-ss/sd-scripts/pull/1061) Thanks to DKnight54! - The support for mps devices is improved. PR [#1054](https://github.com/kohya-ss/sd-scripts/pull/1054) Thanks to akx! If mps device exists instead of CUDA, the mps device is used automatically. +- The `--new_conv_rank` option to specify the new rank of Conv2d is added to `networks/resize_lora.py`. PR [#1102](https://github.com/kohya-ss/sd-scripts/pull/1102) Thanks to mgz-dev! - An option `--highvram` to disable the optimization for environments with little VRAM is added to the training scripts. If you specify it when there is enough VRAM, the operation will be faster. - Currently, only the cache part of latents is optimized. - The IPEX support is improved. PR [#1086](https://github.com/kohya-ss/sd-scripts/pull/1086) Thanks to Disty0! - Fixed a bug that `svd_merge_lora.py` crashes in some cases. PR [#1087](https://github.com/kohya-ss/sd-scripts/pull/1087) Thanks to mgz-dev! +- DyLoRA is fixed to work with SDXL. PR [#1126](https://github.com/kohya-ss/sd-scripts/pull/1126) Thanks to tamlog06! - The common image generation script `gen_img.py` for SD 1/2 and SDXL is added. The basic functions are the same as the scripts for SD 1/2 and SDXL, but some new features are added. - External scripts to generate prompts can be supported. It can be called with `--from_module` option. (The documentation will be added later) - The normalization method after prompt weighting can be specified with `--emb_normalize_mode` option. `original` is the original method, `abs` is the normalization with the average of the absolute values, `none` is no normalization. @@ -278,10 +280,12 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum - `--console_log_file` でログファイルを出力できます。デフォルトは `None`(コンソールに出力) です。 - 複数 GPU 学習時に学習中のサンプル画像生成を複数 GPU で行うようになりました。 PR [#1061](https://github.com/kohya-ss/sd-scripts/pull/1061) DKnight54 氏に感謝します。 - mps デバイスのサポートが改善されました。 PR [#1054](https://github.com/kohya-ss/sd-scripts/pull/1054) akx 氏に感謝します。CUDA ではなく mps が存在する場合には自動的に mps デバイスを使用します。 +- `networks/resize_lora.py` に Conv2d の新しいランクを指定するオプション `--new_conv_rank` が追加されました。 PR [#1102](https://github.com/kohya-ss/sd-scripts/pull/1102) mgz-dev 氏に感謝します。 - 学習スクリプトに VRAMが少ない環境向け最適化を無効にするオプション `--highvram` を追加しました。VRAM に余裕がある場合に指定すると動作が高速化されます。 - 現在は latents のキャッシュ部分のみ高速化されます。 - IPEX サポートが改善されました。 PR [#1086](https://github.com/kohya-ss/sd-scripts/pull/1086) Disty0 氏に感謝します。 - `svd_merge_lora.py` が場合によってエラーになる不具合が修正されました。 PR [#1087](https://github.com/kohya-ss/sd-scripts/pull/1087) mgz-dev 氏に感謝します。 +- DyLoRA が SDXL で動くよう修正されました。PR [#1126](https://github.com/kohya-ss/sd-scripts/pull/1126) tamlog06 氏に感謝します。 - SD 1/2 および SDXL 共通の生成スクリプト `gen_img.py` を追加しました。基本的な機能は SD 1/2、SDXL 向けスクリプトと同じですが、いくつかの新機能が追加されています。 - プロンプトを動的に生成する外部スクリプトをサポートしました。 `--from_module` で呼び出せます。(ドキュメントはのちほど追加します) - プロンプト重みづけ後の正規化方法を `--emb_normalize_mode` で指定できます。`original` は元の方法、`abs` は絶対値の平均値で正規化、`none` は正規化を行いません。 @@ -358,45 +362,6 @@ network_multiplier = -1.0 ``` -### Jan 17, 2024 / 2024/1/17: v0.8.1 - -- Fixed a bug that the VRAM usage without Text Encoder training is larger than before in training scripts for LoRA etc (`train_network.py`, `sdxl_train_network.py`). - - Text Encoders were not moved to CPU. -- Fixed typos. Thanks to akx! [PR #1053](https://github.com/kohya-ss/sd-scripts/pull/1053) - -- LoRA 等の学習スクリプト(`train_network.py`、`sdxl_train_network.py`)で、Text Encoder を学習しない場合の VRAM 使用量が以前に比べて大きくなっていた不具合を修正しました。 - - Text Encoder が GPU に保持されたままになっていました。 -- 誤字が修正されました。 [PR #1053](https://github.com/kohya-ss/sd-scripts/pull/1053) akx 氏に感謝します。 - -### Jan 15, 2024 / 2024/1/15: v0.8.0 - -- Diffusers, Accelerate, Transformers and other related libraries have been updated. Please update the libraries with [Upgrade](#upgrade). - - Some model files (Text Encoder without position_id) based on the latest Transformers can be loaded. -- `torch.compile` is supported (experimental). PR [#1024](https://github.com/kohya-ss/sd-scripts/pull/1024) Thanks to p1atdev! - - This feature works only on Linux or WSL. - - Please specify `--torch_compile` option in each training script. - - You can select the backend with `--dynamo_backend` option. The default is `"inductor"`. `inductor` or `eager` seems to work. - - Please use `--sdpa` option instead of `--xformers` option. - - PyTorch 2.1 or later is recommended. - - Please see [PR](https://github.com/kohya-ss/sd-scripts/pull/1024) for details. -- The session name for wandb can be specified with `--wandb_run_name` option. PR [#1032](https://github.com/kohya-ss/sd-scripts/pull/1032) Thanks to hopl1t! -- IPEX library is updated. PR [#1030](https://github.com/kohya-ss/sd-scripts/pull/1030) Thanks to Disty0! -- Fixed a bug that Diffusers format model cannot be saved. - -- Diffusers、Accelerate、Transformers 等の関連ライブラリを更新しました。[Upgrade](#upgrade) を参照し更新をお願いします。 - - 最新の Transformers を前提とした一部のモデルファイル(Text Encoder が position_id を持たないもの)が読み込めるようになりました。 -- `torch.compile` がサポートされしました(実験的)。 PR [#1024](https://github.com/kohya-ss/sd-scripts/pull/1024) p1atdev 氏に感謝します。 - - Linux または WSL でのみ動作します。 - - 各学習スクリプトで `--torch_compile` オプションを指定してください。 - - `--dynamo_backend` オプションで使用される backend を選択できます。デフォルトは `"inductor"` です。 `inductor` または `eager` が動作するようです。 - - `--xformers` オプションとは互換性がありません。 代わりに `--sdpa` オプションを使用してください。 - - PyTorch 2.1以降を推奨します。 - - 詳細は [PR](https://github.com/kohya-ss/sd-scripts/pull/1024) をご覧ください。 -- wandb 保存時のセッション名が各学習スクリプトの `--wandb_run_name` オプションで指定できるようになりました。 PR [#1032](https://github.com/kohya-ss/sd-scripts/pull/1032) hopl1t 氏に感謝します。 -- IPEX ライブラリが更新されました。[PR #1030](https://github.com/kohya-ss/sd-scripts/pull/1030) Disty0 氏に感謝します。 -- Diffusers 形式でのモデル保存ができなくなっていた不具合を修正しました。 - - Please read [Releases](https://github.com/kohya-ss/sd-scripts/releases) for recent updates. 最近の更新情報は [Release](https://github.com/kohya-ss/sd-scripts/releases) をご覧ください。 From 5d5f39b6e6bccfc4d4265f8a7ce651acec132f39 Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Sun, 25 Feb 2024 00:50:36 +0900 Subject: [PATCH 072/103] Replaced print with logger --- gen_img.py | 194 +++++++++++++++++++++------------------- gen_img_diffusers.py | 22 ++--- library/device_utils.py | 11 ++- networks/dylora.py | 4 +- sdxl_gen_img.py | 32 +++---- 5 files changed, 138 insertions(+), 125 deletions(-) diff --git a/gen_img.py b/gen_img.py index daf88d2a1..4fe898716 100644 --- a/gen_img.py +++ b/gen_img.py @@ -61,6 +61,12 @@ from library.original_unet import FlashAttentionFunction from networks.control_net_lllite import ControlNetLLLite from library.utils import GradualLatent, EulerAncestralDiscreteSchedulerGL +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) # scheduler: SCHEDULER_LINEAR_START = 0.00085 @@ -82,12 +88,12 @@ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa): if mem_eff_attn: - print("Enable memory efficient attention for U-Net") + logger.info("Enable memory efficient attention for U-Net") # これはDiffusersのU-Netではなく自前のU-Netなので置き換えなくても良い unet.set_use_memory_efficient_attention(False, True) elif xformers: - print("Enable xformers for U-Net") + logger.info("Enable xformers for U-Net") try: import xformers.ops except ImportError: @@ -95,7 +101,7 @@ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditio unet.set_use_memory_efficient_attention(True, False) elif sdpa: - print("Enable SDPA for U-Net") + logger.info("Enable SDPA for U-Net") unet.set_use_memory_efficient_attention(False, False) unet.set_use_sdpa(True) @@ -112,7 +118,7 @@ def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xform def replace_vae_attn_to_memory_efficient(): - print("VAE Attention.forward has been replaced to FlashAttention (not xformers)") + logger.info("VAE Attention.forward has been replaced to FlashAttention (not xformers)") flash_func = FlashAttentionFunction def forward_flash_attn(self, hidden_states, **kwargs): @@ -168,7 +174,7 @@ def forward_flash_attn_0_14(self, hidden_states, **kwargs): def replace_vae_attn_to_xformers(): - print("VAE: Attention.forward has been replaced to xformers") + logger.info("VAE: Attention.forward has been replaced to xformers") import xformers.ops def forward_xformers(self, hidden_states, **kwargs): @@ -224,7 +230,7 @@ def forward_xformers_0_14(self, hidden_states, **kwargs): def replace_vae_attn_to_sdpa(): - print("VAE: Attention.forward has been replaced to sdpa") + logger.info("VAE: Attention.forward has been replaced to sdpa") def forward_sdpa(self, hidden_states, **kwargs): residual = hidden_states @@ -386,10 +392,10 @@ def set_control_net_lllites(self, ctrl_net_lllites): def set_gradual_latent(self, gradual_latent): if gradual_latent is None: - print("gradual_latent is disabled") + logger.info("gradual_latent is disabled") self.gradual_latent = None else: - print(f"gradual_latent is enabled: {gradual_latent}") + logger.info(f"gradual_latent is enabled: {gradual_latent}") self.gradual_latent = gradual_latent # (ds_ratio, start_timesteps, every_n_steps, ratio_step) @torch.no_grad() @@ -467,7 +473,7 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 if not do_classifier_free_guidance and negative_scale is not None: - print(f"negative_scale is ignored if guidance scalle <= 1.0") + logger.warning(f"negative_scale is ignored if guidance scalle <= 1.0") negative_scale = None # get unconditional embeddings for classifier free guidance @@ -576,7 +582,7 @@ def __call__( text_pool = text_pool[num_sub_prompts - 1 :: num_sub_prompts] # last subprompt if init_image is not None and self.clip_vision_model is not None: - print(f"encode by clip_vision_model and apply clip_vision_strength={self.clip_vision_strength}") + logger.info(f"encode by clip_vision_model and apply clip_vision_strength={self.clip_vision_strength}") vision_input = self.clip_vision_processor(init_image, return_tensors="pt", device=self.device) pixel_values = vision_input["pixel_values"].to(self.device, dtype=text_embeddings.dtype) @@ -742,8 +748,8 @@ def __call__( enable_gradual_latent = False if self.gradual_latent: if not hasattr(self.scheduler, "set_gradual_latent_params"): - print("gradual_latent is not supported for this scheduler. Ignoring.") - print(self.scheduler.__class__.__name__) + logger.warning("gradual_latent is not supported for this scheduler. Ignoring.") + logger.warning(f"{self.scheduler.__class__.__name__}") else: enable_gradual_latent = True step_elapsed = 1000 @@ -792,7 +798,7 @@ def __call__( if not enabled or ratio >= 1.0: continue if ratio < i / len(timesteps): - print(f"ControlNetLLLite {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})") + logger.info(f"ControlNetLLLite {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})") control_net.set_cond_image(None) each_control_net_enabled[j] = False @@ -1013,7 +1019,7 @@ def get_prompts_with_weights(tokenizer: CLIPTokenizer, token_replacer, prompt: L if word.strip() == "BREAK": # pad until next multiple of tokenizer's max token length pad_len = tokenizer.model_max_length - (len(text_token) % tokenizer.model_max_length) - print(f"BREAK pad_len: {pad_len}") + logger.info(f"BREAK pad_len: {pad_len}") for i in range(pad_len): # v2のときEOSをつけるべきかどうかわからないぜ # if i == 0: @@ -1043,7 +1049,7 @@ def get_prompts_with_weights(tokenizer: CLIPTokenizer, token_replacer, prompt: L tokens.append(text_token) weights.append(text_weight) if truncated: - print("warning: Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") + logger.warning("warning: Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") return tokens, weights @@ -1344,7 +1350,7 @@ def handle_dynamic_prompt_variants(prompt, repeat_count): elif len(count_range) == 2: count_range = [int(count_range[0]), int(count_range[1])] else: - print(f"invalid count range: {count_range}") + logger.warning(f"invalid count range: {count_range}") count_range = [1, 1] if count_range[0] > count_range[1]: count_range = [count_range[1], count_range[0]] @@ -1488,9 +1494,9 @@ def main(args): # assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません" if args.v_parameterization and not args.v2: - print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません") + logger.warning("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません") if args.v2 and args.clip_skip is not None: - print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") + logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") # モデルを読み込む if not os.path.exists(args.ckpt): # ファイルがないならパターンで探し、一つだけ該当すればそれを使う @@ -1510,7 +1516,7 @@ def main(args): else: # if `text_encoder_2` subdirectory exists, sdxl is_sdxl = os.path.isdir(os.path.join(name_or_path, "text_encoder_2")) - print(f"SDXL: {is_sdxl}") + logger.info(f"SDXL: {is_sdxl}") if is_sdxl: if args.clip_skip is None: @@ -1526,10 +1532,10 @@ def main(args): args.clip_skip = 2 if args.v2 else 1 if use_stable_diffusion_format: - print("load StableDiffusion checkpoint") + logger.info("load StableDiffusion checkpoint") text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.ckpt) else: - print("load Diffusers pretrained models") + logger.info("load Diffusers pretrained models") loading_pipe = StableDiffusionPipeline.from_pretrained(args.ckpt, safety_checker=None, torch_dtype=dtype) text_encoder = loading_pipe.text_encoder vae = loading_pipe.vae @@ -1553,7 +1559,7 @@ def main(args): # VAEを読み込む if args.vae is not None: vae = model_util.load_vae(args.vae, dtype) - print("additional VAE loaded") + logger.info("additional VAE loaded") # xformers、Hypernetwork対応 if not args.diffusers_xformers: @@ -1562,7 +1568,7 @@ def main(args): replace_vae_modules(vae, mem_eff, args.xformers, args.sdpa) # tokenizerを読み込む - print("loading tokenizer") + logger.info("loading tokenizer") if is_sdxl: tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) tokenizers = [tokenizer1, tokenizer2] @@ -1654,7 +1660,7 @@ def randn(self, shape, device=None, dtype=None, layout=None, generator=None): noise = None if noise == None: - print(f"unexpected noise request: {self.sampler_noise_index}, {shape}") + logger.warning(f"unexpected noise request: {self.sampler_noise_index}, {shape}") noise = torch.randn(shape, dtype=dtype, device=device, generator=generator) self.sampler_noise_index += 1 @@ -1715,7 +1721,7 @@ def __getattr__(self, item): vae_dtype = dtype if args.no_half_vae: - print("set vae_dtype to float32") + logger.info("set vae_dtype to float32") vae_dtype = torch.float32 vae.to(vae_dtype).to(device) vae.eval() @@ -1739,10 +1745,10 @@ def __getattr__(self, item): network_merge = args.network_merge_n_models else: network_merge = 0 - print(f"network_merge: {network_merge}") + logger.info(f"network_merge: {network_merge}") for i, network_module in enumerate(args.network_module): - print("import network module:", network_module) + logger.info("import network module: {network_module}") imported_module = importlib.import_module(network_module) network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i] @@ -1760,7 +1766,7 @@ def __getattr__(self, item): raise ValueError("No weight. Weight is required.") network_weight = args.network_weights[i] - print("load network weights from:", network_weight) + logger.info(f"load network weights from: {network_weight}") if model_util.is_safetensors(network_weight) and args.network_show_meta: from safetensors.torch import safe_open @@ -1768,7 +1774,7 @@ def __getattr__(self, item): with safe_open(network_weight, framework="pt") as f: metadata = f.metadata() if metadata is not None: - print(f"metadata for: {network_weight}: {metadata}") + logger.info(f"metadata for: {network_weight}: {metadata}") network, weights_sd = imported_module.create_network_from_weights( network_mul, network_weight, vae, text_encoders, unet, for_inference=True, **net_kwargs @@ -1778,20 +1784,20 @@ def __getattr__(self, item): mergeable = network.is_mergeable() if network_merge and not mergeable: - print("network is not mergiable. ignore merge option.") + logger.warning("network is not mergiable. ignore merge option.") if not mergeable or i >= network_merge: # not merging network.apply_to(text_encoders, unet) info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい - print(f"weights are loaded: {info}") + logger.info(f"weights are loaded: {info}") if args.opt_channels_last: network.to(memory_format=torch.channels_last) network.to(dtype).to(device) if network_pre_calc: - print("backup original weights") + logger.info("backup original weights") network.backup_weights() networks.append(network) @@ -1805,7 +1811,7 @@ def __getattr__(self, item): # upscalerの指定があれば取得する upscaler = None if args.highres_fix_upscaler: - print("import upscaler module:", args.highres_fix_upscaler) + logger.info("import upscaler module: {args.highres_fix_upscaler}") imported_module = importlib.import_module(args.highres_fix_upscaler) us_kwargs = {} @@ -1814,7 +1820,7 @@ def __getattr__(self, item): key, value = net_arg.split("=") us_kwargs[key] = value - print("create upscaler") + logger.info("create upscaler") upscaler = imported_module.create_upscaler(**us_kwargs) upscaler.to(dtype).to(device) @@ -1833,7 +1839,7 @@ def __getattr__(self, item): control_net_lllites: List[Tuple[ControlNetLLLite, float]] = [] if args.control_net_lllite_models: for i, model_file in enumerate(args.control_net_lllite_models): - print(f"loading ControlNet-LLLite: {model_file}") + logger.info(f"loading ControlNet-LLLite: {model_file}") from safetensors.torch import load_file @@ -1867,7 +1873,7 @@ def __getattr__(self, item): ), "ControlNet and ControlNet-LLLite cannot be used at the same time" if args.opt_channels_last: - print(f"set optimizing: channels last") + logger.info(f"set optimizing: channels last") for text_encoder in text_encoders: text_encoder.to(memory_format=torch.channels_last) vae.to(memory_format=torch.channels_last) @@ -1894,7 +1900,7 @@ def __getattr__(self, item): ) pipe.set_control_nets(control_nets) pipe.set_control_net_lllites(control_net_lllites) - print("pipeline is ready.") + logger.info("pipeline is ready.") if args.diffusers_xformers: pipe.enable_xformers_memory_efficient_attention() @@ -1965,7 +1971,7 @@ def __getattr__(self, item): token_ids1 = tokenizers[0].convert_tokens_to_ids(token_strings) token_ids2 = tokenizers[1].convert_tokens_to_ids(token_strings) if is_sdxl else None - print(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids1} and {token_ids2}") + logger.info(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids1} and {token_ids2}") assert ( min(token_ids1) == token_ids1[0] and token_ids1[-1] == token_ids1[0] + len(token_ids1) - 1 ), f"token ids1 is not ordered" @@ -2002,7 +2008,7 @@ def __getattr__(self, item): # promptを取得する prompt_list = None if args.from_file is not None: - print(f"reading prompts from {args.from_file}") + logger.info(f"reading prompts from {args.from_file}") with open(args.from_file, "r", encoding="utf-8") as f: prompt_list = f.read().splitlines() prompt_list = [d for d in prompt_list if len(d.strip()) > 0 and d[0] != "#"] @@ -2019,7 +2025,7 @@ def load_module_from_path(module_name, file_path): spec.loader.exec_module(module) return module - print(f"reading prompts from module: {args.from_module}") + logger.info(f"reading prompts from module: {args.from_module}") prompt_module = load_module_from_path("prompt_module", args.from_module) prompter = prompt_module.get_prompter(args, pipe, networks) @@ -2050,7 +2056,7 @@ def load_images(path): for p in paths: image = Image.open(p) if image.mode != "RGB": - print(f"convert image to RGB from {image.mode}: {p}") + logger.info(f"convert image to RGB from {image.mode}: {p}") image = image.convert("RGB") images.append(image) @@ -2066,14 +2072,14 @@ def resize_images(imgs, size): return resized if args.image_path is not None: - print(f"load image for img2img: {args.image_path}") + logger.info(f"load image for img2img: {args.image_path}") init_images = load_images(args.image_path) assert len(init_images) > 0, f"No image / 画像がありません: {args.image_path}" - print(f"loaded {len(init_images)} images for img2img") + logger.info(f"loaded {len(init_images)} images for img2img") # CLIP Vision if args.clip_vision_strength is not None: - print(f"load CLIP Vision model: {CLIP_VISION_MODEL}") + logger.info(f"load CLIP Vision model: {CLIP_VISION_MODEL}") vision_model = CLIPVisionModelWithProjection.from_pretrained(CLIP_VISION_MODEL, projection_dim=1280) vision_model.to(device, dtype) processor = CLIPImageProcessor.from_pretrained(CLIP_VISION_MODEL) @@ -2081,22 +2087,22 @@ def resize_images(imgs, size): pipe.clip_vision_model = vision_model pipe.clip_vision_processor = processor pipe.clip_vision_strength = args.clip_vision_strength - print(f"CLIP Vision model loaded.") + logger.info(f"CLIP Vision model loaded.") else: init_images = None if args.mask_path is not None: - print(f"load mask for inpainting: {args.mask_path}") + logger.info(f"load mask for inpainting: {args.mask_path}") mask_images = load_images(args.mask_path) assert len(mask_images) > 0, f"No mask image / マスク画像がありません: {args.image_path}" - print(f"loaded {len(mask_images)} mask images for inpainting") + logger.info(f"loaded {len(mask_images)} mask images for inpainting") else: mask_images = None # promptがないとき、画像のPngInfoから取得する if init_images is not None and prompter is None and not args.interactive: - print("get prompts from images' metadata") + logger.info("get prompts from images' metadata") prompt_list = [] for img in init_images: if "prompt" in img.text: @@ -2127,17 +2133,17 @@ def resize_images(imgs, size): h = int(h * args.highres_fix_scale + 0.5) if init_images is not None: - print(f"resize img2img source images to {w}*{h}") + logger.info(f"resize img2img source images to {w}*{h}") init_images = resize_images(init_images, (w, h)) if mask_images is not None: - print(f"resize img2img mask images to {w}*{h}") + logger.info(f"resize img2img mask images to {w}*{h}") mask_images = resize_images(mask_images, (w, h)) regional_network = False if networks and mask_images: # mask を領域情報として流用する、現在は一回のコマンド呼び出しで1枚だけ対応 regional_network = True - print("use mask as region") + logger.info("use mask as region") size = None for i, network in enumerate(networks): @@ -2162,14 +2168,14 @@ def resize_images(imgs, size): prev_image = None # for VGG16 guided if args.guide_image_path is not None: - print(f"load image for ControlNet guidance: {args.guide_image_path}") + logger.info(f"load image for ControlNet guidance: {args.guide_image_path}") guide_images = [] for p in args.guide_image_path: guide_images.extend(load_images(p)) - print(f"loaded {len(guide_images)} guide images for guidance") + logger.info(f"loaded {len(guide_images)} guide images for guidance") if len(guide_images) == 0: - print( + logger.warning( f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}" ) guide_images = None @@ -2200,7 +2206,7 @@ def fixed_seed(*args, **kwargs): max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples for gen_iter in range(args.n_iter): - print(f"iteration {gen_iter+1}/{args.n_iter}") + logger.info(f"iteration {gen_iter+1}/{args.n_iter}") if args.iter_same_seed: iter_seed = seed_random.randint(0, 2**32 - 1) else: @@ -2219,7 +2225,7 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): # 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す is_1st_latent = upscaler.support_latents() if upscaler else args.highres_fix_latents_upscaling - print("process 1st stage") + logger.info("process 1st stage") batch_1st = [] for _, base, ext in batch: @@ -2264,7 +2270,7 @@ def scale_and_round(x): images_1st = process_batch(batch_1st, True, True) # 2nd stageのバッチを作成して以下処理する - print("process 2nd stage") + logger.info("process 2nd stage") width_2nd, height_2nd = batch[0].ext.width, batch[0].ext.height if upscaler: @@ -2437,7 +2443,7 @@ def scale_and_round(x): n.restore_weights() for n in networks: n.pre_calculation() - print("pre-calculation... done") + logger.info("pre-calculation... done") images = pipe( prompts, @@ -2520,7 +2526,7 @@ def scale_and_round(x): cv2.waitKey() cv2.destroyAllWindows() except ImportError: - print( + logger.warning( "opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません" ) @@ -2535,7 +2541,7 @@ def scale_and_round(x): # interactive valid = False while not valid: - print("\nType prompt:") + logger.info("\nType prompt:") try: raw_prompt = input() except EOFError: @@ -2595,74 +2601,74 @@ def scale_and_round(x): prompt_args = raw_prompt.strip().split(" --") prompt = prompt_args[0] length = len(prompter) if hasattr(prompter, "__len__") else 0 - print(f"prompt {prompt_index+1}/{length}: {prompt}") + logger.info(f"prompt {prompt_index+1}/{length}: {prompt}") for parg in prompt_args[1:]: try: m = re.match(r"w (\d+)", parg, re.IGNORECASE) if m: width = int(m.group(1)) - print(f"width: {width}") + logger.info(f"width: {width}") continue m = re.match(r"h (\d+)", parg, re.IGNORECASE) if m: height = int(m.group(1)) - print(f"height: {height}") + logger.info(f"height: {height}") continue m = re.match(r"ow (\d+)", parg, re.IGNORECASE) if m: original_width = int(m.group(1)) - print(f"original width: {original_width}") + logger.info(f"original width: {original_width}") continue m = re.match(r"oh (\d+)", parg, re.IGNORECASE) if m: original_height = int(m.group(1)) - print(f"original height: {original_height}") + logger.info(f"original height: {original_height}") continue m = re.match(r"nw (\d+)", parg, re.IGNORECASE) if m: original_width_negative = int(m.group(1)) - print(f"original width negative: {original_width_negative}") + logger.info(f"original width negative: {original_width_negative}") continue m = re.match(r"nh (\d+)", parg, re.IGNORECASE) if m: original_height_negative = int(m.group(1)) - print(f"original height negative: {original_height_negative}") + logger.info(f"original height negative: {original_height_negative}") continue m = re.match(r"ct (\d+)", parg, re.IGNORECASE) if m: crop_top = int(m.group(1)) - print(f"crop top: {crop_top}") + logger.info(f"crop top: {crop_top}") continue m = re.match(r"cl (\d+)", parg, re.IGNORECASE) if m: crop_left = int(m.group(1)) - print(f"crop left: {crop_left}") + logger.info(f"crop left: {crop_left}") continue m = re.match(r"s (\d+)", parg, re.IGNORECASE) if m: # steps steps = max(1, min(1000, int(m.group(1)))) - print(f"steps: {steps}") + logger.info(f"steps: {steps}") continue m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE) if m: # seed seeds = [int(d) for d in m.group(1).split(",")] - print(f"seeds: {seeds}") + logger.info(f"seeds: {seeds}") continue m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) if m: # scale scale = float(m.group(1)) - print(f"scale: {scale}") + logger.info(f"scale: {scale}") continue m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE) @@ -2671,25 +2677,25 @@ def scale_and_round(x): negative_scale = None else: negative_scale = float(m.group(1)) - print(f"negative scale: {negative_scale}") + logger.info(f"negative scale: {negative_scale}") continue m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE) if m: # strength strength = float(m.group(1)) - print(f"strength: {strength}") + logger.info(f"strength: {strength}") continue m = re.match(r"n (.+)", parg, re.IGNORECASE) if m: # negative prompt negative_prompt = m.group(1) - print(f"negative prompt: {negative_prompt}") + logger.info(f"negative prompt: {negative_prompt}") continue m = re.match(r"c (.+)", parg, re.IGNORECASE) if m: # clip prompt clip_prompt = m.group(1) - print(f"clip prompt: {clip_prompt}") + logger.info(f"clip prompt: {clip_prompt}") continue m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE) @@ -2697,89 +2703,89 @@ def scale_and_round(x): network_muls = [float(v) for v in m.group(1).split(",")] while len(network_muls) < len(networks): network_muls.append(network_muls[-1]) - print(f"network mul: {network_muls}") + logger.info(f"network mul: {network_muls}") continue # Deep Shrink m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE) if m: # deep shrink depth 1 ds_depth_1 = int(m.group(1)) - print(f"deep shrink depth 1: {ds_depth_1}") + logger.info(f"deep shrink depth 1: {ds_depth_1}") continue m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE) if m: # deep shrink timesteps 1 ds_timesteps_1 = int(m.group(1)) ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - print(f"deep shrink timesteps 1: {ds_timesteps_1}") + logger.info(f"deep shrink timesteps 1: {ds_timesteps_1}") continue m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE) if m: # deep shrink depth 2 ds_depth_2 = int(m.group(1)) ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - print(f"deep shrink depth 2: {ds_depth_2}") + logger.info(f"deep shrink depth 2: {ds_depth_2}") continue m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE) if m: # deep shrink timesteps 2 ds_timesteps_2 = int(m.group(1)) ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - print(f"deep shrink timesteps 2: {ds_timesteps_2}") + logger.info(f"deep shrink timesteps 2: {ds_timesteps_2}") continue m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE) if m: # deep shrink ratio ds_ratio = float(m.group(1)) ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - print(f"deep shrink ratio: {ds_ratio}") + logger.info(f"deep shrink ratio: {ds_ratio}") continue # Gradual Latent m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE) if m: # gradual latent timesteps gl_timesteps = int(m.group(1)) - print(f"gradual latent timesteps: {gl_timesteps}") + logger.info(f"gradual latent timesteps: {gl_timesteps}") continue m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE) if m: # gradual latent ratio gl_ratio = float(m.group(1)) gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - print(f"gradual latent ratio: {ds_ratio}") + logger.info(f"gradual latent ratio: {ds_ratio}") continue m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE) if m: # gradual latent every n steps gl_every_n_steps = int(m.group(1)) gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - print(f"gradual latent every n steps: {gl_every_n_steps}") + logger.info(f"gradual latent every n steps: {gl_every_n_steps}") continue m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE) if m: # gradual latent ratio step gl_ratio_step = float(m.group(1)) gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - print(f"gradual latent ratio step: {gl_ratio_step}") + logger.info(f"gradual latent ratio step: {gl_ratio_step}") continue m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE) if m: # gradual latent s noise gl_s_noise = float(m.group(1)) gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - print(f"gradual latent s noise: {gl_s_noise}") + logger.info(f"gradual latent s noise: {gl_s_noise}") continue m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE) if m: # gradual latent unsharp params gl_unsharp_params = m.group(1) gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - print(f"gradual latent unsharp params: {gl_unsharp_params}") + logger.info(f"gradual latent unsharp params: {gl_unsharp_params}") continue except ValueError as ex: - print(f"Exception in parsing / 解析エラー: {parg}") - print(ex) + logger.error(f"Exception in parsing / 解析エラー: {parg}") + logger.error(f"{ex}") # override Deep Shrink if ds_depth_1 is not None: @@ -2825,7 +2831,7 @@ def scale_and_round(x): if seed is None: seed = seed_random.randint(0, 2**32 - 1) if args.interactive: - print(f"seed: {seed}") + logger.info(f"seed: {seed}") # prepare init image, guide image and mask init_image = mask_image = guide_image = None @@ -2841,7 +2847,7 @@ def scale_and_round(x): width = width - width % 32 height = height - height % 32 if width != init_image.size[0] or height != init_image.size[1]: - print( + logger.warning( f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます" ) @@ -2903,12 +2909,14 @@ def scale_and_round(x): process_batch(batch_data, highres_fix) batch_data.clear() - print("done!") + logger.info("done!") def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() + add_logging_arguments(parser) + parser.add_argument( "--sdxl", action="store_true", help="load Stable Diffusion XL model / Stable Diffusion XLのモデルを読み込む" ) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 2c5f84a93..2c40f1a06 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -489,10 +489,10 @@ def set_control_nets(self, ctrl_nets): def set_gradual_latent(self, gradual_latent): if gradual_latent is None: - print("gradual_latent is disabled") + logger.info("gradual_latent is disabled") self.gradual_latent = None else: - print(f"gradual_latent is enabled: {gradual_latent}") + logger.info(f"gradual_latent is enabled: {gradual_latent}") self.gradual_latent = gradual_latent # (ds_ratio, start_timesteps, every_n_steps, ratio_step) # region xformersとか使う部分:独自に書き換えるので関係なし @@ -971,8 +971,8 @@ def __call__( enable_gradual_latent = False if self.gradual_latent: if not hasattr(self.scheduler, "set_gradual_latent_params"): - print("gradual_latent is not supported for this scheduler. Ignoring.") - print(self.scheduler.__class__.__name__) + logger.info("gradual_latent is not supported for this scheduler. Ignoring.") + logger.info(f'{self.scheduler.__class__.__name__}') else: enable_gradual_latent = True step_elapsed = 1000 @@ -3314,42 +3314,42 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE) if m: # gradual latent timesteps gl_timesteps = int(m.group(1)) - print(f"gradual latent timesteps: {gl_timesteps}") + logger.info(f"gradual latent timesteps: {gl_timesteps}") continue m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE) if m: # gradual latent ratio gl_ratio = float(m.group(1)) gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - print(f"gradual latent ratio: {ds_ratio}") + logger.info(f"gradual latent ratio: {ds_ratio}") continue m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE) if m: # gradual latent every n steps gl_every_n_steps = int(m.group(1)) gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - print(f"gradual latent every n steps: {gl_every_n_steps}") + logger.info(f"gradual latent every n steps: {gl_every_n_steps}") continue m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE) if m: # gradual latent ratio step gl_ratio_step = float(m.group(1)) gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - print(f"gradual latent ratio step: {gl_ratio_step}") + logger.info(f"gradual latent ratio step: {gl_ratio_step}") continue m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE) if m: # gradual latent s noise gl_s_noise = float(m.group(1)) gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - print(f"gradual latent s noise: {gl_s_noise}") + logger.info(f"gradual latent s noise: {gl_s_noise}") continue m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE) if m: # gradual latent unsharp params gl_unsharp_params = m.group(1) gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - print(f"gradual latent unsharp params: {gl_unsharp_params}") + logger.info(f"gradual latent unsharp params: {gl_unsharp_params}") continue except ValueError as ex: @@ -3369,7 +3369,7 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): if gl_unsharp_params is not None: unsharp_params = gl_unsharp_params.split(",") us_ksize, us_sigma, us_strength = [float(v) for v in unsharp_params[:3]] - print(unsharp_params) + logger.info(f'{unsharp_params}') us_target_x = True if len(unsharp_params) < 4 else bool(int(unsharp_params[3])) us_ksize = int(us_ksize) else: diff --git a/library/device_utils.py b/library/device_utils.py index 8823c5d9a..f6641ab5c 100644 --- a/library/device_utils.py +++ b/library/device_utils.py @@ -3,6 +3,11 @@ import torch +from .utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) + try: HAS_CUDA = torch.cuda.is_available() except Exception: @@ -59,7 +64,7 @@ def get_preferred_device() -> torch.device: device = torch.device("mps") else: device = torch.device("cpu") - print(f"get_preferred_device() -> {device}") + logger.info(f"get_preferred_device() -> {device}") return device @@ -77,8 +82,8 @@ def init_ipex(): is_initialized, error_message = ipex_init() if not is_initialized: - print("failed to initialize ipex:", error_message) + logger.error("failed to initialize ipex: {error_message}") else: return except Exception as e: - print("failed to initialize ipex:", e) + logger.error("failed to initialize ipex: {e}") diff --git a/networks/dylora.py b/networks/dylora.py index d71279c55..637f33450 100644 --- a/networks/dylora.py +++ b/networks/dylora.py @@ -327,10 +327,10 @@ def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules for i, text_encoder in enumerate(text_encoders): if len(text_encoders) > 1: index = i + 1 - print(f"create LoRA for Text Encoder {index}") + logger.info(f"create LoRA for Text Encoder {index}") else: index = None - print(f"create LoRA for Text Encoder") + logger.info("create LoRA for Text Encoder") text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) self.text_encoder_loras.extend(text_encoder_loras) diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index 641b3209f..d52f85a8f 100755 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -380,10 +380,10 @@ def set_control_nets(self, ctrl_nets): def set_gradual_latent(self, gradual_latent): if gradual_latent is None: - print("gradual_latent is disabled") + logger.info("gradual_latent is disabled") self.gradual_latent = None else: - print(f"gradual_latent is enabled: {gradual_latent}") + logger.info(f"gradual_latent is enabled: {gradual_latent}") self.gradual_latent = gradual_latent # (ds_ratio, start_timesteps, every_n_steps, ratio_step) @torch.no_grad() @@ -789,8 +789,8 @@ def __call__( enable_gradual_latent = False if self.gradual_latent: if not hasattr(self.scheduler, "set_gradual_latent_params"): - print("gradual_latent is not supported for this scheduler. Ignoring.") - print(self.scheduler.__class__.__name__) + logger.info("gradual_latent is not supported for this scheduler. Ignoring.") + logger.info(f'{self.scheduler.__class__.__name__}') else: enable_gradual_latent = True step_elapsed = 1000 @@ -2614,84 +2614,84 @@ def scale_and_round(x): m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE) if m: # gradual latent timesteps gl_timesteps = int(m.group(1)) - print(f"gradual latent timesteps: {gl_timesteps}") + logger.info(f"gradual latent timesteps: {gl_timesteps}") continue m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE) if m: # gradual latent ratio gl_ratio = float(m.group(1)) gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - print(f"gradual latent ratio: {ds_ratio}") + logger.info(f"gradual latent ratio: {ds_ratio}") continue m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE) if m: # gradual latent every n steps gl_every_n_steps = int(m.group(1)) gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - print(f"gradual latent every n steps: {gl_every_n_steps}") + logger.info(f"gradual latent every n steps: {gl_every_n_steps}") continue m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE) if m: # gradual latent ratio step gl_ratio_step = float(m.group(1)) gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - print(f"gradual latent ratio step: {gl_ratio_step}") + logger.info(f"gradual latent ratio step: {gl_ratio_step}") continue m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE) if m: # gradual latent s noise gl_s_noise = float(m.group(1)) gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - print(f"gradual latent s noise: {gl_s_noise}") + logger.info(f"gradual latent s noise: {gl_s_noise}") continue m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE) if m: # gradual latent unsharp params gl_unsharp_params = m.group(1) gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - print(f"gradual latent unsharp params: {gl_unsharp_params}") + logger.info(f"gradual latent unsharp params: {gl_unsharp_params}") continue # Gradual Latent m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE) if m: # gradual latent timesteps gl_timesteps = int(m.group(1)) - print(f"gradual latent timesteps: {gl_timesteps}") + logger.info(f"gradual latent timesteps: {gl_timesteps}") continue m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE) if m: # gradual latent ratio gl_ratio = float(m.group(1)) gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - print(f"gradual latent ratio: {ds_ratio}") + logger.info(f"gradual latent ratio: {ds_ratio}") continue m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE) if m: # gradual latent every n steps gl_every_n_steps = int(m.group(1)) gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - print(f"gradual latent every n steps: {gl_every_n_steps}") + logger.info(f"gradual latent every n steps: {gl_every_n_steps}") continue m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE) if m: # gradual latent ratio step gl_ratio_step = float(m.group(1)) gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - print(f"gradual latent ratio step: {gl_ratio_step}") + logger.info(f"gradual latent ratio step: {gl_ratio_step}") continue m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE) if m: # gradual latent s noise gl_s_noise = float(m.group(1)) gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - print(f"gradual latent s noise: {gl_s_noise}") + logger.info(f"gradual latent s noise: {gl_s_noise}") continue m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE) if m: # gradual latent unsharp params gl_unsharp_params = m.group(1) gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - print(f"gradual latent unsharp params: {gl_unsharp_params}") + logger.info(f"gradual latent unsharp params: {gl_unsharp_params}") continue except ValueError as ex: From fccbee27277d65a8dcbdeeb81787ed4116b92e0b Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 25 Feb 2024 10:43:14 +0900 Subject: [PATCH 073/103] revert logging #1137 --- library/device_utils.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/library/device_utils.py b/library/device_utils.py index f6641ab5c..8823c5d9a 100644 --- a/library/device_utils.py +++ b/library/device_utils.py @@ -3,11 +3,6 @@ import torch -from .utils import setup_logging -setup_logging() -import logging -logger = logging.getLogger(__name__) - try: HAS_CUDA = torch.cuda.is_available() except Exception: @@ -64,7 +59,7 @@ def get_preferred_device() -> torch.device: device = torch.device("mps") else: device = torch.device("cpu") - logger.info(f"get_preferred_device() -> {device}") + print(f"get_preferred_device() -> {device}") return device @@ -82,8 +77,8 @@ def init_ipex(): is_initialized, error_message = ipex_init() if not is_initialized: - logger.error("failed to initialize ipex: {error_message}") + print("failed to initialize ipex:", error_message) else: return except Exception as e: - logger.error("failed to initialize ipex: {e}") + print("failed to initialize ipex:", e) From 577e9913ca241a56631dbca924e17d9012bde116 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 26 Feb 2024 20:01:25 +0900 Subject: [PATCH 074/103] add some new dataset settings --- README.md | 161 ++++++++++++++++++++++++++--------------- library/config_util.py | 6 ++ library/train_util.py | 62 +++++++++++++++- train_network.py | 5 ++ 4 files changed, 175 insertions(+), 59 deletions(-) diff --git a/README.md b/README.md index e635e5aed..e1b6a26c3 100644 --- a/README.md +++ b/README.md @@ -249,6 +249,109 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum ## Change History +### Working in progress + +- `train_network.py` and `sdxl_train_network.py` are modified to record some dataset settings in the metadata of the trained model (`caption_prefix`, `caption_suffix`, `keep_tokens_separator`, `secondary_separator`, `enable_wildcard`). +- Some features are added to the dataset subset settings. + - `secondary_separator` is added to specify the tag separator that is not the target of shuffling or dropping. + - Specify `secondary_separator=";;;"`. When you specify `secondary_separator`, the part is not shuffled or dropped. See the example below. + - `enable_wildcard` is added. When set to `true`, the wildcard notation `{aaa|bbb|ccc}` can be used. See the example below. + - `keep_tokens_separator` is updated to be used twice in the caption. When you specify `keep_tokens_separator="|||"`, the part divided by the second `|||` is not shuffled or dropped and remains at the end. + - The existing features `caption_prefix` and `caption_suffix` can be used together. `caption_prefix` and `caption_suffix` are processed first, and then `enable_wildcard`, `keep_tokens_separator`, shuffling and dropping, and `secondary_separator` are processed in order. + - The examples are [shown below](#example-of-dataset-settings--データセット設定の記述例). + +- `train_network.py` および `sdxl_train_network.py` で、学習したモデルのメタデータに一部のデータセット設定が記録されるよう修正しました(`caption_prefix`、`caption_suffix`、`keep_tokens_separator`、`secondary_separator`、`enable_wildcard`)。 +- データセットのサブセット設定にいくつかの機能を追加しました。 + - シャッフルの対象とならないタグ分割識別子の指定 `secondary_separator` を追加しました。`secondary_separator=";;;"` のように指定します。`secondary_separator` で区切ることで、その部分はシャッフル、drop 時にまとめて扱われます。詳しくは記述例をご覧ください。 + - `enable_wildcard` を追加しました。`true` にするとワイルドカード記法 `{aaa|bbb|ccc}` が使えます。詳しくは記述例をご覧ください。 + - `keep_tokens_separator` をキャプション内に 2 つ使えるようにしました。たとえば `keep_tokens_separator="|||"` と指定したとき、`1girl, hatsune miku, vocaloid ||| stage, mic ||| best quality, rating: general` とキャプションを指定すると、二番目の `|||` で分割された部分はシャッフル、drop されず末尾に残ります。 + - 既存の機能 `caption_prefix` と `caption_suffix` とあわせて使えます。`caption_prefix` と `caption_suffix` は一番最初に処理され、その後、ワイルドカード、`keep_tokens_separator`、シャッフルおよび drop、`secondary_separator` の順に処理されます。 + +#### Example of dataset settings / データセット設定の記述例: + +```toml +[general] +flip_aug = true +color_aug = false +resolution = [1024, 1024] + +[[datasets]] +batch_size = 6 +enable_bucket = true +bucket_no_upscale = true +caption_extension = ".txt" +keep_tokens_separator= "|||" +shuffle_caption = true +caption_tag_dropout_rate = 0.1 +secondary_separator = ";;;" # subset 側に書くこともできます / can be written in the subset side +enable_wildcard = true # 同上 / same as above + + [[datasets.subsets]] + image_dir = "/path/to/image_dir" + num_repeats = 1 + + # ||| の前後はカンマは不要です(自動的に追加されます) / No comma is required before and after ||| (it is added automatically) + caption_prefix = "1girl, hatsune miku, vocaloid |||" + + # ||| の後はシャッフル、drop されず残ります / After |||, it is not shuffled or dropped and remains + # 単純に文字列として連結されるので、カンマなどは自分で入れる必要があります / It is simply concatenated as a string, so you need to put commas yourself + caption_suffix = ", anime screencap ||| masterpiece, rating: general" +``` + +#### Example of caption, secondary_separator notation: `secondary_separator = ";;;"` + +```txt +1girl, hatsune miku, vocaloid, upper body, looking at viewer, sky;;;cloud;;;day, outdoors +``` +The part `sky;;;cloud;;;day` is replaced with `sky,cloud,day` without shuffling or dropping. When shuffling and dropping are enabled, it is processed as a whole (as one tag). For example, it becomes `vocaloid, 1girl, upper body, sky,cloud,day, outdoors, hatsune miku` (shuffled) or `vocaloid, 1girl, outdoors, looking at viewer, upper body, hatsune miku` (dropped). + +#### Example of caption, enable_wildcard notation: `enable_wildcard = true` + +```txt +1girl, hatsune miku, vocaloid, upper body, looking at viewer, {simple|white} background +``` +`simple` or `white` is randomly selected, and it becomes `simple background` or `white background`. + +```txt +1girl, hatsune miku, vocaloid, {{retro style}} +``` +If you want to include `{` or `}` in the tag string, double them like `{{` or `}}` (in this example, the actual caption used for training is `{retro style}`). + +#### Example of caption, `keep_tokens_separator` notation: `keep_tokens_separator = "|||"` + +```txt +1girl, hatsune miku, vocaloid ||| stage, microphone, white shirt, smile ||| best quality, rating: general +``` +It becomes `1girl, hatsune miku, vocaloid, microphone, stage, white shirt, best quality, rating: general` or `1girl, hatsune miku, vocaloid, white shirt, smile, stage, microphone, best quality, rating: general` etc. + + +#### キャプション記述例、secondary_separator 記法:`secondary_separator = ";;;"` の場合 + +```txt +1girl, hatsune miku, vocaloid, upper body, looking at viewer, sky;;;cloud;;;day, outdoors +``` +`sky;;;cloud;;;day` の部分はシャッフル、drop されず `sky,cloud,day` に置換されます。シャッフル、drop が有効な場合、まとめて(一つのタグとして)処理されます。つまり `vocaloid, 1girl, upper body, sky,cloud,day, outdoors, hatsune miku` (シャッフル)や `vocaloid, 1girl, outdoors, looking at viewer, upper body, hatsune miku` (drop されたケース)などになります。 + +#### キャプション記述例、ワイルドカード記法: `enable_wildcard = true` の場合 + +```txt +1girl, hatsune miku, vocaloid, upper body, looking at viewer, {simple|white} background +``` +ランダムに `simple` または `white` が選ばれ、`simple background` または `white background` になります。 + +```txt +1girl, hatsune miku, vocaloid, {{retro style}} +``` +タグ文字列に `{` や `}` そのものを含めたい場合は `{{` や `}}` のように二つ重ねてください(この例では実際に学習に用いられるキャプションは `{retro style}` になります)。 + +#### キャプション記述例、`keep_tokens_separator` 記法: `keep_tokens_separator = "|||"` の場合 + +```txt +1girl, hatsune miku, vocaloid ||| stage, microphone, white shirt, smile ||| best quality, rating: general +``` +`1girl, hatsune miku, vocaloid, microphone, stage, white shirt, best quality, rating: general` や `1girl, hatsune miku, vocaloid, white shirt, smile, stage, microphone, best quality, rating: general` などになります。 + + ### Feb 24, 2024 / 2024/2/24: v0.8.4 - The log output has been improved. PR [#905](https://github.com/kohya-ss/sd-scripts/pull/905) Thanks to shirayu! @@ -304,64 +407,6 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum - 複数 GPU での学習時に `network_multiplier` を指定するとクラッシュする不具合が修正されました。 PR [#1084](https://github.com/kohya-ss/sd-scripts/pull/1084) fireicewolf 氏に感謝します。 - ControlNet-LLLite の学習がエラーになる不具合を修正しました。 -### Jan 23, 2024 / 2024/1/23: v0.8.2 - -- [Experimental] The `--fp8_base` option is added to the training scripts for LoRA etc. The base model (U-Net, and Text Encoder when training modules for Text Encoder) can be trained with fp8. PR [#1057](https://github.com/kohya-ss/sd-scripts/pull/1057) Thanks to KohakuBlueleaf! - - Please specify `--fp8_base` in `train_network.py` or `sdxl_train_network.py`. - - PyTorch 2.1 or later is required. - - If you use xformers with PyTorch 2.1, please see [xformers repository](https://github.com/facebookresearch/xformers) and install the appropriate version according to your CUDA version. - - The sample image generation during training consumes a lot of memory. It is recommended to turn it off. - -- [Experimental] The network multiplier can be specified for each dataset in the training scripts for LoRA etc. - - This is an experimental option and may be removed or changed in the future. - - For example, if you train with state A as `1.0` and state B as `-1.0`, you may be able to generate by switching between state A and B depending on the LoRA application rate. - - Also, if you prepare five states and train them as `0.2`, `0.4`, `0.6`, `0.8`, and `1.0`, you may be able to generate by switching the states smoothly depending on the application rate. - - Please specify `network_multiplier` in `[[datasets]]` in `.toml` file. -- Some options are added to `networks/extract_lora_from_models.py` to reduce the memory usage. - - `--load_precision` option can be used to specify the precision when loading the model. If the model is saved in fp16, you can reduce the memory usage by specifying `--load_precision fp16` without losing precision. - - `--load_original_model_to` option can be used to specify the device to load the original model. `--load_tuned_model_to` option can be used to specify the device to load the derived model. The default is `cpu` for both options, but you can specify `cuda` etc. You can reduce the memory usage by loading one of them to GPU. This option is available only for SDXL. - -- The gradient synchronization in LoRA training with multi-GPU is improved. PR [#1064](https://github.com/kohya-ss/sd-scripts/pull/1064) Thanks to KohakuBlueleaf! -- The code for Intel IPEX support is improved. PR [#1060](https://github.com/kohya-ss/sd-scripts/pull/1060) Thanks to akx! -- Fixed a bug in multi-GPU Textual Inversion training. - -- (実験的) LoRA等の学習スクリプトで、ベースモデル(U-Net、および Text Encoder のモジュール学習時は Text Encoder も)の重みを fp8 にして学習するオプションが追加されました。 PR [#1057](https://github.com/kohya-ss/sd-scripts/pull/1057) KohakuBlueleaf 氏に感謝します。 - - `train_network.py` または `sdxl_train_network.py` で `--fp8_base` を指定してください。 - - PyTorch 2.1 以降が必要です。 - - PyTorch 2.1 で xformers を使用する場合は、[xformers のリポジトリ](https://github.com/facebookresearch/xformers) を参照し、CUDA バージョンに応じて適切なバージョンをインストールしてください。 - - 学習中のサンプル画像生成はメモリを大量に消費するため、オフにすることをお勧めします。 -- (実験的) LoRA 等の学習で、データセットごとに異なるネットワーク適用率を指定できるようになりました。 - - 実験的オプションのため、将来的に削除または仕様変更される可能性があります。 - - たとえば状態 A を `1.0`、状態 B を `-1.0` として学習すると、LoRA の適用率に応じて状態 A と B を切り替えつつ生成できるかもしれません。 - - また、五段階の状態を用意し、それぞれ `0.2`、`0.4`、`0.6`、`0.8`、`1.0` として学習すると、適用率でなめらかに状態を切り替えて生成できるかもしれません。 - - `.toml` ファイルで `[[datasets]]` に `network_multiplier` を指定してください。 -- `networks/extract_lora_from_models.py` に使用メモリ量を削減するいくつかのオプションを追加しました。 - - `--load_precision` で読み込み時の精度を指定できます。モデルが fp16 で保存されている場合は `--load_precision fp16` を指定して精度を変えずにメモリ量を削減できます。 - - `--load_original_model_to` で元モデルを読み込むデバイスを、`--load_tuned_model_to` で派生モデルを読み込むデバイスを指定できます。デフォルトは両方とも `cpu` ですがそれぞれ `cuda` 等を指定できます。片方を GPU に読み込むことでメモリ量を削減できます。SDXL の場合のみ有効です。 -- マルチ GPU での LoRA 等の学習時に勾配の同期が改善されました。 PR [#1064](https://github.com/kohya-ss/sd-scripts/pull/1064) KohakuBlueleaf 氏に感謝します。 -- Intel IPEX サポートのコードが改善されました。PR [#1060](https://github.com/kohya-ss/sd-scripts/pull/1060) akx 氏に感謝します。 -- マルチ GPU での Textual Inversion 学習の不具合を修正しました。 - -- `.toml` example for network multiplier / ネットワーク適用率の `.toml` の記述例 - -```toml -[general] -[[datasets]] -resolution = 512 -batch_size = 8 -network_multiplier = 1.0 - -... subset settings ... - -[[datasets]] -resolution = 512 -batch_size = 8 -network_multiplier = -1.0 - -... subset settings ... -``` - - Please read [Releases](https://github.com/kohya-ss/sd-scripts/releases) for recent updates. 最近の更新情報は [Release](https://github.com/kohya-ss/sd-scripts/releases) をご覧ください。 diff --git a/library/config_util.py b/library/config_util.py index fc4b36175..eb652ecf3 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -60,6 +60,8 @@ class BaseSubsetParams: caption_separator: str = (",",) keep_tokens: int = 0 keep_tokens_separator: str = (None,) + secondary_separator: Optional[str] = None + enable_wildcard: bool = False color_aug: bool = False flip_aug: bool = False face_crop_aug_range: Optional[Tuple[float, float]] = None @@ -181,6 +183,8 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence] "shuffle_caption": bool, "keep_tokens": int, "keep_tokens_separator": str, + "secondary_separator": str, + "enable_wildcard": bool, "token_warmup_min": int, "token_warmup_step": Any(float, int), "caption_prefix": str, @@ -504,6 +508,8 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu shuffle_caption: {subset.shuffle_caption} keep_tokens: {subset.keep_tokens} keep_tokens_separator: {subset.keep_tokens_separator} + secondary_separator: {subset.secondary_separator} + enable_wildcard: {subset.enable_wildcard} caption_dropout_rate: {subset.caption_dropout_rate} caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs} caption_tag_dropout_rate: {subset.caption_tag_dropout_rate} diff --git a/library/train_util.py b/library/train_util.py index d2b69edb5..b71e4edc6 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -364,6 +364,8 @@ def __init__( caption_separator: str, keep_tokens: int, keep_tokens_separator: str, + secondary_separator: Optional[str], + enable_wildcard: bool, color_aug: bool, flip_aug: bool, face_crop_aug_range: Optional[Tuple[float, float]], @@ -382,6 +384,8 @@ def __init__( self.caption_separator = caption_separator self.keep_tokens = keep_tokens self.keep_tokens_separator = keep_tokens_separator + self.secondary_separator = secondary_separator + self.enable_wildcard = enable_wildcard self.color_aug = color_aug self.flip_aug = flip_aug self.face_crop_aug_range = face_crop_aug_range @@ -410,6 +414,8 @@ def __init__( caption_separator: str, keep_tokens, keep_tokens_separator, + secondary_separator, + enable_wildcard, color_aug, flip_aug, face_crop_aug_range, @@ -431,6 +437,8 @@ def __init__( caption_separator, keep_tokens, keep_tokens_separator, + secondary_separator, + enable_wildcard, color_aug, flip_aug, face_crop_aug_range, @@ -466,6 +474,8 @@ def __init__( caption_separator, keep_tokens, keep_tokens_separator, + secondary_separator, + enable_wildcard, color_aug, flip_aug, face_crop_aug_range, @@ -487,6 +497,8 @@ def __init__( caption_separator, keep_tokens, keep_tokens_separator, + secondary_separator, + enable_wildcard, color_aug, flip_aug, face_crop_aug_range, @@ -519,6 +531,8 @@ def __init__( caption_separator, keep_tokens, keep_tokens_separator, + secondary_separator, + enable_wildcard, color_aug, flip_aug, face_crop_aug_range, @@ -540,6 +554,8 @@ def __init__( caption_separator, keep_tokens, keep_tokens_separator, + secondary_separator, + enable_wildcard, color_aug, flip_aug, face_crop_aug_range, @@ -675,15 +691,41 @@ def process_caption(self, subset: BaseSubset, caption): if is_drop_out: caption = "" else: + # process wildcards + if subset.enable_wildcard: + # wildcard is like '{aaa|bbb|ccc...}' + # escape the curly braces like {{ or }} + replacer1 = "⦅" + replacer2 = "⦆" + while replacer1 in caption or replacer2 in caption: + replacer1 += "⦅" + replacer2 += "⦆" + + caption = caption.replace("{{", replacer1).replace("}}", replacer2) + + # replace the wildcard + def replace_wildcard(match): + return random.choice(match.group(1).split("|")) + + caption = re.sub(r"\{([^}]+)\}", replace_wildcard, caption) + + # unescape the curly braces + caption = caption.replace(replacer1, "{").replace(replacer2, "}") + if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0: fixed_tokens = [] flex_tokens = [] + fixed_suffix_tokens = [] if ( hasattr(subset, "keep_tokens_separator") and subset.keep_tokens_separator and subset.keep_tokens_separator in caption ): fixed_part, flex_part = caption.split(subset.keep_tokens_separator, 1) + if subset.keep_tokens_separator in flex_part: + flex_part, fixed_suffix_part = flex_part.split(subset.keep_tokens_separator, 1) + fixed_suffix_tokens = [t.strip() for t in fixed_suffix_part.split(subset.caption_separator) if t.strip()] + fixed_tokens = [t.strip() for t in fixed_part.split(subset.caption_separator) if t.strip()] flex_tokens = [t.strip() for t in flex_part.split(subset.caption_separator) if t.strip()] else: @@ -718,7 +760,11 @@ def dropout_tags(tokens): flex_tokens = dropout_tags(flex_tokens) - caption = ", ".join(fixed_tokens + flex_tokens) + caption = ", ".join(fixed_tokens + flex_tokens + fixed_suffix_tokens) + + # process secondary separator + if subset.secondary_separator: + caption = caption.replace(subset.secondary_separator, subset.caption_separator) # textual inversion対応 for str_from, str_to in self.replacements.items(): @@ -1774,6 +1820,8 @@ def __init__( subset.caption_separator, subset.keep_tokens, subset.keep_tokens_separator, + subset.secondary_separator, + subset.enable_wildcard, subset.color_aug, subset.flip_aug, subset.face_crop_aug_range, @@ -3284,6 +3332,18 @@ def add_dataset_arguments( help="A custom separator to divide the caption into fixed and flexible parts. Tokens before this separator will not be shuffled. If not specified, '--keep_tokens' will be used to determine the fixed number of tokens." + " / captionを固定部分と可変部分に分けるためのカスタム区切り文字。この区切り文字より前のトークンはシャッフルされない。指定しない場合、'--keep_tokens'が固定部分のトークン数として使用される。", ) + parser.add_argument( + "--secondary_separator", + type=str, + default=None, + help="a secondary separator for caption. This separator is replaced to caption_separator after dropping/shuffling caption" + + " / captionのセカンダリ区切り文字。この区切り文字はcaptionのドロップやシャッフル後にcaption_separatorに置き換えられる", + ) + parser.add_argument( + "--enable_wildcard", + action="store_true", + help="enable wildcard for caption (e.g. '{image|picture|rendition}') / captionのワイルドカードを有効にする(例:'{image|picture|rendition}')", + ) parser.add_argument( "--caption_prefix", type=str, diff --git a/train_network.py b/train_network.py index e0fa69458..e5b26d8a2 100644 --- a/train_network.py +++ b/train_network.py @@ -564,6 +564,11 @@ def train(self, args): "random_crop": bool(subset.random_crop), "shuffle_caption": bool(subset.shuffle_caption), "keep_tokens": subset.keep_tokens, + "keep_tokens_separator": subset.keep_tokens_separator, + "secondary_separator": subset.secondary_separator, + "enable_wildcard": bool(subset.enable_wildcard), + "caption_prefix": subset.caption_prefix, + "caption_suffix": subset.caption_suffix, } image_dir_or_metadata_file = None From 14c9372a38c2d50d8206846fe0aa4406152258de Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 3 Mar 2024 21:47:37 +0900 Subject: [PATCH 075/103] add doc about Colab/rich issue --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index e1b6a26c3..7df6fafcc 100644 --- a/README.md +++ b/README.md @@ -251,6 +251,7 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum ### Working in progress +- Colab seems to stop with log output. Try specifying `--console_log_simple` option in the training script to disable rich logging. - `train_network.py` and `sdxl_train_network.py` are modified to record some dataset settings in the metadata of the trained model (`caption_prefix`, `caption_suffix`, `keep_tokens_separator`, `secondary_separator`, `enable_wildcard`). - Some features are added to the dataset subset settings. - `secondary_separator` is added to specify the tag separator that is not the target of shuffling or dropping. @@ -260,6 +261,8 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum - The existing features `caption_prefix` and `caption_suffix` can be used together. `caption_prefix` and `caption_suffix` are processed first, and then `enable_wildcard`, `keep_tokens_separator`, shuffling and dropping, and `secondary_separator` are processed in order. - The examples are [shown below](#example-of-dataset-settings--データセット設定の記述例). + +- Colab での動作時、ログ出力で停止してしまうようです。学習スクリプトに `--console_log_simple` オプションを指定し、rich のロギングを無効してお試しください。 - `train_network.py` および `sdxl_train_network.py` で、学習したモデルのメタデータに一部のデータセット設定が記録されるよう修正しました(`caption_prefix`、`caption_suffix`、`keep_tokens_separator`、`secondary_separator`、`enable_wildcard`)。 - データセットのサブセット設定にいくつかの機能を追加しました。 - シャッフルの対象とならないタグ分割識別子の指定 `secondary_separator` を追加しました。`secondary_separator=";;;"` のように指定します。`secondary_separator` で区切ることで、その部分はシャッフル、drop 時にまとめて扱われます。詳しくは記述例をご覧ください。 From 124ec45876a9f07820b42fda0d7ca9019de773d5 Mon Sep 17 00:00:00 2001 From: Horizon1704 <92718180+Horizon1704@users.noreply.github.com> Date: Sun, 10 Mar 2024 22:53:05 +0800 Subject: [PATCH 076/103] Add "encoding='utf-8'" --- library/train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index d2b69edb5..5f23dd13f 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3474,7 +3474,7 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar exit(1) logger.info(f"Loading settings from {config_path}...") - with open(config_path, "r") as f: + with open(config_path, "r", encoding='utf-8') as f: config_dict = toml.load(f) # combine all sections into one From 095b8035e63f7c79a232114d8f0e1ec27f431ebc Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Sun, 10 Mar 2024 23:33:38 +0800 Subject: [PATCH 077/103] save state on train end --- fine_tune.py | 2 +- library/train_util.py | 5 +++++ sdxl_train.py | 2 +- sdxl_train_control_net_lllite.py | 2 +- train_controlnet.py | 2 +- train_db.py | 2 +- train_network.py | 2 +- train_textual_inversion.py | 2 +- train_textual_inversion_XTI.py | 2 +- 9 files changed, 13 insertions(+), 8 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 875a91951..46f128287 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -457,7 +457,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): accelerator.end_training() - if args.save_state and is_main_process: + if is_main_process and (args.save_state or args.save_state_on_train_end): train_util.save_state_on_train_end(args, accelerator) del accelerator # この後メモリを使うのでこれは消す diff --git a/library/train_util.py b/library/train_util.py index d2b69edb5..b3ca15f55 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2890,6 +2890,11 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: action="store_true", help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する", ) + parser.add_argument( + "--save_state_on_train_end", + action="store_true", + help="save training state additionally (including optimizer states etc.) on train end / optimizerなど学習状態も含めたstateを追加で保存する", + ) parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate") parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ") diff --git a/sdxl_train.py b/sdxl_train.py index e0df263d6..107bb9451 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -712,7 +712,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): accelerator.end_training() - if args.save_state: # and is_main_process: + if args.save_state or args.save_state_on_train_end: train_util.save_state_on_train_end(args, accelerator) del accelerator # この後メモリを使うのでこれは消す diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 1e5f92349..e99b4e35c 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -549,7 +549,7 @@ def remove_model(old_ckpt_name): accelerator.end_training() - if is_main_process and args.save_state: + if is_main_process and (args.save_state or args.save_state_on_train_end): train_util.save_state_on_train_end(args, accelerator) if is_main_process: diff --git a/train_controlnet.py b/train_controlnet.py index dc73a91c8..e44f08853 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -565,7 +565,7 @@ def remove_model(old_ckpt_name): accelerator.end_training() - if is_main_process and args.save_state: + if is_main_process and (args.save_state or args.save_state_on_train_end): train_util.save_state_on_train_end(args, accelerator) # del accelerator # この後メモリを使うのでこれは消す→printで使うので消さずにおく diff --git a/train_db.py b/train_db.py index 8d36097a5..41a9a7b99 100644 --- a/train_db.py +++ b/train_db.py @@ -444,7 +444,7 @@ def train(args): accelerator.end_training() - if args.save_state and is_main_process: + if is_main_process and (args.save_state or args.save_state_on_train_end): train_util.save_state_on_train_end(args, accelerator) del accelerator # この後メモリを使うのでこれは消す diff --git a/train_network.py b/train_network.py index e0fa69458..4707d5ae5 100644 --- a/train_network.py +++ b/train_network.py @@ -935,7 +935,7 @@ def remove_model(old_ckpt_name): accelerator.end_training() - if is_main_process and args.save_state: + if is_main_process and args.save_state or args.save_state_on_train_end: train_util.save_state_on_train_end(args, accelerator) if is_main_process: diff --git a/train_textual_inversion.py b/train_textual_inversion.py index df1d8485a..0266bc143 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -732,7 +732,7 @@ def remove_model(old_ckpt_name): accelerator.end_training() - if args.save_state and is_main_process: + if is_main_process and (args.save_state or args.save_state_on_train_end): train_util.save_state_on_train_end(args, accelerator) if is_main_process: diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 695fad2a8..ad7c267eb 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -586,7 +586,7 @@ def remove_model(old_ckpt_name): accelerator.end_training() - if args.save_state and is_main_process: + if is_main_process and (args.save_state or args.save_state_on_train_end): train_util.save_state_on_train_end(args, accelerator) updated_embs = text_encoder.get_input_embeddings().weight[token_ids_XTI].data.detach().clone() From d282c450026dcfd5f1fd5856f5087ebaed47be46 Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Mon, 11 Mar 2024 23:56:09 +0800 Subject: [PATCH 078/103] Update train_network.py --- train_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index 4707d5ae5..3db583f1d 100644 --- a/train_network.py +++ b/train_network.py @@ -935,7 +935,7 @@ def remove_model(old_ckpt_name): accelerator.end_training() - if is_main_process and args.save_state or args.save_state_on_train_end: + if is_main_process and (args.save_state or args.save_state_on_train_end): train_util.save_state_on_train_end(args, accelerator) if is_main_process: From 948029fe61d9142f88374d6701223bf9f7ee5d47 Mon Sep 17 00:00:00 2001 From: kblueleaf Date: Tue, 12 Mar 2024 19:11:45 +0800 Subject: [PATCH 079/103] random ip_noise_gamma strength --- library/train_util.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index b71e4edc6..aa2d9b90b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3100,6 +3100,13 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: help="enable input perturbation noise. used for regularization. recommended value: around 0.1 (from arxiv.org/abs/2301.11706) " + "/ input perturbation noiseを有効にする。正則化に使用される。推奨値: 0.1程度 (arxiv.org/abs/2301.11706 より)", ) + parser.add_argument( + "--ip_noise_gamma_random_strength", + type=bool, + default=False, + help="Use random strength between 0~ip_noise_gamma for input perturbation noise." + + "/ input perturbation noiseにおいて、0からip_noise_gammaの間でランダムな強度を使用します。", + ) # parser.add_argument( # "--perlin_noise", # type=int, @@ -4673,7 +4680,11 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) if args.ip_noise_gamma: - noisy_latents = noise_scheduler.add_noise(latents, noise + args.ip_noise_gamma * torch.randn_like(latents), timesteps) + if args.ip_noise_gamma_random_strength: + strength = torch.rand(1, device=latents.device) * args.ip_noise_gamma + else: + strength = args.ip_noise_gamma + noisy_latents = noise_scheduler.add_noise(latents, noise + strength * torch.randn_like(latents), timesteps) else: noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) From 86399407b2bb5a93d691846acfa88e7ba38ae70d Mon Sep 17 00:00:00 2001 From: kblueleaf Date: Tue, 12 Mar 2024 19:14:01 +0800 Subject: [PATCH 080/103] random noise_offset strength --- library/train_util.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index aa2d9b90b..5282b5240 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3087,6 +3087,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: default=None, help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)", ) + parser.add_argument( + "--noise_offset_random_strength", + type=bool, + default=False, + help="use random strength between 0~noise_offset for noise offset. / noise offsetにおいて、0からnoise_offsetの間でランダムな強度を使用します。", + ) parser.add_argument( "--multires_noise_iterations", type=int, @@ -4663,7 +4669,11 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): # Sample noise that we'll add to the latents noise = torch.randn_like(latents, device=latents.device) if args.noise_offset: - noise = custom_train_functions.apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale) + if args.noise_offset_random_strength: + noise_offset = torch.rand(1, device=latents.device) * args.noise_offset + else: + noise_offset = args.noise_offset + noise = custom_train_functions.apply_noise_offset(latents, noise, noise_offset, args.adaptive_noise_scale) if args.multires_noise_iterations: noise = custom_train_functions.pyramid_noise_like( noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount From 53954a1e2e05648bae6eb479720402968029cd3d Mon Sep 17 00:00:00 2001 From: kblueleaf Date: Tue, 12 Mar 2024 19:24:27 +0800 Subject: [PATCH 081/103] use correct settings for parser --- library/train_util.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 5282b5240..73a768672 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3089,8 +3089,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: ) parser.add_argument( "--noise_offset_random_strength", - type=bool, - default=False, + action="store_true", help="use random strength between 0~noise_offset for noise offset. / noise offsetにおいて、0からnoise_offsetの間でランダムな強度を使用します。", ) parser.add_argument( @@ -3108,8 +3107,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: ) parser.add_argument( "--ip_noise_gamma_random_strength", - type=bool, - default=False, + action="store_true", help="Use random strength between 0~ip_noise_gamma for input perturbation noise." + "/ input perturbation noiseにおいて、0からip_noise_gammaの間でランダムな強度を使用します。", ) From f811b115ba867554e989fbe1fbbf97f8df0f7852 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 15 Mar 2024 21:05:00 +0900 Subject: [PATCH 082/103] fix sdxl timestep embedding --- README.md | 10 ++++++++++ library/sdxl_original_unet.py | 8 +++++--- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index e635e5aed..3639b7be8 100644 --- a/README.md +++ b/README.md @@ -249,6 +249,16 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum ## Change History +### Mar 15, 2024 / 2024/3/15: v0.8.5 + +- Fixed a bug that the value of timestep embedding during SDXL training was incorrect. + - The inference with the generation script is also fixed. + - The impact is unknown, but please update for SDXL training. + +- SDXL 学習時の timestep embedding の値が誤っていたのを修正しました。 + - 生成スクリプトでの推論時についてもあわせて修正しました。 + - 影響の度合いは不明ですが、SDXL の学習時にはアップデートをお願いいたします。 + ### Feb 24, 2024 / 2024/2/24: v0.8.4 - The log output has been improved. PR [#905](https://github.com/kohya-ss/sd-scripts/pull/905) Thanks to shirayu! diff --git a/library/sdxl_original_unet.py b/library/sdxl_original_unet.py index 673cf9f65..17c345a89 100644 --- a/library/sdxl_original_unet.py +++ b/library/sdxl_original_unet.py @@ -31,8 +31,10 @@ from torch.nn import functional as F from einops import rearrange from .utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) IN_CHANNELS: int = 4 @@ -1074,7 +1076,7 @@ def forward(self, x, timesteps=None, context=None, y=None, **kwargs): timesteps = timesteps.expand(x.shape[0]) hs = [] - t_emb = get_timestep_embedding(timesteps, self.model_channels) # , repeat_only=False) + t_emb = get_timestep_embedding(timesteps, self.model_channels, downscale_freq_shift=0) # , repeat_only=False) t_emb = t_emb.to(x.dtype) emb = self.time_embed(t_emb) @@ -1132,7 +1134,7 @@ def __init__(self, original_unet: SdxlUNet2DConditionModel, **kwargs): # call original model's methods def __getattr__(self, name): return getattr(self.delegate, name) - + def __call__(self, *args, **kwargs): return self.delegate(*args, **kwargs) @@ -1164,7 +1166,7 @@ def forward(self, x, timesteps=None, context=None, y=None, **kwargs): timesteps = timesteps.expand(x.shape[0]) hs = [] - t_emb = get_timestep_embedding(timesteps, _self.model_channels) # , repeat_only=False) + t_emb = get_timestep_embedding(timesteps, _self.model_channels, downscale_freq_shift=0) # , repeat_only=False) t_emb = t_emb.to(x.dtype) emb = _self.time_embed(t_emb) From 443f02942cfefd6e1899849f563580508d118ce0 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 15 Mar 2024 21:35:14 +0900 Subject: [PATCH 083/103] fix doc --- README.md | 106 ------------------------------------------------------ 1 file changed, 106 deletions(-) diff --git a/README.md b/README.md index cc5aca505..927d7a42e 100644 --- a/README.md +++ b/README.md @@ -355,112 +355,6 @@ It becomes `1girl, hatsune miku, vocaloid, microphone, stage, white shirt, best `1girl, hatsune miku, vocaloid, microphone, stage, white shirt, best quality, rating: general` や `1girl, hatsune miku, vocaloid, white shirt, smile, stage, microphone, best quality, rating: general` などになります。 -### Working in progress - -- Colab seems to stop with log output. Try specifying `--console_log_simple` option in the training script to disable rich logging. -- `train_network.py` and `sdxl_train_network.py` are modified to record some dataset settings in the metadata of the trained model (`caption_prefix`, `caption_suffix`, `keep_tokens_separator`, `secondary_separator`, `enable_wildcard`). -- Some features are added to the dataset subset settings. - - `secondary_separator` is added to specify the tag separator that is not the target of shuffling or dropping. - - Specify `secondary_separator=";;;"`. When you specify `secondary_separator`, the part is not shuffled or dropped. See the example below. - - `enable_wildcard` is added. When set to `true`, the wildcard notation `{aaa|bbb|ccc}` can be used. See the example below. - - `keep_tokens_separator` is updated to be used twice in the caption. When you specify `keep_tokens_separator="|||"`, the part divided by the second `|||` is not shuffled or dropped and remains at the end. - - The existing features `caption_prefix` and `caption_suffix` can be used together. `caption_prefix` and `caption_suffix` are processed first, and then `enable_wildcard`, `keep_tokens_separator`, shuffling and dropping, and `secondary_separator` are processed in order. - - The examples are [shown below](#example-of-dataset-settings--データセット設定の記述例). - - -- Colab での動作時、ログ出力で停止してしまうようです。学習スクリプトに `--console_log_simple` オプションを指定し、rich のロギングを無効してお試しください。 -- `train_network.py` および `sdxl_train_network.py` で、学習したモデルのメタデータに一部のデータセット設定が記録されるよう修正しました(`caption_prefix`、`caption_suffix`、`keep_tokens_separator`、`secondary_separator`、`enable_wildcard`)。 -- データセットのサブセット設定にいくつかの機能を追加しました。 - - シャッフルの対象とならないタグ分割識別子の指定 `secondary_separator` を追加しました。`secondary_separator=";;;"` のように指定します。`secondary_separator` で区切ることで、その部分はシャッフル、drop 時にまとめて扱われます。詳しくは記述例をご覧ください。 - - `enable_wildcard` を追加しました。`true` にするとワイルドカード記法 `{aaa|bbb|ccc}` が使えます。詳しくは記述例をご覧ください。 - - `keep_tokens_separator` をキャプション内に 2 つ使えるようにしました。たとえば `keep_tokens_separator="|||"` と指定したとき、`1girl, hatsune miku, vocaloid ||| stage, mic ||| best quality, rating: general` とキャプションを指定すると、二番目の `|||` で分割された部分はシャッフル、drop されず末尾に残ります。 - - 既存の機能 `caption_prefix` と `caption_suffix` とあわせて使えます。`caption_prefix` と `caption_suffix` は一番最初に処理され、その後、ワイルドカード、`keep_tokens_separator`、シャッフルおよび drop、`secondary_separator` の順に処理されます。 - -#### Example of dataset settings / データセット設定の記述例: - -```toml -[general] -flip_aug = true -color_aug = false -resolution = [1024, 1024] - -[[datasets]] -batch_size = 6 -enable_bucket = true -bucket_no_upscale = true -caption_extension = ".txt" -keep_tokens_separator= "|||" -shuffle_caption = true -caption_tag_dropout_rate = 0.1 -secondary_separator = ";;;" # subset 側に書くこともできます / can be written in the subset side -enable_wildcard = true # 同上 / same as above - - [[datasets.subsets]] - image_dir = "/path/to/image_dir" - num_repeats = 1 - - # ||| の前後はカンマは不要です(自動的に追加されます) / No comma is required before and after ||| (it is added automatically) - caption_prefix = "1girl, hatsune miku, vocaloid |||" - - # ||| の後はシャッフル、drop されず残ります / After |||, it is not shuffled or dropped and remains - # 単純に文字列として連結されるので、カンマなどは自分で入れる必要があります / It is simply concatenated as a string, so you need to put commas yourself - caption_suffix = ", anime screencap ||| masterpiece, rating: general" -``` - -#### Example of caption, secondary_separator notation: `secondary_separator = ";;;"` - -```txt -1girl, hatsune miku, vocaloid, upper body, looking at viewer, sky;;;cloud;;;day, outdoors -``` -The part `sky;;;cloud;;;day` is replaced with `sky,cloud,day` without shuffling or dropping. When shuffling and dropping are enabled, it is processed as a whole (as one tag). For example, it becomes `vocaloid, 1girl, upper body, sky,cloud,day, outdoors, hatsune miku` (shuffled) or `vocaloid, 1girl, outdoors, looking at viewer, upper body, hatsune miku` (dropped). - -#### Example of caption, enable_wildcard notation: `enable_wildcard = true` - -```txt -1girl, hatsune miku, vocaloid, upper body, looking at viewer, {simple|white} background -``` -`simple` or `white` is randomly selected, and it becomes `simple background` or `white background`. - -```txt -1girl, hatsune miku, vocaloid, {{retro style}} -``` -If you want to include `{` or `}` in the tag string, double them like `{{` or `}}` (in this example, the actual caption used for training is `{retro style}`). - -#### Example of caption, `keep_tokens_separator` notation: `keep_tokens_separator = "|||"` - -```txt -1girl, hatsune miku, vocaloid ||| stage, microphone, white shirt, smile ||| best quality, rating: general -``` -It becomes `1girl, hatsune miku, vocaloid, microphone, stage, white shirt, best quality, rating: general` or `1girl, hatsune miku, vocaloid, white shirt, smile, stage, microphone, best quality, rating: general` etc. - - -#### キャプション記述例、secondary_separator 記法:`secondary_separator = ";;;"` の場合 - -```txt -1girl, hatsune miku, vocaloid, upper body, looking at viewer, sky;;;cloud;;;day, outdoors -``` -`sky;;;cloud;;;day` の部分はシャッフル、drop されず `sky,cloud,day` に置換されます。シャッフル、drop が有効な場合、まとめて(一つのタグとして)処理されます。つまり `vocaloid, 1girl, upper body, sky,cloud,day, outdoors, hatsune miku` (シャッフル)や `vocaloid, 1girl, outdoors, looking at viewer, upper body, hatsune miku` (drop されたケース)などになります。 - -#### キャプション記述例、ワイルドカード記法: `enable_wildcard = true` の場合 - -```txt -1girl, hatsune miku, vocaloid, upper body, looking at viewer, {simple|white} background -``` -ランダムに `simple` または `white` が選ばれ、`simple background` または `white background` になります。 - -```txt -1girl, hatsune miku, vocaloid, {{retro style}} -``` -タグ文字列に `{` や `}` そのものを含めたい場合は `{{` や `}}` のように二つ重ねてください(この例では実際に学習に用いられるキャプションは `{retro style}` になります)。 - -#### キャプション記述例、`keep_tokens_separator` 記法: `keep_tokens_separator = "|||"` の場合 - -```txt -1girl, hatsune miku, vocaloid ||| stage, microphone, white shirt, smile ||| best quality, rating: general -``` -`1girl, hatsune miku, vocaloid, microphone, stage, white shirt, best quality, rating: general` や `1girl, hatsune miku, vocaloid, white shirt, smile, stage, microphone, best quality, rating: general` などになります。 - - ### Mar 15, 2024 / 2024/3/15: v0.8.5 - Fixed a bug that the value of timestep embedding during SDXL training was incorrect. From f9317052edb4ab3b3c531ac3b28825ee78b4a966 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 18 Mar 2024 08:53:23 +0900 Subject: [PATCH 084/103] update readme for timestep embs bug --- README.md | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 3639b7be8..001446b7c 100644 --- a/README.md +++ b/README.md @@ -252,12 +252,15 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum ### Mar 15, 2024 / 2024/3/15: v0.8.5 - Fixed a bug that the value of timestep embedding during SDXL training was incorrect. + - Please update for SDXL training. - The inference with the generation script is also fixed. - - The impact is unknown, but please update for SDXL training. + - This fix appears to resolve an issue where unintended artifacts occurred in trained models under certain conditions. +We would like to express our deep gratitude to Mark Saint (cacoe) from leonardo.ai, for reporting the issue and cooperating with the verification, and to gcem156 for the advice provided in identifying the part of the code that needed to be fixed. - SDXL 学習時の timestep embedding の値が誤っていたのを修正しました。 + - SDXL の学習時にはアップデートをお願いいたします。 - 生成スクリプトでの推論時についてもあわせて修正しました。 - - 影響の度合いは不明ですが、SDXL の学習時にはアップデートをお願いいたします。 + - この修正により、特定の条件下で学習されたモデルに意図しないアーティファクトが発生する問題が解消されるようです。問題を報告いただき、また検証にご協力いただいた leonardo.ai の Mark Saint (cacoe) 氏、および修正点の特定に関するアドバイスをいただいた gcem156 氏に深く感謝いたします。 ### Feb 24, 2024 / 2024/2/24: v0.8.4 From a7dff592d34a5dd9d306de822db70f0028676cab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= Date: Mon, 18 Mar 2024 22:29:05 +0800 Subject: [PATCH 085/103] Update tag_images_by_wd14_tagger.py add WDV3 --- finetune/tag_images_by_wd14_tagger.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index b56d921a3..e63ec3eb4 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -86,23 +86,26 @@ def main(args): logger.info(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}") files = FILES if args.onnx: + files = ["selected_tags.csv"] files += FILES_ONNX + else: + for file in SUB_DIR_FILES: + hf_hub_download( + args.repo_id, + file, + subfolder=SUB_DIR, + cache_dir=os.path.join(args.model_dir, SUB_DIR), + force_download=True, + force_filename=file, + ) for file in files: hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file) - for file in SUB_DIR_FILES: - hf_hub_download( - args.repo_id, - file, - subfolder=SUB_DIR, - cache_dir=os.path.join(args.model_dir, SUB_DIR), - force_download=True, - force_filename=file, - ) else: logger.info("using existing wd14 tagger model") # 画像を読み込む if args.onnx: + import torch import onnx import onnxruntime as ort From 5410a8c79b23c594bb340050b4a81e30d95cd7be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= Date: Mon, 18 Mar 2024 22:31:00 +0800 Subject: [PATCH 086/103] Update requirements.txt --- requirements.txt | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/requirements.txt b/requirements.txt index 279de350c..326b65b3e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,9 +22,9 @@ huggingface-hub==0.20.1 # for WD14 captioning (tensorflow) # tensorflow==2.10.1 # for WD14 captioning (onnx) -# onnx==1.14.1 -# onnxruntime-gpu==1.16.0 -# onnxruntime==1.16.0 +# onnx==1.15.1 +# onnxruntime-gpu==1.17.1 +# onnxruntime==1.17.1 # this is for onnx: # protobuf==3.20.3 # open clip for SDXL From a71c35ccd9c813821fcbd3f0e00d71fb5e6d91d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= Date: Mon, 18 Mar 2024 22:31:59 +0800 Subject: [PATCH 087/103] Update requirements.txt --- requirements.txt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/requirements.txt b/requirements.txt index 326b65b3e..6898eccf6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,6 +25,9 @@ huggingface-hub==0.20.1 # onnx==1.15.1 # onnxruntime-gpu==1.17.1 # onnxruntime==1.17.1 +# for cuda 12.1(default 11.8) +# onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/ + # this is for onnx: # protobuf==3.20.3 # open clip for SDXL From 6c51c971d135a346d2f9081760f138b1c6515e9b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= Date: Wed, 20 Mar 2024 09:35:21 +0800 Subject: [PATCH 088/103] fix typo --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 6898eccf6..805f0501d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,7 +22,7 @@ huggingface-hub==0.20.1 # for WD14 captioning (tensorflow) # tensorflow==2.10.1 # for WD14 captioning (onnx) -# onnx==1.15.1 +# onnx==1.15.0 # onnxruntime-gpu==1.17.1 # onnxruntime==1.17.1 # for cuda 12.1(default 11.8) From 80dbbf5e4875f56ff1e0d8aacea4e73b96a14b63 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 20 Mar 2024 16:14:57 +0900 Subject: [PATCH 089/103] tagger now stores model under repo_id subdir --- README.md | 9 ++++- finetune/tag_images_by_wd14_tagger.py | 55 ++++++++++++++++++--------- 2 files changed, 45 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index f0cad611f..d03204037 100644 --- a/README.md +++ b/README.md @@ -260,7 +260,9 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum - `keep_tokens_separator` is updated to be used twice in the caption. When you specify `keep_tokens_separator="|||"`, the part divided by the second `|||` is not shuffled or dropped and remains at the end. - The existing features `caption_prefix` and `caption_suffix` can be used together. `caption_prefix` and `caption_suffix` are processed first, and then `enable_wildcard`, `keep_tokens_separator`, shuffling and dropping, and `secondary_separator` are processed in order. - The examples are [shown below](#example-of-dataset-settings--データセット設定の記述例). - +- The support for v3 repositories is added to `tag_image_by_wd14_tagger.py` (`--onnx` option only). PR [#1192](https://github.com/kohya-ss/sd-scripts/pull/1192) Thanks to sdbds! + - Onnx may need to be updated. Onnx is not installed by default, so please install or update it with `pip install onnx==1.15.0 onnxruntime-gpu==1.17.1` etc. Please also check the comments in `requirements.txt`. +- The model is now saved in the subdirectory as `--repo_id` in `tag_image_by_wd14_tagger.py` . This caches multiple repo_id models. Please delete unnecessary files under `--model_dir`. - Colab での動作時、ログ出力で停止してしまうようです。学習スクリプトに `--console_log_simple` オプションを指定し、rich のロギングを無効してお試しください。 - `train_network.py` および `sdxl_train_network.py` で、学習したモデルのメタデータに一部のデータセット設定が記録されるよう修正しました(`caption_prefix`、`caption_suffix`、`keep_tokens_separator`、`secondary_separator`、`enable_wildcard`)。 @@ -269,6 +271,11 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum - `enable_wildcard` を追加しました。`true` にするとワイルドカード記法 `{aaa|bbb|ccc}` が使えます。詳しくは記述例をご覧ください。 - `keep_tokens_separator` をキャプション内に 2 つ使えるようにしました。たとえば `keep_tokens_separator="|||"` と指定したとき、`1girl, hatsune miku, vocaloid ||| stage, mic ||| best quality, rating: general` とキャプションを指定すると、二番目の `|||` で分割された部分はシャッフル、drop されず末尾に残ります。 - 既存の機能 `caption_prefix` と `caption_suffix` とあわせて使えます。`caption_prefix` と `caption_suffix` は一番最初に処理され、その後、ワイルドカード、`keep_tokens_separator`、シャッフルおよび drop、`secondary_separator` の順に処理されます。 +- `tag_image_by_wd14_tagger.py` で v3 のリポジトリがサポートされました(`--onnx` 指定時のみ有効)。 PR [#1192](https://github.com/kohya-ss/sd-scripts/pull/1192) sdbds 氏に感謝します。 + - Onnx のバージョンアップが必要になるかもしれません。デフォルトでは Onnx はインストールされていませんので、`pip install onnx==1.15.0 onnxruntime-gpu==1.17.1` 等でインストール、アップデートしてください。`requirements.txt` のコメントもあわせてご確認ください。 +- `tag_image_by_wd14_tagger.py` で、モデルを`--repo_id` のサブディレクトリに保存するようにしました。これにより複数のモデルファイルがキャッシュされます。`--model_dir` 直下の不要なファイルは削除願います。 + + #### Example of dataset settings / データセット設定の記述例: diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index e63ec3eb4..401c6d1ec 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -12,8 +12,10 @@ import library.train_util as train_util from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) # from wd14 tagger @@ -79,10 +81,15 @@ def collate_fn_remove_corrupted(batch): def main(args): + # model location is model_dir + repo_id + # repo id may be like "user/repo" or "user/repo/branch", so we need to remove slash + model_location = os.path.join(args.model_dir, args.repo_id.replace("/", "_")) + # hf_hub_downloadをそのまま使うとsymlink関係で問題があるらしいので、キャッシュディレクトリとforce_filenameを指定してなんとかする # depreacatedの警告が出るけどなくなったらその時 # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22 - if not os.path.exists(args.model_dir) or args.force_download: + if not os.path.exists(model_location) or args.force_download: + os.makedirs(args.model_dir, exist_ok=True) logger.info(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}") files = FILES if args.onnx: @@ -94,12 +101,12 @@ def main(args): args.repo_id, file, subfolder=SUB_DIR, - cache_dir=os.path.join(args.model_dir, SUB_DIR), + cache_dir=os.path.join(model_location, SUB_DIR), force_download=True, force_filename=file, ) for file in files: - hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file) + hf_hub_download(args.repo_id, file, cache_dir=model_location, force_download=True, force_filename=file) else: logger.info("using existing wd14 tagger model") @@ -109,7 +116,7 @@ def main(args): import onnx import onnxruntime as ort - onnx_path = f"{args.model_dir}/model.onnx" + onnx_path = f"{model_location}/model.onnx" logger.info("Running wd14 tagger with onnx") logger.info(f"loading onnx model: {onnx_path}") @@ -126,7 +133,7 @@ def main(args): except: batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_param - if args.batch_size != batch_size and type(batch_size) != str: + if args.batch_size != batch_size and type(batch_size) != str and batch_size > 0: # some rebatch model may use 'N' as dynamic axes logger.warning( f"Batch size {args.batch_size} doesn't match onnx model batch size {batch_size}, use model batch size {batch_size}" @@ -137,19 +144,19 @@ def main(args): ort_sess = ort.InferenceSession( onnx_path, - providers=["CUDAExecutionProvider"] - if "CUDAExecutionProvider" in ort.get_available_providers() - else ["CPUExecutionProvider"], + providers=( + ["CUDAExecutionProvider"] if "CUDAExecutionProvider" in ort.get_available_providers() else ["CPUExecutionProvider"] + ), ) else: from tensorflow.keras.models import load_model - model = load_model(f"{args.model_dir}") + model = load_model(f"{model_location}") # label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv") # 依存ライブラリを増やしたくないので自力で読むよ - with open(os.path.join(args.model_dir, CSV_FILE), "r", encoding="utf-8") as f: + with open(os.path.join(model_location, CSV_FILE), "r", encoding="utf-8") as f: reader = csv.reader(f) l = [row for row in reader] header = l[0] # tag_id,name,category,count @@ -175,8 +182,8 @@ def run_batch(path_imgs): imgs = np.array([im for _, im in path_imgs]) if args.onnx: - if len(imgs) < args.batch_size: - imgs = np.concatenate([imgs, np.zeros((args.batch_size - len(imgs), IMAGE_SIZE, IMAGE_SIZE, 3))], axis=0) + # if len(imgs) < args.batch_size: + # imgs = np.concatenate([imgs, np.zeros((args.batch_size - len(imgs), IMAGE_SIZE, IMAGE_SIZE, 3))], axis=0) probs = ort_sess.run(None, {input_name: imgs})[0] # onnx output numpy probs = probs[: len(path_imgs)] else: @@ -317,7 +324,9 @@ def setup_parser() -> argparse.ArgumentParser: help="directory to store wd14 tagger model / wd14 taggerのモデルを格納するディレクトリ", ) parser.add_argument( - "--force_download", action="store_true", help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします" + "--force_download", + action="store_true", + help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします", ) parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") parser.add_argument( @@ -332,8 +341,12 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)", ) - parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子") - parser.add_argument("--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値") + parser.add_argument( + "--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子" + ) + parser.add_argument( + "--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値" + ) parser.add_argument( "--general_threshold", type=float, @@ -346,7 +359,9 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="threshold of confidence to add a tag for character category, same as --thres if omitted / characterカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ", ) - parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する") + parser.add_argument( + "--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する" + ) parser.add_argument( "--remove_underscore", action="store_true", @@ -359,9 +374,13 @@ def setup_parser() -> argparse.ArgumentParser: default="", help="comma-separated list of undesired tags to remove from the output / 出力から除外したいタグのカンマ区切りのリスト", ) - parser.add_argument("--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する") + parser.add_argument( + "--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する" + ) parser.add_argument("--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する") - parser.add_argument("--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する") + parser.add_argument( + "--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する" + ) parser.add_argument( "--caption_separator", type=str, From 46331a9e8ef695ea0b5a19686202d011109a56b6 Mon Sep 17 00:00:00 2001 From: Victor Espinoza-Guerra Date: Wed, 20 Mar 2024 00:31:01 -0700 Subject: [PATCH 090/103] English Translation of config_README-ja.md (#1175) * Add files via upload Creating template to work on. * Update config_README-en.md Total Conversion from Japanese to English. * Update config_README-en.md * Update config_README-en.md * Update config_README-en.md --- docs/config_README-en.md | 279 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 279 insertions(+) create mode 100644 docs/config_README-en.md diff --git a/docs/config_README-en.md b/docs/config_README-en.md new file mode 100644 index 000000000..a0727934d --- /dev/null +++ b/docs/config_README-en.md @@ -0,0 +1,279 @@ +Original Source by kohya-ss + +A.I Translation by Model: NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO, editing by Darkstorm2150 + +# Config Readme + +This README is about the configuration files that can be passed with the `--dataset_config` option. + +## Overview + +By passing a configuration file, users can make detailed settings. + +* Multiple datasets can be configured + * For example, by setting `resolution` for each dataset, they can be mixed and trained. + * In training methods that support both the DreamBooth approach and the fine-tuning approach, datasets of the DreamBooth method and the fine-tuning method can be mixed. +* Settings can be changed for each subset + * A subset is a partition of the dataset by image directory or metadata. Several subsets make up a dataset. + * Options such as `keep_tokens` and `flip_aug` can be set for each subset. On the other hand, options such as `resolution` and `batch_size` can be set for each dataset, and their values are common among subsets belonging to the same dataset. More details will be provided later. + +The configuration file format can be JSON or TOML. Considering the ease of writing, it is recommended to use [TOML](https://toml.io/ja/v1.0.0-rc.2). The following explanation assumes the use of TOML. + + +Here is an example of a configuration file written in TOML. + +```toml +[general] +shuffle_caption = true +caption_extension = '.txt' +keep_tokens = 1 + +# This is a DreamBooth-style dataset +[[datasets]] +resolution = 512 +batch_size = 4 +keep_tokens = 2 + + [[datasets.subsets]] + image_dir = 'C:\hoge' + class_tokens = 'hoge girl' + # This subset uses keep_tokens = 2 (the value of the parent datasets) + + [[datasets.subsets]] + image_dir = 'C:\fuga' + class_tokens = 'fuga boy' + keep_tokens = 3 + + [[datasets.subsets]] + is_reg = true + image_dir = 'C:\reg' + class_tokens = 'human' + keep_tokens = 1 + +# This is a fine-tuning dataset +[[datasets]] +resolution = [768, 768] +batch_size = 2 + + [[datasets.subsets]] + image_dir = 'C:\piyo' + metadata_file = 'C:\piyo\piyo_md.json' + # This subset uses keep_tokens = 1 (the value of [general]) +``` + +In this example, three directories are trained as a DreamBooth-style dataset at 512x512 (batch size 4), and one directory is trained as a fine-tuning dataset at 768x768 (batch size 2). + +## Settings for datasets and subsets + +Settings for datasets and subsets are divided into several registration locations. + +* `[general]` + * This is where options that apply to all datasets or all subsets are specified. + * If there are options with the same name in the dataset-specific or subset-specific settings, the dataset-specific or subset-specific settings take precedence. +* `[[datasets]]` + * `datasets` is where settings for datasets are registered. This is where options that apply individually to each dataset are specified. + * If there are subset-specific settings, the subset-specific settings take precedence. +* `[[datasets.subsets]]` + * `datasets.subsets` is where settings for subsets are registered. This is where options that apply individually to each subset are specified. + +Here is an image showing the correspondence between image directories and registration locations in the previous example. + +``` +C:\ +├─ hoge -> [[datasets.subsets]] No.1 ┐ ┐ +├─ fuga -> [[datasets.subsets]] No.2 |-> [[datasets]] No.1 |-> [general] +├─ reg -> [[datasets.subsets]] No.3 ┘ | +└─ piyo -> [[datasets.subsets]] No.4 --> [[datasets]] No.2 ┘ +``` + +The image directory corresponds to each `[[datasets.subsets]]`. Then, multiple `[[datasets.subsets]]` are combined to form one `[[datasets]]`. All `[[datasets]]` and `[[datasets.subsets]]` belong to `[general]`. + +The available options for each registration location may differ, but if the same option is specified, the value in the lower registration location will take precedence. You can check how the `keep_tokens` option is handled in the previous example for better understanding. + +Additionally, the available options may vary depending on the method that the learning approach supports. + +* Options specific to the DreamBooth method +* Options specific to the fine-tuning method +* Options available when using the caption dropout technique + +When using both the DreamBooth method and the fine-tuning method, they can be used together with a learning approach that supports both. +When using them together, a point to note is that the method is determined based on the dataset, so it is not possible to mix DreamBooth method subsets and fine-tuning method subsets within the same dataset. +In other words, if you want to use both methods together, you need to set up subsets of different methods belonging to different datasets. + +In terms of program behavior, if the `metadata_file` option exists, it is determined to be a subset of fine-tuning. Therefore, for subsets belonging to the same dataset, as long as they are either "all have the `metadata_file` option" or "all have no `metadata_file` option," there is no problem. + +Below, the available options will be explained. For options with the same name as the command-line argument, the explanation will be omitted in principle. Please refer to other READMEs. + +### Common options for all learning methods + +These are options that can be specified regardless of the learning method. + +#### Data set specific options + +These are options related to the configuration of the data set. They cannot be described in `datasets.subsets`. + + +| Option Name | Example Setting | `[general]` | `[[datasets]]` | +| ---- | ---- | ---- | ---- | +| `batch_size` | `1` | o | o | +| `bucket_no_upscale` | `true` | o | o | +| `bucket_reso_steps` | `64` | o | o | +| `enable_bucket` | `true` | o | o | +| `max_bucket_reso` | `1024` | o | o | +| `min_bucket_reso` | `128` | o | o | +| `resolution` | `256`, `[512, 512]` | o | o | + +* `batch_size` + * This corresponds to the command-line argument `--train_batch_size`. + +These settings are fixed per dataset. That means that subsets belonging to the same dataset will share these settings. For example, if you want to prepare datasets with different resolutions, you can define them as separate datasets as shown in the example above, and set different resolutions for each. + +#### Options for Subsets + +These options are related to subset configuration. + +| Option Name | Example | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` | +| ---- | ---- | ---- | ---- | ---- | +| `color_aug` | `false` | o | o | o | +| `face_crop_aug_range` | `[1.0, 3.0]` | o | o | o | +| `flip_aug` | `true` | o | o | o | +| `keep_tokens` | `2` | o | o | o | +| `num_repeats` | `10` | o | o | o | +| `random_crop` | `false` | o | o | o | +| `shuffle_caption` | `true` | o | o | o | +| `caption_prefix` | `"masterpiece, best quality, "` | o | o | o | +| `caption_suffix` | `", from side"` | o | o | o | + +* `num_repeats` + * Specifies the number of repeats for images in a subset. This is equivalent to `--dataset_repeats` in fine-tuning but can be specified for any training method. +* `caption_prefix`, `caption_suffix` + * Specifies the prefix and suffix strings to be appended to the captions. Shuffling is performed with these strings included. Be cautious when using `keep_tokens`. + +### DreamBooth-specific options + +DreamBooth-specific options only exist as subsets-specific options. + +#### Subset-specific options + +Options related to the configuration of DreamBooth subsets. + +| Option Name | Example Setting | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` | +| ---- | ---- | ---- | ---- | ---- | +| `image_dir` | `'C:\hoge'` | - | - | o (required) | +| `caption_extension` | `".txt"` | o | o | o | +| `class_tokens` | `"sks girl"` | - | - | o | +| `is_reg` | `false` | - | - | o | + +Firstly, note that for `image_dir`, the path to the image files must be specified as being directly in the directory. Unlike the previous DreamBooth method, where images had to be placed in subdirectories, this is not compatible with that specification. Also, even if you name the folder something like "5_cat", the number of repeats of the image and the class name will not be reflected. If you want to set these individually, you will need to explicitly specify them using `num_repeats` and `class_tokens`. + +* `image_dir` + * Specifies the path to the image directory. This is a required option. + * Images must be placed directly under the directory. +* `class_tokens` + * Sets the class tokens. + * Only used during training when a corresponding caption file does not exist. The determination of whether or not to use it is made on a per-image basis. If `class_tokens` is not specified and a caption file is not found, an error will occur. +* `is_reg` + * Specifies whether the subset images are for normalization. If not specified, it is set to `false`, meaning that the images are not for normalization. + +### Fine-tuning method specific options + +The options for the fine-tuning method only exist for subset-specific options. + +#### Subset-specific options + +These options are related to the configuration of the fine-tuning method's subsets. + +| Option name | Example setting | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` | +| ---- | ---- | ---- | ---- | ---- | +| `image_dir` | `'C:\hoge'` | - | - | o | +| `metadata_file` | `'C:\piyo\piyo_md.json'` | - | - | o (required) | + +* `image_dir` + * Specify the path to the image directory. Unlike the DreamBooth method, specifying it is not mandatory, but it is recommended to do so. + * The case where it is not necessary to specify is when the `--full_path` is added to the command line when generating the metadata file. + * The images must be placed directly under the directory. +* `metadata_file` + * Specify the path to the metadata file used for the subset. This is a required option. + * It is equivalent to the command-line argument `--in_json`. + * Due to the specification that a metadata file must be specified for each subset, it is recommended to avoid creating a metadata file with images from different directories as a single metadata file. It is strongly recommended to prepare a separate metadata file for each image directory and register them as separate subsets. + +### Options available when caption dropout method can be used + +The options available when the caption dropout method can be used exist only for subsets. Regardless of whether it's the DreamBooth method or fine-tuning method, if it supports caption dropout, it can be specified. + +#### Subset-specific options + +Options related to the setting of subsets that caption dropout can be used for. + +| Option Name | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` | +| ---- | ---- | ---- | ---- | +| `caption_dropout_every_n_epochs` | o | o | o | +| `caption_dropout_rate` | o | o | o | +| `caption_tag_dropout_rate` | o | o | o | + +## Behavior when there are duplicate subsets + +In the case of the DreamBooth dataset, if there are multiple `image_dir` directories with the same content, they are considered to be duplicate subsets. For the fine-tuning dataset, if there are multiple `metadata_file` files with the same content, they are considered to be duplicate subsets. If duplicate subsets exist in the dataset, subsequent subsets will be ignored. + +However, if they belong to different datasets, they are not considered duplicates. For example, if you have subsets with the same `image_dir` in different datasets, they will not be considered duplicates. This is useful when you want to train with the same image but with different resolutions. + +```toml +# If data sets exist separately, they are not considered duplicates and are both used for training. + +[[datasets]] +resolution = 512 + + [[datasets.subsets]] + image_dir = 'C:\hoge' + +[[datasets]] +resolution = 768 + + [[datasets.subsets]] + image_dir = 'C:\hoge' +``` + +## Command Line Argument and Configuration File + +There are options in the configuration file that have overlapping roles with command line argument options. + +The following command line argument options are ignored if a configuration file is passed: + +* `--train_data_dir` +* `--reg_data_dir` +* `--in_json` + +The following command line argument options are given priority over the configuration file options if both are specified simultaneously. In most cases, they have the same names as the corresponding options in the configuration file. + +| Command Line Argument Option | Prioritized Configuration File Option | +| ------------------------------- | ------------------------------------- | +| `--bucket_no_upscale` | | +| `--bucket_reso_steps` | | +| `--caption_dropout_every_n_epochs` | | +| `--caption_dropout_rate` | | +| `--caption_extension` | | +| `--caption_tag_dropout_rate` | | +| `--color_aug` | | +| `--dataset_repeats` | `num_repeats` | +| `--enable_bucket` | | +| `--face_crop_aug_range` | | +| `--flip_aug` | | +| `--keep_tokens` | | +| `--min_bucket_reso` | | +| `--random_crop` | | +| `--resolution` | | +| `--shuffle_caption` | | +| `--train_batch_size` | `batch_size` | + +## Error Guide + +Currently, we are using an external library to check if the configuration file is written correctly, but the development has not been completed, and there is a problem that the error message is not clear. In the future, we plan to improve this problem. + +As a temporary measure, we will list common errors and their solutions. If you encounter an error even though it should be correct or if the error content is not understandable, please contact us as it may be a bug. + +* `voluptuous.error.MultipleInvalid: required key not provided @ ...`: This error occurs when a required option is not provided. It is highly likely that you forgot to specify the option or misspelled the option name. + * The error location is indicated by `...` in the error message. For example, if you encounter an error like `voluptuous.error.MultipleInvalid: required key not provided @ data['datasets'][0]['subsets'][0]['image_dir']`, it means that the `image_dir` option does not exist in the 0th `subsets` of the 0th `datasets` setting. +* `voluptuous.error.MultipleInvalid: expected int for dictionary value @ ...`: This error occurs when the specified value format is incorrect. It is highly likely that the value format is incorrect. The `int` part changes depending on the target option. The example configurations in this README may be helpful. +* `voluptuous.error.MultipleInvalid: extra keys not allowed @ ...`: This error occurs when there is an option name that is not supported. It is highly likely that you misspelled the option name or mistakenly included it. + + From 5f6196e4c71763250da316cc0f4ce15db1696017 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 20 Mar 2024 16:35:23 +0900 Subject: [PATCH 091/103] update readme --- README.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index d03204037..f8601c1b2 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,7 @@ Most of the documents are written in Japanese. * [Training guide - common](./docs/train_README-ja.md) : data preparation, options etc... * [Chinese version](./docs/train_README-zh.md) * [Dataset config](./docs/config_README-ja.md) + * [English version](./docs/config_README-en.md) * [DreamBooth training guide](./docs/train_db_README-ja.md) * [Step by Step fine-tuning guide](./docs/fine_tune_README_ja.md): * [training LoRA](./docs/train_network_README-ja.md) @@ -263,6 +264,8 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum - The support for v3 repositories is added to `tag_image_by_wd14_tagger.py` (`--onnx` option only). PR [#1192](https://github.com/kohya-ss/sd-scripts/pull/1192) Thanks to sdbds! - Onnx may need to be updated. Onnx is not installed by default, so please install or update it with `pip install onnx==1.15.0 onnxruntime-gpu==1.17.1` etc. Please also check the comments in `requirements.txt`. - The model is now saved in the subdirectory as `--repo_id` in `tag_image_by_wd14_tagger.py` . This caches multiple repo_id models. Please delete unnecessary files under `--model_dir`. +- The options `--noise_offset_random_strength` and `--ip_noise_gamma_random_strength` are added to each training script. These options can be used to vary the noise offset and ip noise gamma in the range of 0 to the specified value. PR [#1177](https://github.com/kohya-ss/sd-scripts/pull/1177) Thanks to KohakuBlueleaf! +- The [English version of the dataset settings documentation](./docs/config_README-en.md) is added. PR [#1175](https://github.com/kohya-ss/sd-scripts/pull/1175) Thanks to darkstorm2150! - Colab での動作時、ログ出力で停止してしまうようです。学習スクリプトに `--console_log_simple` オプションを指定し、rich のロギングを無効してお試しください。 - `train_network.py` および `sdxl_train_network.py` で、学習したモデルのメタデータに一部のデータセット設定が記録されるよう修正しました(`caption_prefix`、`caption_suffix`、`keep_tokens_separator`、`secondary_separator`、`enable_wildcard`)。 @@ -274,7 +277,8 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum - `tag_image_by_wd14_tagger.py` で v3 のリポジトリがサポートされました(`--onnx` 指定時のみ有効)。 PR [#1192](https://github.com/kohya-ss/sd-scripts/pull/1192) sdbds 氏に感謝します。 - Onnx のバージョンアップが必要になるかもしれません。デフォルトでは Onnx はインストールされていませんので、`pip install onnx==1.15.0 onnxruntime-gpu==1.17.1` 等でインストール、アップデートしてください。`requirements.txt` のコメントもあわせてご確認ください。 - `tag_image_by_wd14_tagger.py` で、モデルを`--repo_id` のサブディレクトリに保存するようにしました。これにより複数のモデルファイルがキャッシュされます。`--model_dir` 直下の不要なファイルは削除願います。 - +- 各学習スクリプトに、noise offset、ip noise gammaを、それぞれ 0~指定した値の範囲で変動させるオプション `--noise_offset_random_strength` および `--ip_noise_gamma_random_strength` が追加されました。 PR [#1177](https://github.com/kohya-ss/sd-scripts/pull/1177) KohakuBlueleaf 氏に感謝します。 +- データセット設定の[英語版ドキュメント](./docs/config_README-en.md) が追加されました。PR [#1175](https://github.com/kohya-ss/sd-scripts/pull/1175) darkstorm2150 氏に感謝します。 #### Example of dataset settings / データセット設定の記述例: From 3b0db0f17f46148abe345c5cdce76ff707bdccd3 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 20 Mar 2024 17:45:35 +0900 Subject: [PATCH 092/103] update readme --- README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index f8601c1b2..81c176a79 100644 --- a/README.md +++ b/README.md @@ -266,6 +266,8 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum - The model is now saved in the subdirectory as `--repo_id` in `tag_image_by_wd14_tagger.py` . This caches multiple repo_id models. Please delete unnecessary files under `--model_dir`. - The options `--noise_offset_random_strength` and `--ip_noise_gamma_random_strength` are added to each training script. These options can be used to vary the noise offset and ip noise gamma in the range of 0 to the specified value. PR [#1177](https://github.com/kohya-ss/sd-scripts/pull/1177) Thanks to KohakuBlueleaf! - The [English version of the dataset settings documentation](./docs/config_README-en.md) is added. PR [#1175](https://github.com/kohya-ss/sd-scripts/pull/1175) Thanks to darkstorm2150! +- The `.toml` file for the dataset config is now read in UTF-8 encoding. PR [#1167](https://github.com/kohya-ss/sd-scripts/pull/1167) Thanks to Horizon1704! + - Colab での動作時、ログ出力で停止してしまうようです。学習スクリプトに `--console_log_simple` オプションを指定し、rich のロギングを無効してお試しください。 - `train_network.py` および `sdxl_train_network.py` で、学習したモデルのメタデータに一部のデータセット設定が記録されるよう修正しました(`caption_prefix`、`caption_suffix`、`keep_tokens_separator`、`secondary_separator`、`enable_wildcard`)。 @@ -279,7 +281,7 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum - `tag_image_by_wd14_tagger.py` で、モデルを`--repo_id` のサブディレクトリに保存するようにしました。これにより複数のモデルファイルがキャッシュされます。`--model_dir` 直下の不要なファイルは削除願います。 - 各学習スクリプトに、noise offset、ip noise gammaを、それぞれ 0~指定した値の範囲で変動させるオプション `--noise_offset_random_strength` および `--ip_noise_gamma_random_strength` が追加されました。 PR [#1177](https://github.com/kohya-ss/sd-scripts/pull/1177) KohakuBlueleaf 氏に感謝します。 - データセット設定の[英語版ドキュメント](./docs/config_README-en.md) が追加されました。PR [#1175](https://github.com/kohya-ss/sd-scripts/pull/1175) darkstorm2150 氏に感謝します。 - +- データセット設定の `.toml` ファイルが UTF-8 encoding で読み込まれるようになりました。PR [#1167](https://github.com/kohya-ss/sd-scripts/pull/1167) Horizon1704 氏に感謝します。 #### Example of dataset settings / データセット設定の記述例: From 855add067b06464eaa47ed55840da0f17d675762 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 20 Mar 2024 18:14:05 +0900 Subject: [PATCH 093/103] update option help and readme --- README.md | 7 +++++-- library/train_util.py | 8 ++++---- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 81c176a79..804bad84a 100644 --- a/README.md +++ b/README.md @@ -253,6 +253,7 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum ### Working in progress - Colab seems to stop with log output. Try specifying `--console_log_simple` option in the training script to disable rich logging. +- The `.toml` file for the dataset config is now read in UTF-8 encoding. PR [#1167](https://github.com/kohya-ss/sd-scripts/pull/1167) Thanks to Horizon1704! - `train_network.py` and `sdxl_train_network.py` are modified to record some dataset settings in the metadata of the trained model (`caption_prefix`, `caption_suffix`, `keep_tokens_separator`, `secondary_separator`, `enable_wildcard`). - Some features are added to the dataset subset settings. - `secondary_separator` is added to specify the tag separator that is not the target of shuffling or dropping. @@ -266,10 +267,11 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum - The model is now saved in the subdirectory as `--repo_id` in `tag_image_by_wd14_tagger.py` . This caches multiple repo_id models. Please delete unnecessary files under `--model_dir`. - The options `--noise_offset_random_strength` and `--ip_noise_gamma_random_strength` are added to each training script. These options can be used to vary the noise offset and ip noise gamma in the range of 0 to the specified value. PR [#1177](https://github.com/kohya-ss/sd-scripts/pull/1177) Thanks to KohakuBlueleaf! - The [English version of the dataset settings documentation](./docs/config_README-en.md) is added. PR [#1175](https://github.com/kohya-ss/sd-scripts/pull/1175) Thanks to darkstorm2150! -- The `.toml` file for the dataset config is now read in UTF-8 encoding. PR [#1167](https://github.com/kohya-ss/sd-scripts/pull/1167) Thanks to Horizon1704! +- The options `--save_state_on_train_end` are added to each training script. PR [#1168](https://github.com/kohya-ss/sd-scripts/pull/1168) Thanks to gesen2egee! - Colab での動作時、ログ出力で停止してしまうようです。学習スクリプトに `--console_log_simple` オプションを指定し、rich のロギングを無効してお試しください。 +- データセット設定の `.toml` ファイルが UTF-8 encoding で読み込まれるようになりました。PR [#1167](https://github.com/kohya-ss/sd-scripts/pull/1167) Horizon1704 氏に感謝します。 - `train_network.py` および `sdxl_train_network.py` で、学習したモデルのメタデータに一部のデータセット設定が記録されるよう修正しました(`caption_prefix`、`caption_suffix`、`keep_tokens_separator`、`secondary_separator`、`enable_wildcard`)。 - データセットのサブセット設定にいくつかの機能を追加しました。 - シャッフルの対象とならないタグ分割識別子の指定 `secondary_separator` を追加しました。`secondary_separator=";;;"` のように指定します。`secondary_separator` で区切ることで、その部分はシャッフル、drop 時にまとめて扱われます。詳しくは記述例をご覧ください。 @@ -281,7 +283,8 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum - `tag_image_by_wd14_tagger.py` で、モデルを`--repo_id` のサブディレクトリに保存するようにしました。これにより複数のモデルファイルがキャッシュされます。`--model_dir` 直下の不要なファイルは削除願います。 - 各学習スクリプトに、noise offset、ip noise gammaを、それぞれ 0~指定した値の範囲で変動させるオプション `--noise_offset_random_strength` および `--ip_noise_gamma_random_strength` が追加されました。 PR [#1177](https://github.com/kohya-ss/sd-scripts/pull/1177) KohakuBlueleaf 氏に感謝します。 - データセット設定の[英語版ドキュメント](./docs/config_README-en.md) が追加されました。PR [#1175](https://github.com/kohya-ss/sd-scripts/pull/1175) darkstorm2150 氏に感謝します。 -- データセット設定の `.toml` ファイルが UTF-8 encoding で読み込まれるようになりました。PR [#1167](https://github.com/kohya-ss/sd-scripts/pull/1167) Horizon1704 氏に感謝します。 +- 各学習スクリプトに、学習終了時に state を保存する `--save_state_on_train_end` オプションが追加されました。 PR [#1168](https://github.com/kohya-ss/sd-scripts/pull/1168) gesen2egee 氏に感謝します。 + #### Example of dataset settings / データセット設定の記述例: diff --git a/library/train_util.py b/library/train_util.py index 23961505f..a13985ee2 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2936,13 +2936,13 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument( "--save_state", action="store_true", - help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する", + help="save training state additionally (including optimizer states etc.) when saving model / optimizerなど学習状態も含めたstateをモデル保存時に追加で保存する", ) parser.add_argument( "--save_state_on_train_end", action="store_true", - help="save training state additionally (including optimizer states etc.) on train end / optimizerなど学習状態も含めたstateを追加で保存する", - ) + help="save training state (including optimizer states etc.) on train end / optimizerなど学習状態も含めたstateを学習完了時に保存する", + ) parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate") parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ") @@ -3550,7 +3550,7 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar exit(1) logger.info(f"Loading settings from {config_path}...") - with open(config_path, "r", encoding='utf-8') as f: + with open(config_path, "r", encoding="utf-8") as f: config_dict = toml.load(f) # combine all sections into one From d17c0f508416d734360393804732bfa420fe1c27 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 21 Mar 2024 08:31:29 +0900 Subject: [PATCH 094/103] update dataset config doc --- README.md | 88 +--------------------------------------- docs/config_README-en.md | 73 +++++++++++++++++++++++++++++++++ docs/config_README-ja.md | 75 +++++++++++++++++++++++++++++++++- 3 files changed, 148 insertions(+), 88 deletions(-) diff --git a/README.md b/README.md index 804bad84a..dae311325 100644 --- a/README.md +++ b/README.md @@ -261,7 +261,7 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum - `enable_wildcard` is added. When set to `true`, the wildcard notation `{aaa|bbb|ccc}` can be used. See the example below. - `keep_tokens_separator` is updated to be used twice in the caption. When you specify `keep_tokens_separator="|||"`, the part divided by the second `|||` is not shuffled or dropped and remains at the end. - The existing features `caption_prefix` and `caption_suffix` can be used together. `caption_prefix` and `caption_suffix` are processed first, and then `enable_wildcard`, `keep_tokens_separator`, shuffling and dropping, and `secondary_separator` are processed in order. - - The examples are [shown below](#example-of-dataset-settings--データセット設定の記述例). + - See [Dataset config](./docs/config_README-en.md) for details. - The support for v3 repositories is added to `tag_image_by_wd14_tagger.py` (`--onnx` option only). PR [#1192](https://github.com/kohya-ss/sd-scripts/pull/1192) Thanks to sdbds! - Onnx may need to be updated. Onnx is not installed by default, so please install or update it with `pip install onnx==1.15.0 onnxruntime-gpu==1.17.1` etc. Please also check the comments in `requirements.txt`. - The model is now saved in the subdirectory as `--repo_id` in `tag_image_by_wd14_tagger.py` . This caches multiple repo_id models. Please delete unnecessary files under `--model_dir`. @@ -278,6 +278,7 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum - `enable_wildcard` を追加しました。`true` にするとワイルドカード記法 `{aaa|bbb|ccc}` が使えます。詳しくは記述例をご覧ください。 - `keep_tokens_separator` をキャプション内に 2 つ使えるようにしました。たとえば `keep_tokens_separator="|||"` と指定したとき、`1girl, hatsune miku, vocaloid ||| stage, mic ||| best quality, rating: general` とキャプションを指定すると、二番目の `|||` で分割された部分はシャッフル、drop されず末尾に残ります。 - 既存の機能 `caption_prefix` と `caption_suffix` とあわせて使えます。`caption_prefix` と `caption_suffix` は一番最初に処理され、その後、ワイルドカード、`keep_tokens_separator`、シャッフルおよび drop、`secondary_separator` の順に処理されます。 + - 詳細は [データセット設定](./docs/config_README-ja.md) をご覧ください。 - `tag_image_by_wd14_tagger.py` で v3 のリポジトリがサポートされました(`--onnx` 指定時のみ有効)。 PR [#1192](https://github.com/kohya-ss/sd-scripts/pull/1192) sdbds 氏に感謝します。 - Onnx のバージョンアップが必要になるかもしれません。デフォルトでは Onnx はインストールされていませんので、`pip install onnx==1.15.0 onnxruntime-gpu==1.17.1` 等でインストール、アップデートしてください。`requirements.txt` のコメントもあわせてご確認ください。 - `tag_image_by_wd14_tagger.py` で、モデルを`--repo_id` のサブディレクトリに保存するようにしました。これにより複数のモデルファイルがキャッシュされます。`--model_dir` 直下の不要なファイルは削除願います。 @@ -286,91 +287,6 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum - 各学習スクリプトに、学習終了時に state を保存する `--save_state_on_train_end` オプションが追加されました。 PR [#1168](https://github.com/kohya-ss/sd-scripts/pull/1168) gesen2egee 氏に感謝します。 -#### Example of dataset settings / データセット設定の記述例: - -```toml -[general] -flip_aug = true -color_aug = false -resolution = [1024, 1024] - -[[datasets]] -batch_size = 6 -enable_bucket = true -bucket_no_upscale = true -caption_extension = ".txt" -keep_tokens_separator= "|||" -shuffle_caption = true -caption_tag_dropout_rate = 0.1 -secondary_separator = ";;;" # subset 側に書くこともできます / can be written in the subset side -enable_wildcard = true # 同上 / same as above - - [[datasets.subsets]] - image_dir = "/path/to/image_dir" - num_repeats = 1 - - # ||| の前後はカンマは不要です(自動的に追加されます) / No comma is required before and after ||| (it is added automatically) - caption_prefix = "1girl, hatsune miku, vocaloid |||" - - # ||| の後はシャッフル、drop されず残ります / After |||, it is not shuffled or dropped and remains - # 単純に文字列として連結されるので、カンマなどは自分で入れる必要があります / It is simply concatenated as a string, so you need to put commas yourself - caption_suffix = ", anime screencap ||| masterpiece, rating: general" -``` - -#### Example of caption, secondary_separator notation: `secondary_separator = ";;;"` - -```txt -1girl, hatsune miku, vocaloid, upper body, looking at viewer, sky;;;cloud;;;day, outdoors -``` -The part `sky;;;cloud;;;day` is replaced with `sky,cloud,day` without shuffling or dropping. When shuffling and dropping are enabled, it is processed as a whole (as one tag). For example, it becomes `vocaloid, 1girl, upper body, sky,cloud,day, outdoors, hatsune miku` (shuffled) or `vocaloid, 1girl, outdoors, looking at viewer, upper body, hatsune miku` (dropped). - -#### Example of caption, enable_wildcard notation: `enable_wildcard = true` - -```txt -1girl, hatsune miku, vocaloid, upper body, looking at viewer, {simple|white} background -``` -`simple` or `white` is randomly selected, and it becomes `simple background` or `white background`. - -```txt -1girl, hatsune miku, vocaloid, {{retro style}} -``` -If you want to include `{` or `}` in the tag string, double them like `{{` or `}}` (in this example, the actual caption used for training is `{retro style}`). - -#### Example of caption, `keep_tokens_separator` notation: `keep_tokens_separator = "|||"` - -```txt -1girl, hatsune miku, vocaloid ||| stage, microphone, white shirt, smile ||| best quality, rating: general -``` -It becomes `1girl, hatsune miku, vocaloid, microphone, stage, white shirt, best quality, rating: general` or `1girl, hatsune miku, vocaloid, white shirt, smile, stage, microphone, best quality, rating: general` etc. - - -#### キャプション記述例、secondary_separator 記法:`secondary_separator = ";;;"` の場合 - -```txt -1girl, hatsune miku, vocaloid, upper body, looking at viewer, sky;;;cloud;;;day, outdoors -``` -`sky;;;cloud;;;day` の部分はシャッフル、drop されず `sky,cloud,day` に置換されます。シャッフル、drop が有効な場合、まとめて(一つのタグとして)処理されます。つまり `vocaloid, 1girl, upper body, sky,cloud,day, outdoors, hatsune miku` (シャッフル)や `vocaloid, 1girl, outdoors, looking at viewer, upper body, hatsune miku` (drop されたケース)などになります。 - -#### キャプション記述例、ワイルドカード記法: `enable_wildcard = true` の場合 - -```txt -1girl, hatsune miku, vocaloid, upper body, looking at viewer, {simple|white} background -``` -ランダムに `simple` または `white` が選ばれ、`simple background` または `white background` になります。 - -```txt -1girl, hatsune miku, vocaloid, {{retro style}} -``` -タグ文字列に `{` や `}` そのものを含めたい場合は `{{` や `}}` のように二つ重ねてください(この例では実際に学習に用いられるキャプションは `{retro style}` になります)。 - -#### キャプション記述例、`keep_tokens_separator` 記法: `keep_tokens_separator = "|||"` の場合 - -```txt -1girl, hatsune miku, vocaloid ||| stage, microphone, white shirt, smile ||| best quality, rating: general -``` -`1girl, hatsune miku, vocaloid, microphone, stage, white shirt, best quality, rating: general` や `1girl, hatsune miku, vocaloid, white shirt, smile, stage, microphone, best quality, rating: general` などになります。 - - ### Mar 15, 2024 / 2024/3/15: v0.8.5 - Fixed a bug that the value of timestep embedding during SDXL training was incorrect. diff --git a/docs/config_README-en.md b/docs/config_README-en.md index a0727934d..bdcaabfc7 100644 --- a/docs/config_README-en.md +++ b/docs/config_README-en.md @@ -1,7 +1,10 @@ Original Source by kohya-ss +First version: A.I Translation by Model: NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO, editing by Darkstorm2150 +Some parts are manually added. + # Config Readme This README is about the configuration files that can be passed with the `--dataset_config` option. @@ -143,11 +146,23 @@ These options are related to subset configuration. | `shuffle_caption` | `true` | o | o | o | | `caption_prefix` | `"masterpiece, best quality, "` | o | o | o | | `caption_suffix` | `", from side"` | o | o | o | +| `caption_separator` | (not specified) | o | o | o | +| `keep_tokens_separator` | `“|||”` | o | o | o | +| `secondary_separator` | `“;;;”` | o | o | o | +| `enable_wildcard` | `true` | o | o | o | * `num_repeats` * Specifies the number of repeats for images in a subset. This is equivalent to `--dataset_repeats` in fine-tuning but can be specified for any training method. * `caption_prefix`, `caption_suffix` * Specifies the prefix and suffix strings to be appended to the captions. Shuffling is performed with these strings included. Be cautious when using `keep_tokens`. +* `caption_separator` + * Specifies the string to separate the tags. The default is `,`. This option is usually not necessary to set. +* `keep_tokens_separator` + * Specifies the string to separate the parts to be fixed in the caption. For example, if you specify `aaa, bbb ||| ccc, ddd, eee, fff ||| ggg, hhh`, the parts `aaa, bbb` and `ggg, hhh` will remain, and the rest will be shuffled and dropped. The comma in between is not necessary. As a result, the prompt will be `aaa, bbb, eee, ccc, fff, ggg, hhh` or `aaa, bbb, fff, ccc, eee, ggg, hhh`, etc. +* `secondary_separator` + * Specifies an additional separator. The part separated by this separator is treated as one tag and is shuffled and dropped. It is then replaced by `caption_separator`. For example, if you specify `aaa;;;bbb;;;ccc`, it will be replaced by `aaa,bbb,ccc` or dropped together. +* `enable_wildcard` + * Enables wildcard notation. This will be explained later. ### DreamBooth-specific options @@ -276,4 +291,62 @@ As a temporary measure, we will list common errors and their solutions. If you e * `voluptuous.error.MultipleInvalid: expected int for dictionary value @ ...`: This error occurs when the specified value format is incorrect. It is highly likely that the value format is incorrect. The `int` part changes depending on the target option. The example configurations in this README may be helpful. * `voluptuous.error.MultipleInvalid: extra keys not allowed @ ...`: This error occurs when there is an option name that is not supported. It is highly likely that you misspelled the option name or mistakenly included it. +## Miscellaneous + +### Example of configuration file, 設定ファイルの記述例 + +```toml +[general] +flip_aug = true +color_aug = false +resolution = [1024, 1024] + +[[datasets]] +batch_size = 6 +enable_bucket = true +bucket_no_upscale = true +caption_extension = ".txt" +keep_tokens_separator= "|||" +shuffle_caption = true +caption_tag_dropout_rate = 0.1 +secondary_separator = ";;;" # subset 側に書くこともできます / can be written in the subset side +enable_wildcard = true # 同上 / same as above + + [[datasets.subsets]] + image_dir = "/path/to/image_dir" + num_repeats = 1 + + # ||| の前後はカンマは不要です(自動的に追加されます) / No comma is required before and after ||| (it is added automatically) + caption_prefix = "1girl, hatsune miku, vocaloid |||" + + # ||| の後はシャッフル、drop されず残ります / After |||, it is not shuffled or dropped and remains + # 単純に文字列として連結されるので、カンマなどは自分で入れる必要があります / It is simply concatenated as a string, so you need to put commas yourself + caption_suffix = ", anime screencap ||| masterpiece, rating: general" +``` + +### Example of caption, secondary_separator notation: `secondary_separator = ";;;"` + +```txt +1girl, hatsune miku, vocaloid, upper body, looking at viewer, sky;;;cloud;;;day, outdoors +``` +The part `sky;;;cloud;;;day` is replaced with `sky,cloud,day` without shuffling or dropping. When shuffling and dropping are enabled, it is processed as a whole (as one tag). For example, it becomes `vocaloid, 1girl, upper body, sky,cloud,day, outdoors, hatsune miku` (shuffled) or `vocaloid, 1girl, outdoors, looking at viewer, upper body, hatsune miku` (dropped). + +### Example of caption, enable_wildcard notation: `enable_wildcard = true` + +```txt +1girl, hatsune miku, vocaloid, upper body, looking at viewer, {simple|white} background +``` +`simple` or `white` is randomly selected, and it becomes `simple background` or `white background`. + +```txt +1girl, hatsune miku, vocaloid, {{retro style}} +``` +If you want to include `{` or `}` in the tag string, double them like `{{` or `}}` (in this example, the actual caption used for training is `{retro style}`). + +### Example of caption, `keep_tokens_separator` notation: `keep_tokens_separator = "|||"` + +```txt +1girl, hatsune miku, vocaloid ||| stage, microphone, white shirt, smile ||| best quality, rating: general +``` +It becomes `1girl, hatsune miku, vocaloid, microphone, stage, white shirt, best quality, rating: general` or `1girl, hatsune miku, vocaloid, white shirt, smile, stage, microphone, best quality, rating: general` etc. diff --git a/docs/config_README-ja.md b/docs/config_README-ja.md index 69a03f6cf..47bb5c57d 100644 --- a/docs/config_README-ja.md +++ b/docs/config_README-ja.md @@ -1,5 +1,3 @@ -For non-Japanese speakers: this README is provided only in Japanese in the current state. Sorry for inconvenience. We will provide English version in the near future. - `--dataset_config` で渡すことができる設定ファイルに関する説明です。 ## 概要 @@ -140,12 +138,28 @@ DreamBooth の手法と fine tuning の手法の両方とも利用可能な学 | `shuffle_caption` | `true` | o | o | o | | `caption_prefix` | `“masterpiece, best quality, ”` | o | o | o | | `caption_suffix` | `“, from side”` | o | o | o | +| `caption_separator` | (通常は設定しません) | o | o | o | +| `keep_tokens_separator` | `“|||”` | o | o | o | +| `secondary_separator` | `“;;;”` | o | o | o | +| `enable_wildcard` | `true` | o | o | o | * `num_repeats` * サブセットの画像の繰り返し回数を指定します。fine tuning における `--dataset_repeats` に相当しますが、`num_repeats` はどの学習方法でも指定可能です。 * `caption_prefix`, `caption_suffix` * キャプションの前、後に付与する文字列を指定します。シャッフルはこれらの文字列を含めた状態で行われます。`keep_tokens` を指定する場合には注意してください。 +* `caption_separator` + * タグを区切る文字列を指定します。デフォルトは `,` です。このオプションは通常は設定する必要はありません。 + +* `keep_tokens_separator` + * キャプションで固定したい部分を区切る文字列を指定します。たとえば `aaa, bbb ||| ccc, ddd, eee, fff ||| ggg, hhh` のように指定すると、`aaa, bbb` と `ggg, hhh` の部分はシャッフル、drop されず残ります。間のカンマは不要です。結果としてプロンプトは `aaa, bbb, eee, ccc, fff, ggg, hhh` や `aaa, bbb, fff, ccc, eee, ggg, hhh` などになります。 + +* `secondary_separator` + * 追加の区切り文字を指定します。この区切り文字で区切られた部分は一つのタグとして扱われ、シャッフル、drop されます。その後、`caption_separator` に置き換えられます。たとえば `aaa;;;bbb;;;ccc` のように指定すると、`aaa,bbb,ccc` に置き換えられるか、まとめて drop されます。 + +* `enable_wildcard` + * ワイルドカード記法を有効にします。ワイルドカード記法については後述します。 + ### DreamBooth 方式専用のオプション DreamBooth 方式のオプションは、サブセット向けオプションのみ存在します。 @@ -280,4 +294,61 @@ resolution = 768 * `voluptuous.error.MultipleInvalid: expected int for dictionary value @ ...`: 指定する値の形式が不正というエラーです。値の形式が間違っている可能性が高いです。`int` の部分は対象となるオプションによって変わります。この README に載っているオプションの「設定例」が役立つかもしれません。 * `voluptuous.error.MultipleInvalid: extra keys not allowed @ ...`: 対応していないオプション名が存在している場合に発生するエラーです。オプション名を間違って記述しているか、誤って紛れ込んでいる可能性が高いです。 +## その他 + +### Example of configuration file, 設定ファイルの記述例 +```toml +[general] +flip_aug = true +color_aug = false +resolution = [1024, 1024] + +[[datasets]] +batch_size = 6 +enable_bucket = true +bucket_no_upscale = true +caption_extension = ".txt" +keep_tokens_separator= "|||" +shuffle_caption = true +caption_tag_dropout_rate = 0.1 +secondary_separator = ";;;" # subset 側に書くこともできます / can be written in the subset side +enable_wildcard = true # 同上 / same as above + + [[datasets.subsets]] + image_dir = "/path/to/image_dir" + num_repeats = 1 + + # ||| の前後はカンマは不要です(自動的に追加されます) / No comma is required before and after ||| (it is added automatically) + caption_prefix = "1girl, hatsune miku, vocaloid |||" + + # ||| の後はシャッフル、drop されず残ります / After |||, it is not shuffled or dropped and remains + # 単純に文字列として連結されるので、カンマなどは自分で入れる必要があります / It is simply concatenated as a string, so you need to put commas yourself + caption_suffix = ", anime screencap ||| masterpiece, rating: general" +``` + +### キャプション記述例、secondary_separator 記法:`secondary_separator = ";;;"` の場合 + +```txt +1girl, hatsune miku, vocaloid, upper body, looking at viewer, sky;;;cloud;;;day, outdoors +``` +`sky;;;cloud;;;day` の部分はシャッフル、drop されず `sky,cloud,day` に置換されます。シャッフル、drop が有効な場合、まとめて(一つのタグとして)処理されます。つまり `vocaloid, 1girl, upper body, sky,cloud,day, outdoors, hatsune miku` (シャッフル)や `vocaloid, 1girl, outdoors, looking at viewer, upper body, hatsune miku` (drop されたケース)などになります。 + +### キャプション記述例、ワイルドカード記法: `enable_wildcard = true` の場合 + +```txt +1girl, hatsune miku, vocaloid, upper body, looking at viewer, {simple|white} background +``` +ランダムに `simple` または `white` が選ばれ、`simple background` または `white background` になります。 + +```txt +1girl, hatsune miku, vocaloid, {{retro style}} +``` +タグ文字列に `{` や `}` そのものを含めたい場合は `{{` や `}}` のように二つ重ねてください(この例では実際に学習に用いられるキャプションは `{retro style}` になります)。 + +### キャプション記述例、`keep_tokens_separator` 記法: `keep_tokens_separator = "|||"` の場合 + +```txt +1girl, hatsune miku, vocaloid ||| stage, microphone, white shirt, smile ||| best quality, rating: general +``` +`1girl, hatsune miku, vocaloid, microphone, stage, white shirt, best quality, rating: general` や `1girl, hatsune miku, vocaloid, white shirt, smile, stage, microphone, best quality, rating: general` などになります。 From 594c7f70500e402586654e73501e7d8fc74592b8 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 23 Mar 2024 16:11:31 +0900 Subject: [PATCH 095/103] format by black --- finetune/merge_captions_to_metadata.py | 132 ++++++++++++++----------- finetune/merge_dd_tags_to_metadata.py | 118 ++++++++++++---------- 2 files changed, 144 insertions(+), 106 deletions(-) diff --git a/finetune/merge_captions_to_metadata.py b/finetune/merge_captions_to_metadata.py index 60765b863..89f717473 100644 --- a/finetune/merge_captions_to_metadata.py +++ b/finetune/merge_captions_to_metadata.py @@ -6,75 +6,95 @@ import library.train_util as train_util import os from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) + def main(args): - assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください" + assert not args.recursive or ( + args.recursive and args.full_path + ), "recursive requires full_path / recursiveはfull_pathと同時に指定してください" - train_data_dir_path = Path(args.train_data_dir) - image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) - logger.info(f"found {len(image_paths)} images.") + train_data_dir_path = Path(args.train_data_dir) + image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) + logger.info(f"found {len(image_paths)} images.") - if args.in_json is None and Path(args.out_json).is_file(): - args.in_json = args.out_json + if args.in_json is None and Path(args.out_json).is_file(): + args.in_json = args.out_json - if args.in_json is not None: - logger.info(f"loading existing metadata: {args.in_json}") - metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8')) - logger.warning("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます") - else: - logger.info("new metadata will be created / 新しいメタデータファイルが作成されます") - metadata = {} + if args.in_json is not None: + logger.info(f"loading existing metadata: {args.in_json}") + metadata = json.loads(Path(args.in_json).read_text(encoding="utf-8")) + logger.warning("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます") + else: + logger.info("new metadata will be created / 新しいメタデータファイルが作成されます") + metadata = {} - logger.info("merge caption texts to metadata json.") - for image_path in tqdm(image_paths): - caption_path = image_path.with_suffix(args.caption_extension) - caption = caption_path.read_text(encoding='utf-8').strip() + logger.info("merge caption texts to metadata json.") + for image_path in tqdm(image_paths): + caption_path = image_path.with_suffix(args.caption_extension) + caption = caption_path.read_text(encoding="utf-8").strip() - if not os.path.exists(caption_path): - caption_path = os.path.join(image_path, args.caption_extension) + if not os.path.exists(caption_path): + caption_path = os.path.join(image_path, args.caption_extension) - image_key = str(image_path) if args.full_path else image_path.stem - if image_key not in metadata: - metadata[image_key] = {} + image_key = str(image_path) if args.full_path else image_path.stem + if image_key not in metadata: + metadata[image_key] = {} - metadata[image_key]['caption'] = caption - if args.debug: - logger.info(f"{image_key} {caption}") + metadata[image_key]["caption"] = caption + if args.debug: + logger.info(f"{image_key} {caption}") - # metadataを書き出して終わり - logger.info(f"writing metadata: {args.out_json}") - Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8') - logger.info("done!") + # metadataを書き出して終わり + logger.info(f"writing metadata: {args.out_json}") + Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding="utf-8") + logger.info("done!") def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") - parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") - parser.add_argument("--in_json", type=str, - help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)") - parser.add_argument("--caption_extention", type=str, default=None, - help="extension of caption file (for backward compatibility) / 読み込むキャプションファイルの拡張子(スペルミスしていたのを残してあります)") - parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 読み込むキャプションファイルの拡張子") - parser.add_argument("--full_path", action="store_true", - help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)") - parser.add_argument("--recursive", action="store_true", - help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す") - parser.add_argument("--debug", action="store_true", help="debug mode") - - return parser - - -if __name__ == '__main__': - parser = setup_parser() - - args = parser.parse_args() - - # スペルミスしていたオプションを復元する - if args.caption_extention is not None: - args.caption_extension = args.caption_extention - - main(args) + parser = argparse.ArgumentParser() + parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") + parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") + parser.add_argument( + "--in_json", + type=str, + help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)", + ) + parser.add_argument( + "--caption_extention", + type=str, + default=None, + help="extension of caption file (for backward compatibility) / 読み込むキャプションファイルの拡張子(スペルミスしていたのを残してあります)", + ) + parser.add_argument( + "--caption_extension", type=str, default=".caption", help="extension of caption file / 読み込むキャプションファイルの拡張子" + ) + parser.add_argument( + "--full_path", + action="store_true", + help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)", + ) + parser.add_argument( + "--recursive", + action="store_true", + help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す", + ) + parser.add_argument("--debug", action="store_true", help="debug mode") + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + + # スペルミスしていたオプションを復元する + if args.caption_extention is not None: + args.caption_extension = args.caption_extention + + main(args) diff --git a/finetune/merge_dd_tags_to_metadata.py b/finetune/merge_dd_tags_to_metadata.py index 9ef8f14b0..ce22d990e 100644 --- a/finetune/merge_dd_tags_to_metadata.py +++ b/finetune/merge_dd_tags_to_metadata.py @@ -6,70 +6,88 @@ import library.train_util as train_util import os from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) + def main(args): - assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください" + assert not args.recursive or ( + args.recursive and args.full_path + ), "recursive requires full_path / recursiveはfull_pathと同時に指定してください" - train_data_dir_path = Path(args.train_data_dir) - image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) - logger.info(f"found {len(image_paths)} images.") + train_data_dir_path = Path(args.train_data_dir) + image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) + logger.info(f"found {len(image_paths)} images.") - if args.in_json is None and Path(args.out_json).is_file(): - args.in_json = args.out_json + if args.in_json is None and Path(args.out_json).is_file(): + args.in_json = args.out_json - if args.in_json is not None: - logger.info(f"loading existing metadata: {args.in_json}") - metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8')) - logger.warning("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます") - else: - logger.info("new metadata will be created / 新しいメタデータファイルが作成されます") - metadata = {} + if args.in_json is not None: + logger.info(f"loading existing metadata: {args.in_json}") + metadata = json.loads(Path(args.in_json).read_text(encoding="utf-8")) + logger.warning("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます") + else: + logger.info("new metadata will be created / 新しいメタデータファイルが作成されます") + metadata = {} - logger.info("merge tags to metadata json.") - for image_path in tqdm(image_paths): - tags_path = image_path.with_suffix(args.caption_extension) - tags = tags_path.read_text(encoding='utf-8').strip() + logger.info("merge tags to metadata json.") + for image_path in tqdm(image_paths): + tags_path = image_path.with_suffix(args.caption_extension) + tags = tags_path.read_text(encoding="utf-8").strip() - if not os.path.exists(tags_path): - tags_path = os.path.join(image_path, args.caption_extension) + if not os.path.exists(tags_path): + tags_path = os.path.join(image_path, args.caption_extension) - image_key = str(image_path) if args.full_path else image_path.stem - if image_key not in metadata: - metadata[image_key] = {} + image_key = str(image_path) if args.full_path else image_path.stem + if image_key not in metadata: + metadata[image_key] = {} - metadata[image_key]['tags'] = tags - if args.debug: - logger.info(f"{image_key} {tags}") + metadata[image_key]["tags"] = tags + if args.debug: + logger.info(f"{image_key} {tags}") - # metadataを書き出して終わり - logger.info(f"writing metadata: {args.out_json}") - Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8') + # metadataを書き出して終わり + logger.info(f"writing metadata: {args.out_json}") + Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding="utf-8") - logger.info("done!") + logger.info("done!") def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") - parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") - parser.add_argument("--in_json", type=str, - help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)") - parser.add_argument("--full_path", action="store_true", - help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)") - parser.add_argument("--recursive", action="store_true", - help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す") - parser.add_argument("--caption_extension", type=str, default=".txt", - help="extension of caption (tag) file / 読み込むキャプション(タグ)ファイルの拡張子") - parser.add_argument("--debug", action="store_true", help="debug mode, print tags") - - return parser - - -if __name__ == '__main__': - parser = setup_parser() - - args = parser.parse_args() - main(args) + parser = argparse.ArgumentParser() + parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") + parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") + parser.add_argument( + "--in_json", + type=str, + help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)", + ) + parser.add_argument( + "--full_path", + action="store_true", + help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)", + ) + parser.add_argument( + "--recursive", + action="store_true", + help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す", + ) + parser.add_argument( + "--caption_extension", + type=str, + default=".txt", + help="extension of caption (tag) file / 読み込むキャプション(タグ)ファイルの拡張子", + ) + parser.add_argument("--debug", action="store_true", help="debug mode, print tags") + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + main(args) From f4a4c11cd30a885d1d5ddb86bee609305c5398f3 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 23 Mar 2024 18:51:37 +0900 Subject: [PATCH 096/103] support multiline captions ref #1155 --- README.md | 8 ++++---- docs/config_README-en.md | 30 +++++++++++++++++++++++++++++- docs/config_README-ja.md | 32 ++++++++++++++++++++++++++++++-- library/train_util.py | 38 +++++++++++++++++++++++++++++++------- 4 files changed, 94 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index dae311325..dd000d126 100644 --- a/README.md +++ b/README.md @@ -257,8 +257,8 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum - `train_network.py` and `sdxl_train_network.py` are modified to record some dataset settings in the metadata of the trained model (`caption_prefix`, `caption_suffix`, `keep_tokens_separator`, `secondary_separator`, `enable_wildcard`). - Some features are added to the dataset subset settings. - `secondary_separator` is added to specify the tag separator that is not the target of shuffling or dropping. - - Specify `secondary_separator=";;;"`. When you specify `secondary_separator`, the part is not shuffled or dropped. See the example below. - - `enable_wildcard` is added. When set to `true`, the wildcard notation `{aaa|bbb|ccc}` can be used. See the example below. + - Specify `secondary_separator=";;;"`. When you specify `secondary_separator`, the part is not shuffled or dropped. + - `enable_wildcard` is added. When set to `true`, the wildcard notation `{aaa|bbb|ccc}` can be used. The multi-line caption is also enabled. - `keep_tokens_separator` is updated to be used twice in the caption. When you specify `keep_tokens_separator="|||"`, the part divided by the second `|||` is not shuffled or dropped and remains at the end. - The existing features `caption_prefix` and `caption_suffix` can be used together. `caption_prefix` and `caption_suffix` are processed first, and then `enable_wildcard`, `keep_tokens_separator`, shuffling and dropping, and `secondary_separator` are processed in order. - See [Dataset config](./docs/config_README-en.md) for details. @@ -274,8 +274,8 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum - データセット設定の `.toml` ファイルが UTF-8 encoding で読み込まれるようになりました。PR [#1167](https://github.com/kohya-ss/sd-scripts/pull/1167) Horizon1704 氏に感謝します。 - `train_network.py` および `sdxl_train_network.py` で、学習したモデルのメタデータに一部のデータセット設定が記録されるよう修正しました(`caption_prefix`、`caption_suffix`、`keep_tokens_separator`、`secondary_separator`、`enable_wildcard`)。 - データセットのサブセット設定にいくつかの機能を追加しました。 - - シャッフルの対象とならないタグ分割識別子の指定 `secondary_separator` を追加しました。`secondary_separator=";;;"` のように指定します。`secondary_separator` で区切ることで、その部分はシャッフル、drop 時にまとめて扱われます。詳しくは記述例をご覧ください。 - - `enable_wildcard` を追加しました。`true` にするとワイルドカード記法 `{aaa|bbb|ccc}` が使えます。詳しくは記述例をご覧ください。 + - シャッフルの対象とならないタグ分割識別子の指定 `secondary_separator` を追加しました。`secondary_separator=";;;"` のように指定します。`secondary_separator` で区切ることで、その部分はシャッフル、drop 時にまとめて扱われます。 + - `enable_wildcard` を追加しました。`true` にするとワイルドカード記法 `{aaa|bbb|ccc}` が使えます。また複数行キャプションも有効になります。 - `keep_tokens_separator` をキャプション内に 2 つ使えるようにしました。たとえば `keep_tokens_separator="|||"` と指定したとき、`1girl, hatsune miku, vocaloid ||| stage, mic ||| best quality, rating: general` とキャプションを指定すると、二番目の `|||` で分割された部分はシャッフル、drop されず末尾に残ります。 - 既存の機能 `caption_prefix` と `caption_suffix` とあわせて使えます。`caption_prefix` と `caption_suffix` は一番最初に処理され、その後、ワイルドカード、`keep_tokens_separator`、シャッフルおよび drop、`secondary_separator` の順に処理されます。 - 詳細は [データセット設定](./docs/config_README-ja.md) をご覧ください。 diff --git a/docs/config_README-en.md b/docs/config_README-en.md index bdcaabfc7..e99fde216 100644 --- a/docs/config_README-en.md +++ b/docs/config_README-en.md @@ -293,7 +293,35 @@ As a temporary measure, we will list common errors and their solutions. If you e ## Miscellaneous -### Example of configuration file, 設定ファイルの記述例 +### Multi-line captions + +By setting `enable_wildcard = true`, multiple-line captions are also enabled. If the caption file consists of multiple lines, one line is randomly selected as the caption. + +```txt +1girl, hatsune miku, vocaloid, upper body, looking at viewer, microphone, stage +a girl with a microphone standing on a stage +detailed digital art of a girl with a microphone on a stage +``` + +It can be combined with wildcard notation. + +In metadata files, you can also specify multiple-line captions. In the `.json` metadata file, use `\n` to represent a line break. If the caption file consists of multiple lines, `merge_captions_to_metadata.py` will create a metadata file in this format. + +The tags in the metadata (`tags`) are added to each line of the caption. + +```json +{ + "/path/to/image.png": { + "caption": "a cartoon of a frog with the word frog on it\ntest multiline caption1\ntest multiline caption2", + "tags": "open mouth, simple background, standing, no humans, animal, black background, frog, animal costume, animal focus" + }, + ... +} +``` + +In this case, the actual caption will be `a cartoon of a frog with the word frog on it, open mouth, simple background ...`, `test multiline caption1, open mouth, simple background ...`, `test multiline caption2, open mouth, simple background ...`, etc. + +### Example of configuration file : `secondary_separator`, wildcard notation, `keep_tokens_separator`, etc. ```toml [general] diff --git a/docs/config_README-ja.md b/docs/config_README-ja.md index 47bb5c57d..b57ae86a7 100644 --- a/docs/config_README-ja.md +++ b/docs/config_README-ja.md @@ -158,7 +158,7 @@ DreamBooth の手法と fine tuning の手法の両方とも利用可能な学 * 追加の区切り文字を指定します。この区切り文字で区切られた部分は一つのタグとして扱われ、シャッフル、drop されます。その後、`caption_separator` に置き換えられます。たとえば `aaa;;;bbb;;;ccc` のように指定すると、`aaa,bbb,ccc` に置き換えられるか、まとめて drop されます。 * `enable_wildcard` - * ワイルドカード記法を有効にします。ワイルドカード記法については後述します。 + * ワイルドカード記法および複数行キャプションを有効にします。ワイルドカード記法、複数行キャプションについては後述します。 ### DreamBooth 方式専用のオプション @@ -296,7 +296,35 @@ resolution = 768 ## その他 -### Example of configuration file, 設定ファイルの記述例 +### 複数行キャプション + +`enable_wildcard = true` を設定することで、複数行キャプションも同時に有効になります。キャプションファイルが複数の行からなる場合、ランダムに一つの行が選ばれてキャプションとして利用されます。 + +```txt +1girl, hatsune miku, vocaloid, upper body, looking at viewer, microphone, stage +a girl with a microphone standing on a stage +detailed digital art of a girl with a microphone on a stage +``` + +ワイルドカード記法と組み合わせることも可能です。 + +メタデータファイルでも同様に複数行キャプションを指定することができます。メタデータの .json 内には、`\n` を使って改行を表現してください。キャプションファイルが複数行からなる場合、`merge_captions_to_metadata.py` を使うと、この形式でメタデータファイルが作成されます。 + +メタデータのタグ (`tags`) は、キャプションの各行に追加されます。 + +```json +{ + "/path/to/image.png": { + "caption": "a cartoon of a frog with the word frog on it\ntest multiline caption1\ntest multiline caption2", + "tags": "open mouth, simple background, standing, no humans, animal, black background, frog, animal costume, animal focus" + }, + ... +} +``` + +この場合、実際のキャプションは `a cartoon of a frog with the word frog on it, open mouth, simple background ...` または `test multiline caption1, open mouth, simple background ...`、 `test multiline caption2, open mouth, simple background ...` 等になります。 + +### 設定ファイルの記述例:追加の区切り文字、ワイルドカード記法、`keep_tokens_separator` 等 ```toml [general] diff --git a/library/train_util.py b/library/train_util.py index a13985ee2..d076cf847 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -693,6 +693,10 @@ def process_caption(self, subset: BaseSubset, caption): else: # process wildcards if subset.enable_wildcard: + # if caption is multiline, random choice one line + if "\n" in caption: + caption = random.choice(caption.split("\n")) + # wildcard is like '{aaa|bbb|ccc...}' # escape the curly braces like {{ or }} replacer1 = "⦅" @@ -711,6 +715,9 @@ def replace_wildcard(match): # unescape the curly braces caption = caption.replace(replacer1, "{").replace(replacer2, "}") + else: + # if caption is multiline, use the first line + caption = caption.split("\n")[0] if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0: fixed_tokens = [] @@ -1446,7 +1453,7 @@ def __init__( self.bucket_reso_steps = None # この情報は使われない self.bucket_no_upscale = False - def read_caption(img_path, caption_extension): + def read_caption(img_path, caption_extension, enable_wildcard): # captionの候補ファイル名を作る base_name = os.path.splitext(img_path)[0] base_name_face_det = base_name @@ -1465,7 +1472,10 @@ def read_caption(img_path, caption_extension): logger.error(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}") raise e assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}" - caption = lines[0].strip() + if enable_wildcard: + caption = "\n".join([line.strip() for line in lines if line.strip() != ""]) # 空行を除く、改行で連結 + else: + caption = lines[0].strip() break return caption @@ -1481,7 +1491,7 @@ def load_dreambooth_dir(subset: DreamBoothSubset): captions = [] missing_captions = [] for img_path in img_paths: - cap_for_img = read_caption(img_path, subset.caption_extension) + cap_for_img = read_caption(img_path, subset.caption_extension, subset.enable_wildcard) if cap_for_img is None and subset.class_tokens is None: logger.warning( f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}" @@ -1657,10 +1667,24 @@ def __init__( caption = img_md.get("caption") tags = img_md.get("tags") if caption is None: - caption = tags - elif tags is not None and len(tags) > 0: - caption = caption + ", " + tags - tags_list.append(tags) + caption = tags # could be multiline + tags = None + + if subset.enable_wildcard: + # tags must be single line + if tags is not None: + tags = tags.replace("\n", subset.caption_separator) + + # add tags to each line of caption + if caption is not None and tags is not None: + caption = "\n".join( + [f"{line}{subset.caption_separator}{tags}" for line in caption.split("\n") if line.strip() != ""] + ) + else: + # use as is + if tags is not None and len(tags) > 0: + caption = caption + subset.caption_separator + tags + tags_list.append(tags) if caption is None: caption = "" From 0c7baea88cfa98c5fab2898551c426f2d4fac4c6 Mon Sep 17 00:00:00 2001 From: feffy380 <114889020+feffy380@users.noreply.github.com> Date: Sat, 23 Mar 2024 17:28:02 +0100 Subject: [PATCH 097/103] register reg images with correct subset --- library/train_util.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index d076cf847..b69fb0950 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1554,7 +1554,7 @@ def load_dreambooth_dir(subset: DreamBoothSubset): for img_path, caption in zip(img_paths, captions): info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path) if subset.is_reg: - reg_infos.append(info) + reg_infos.append((info, subset)) else: self.register_image(info, subset) @@ -1575,7 +1575,7 @@ def load_dreambooth_dir(subset: DreamBoothSubset): n = 0 first_loop = True while n < num_train_images: - for info in reg_infos: + for info, subset in reg_infos: if first_loop: self.register_image(info, subset) n += info.num_repeats From 79d1c12ab056e7114257d7079f2f8846e329320e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 24 Mar 2024 11:06:37 +0900 Subject: [PATCH 098/103] disable sample_every_n_xxx if value less than 1 ref #1202 --- library/train_util.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index d076cf847..8fbf3283d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1473,7 +1473,7 @@ def read_caption(img_path, caption_extension, enable_wildcard): raise e assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}" if enable_wildcard: - caption = "\n".join([line.strip() for line in lines if line.strip() != ""]) # 空行を除く、改行で連結 + caption = "\n".join([line.strip() for line in lines if line.strip() != ""]) # 空行を除く、改行で連結 else: caption = lines[0].strip() break @@ -3338,6 +3338,18 @@ def verify_training_args(args: argparse.Namespace): + " / zero_terminal_snrが有効ですが、v_parameterizationが有効ではありません。学習結果は想定外になる可能性があります" ) + if args.sample_every_n_epochs is not None and args.sample_every_n_epochs <= 0: + logger.warning( + "sample_every_n_epochs is less than or equal to 0, so it will be disabled / sample_every_n_epochsに0以下の値が指定されたため無効になります" + ) + args.sample_every_n_epochs = None + + if args.sample_every_n_steps is not None and args.sample_every_n_steps <= 0: + logger.warning( + "sample_every_n_steps is less than or equal to 0, so it will be disabled / sample_every_n_stepsに0以下の値が指定されたため無効になります" + ) + args.sample_every_n_steps = None + def add_dataset_arguments( parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool From 691f04322a48566caf62dd67c2834ca2748c064f Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 24 Mar 2024 11:10:26 +0900 Subject: [PATCH 099/103] update readme --- README.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index dd000d126..25226dff3 100644 --- a/README.md +++ b/README.md @@ -266,8 +266,9 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum - Onnx may need to be updated. Onnx is not installed by default, so please install or update it with `pip install onnx==1.15.0 onnxruntime-gpu==1.17.1` etc. Please also check the comments in `requirements.txt`. - The model is now saved in the subdirectory as `--repo_id` in `tag_image_by_wd14_tagger.py` . This caches multiple repo_id models. Please delete unnecessary files under `--model_dir`. - The options `--noise_offset_random_strength` and `--ip_noise_gamma_random_strength` are added to each training script. These options can be used to vary the noise offset and ip noise gamma in the range of 0 to the specified value. PR [#1177](https://github.com/kohya-ss/sd-scripts/pull/1177) Thanks to KohakuBlueleaf! -- The [English version of the dataset settings documentation](./docs/config_README-en.md) is added. PR [#1175](https://github.com/kohya-ss/sd-scripts/pull/1175) Thanks to darkstorm2150! - The options `--save_state_on_train_end` are added to each training script. PR [#1168](https://github.com/kohya-ss/sd-scripts/pull/1168) Thanks to gesen2egee! +- The options `--sample_every_n_epochs` and `--sample_every_n_steps` in each training script now display a warning and ignore them when a number less than or equal to `0` is specified. Thanks to S-Del for raising the issue. +- The [English version of the dataset settings documentation](./docs/config_README-en.md) is added. PR [#1175](https://github.com/kohya-ss/sd-scripts/pull/1175) Thanks to darkstorm2150! - Colab での動作時、ログ出力で停止してしまうようです。学習スクリプトに `--console_log_simple` オプションを指定し、rich のロギングを無効してお試しください。 @@ -283,8 +284,9 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum - Onnx のバージョンアップが必要になるかもしれません。デフォルトでは Onnx はインストールされていませんので、`pip install onnx==1.15.0 onnxruntime-gpu==1.17.1` 等でインストール、アップデートしてください。`requirements.txt` のコメントもあわせてご確認ください。 - `tag_image_by_wd14_tagger.py` で、モデルを`--repo_id` のサブディレクトリに保存するようにしました。これにより複数のモデルファイルがキャッシュされます。`--model_dir` 直下の不要なファイルは削除願います。 - 各学習スクリプトに、noise offset、ip noise gammaを、それぞれ 0~指定した値の範囲で変動させるオプション `--noise_offset_random_strength` および `--ip_noise_gamma_random_strength` が追加されました。 PR [#1177](https://github.com/kohya-ss/sd-scripts/pull/1177) KohakuBlueleaf 氏に感謝します。 -- データセット設定の[英語版ドキュメント](./docs/config_README-en.md) が追加されました。PR [#1175](https://github.com/kohya-ss/sd-scripts/pull/1175) darkstorm2150 氏に感謝します。 - 各学習スクリプトに、学習終了時に state を保存する `--save_state_on_train_end` オプションが追加されました。 PR [#1168](https://github.com/kohya-ss/sd-scripts/pull/1168) gesen2egee 氏に感謝します。 +- 各学習スクリプトで `--sample_every_n_epochs` および `--sample_every_n_steps` オプションに `0` 以下の数値を指定した時、警告を表示するとともにそれらを無視するよう変更しました。問題提起していただいた S-Del 氏に感謝します。 +- データセット設定の[英語版ドキュメント](./docs/config_README-en.md) が追加されました。PR [#1175](https://github.com/kohya-ss/sd-scripts/pull/1175) darkstorm2150 氏に感謝します。 ### Mar 15, 2024 / 2024/3/15: v0.8.5 From 381c44955e72f04b57c74aa9b3d9a43c839c631f Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 24 Mar 2024 11:27:18 +0900 Subject: [PATCH 100/103] update readme and typing hint --- README.md | 2 ++ library/train_util.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 25226dff3..a19f7968a 100644 --- a/README.md +++ b/README.md @@ -254,6 +254,7 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum - Colab seems to stop with log output. Try specifying `--console_log_simple` option in the training script to disable rich logging. - The `.toml` file for the dataset config is now read in UTF-8 encoding. PR [#1167](https://github.com/kohya-ss/sd-scripts/pull/1167) Thanks to Horizon1704! +- Fixed a bug that the last subset settings are applied to all images when multiple subsets of regularization images are specified in the dataset settings. The settings for each subset are correctly applied to each image. PR [#1205](https://github.com/kohya-ss/sd-scripts/pull/1205) Thanks to feffy380! - `train_network.py` and `sdxl_train_network.py` are modified to record some dataset settings in the metadata of the trained model (`caption_prefix`, `caption_suffix`, `keep_tokens_separator`, `secondary_separator`, `enable_wildcard`). - Some features are added to the dataset subset settings. - `secondary_separator` is added to specify the tag separator that is not the target of shuffling or dropping. @@ -273,6 +274,7 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum - Colab での動作時、ログ出力で停止してしまうようです。学習スクリプトに `--console_log_simple` オプションを指定し、rich のロギングを無効してお試しください。 - データセット設定の `.toml` ファイルが UTF-8 encoding で読み込まれるようになりました。PR [#1167](https://github.com/kohya-ss/sd-scripts/pull/1167) Horizon1704 氏に感謝します。 +- データセット設定で、正則化画像のサブセットを複数指定した時、最後のサブセットの各種設定がすべてのサブセットの画像に適用される不具合が修正されました。それぞれのサブセットの設定が、それぞれの画像に正しく適用されます。PR [#1205](https://github.com/kohya-ss/sd-scripts/pull/1205) feffy380 氏に感謝します。 - `train_network.py` および `sdxl_train_network.py` で、学習したモデルのメタデータに一部のデータセット設定が記録されるよう修正しました(`caption_prefix`、`caption_suffix`、`keep_tokens_separator`、`secondary_separator`、`enable_wildcard`)。 - データセットのサブセット設定にいくつかの機能を追加しました。 - シャッフルの対象とならないタグ分割識別子の指定 `secondary_separator` を追加しました。`secondary_separator=";;;"` のように指定します。`secondary_separator` で区切ることで、その部分はシャッフル、drop 時にまとめて扱われます。 diff --git a/library/train_util.py b/library/train_util.py index ce6e09245..99aeea90d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1525,7 +1525,7 @@ def load_dreambooth_dir(subset: DreamBoothSubset): logger.info("prepare images.") num_train_images = 0 num_reg_images = 0 - reg_infos: List[ImageInfo] = [] + reg_infos: List[Tuple[ImageInfo, DreamBoothSubset]] = [] for subset in subsets: if subset.num_repeats < 1: logger.warning( From 1648ade6da549c7def2e21f236453e7938c499cd Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 24 Mar 2024 20:55:48 +0900 Subject: [PATCH 101/103] format by black --- library/config_util.py | 31 +++++++++++++++++++++---------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index eb652ecf3..ff4de0921 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -41,12 +41,17 @@ DatasetGroup, ) from .utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) + def add_config_arguments(parser: argparse.ArgumentParser): - parser.add_argument("--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル") + parser.add_argument( + "--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル" + ) # TODO: inherit Params class in Subset, Dataset @@ -362,7 +367,9 @@ def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> return self.argparse_config_validator(argparse_namespace) except MultipleInvalid: # XXX: this should be a bug - logger.error("Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。") + logger.error( + "Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。" + ) raise # NOTE: value would be overwritten by latter dict if there is already the same key @@ -547,11 +554,11 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu " ", ) - logger.info(f'{info}') + logger.info(f"{info}") # make buckets first because it determines the length of dataset # and set the same seed for all datasets - seed = random.randint(0, 2**31) # actual seed is seed + epoch_no + seed = random.randint(0, 2**31) # actual seed is seed + epoch_no for i, dataset in enumerate(datasets): logger.info(f"[Dataset {i}]") dataset.make_buckets() @@ -638,13 +645,17 @@ def load_user_config(file: str) -> dict: with open(file, "r") as f: config = json.load(f) except Exception: - logger.error(f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}") + logger.error( + f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}" + ) raise elif file.name.lower().endswith(".toml"): try: config = toml.load(file) except Exception: - logger.error(f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}") + logger.error( + f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}" + ) raise else: raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}") @@ -671,13 +682,13 @@ def load_user_config(file: str) -> dict: train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning) logger.info("[argparse_namespace]") - logger.info(f'{vars(argparse_namespace)}') + logger.info(f"{vars(argparse_namespace)}") user_config = load_user_config(config_args.dataset_config) logger.info("") logger.info("[user_config]") - logger.info(f'{user_config}') + logger.info(f"{user_config}") sanitizer = ConfigSanitizer( config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout @@ -686,10 +697,10 @@ def load_user_config(file: str) -> dict: logger.info("") logger.info("[sanitized_user_config]") - logger.info(f'{sanitized_user_config}') + logger.info(f"{sanitized_user_config}") blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace) logger.info("") logger.info("[blueprint]") - logger.info(f'{blueprint}') + logger.info(f"{blueprint}") From 9bbb28c3619a9ff86a51bdc7ea83584976840663 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 24 Mar 2024 22:06:37 +0900 Subject: [PATCH 102/103] update PyTorch version and reorganize dependencies --- README-ja.md | 55 ++---------- README.md | 149 ++++---------------------------- docs/train_SDXL-en.md | 84 ++++++++++++++++++ requirements.txt | 8 +- sdxl_minimal_inference.py | 30 +++++-- sdxl_train_textual_inversion.py | 1 - 6 files changed, 136 insertions(+), 191 deletions(-) create mode 100644 docs/train_SDXL-en.md diff --git a/README-ja.md b/README-ja.md index 29c33a659..1d83c44f1 100644 --- a/README-ja.md +++ b/README-ja.md @@ -1,7 +1,3 @@ -SDXLがサポートされました。sdxlブランチはmainブランチにマージされました。リポジトリを更新したときにはUpgradeの手順を実行してください。また accelerate のバージョンが上がっていますので、accelerate config を再度実行してください。 - -SDXL学習については[こちら](./README.md#sdxl-training)をご覧ください(英語です)。 - ## リポジトリについて Stable Diffusionの学習、画像生成、その他のスクリプトを入れたリポジトリです。 @@ -21,6 +17,7 @@ GUIやPowerShellスクリプトなど、より使いやすくする機能が[bma * [学習について、共通編](./docs/train_README-ja.md) : データ整備やオプションなど * [データセット設定](./docs/config_README-ja.md) +* [SDXL学習](./docs/train_SDXL-en.md) (英語版) * [DreamBoothの学習について](./docs/train_db_README-ja.md) * [fine-tuningのガイド](./docs/fine_tune_README_ja.md): * [LoRAの学習について](./docs/train_network_README-ja.md) @@ -44,9 +41,7 @@ PowerShellを使う場合、venvを使えるようにするためには以下の ## Windows環境でのインストール -スクリプトはPyTorch 2.0.1でテストしています。PyTorch 1.12.1でも動作すると思われます。 - -以下の例ではPyTorchは2.0.1/CUDA 11.8版をインストールします。CUDA 11.6版やPyTorch 1.12.1を使う場合は適宜書き換えください。 +スクリプトはPyTorch 2.1.1でテストしています。PyTorch 2.0.1、1.12.1でも動作すると思われます。 (なお、python -m venv~の行で「python」とだけ表示された場合、py -m venv~のようにpythonをpyに変更してください。) @@ -59,20 +54,20 @@ cd sd-scripts python -m venv venv .\venv\Scripts\activate -pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118 +pip install torch==2.1.1 torchvision==0.16.1 --index-url https://download.pytorch.org/whl/cu118 pip install --upgrade -r requirements.txt -pip install xformers==0.0.20 +pip install xformers==0.0.23 --index-url https://download.pytorch.org/whl/cu118 accelerate config ``` コマンドプロンプトでも同一です。 -(注:``python -m venv venv`` のほうが ``python -m venv --system-site-packages venv`` より安全そうなため書き換えました。globalなpythonにパッケージがインストールしてあると、後者だといろいろと問題が起きます。) +注:`bitsandbytes==0.43.0`、`prodigyopt==1.0`、`lion-pytorch==0.0.6` は `requirements.txt` に含まれるようになりました。他のバージョンを使う場合は適宜インストールしてください。 -accelerate configの質問には以下のように答えてください。(bf16で学習する場合、最後の質問にはbf16と答えてください。) +この例では PyTorch および xfomers は2.1.1/CUDA 11.8版をインストールします。CUDA 12.1版やPyTorch 1.12.1を使う場合は適宜書き換えください。たとえば CUDA 12.1版の場合は `pip install torch==2.1.1 torchvision==0.16.1 --index-url https://download.pytorch.org/whl/cu121` および `pip install xformers==0.0.23 --index-url https://download.pytorch.org/whl/cu121` としてください。 -※0.15.0から日本語環境では選択のためにカーソルキーを押すと落ちます(……)。数字キーの0、1、2……で選択できますので、そちらを使ってください。 +accelerate configの質問には以下のように答えてください。(bf16で学習する場合、最後の質問にはbf16と答えてください。) ```txt - This machine @@ -87,41 +82,6 @@ accelerate configの質問には以下のように答えてください。(bf1 ※場合によって ``ValueError: fp16 mixed precision requires a GPU`` というエラーが出ることがあるようです。この場合、6番目の質問( ``What GPU(s) (by id) should be used for training on this machine as a comma-separated list? [all]:``)に「0」と答えてください。(id `0`のGPUが使われます。) -### オプション:`bitsandbytes`(8bit optimizer)を使う - -`bitsandbytes`はオプションになりました。Linuxでは通常通りpipでインストールできます(0.41.1または以降のバージョンを推奨)。 - -Windowsでは0.35.0または0.41.1を推奨します。 - -- `bitsandbytes` 0.35.0: 安定しているとみられるバージョンです。AdamW8bitは使用できますが、他のいくつかの8bit optimizer、学習時の`full_bf16`オプションは使用できません。 -- `bitsandbytes` 0.41.1: Lion8bit、PagedAdamW8bit、PagedLion8bitをサポートします。`full_bf16`が使用できます。 - -注:`bitsandbytes` 0.35.0から0.41.0までのバージョンには問題があるようです。 https://github.com/TimDettmers/bitsandbytes/issues/659 - -以下の手順に従い、`bitsandbytes`をインストールしてください。 - -### 0.35.0を使う場合 - -PowerShellの例です。コマンドプロンプトではcpの代わりにcopyを使ってください。 - -```powershell -cd sd-scripts -.\venv\Scripts\activate -pip install bitsandbytes==0.35.0 - -cp .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\ -cp .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py -cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py -``` - -### 0.41.1を使う場合 - -jllllll氏の配布されている[こちら](https://github.com/jllllll/bitsandbytes-windows-webui) または他の場所から、Windows用のwhlファイルをインストールしてください。 - -```powershell -python -m pip install bitsandbytes==0.41.1 --prefer-binary --extra-index-url=https://jllllll.github.io/bitsandbytes-windows-webui -``` - ## アップグレード 新しいリリースがあった場合、以下のコマンドで更新できます。 @@ -151,4 +111,3 @@ Conv2d 3x3への拡大は [cloneofsimo氏](https://github.com/cloneofsimo/lora) [BLIP](https://github.com/salesforce/BLIP): BSD-3-Clause - diff --git a/README.md b/README.md index a19f7968a..ef26acab8 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,3 @@ -__SDXL is now supported. The sdxl branch has been merged into the main branch. If you update the repository, please follow the upgrade instructions. Also, the version of accelerate has been updated, so please run accelerate config again.__ The documentation for SDXL training is [here](./README.md#sdxl-training). - This repository contains training, generation and utility scripts for Stable Diffusion. [__Change History__](#change-history) is moved to the bottom of the page. @@ -20,9 +18,9 @@ This repository contains the scripts for: ## About requirements.txt -These files do not contain requirements for PyTorch. Because the versions of them depend on your environment. Please install PyTorch at first (see installation guide below.) +The file does not contain requirements for PyTorch. Because the version of PyTorch depends on the environment, it is not included in the file. Please install PyTorch first according to the environment. See installation instructions below. -The scripts are tested with Pytorch 2.0.1. 1.12.1 is not tested but should work. +The scripts are tested with Pytorch 2.1.1. 2.0.1 and 1.12.1 is not tested but should work. ## Links to usage documentation @@ -32,12 +30,13 @@ Most of the documents are written in Japanese. * [Training guide - common](./docs/train_README-ja.md) : data preparation, options etc... * [Chinese version](./docs/train_README-zh.md) +* [SDXL training](./docs/train_SDXL-en.md) (English version) * [Dataset config](./docs/config_README-ja.md) * [English version](./docs/config_README-en.md) * [DreamBooth training guide](./docs/train_db_README-ja.md) * [Step by Step fine-tuning guide](./docs/fine_tune_README_ja.md): -* [training LoRA](./docs/train_network_README-ja.md) -* [training Textual Inversion](./docs/train_ti_README-ja.md) +* [Training LoRA](./docs/train_network_README-ja.md) +* [Training Textual Inversion](./docs/train_ti_README-ja.md) * [Image generation](./docs/gen_img_README-ja.md) * note.com [Model conversion](https://note.com/kohya_ss/n/n374f316fe4ad) @@ -65,14 +64,18 @@ cd sd-scripts python -m venv venv .\venv\Scripts\activate -pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118 +pip install torch==2.1.1 torchvision==0.16.1 --index-url https://download.pytorch.org/whl/cu118 pip install --upgrade -r requirements.txt -pip install xformers==0.0.20 +pip install xformers==0.0.23 --index-url https://download.pytorch.org/whl/cu118 accelerate config ``` -__Note:__ Now bitsandbytes is optional. Please install any version of bitsandbytes as needed. Installation instructions are in the following section. +If `python -m venv` shows only `python`, change `python` to `py`. + +__Note:__ Now `bitsandbytes==0.43.0`, `prodigyopt==1.0` and `lion-pytorch==0.0.6` are included in the requirements.txt. If you'd like to use the another version, please install it manually. + +This installation is for CUDA 11.8. If you use a different version of CUDA, please install the appropriate version of PyTorch and xformers. For example, if you use CUDA 12, please install `pip install torch==2.1.1 torchvision==0.16.1 --index-url https://download.pytorch.org/whl/cu121` and `pip install xformers==0.0.23 --index-url https://download.pytorch.org/whl/cu121`.