From 9d28701699daeb3b7d23cd64c11c783d834297fb Mon Sep 17 00:00:00 2001 From: John Doe Date: Sun, 15 Dec 2024 19:31:03 +0000 Subject: [PATCH] Experimental Redux conditioning for Flux Lora training --- flux_train_network.py | 12 ++++++ library/flux_train_utils.py | 19 +++++++++ library/strategy_flux.py | 84 +++++++++++++++++++++++++++++++++++-- library/train_util.py | 7 ++++ 4 files changed, 119 insertions(+), 3 deletions(-) diff --git a/flux_train_network.py b/flux_train_network.py index 75e975bae..3c072bc62 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -8,6 +8,7 @@ from accelerate import Accelerator from library.device_utils import clean_memory_on_device, init_ipex +from library.strategy_flux import move_vision_encoder_to_device init_ipex() @@ -190,6 +191,8 @@ def get_text_encoder_outputs_caching_strategy(self, args): args.skip_cache_check, is_partial=self.train_clip_l or self.train_t5xxl, apply_t5_attn_mask=args.apply_t5_attn_mask, + vision_cond_ratio=args.vision_cond_ratio, + redux_path=args.redux_model_path ) else: return None @@ -250,6 +253,7 @@ def cache_text_encoder_outputs_if_needed( text_encoders[0].to("cpu") logger.info("move t5XXL back to cpu") text_encoders[1].to("cpu") + move_vision_encoder_to_device("cpu") clean_memory_on_device(accelerator.device) if not args.lowram: @@ -372,6 +376,14 @@ def get_noise_pred_and_target( if not args.apply_t5_attn_mask: t5_attn_mask = None + if args.vision_cond_dropout < 1.0: + if random.uniform(0,1) > args.vision_cond_dropout: + vision_encoder_conds = batch.get("vision_encoder_outputs_list", None) + vis_t5_out, vis_txt_ids = vision_encoder_conds + t5_out = vis_t5_out + txt_ids = vis_txt_ids + t5_attn_mask = None + def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask): # if not args.split_mode: # normal forward diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index f7f06c5cf..1239997b8 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -617,3 +617,22 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser): default=3.0, help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。", ) + + parser.add_argument( + "--redux_model_path", + type=str, + help="path to Redux model (*.sft or *.safetensors), should be float16", + ) + parser.add_argument( + "--vision_cond_ratio", + type=float, + default=0.0, + help="Ratio of conditioning for Redux embeddings, averaged with text encoder embeddings. Zero disables vision conditioning, maximum is 1.0", + ) + + parser.add_argument( + "--vision_cond_dropout", + type=float, + default=1.0, + help="Probability of dropout for Redux conditioning.", + ) diff --git a/library/strategy_flux.py b/library/strategy_flux.py index 5e65927f8..f3abb555e 100644 --- a/library/strategy_flux.py +++ b/library/strategy_flux.py @@ -1,9 +1,11 @@ import os -import glob from typing import Any, List, Optional, Tuple, Union + +import safetensors import torch import numpy as np -from transformers import CLIPTokenizer, T5TokenizerFast +import PIL.Image +from transformers import CLIPTokenizer, T5TokenizerFast, SiglipVisionModel, AutoProcessor from library import flux_utils, train_util from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy @@ -20,6 +22,38 @@ T5_XXL_TOKENIZER_ID = "google/t5-v1_1-xxl" +# FIXME: this is a very hacky way of handling the encoder model +siglip_model = None +siglip_processor = None +redux_encoder = None + +def move_vision_encoder_to_device(device): + if siglip_model is not None: + siglip_model.to(device) + if redux_encoder is not None: + redux_encoder.to(device) + + +class ReduxImageEncoder(torch.nn.Module): + def __init__( + self, + redux_dim: int = 1152, + txt_in_features: int = 4096, + device=None, + dtype=None, + ) -> None: + super().__init__() + self.redux_dim = redux_dim + self.device = device + self.dtype = dtype + self.redux_up = torch.nn.Linear(redux_dim, txt_in_features * 3, dtype=dtype) + self.redux_down = torch.nn.Linear(txt_in_features * 3, txt_in_features, dtype=dtype) + + def forward(self, sigclip_embeds) -> torch.Tensor: + projected_x = self.redux_down(torch.nn.functional.silu(self.redux_up(sigclip_embeds))) + return projected_x + + class FluxTokenizeStrategy(TokenizeStrategy): def __init__(self, t5xxl_max_length: int = 512, tokenizer_cache_dir: Optional[str] = None) -> None: self.t5xxl_max_length = t5xxl_max_length @@ -95,10 +129,13 @@ def __init__( skip_disk_cache_validity_check: bool, is_partial: bool = False, apply_t5_attn_mask: bool = False, + vision_cond_ratio: float = 0.0, + redux_path: str = None, ) -> None: super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial) self.apply_t5_attn_mask = apply_t5_attn_mask - + self.vision_cond_ratio = vision_cond_ratio + self.redux_path = redux_path self.warn_fp8_weights = False def get_outputs_npz_path(self, image_abs_path: str) -> str: @@ -142,6 +179,44 @@ def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: # apply_t5_attn_mask should be same as self.apply_t5_attn_mask return [l_pooled, t5_out, txt_ids, t5_attn_mask] + def encode_vision(self, infos, ratio, t5_out, txt_ids): + global siglip_model + global siglip_processor + global redux_encoder + + if siglip_model is None: + model_id = "google/siglip-so400m-patch14-384" + siglip_model = SiglipVisionModel.from_pretrained( + model_id, attn_implementation="sdpa", device_map="cuda") + siglip_processor = AutoProcessor.from_pretrained(model_id) + + if redux_encoder is None: + if self.redux_path is None: + raise Exception("Vision encoding requires Redux model, but no file was provided.") + model_data = safetensors.torch.load_file(self.redux_path, device=torch.device("cpu").type) + redux_encoder = ReduxImageEncoder() + redux_encoder.load_state_dict(model_data) + redux_encoder = redux_encoder.to(device="cuda") + + bsz = txt_ids.shape[0] + imgs = [PIL.Image.open(nfo.absolute_path) for nfo in infos] + siglip_in = siglip_processor(images=imgs, padding="max_length", return_tensors="pt") + siglip_in = siglip_in.to(device="cuda") + + with torch.no_grad(), torch.autocast("cuda"): + siglip_out = siglip_model(**siglip_in) + new_embed = redux_encoder(siglip_out.last_hidden_state).float().cpu().numpy() + new_ids = np.zeros(shape=(bsz, new_embed.shape[1], txt_ids.shape[2])) + + t5_out_ext = np.concatenate([t5_out] + [np.zeros((bsz, new_embed.shape[1] - t5_out.shape[1], t5_out.shape[2]))], axis=1) + new_embed = new_embed * ratio + t5_out_ext * (1.0 - ratio) + + for i, info in enumerate(infos): + new_embed_i = new_embed[i] + new_ids_i = new_ids[i] + info.vision_encoder_outputs = (new_embed_i, new_ids_i) + + def cache_batch_outputs( self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List ): @@ -173,6 +248,9 @@ def cache_batch_outputs( txt_ids = txt_ids.cpu().numpy() t5_attn_mask = tokens_and_masks[2].cpu().numpy() + if self.vision_cond_ratio > 0.0: + self.encode_vision(infos, self.vision_cond_ratio, t5_out, txt_ids) + for i, info in enumerate(infos): l_pooled_i = l_pooled[i] t5_out_i = t5_out[i] diff --git a/library/train_util.py b/library/train_util.py index a35388fee..f25c68f80 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -176,6 +176,8 @@ def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime + self.vision_encoder_outputs: Optional[torch.Tensor] = None + class BucketManager: def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None: @@ -1497,6 +1499,7 @@ def __getitem__(self, index): target_sizes_hw = [] flippeds = [] # 変数名が微妙 text_encoder_outputs_list = [] + vision_encoder_outputs_list = [] custom_attributes = [] for image_key in bucket[image_index : image_index + bucket_batch_size]: @@ -1621,6 +1624,9 @@ def __getitem__(self, index): text_encoder_outputs = None input_ids = None + if image_info.vision_encoder_outputs is not None: + vision_encoder_outputs_list.append(image_info.vision_encoder_outputs) + if image_info.text_encoder_outputs is not None: # cached text_encoder_outputs = image_info.text_encoder_outputs @@ -1676,6 +1682,7 @@ def none_or_stack_elements(tensors_list, converter): example["custom_attributes"] = custom_attributes # may be list of empty dict example["loss_weights"] = torch.FloatTensor(loss_weights) example["text_encoder_outputs_list"] = none_or_stack_elements(text_encoder_outputs_list, torch.FloatTensor) + example["vision_encoder_outputs_list"] = none_or_stack_elements(vision_encoder_outputs_list, torch.FloatTensor) example["input_ids_list"] = none_or_stack_elements(input_ids_list, lambda x: x) # if one of alpha_masks is not None, we need to replace None with ones