From bea99fc96438dda560983e9efcdfab0cb8ecd2bd Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 6 Sep 2023 10:39:44 +0200 Subject: [PATCH] [Textual inversion] Relax loading textual inversion (#4903) * [Textual inversion] Relax loading textual inversion * up --- loaders.py | 34 ++++++++++++------- .../pipeline_stable_diffusion_xl.py | 2 +- .../pipeline_stable_diffusion_xl_img2img.py | 4 ++- .../pipeline_stable_diffusion_xl_inpaint.py | 4 ++- ...ne_stable_diffusion_xl_instruct_pix2pix.py | 4 ++- .../pipeline_stable_diffusion_xl_adapter.py | 4 ++- 6 files changed, 35 insertions(+), 17 deletions(-) diff --git a/loaders.py b/loaders.py index 40d8804893e2..1de899cad927 100644 --- a/loaders.py +++ b/loaders.py @@ -663,6 +663,8 @@ def load_textual_inversion( self, pretrained_model_name_or_path: Union[str, List[str], Dict[str, torch.Tensor], List[Dict[str, torch.Tensor]]], token: Optional[Union[str, List[str]]] = None, + tokenizer: Optional[PreTrainedTokenizer] = None, + text_encoder: Optional[PreTrainedModel] = None, **kwargs, ): r""" @@ -684,6 +686,11 @@ def load_textual_inversion( token (`str` or `List[str]`, *optional*): Override the token to use for the textual inversion weights. If `pretrained_model_name_or_path` is a list, then `token` must also be a list of equal length. + text_encoder ([`~transformers.CLIPTextModel`], *optional*): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + If not specified, function will take self.tokenizer. + tokenizer ([`~transformers.CLIPTokenizer`], *optional*): + A `CLIPTokenizer` to tokenize text. If not specified, function will take self.tokenizer. weight_name (`str`, *optional*): Name of a custom weight file. This should be used when: @@ -757,15 +764,18 @@ def load_textual_inversion( ``` """ - if not hasattr(self, "tokenizer") or not isinstance(self.tokenizer, PreTrainedTokenizer): + tokenizer = tokenizer or getattr(self, "tokenizer", None) + text_encoder = text_encoder or getattr(self, "text_encoder", None) + + if tokenizer is None: raise ValueError( - f"{self.__class__.__name__} requires `self.tokenizer` of type `PreTrainedTokenizer` for calling" + f"{self.__class__.__name__} requires `self.tokenizer` or passing a `tokenizer` of type `PreTrainedTokenizer` for calling" f" `{self.load_textual_inversion.__name__}`" ) - if not hasattr(self, "text_encoder") or not isinstance(self.text_encoder, PreTrainedModel): + if text_encoder is None: raise ValueError( - f"{self.__class__.__name__} requires `self.text_encoder` of type `PreTrainedModel` for calling" + f"{self.__class__.__name__} requires `self.text_encoder` or passing a `text_encoder` of type `PreTrainedModel` for calling" f" `{self.load_textual_inversion.__name__}`" ) @@ -830,7 +840,7 @@ def load_textual_inversion( token_ids_and_embeddings = [] for pretrained_model_name_or_path, token in zip(pretrained_model_name_or_paths, tokens): - if not isinstance(pretrained_model_name_or_path, dict): + if not isinstance(pretrained_model_name_or_path, (dict, torch.Tensor)): # 1. Load textual inversion file model_file = None # Let's first try to load .safetensors weights @@ -897,10 +907,10 @@ def load_textual_inversion( else: token = loaded_token - embedding = embedding.to(dtype=self.text_encoder.dtype, device=self.text_encoder.device) + embedding = embedding.to(dtype=text_encoder.dtype, device=text_encoder.device) # 3. Make sure we don't mess up the tokenizer or text encoder - vocab = self.tokenizer.get_vocab() + vocab = tokenizer.get_vocab() if token in vocab: raise ValueError( f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder." @@ -908,7 +918,7 @@ def load_textual_inversion( elif f"{token}_1" in vocab: multi_vector_tokens = [token] i = 1 - while f"{token}_{i}" in self.tokenizer.added_tokens_encoder: + while f"{token}_{i}" in tokenizer.added_tokens_encoder: multi_vector_tokens.append(f"{token}_{i}") i += 1 @@ -926,16 +936,16 @@ def load_textual_inversion( embeddings = [embedding[0]] if len(embedding.shape) > 1 else [embedding] # add tokens and get ids - self.tokenizer.add_tokens(tokens) - token_ids = self.tokenizer.convert_tokens_to_ids(tokens) + tokenizer.add_tokens(tokens) + token_ids = tokenizer.convert_tokens_to_ids(tokens) token_ids_and_embeddings += zip(token_ids, embeddings) logger.info(f"Loaded textual inversion embedding for {token}.") # resize token embeddings and set all new embeddings - self.text_encoder.resize_token_embeddings(len(self.tokenizer)) + text_encoder.resize_token_embeddings(len(tokenizer)) for token_id, embedding in token_ids_and_embeddings: - self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding + text_encoder.get_input_embeddings().weight.data[token_id] = embedding # offload back if is_model_cpu_offload: diff --git a/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 1c155fa0aabd..7b7755085ed6 100644 --- a/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -84,7 +84,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): return noise_cfg -class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin): +class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin): r""" Pipeline for text-to-image generation using Stable Diffusion XL. diff --git a/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 4c7aa1ff4668..04902234d54e 100644 --- a/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -84,7 +84,9 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): return noise_cfg -class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin): +class StableDiffusionXLImg2ImgPipeline( + DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin +): r""" Pipeline for text-to-image generation using Stable Diffusion XL. diff --git a/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index 39a7788631de..1d86dff702ef 100644 --- a/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -230,7 +230,9 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool return mask, masked_image -class StableDiffusionXLInpaintPipeline(DiffusionPipeline, LoraLoaderMixin, FromSingleFileMixin): +class StableDiffusionXLInpaintPipeline( + DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin +): r""" Pipeline for text-to-image generation using Stable Diffusion XL. diff --git a/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index 45c90b0d8d66..b7633acaffa4 100644 --- a/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -62,7 +62,9 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): return noise_cfg -class StableDiffusionXLInstructPix2PixPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin): +class StableDiffusionXLInstructPix2PixPipeline( + DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin +): r""" Pipeline for pixel-level image editing by following text instructions. Based on Stable Diffusion XL. diff --git a/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py b/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py index 175316e72cd4..9bb8569e331d 100644 --- a/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +++ b/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py @@ -123,7 +123,9 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): return noise_cfg -class StableDiffusionXLAdapterPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin): +class StableDiffusionXLAdapterPipeline( + DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +): r""" Pipeline for text-to-image generation using Stable Diffusion augmented with T2I-Adapter https://arxiv.org/abs/2302.08453