Skip to content

Commit

Permalink
[Textual inversion] Relax loading textual inversion (huggingface#4903)
Browse files Browse the repository at this point in the history
* [Textual inversion] Relax loading textual inversion

* up
  • Loading branch information
patrickvonplaten authored Sep 6, 2023
1 parent 9119b39 commit bea99fc
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 17 deletions.
34 changes: 22 additions & 12 deletions loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,8 @@ def load_textual_inversion(
self,
pretrained_model_name_or_path: Union[str, List[str], Dict[str, torch.Tensor], List[Dict[str, torch.Tensor]]],
token: Optional[Union[str, List[str]]] = None,
tokenizer: Optional[PreTrainedTokenizer] = None,
text_encoder: Optional[PreTrainedModel] = None,
**kwargs,
):
r"""
Expand All @@ -684,6 +686,11 @@ def load_textual_inversion(
token (`str` or `List[str]`, *optional*):
Override the token to use for the textual inversion weights. If `pretrained_model_name_or_path` is a
list, then `token` must also be a list of equal length.
text_encoder ([`~transformers.CLIPTextModel`], *optional*):
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
If not specified, function will take self.tokenizer.
tokenizer ([`~transformers.CLIPTokenizer`], *optional*):
A `CLIPTokenizer` to tokenize text. If not specified, function will take self.tokenizer.
weight_name (`str`, *optional*):
Name of a custom weight file. This should be used when:
Expand Down Expand Up @@ -757,15 +764,18 @@ def load_textual_inversion(
```
"""
if not hasattr(self, "tokenizer") or not isinstance(self.tokenizer, PreTrainedTokenizer):
tokenizer = tokenizer or getattr(self, "tokenizer", None)
text_encoder = text_encoder or getattr(self, "text_encoder", None)

if tokenizer is None:
raise ValueError(
f"{self.__class__.__name__} requires `self.tokenizer` of type `PreTrainedTokenizer` for calling"
f"{self.__class__.__name__} requires `self.tokenizer` or passing a `tokenizer` of type `PreTrainedTokenizer` for calling"
f" `{self.load_textual_inversion.__name__}`"
)

if not hasattr(self, "text_encoder") or not isinstance(self.text_encoder, PreTrainedModel):
if text_encoder is None:
raise ValueError(
f"{self.__class__.__name__} requires `self.text_encoder` of type `PreTrainedModel` for calling"
f"{self.__class__.__name__} requires `self.text_encoder` or passing a `text_encoder` of type `PreTrainedModel` for calling"
f" `{self.load_textual_inversion.__name__}`"
)

Expand Down Expand Up @@ -830,7 +840,7 @@ def load_textual_inversion(
token_ids_and_embeddings = []

for pretrained_model_name_or_path, token in zip(pretrained_model_name_or_paths, tokens):
if not isinstance(pretrained_model_name_or_path, dict):
if not isinstance(pretrained_model_name_or_path, (dict, torch.Tensor)):
# 1. Load textual inversion file
model_file = None
# Let's first try to load .safetensors weights
Expand Down Expand Up @@ -897,18 +907,18 @@ def load_textual_inversion(
else:
token = loaded_token

embedding = embedding.to(dtype=self.text_encoder.dtype, device=self.text_encoder.device)
embedding = embedding.to(dtype=text_encoder.dtype, device=text_encoder.device)

# 3. Make sure we don't mess up the tokenizer or text encoder
vocab = self.tokenizer.get_vocab()
vocab = tokenizer.get_vocab()
if token in vocab:
raise ValueError(
f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder."
)
elif f"{token}_1" in vocab:
multi_vector_tokens = [token]
i = 1
while f"{token}_{i}" in self.tokenizer.added_tokens_encoder:
while f"{token}_{i}" in tokenizer.added_tokens_encoder:
multi_vector_tokens.append(f"{token}_{i}")
i += 1

Expand All @@ -926,16 +936,16 @@ def load_textual_inversion(
embeddings = [embedding[0]] if len(embedding.shape) > 1 else [embedding]

# add tokens and get ids
self.tokenizer.add_tokens(tokens)
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
tokenizer.add_tokens(tokens)
token_ids = tokenizer.convert_tokens_to_ids(tokens)
token_ids_and_embeddings += zip(token_ids, embeddings)

logger.info(f"Loaded textual inversion embedding for {token}.")

# resize token embeddings and set all new embeddings
self.text_encoder.resize_token_embeddings(len(self.tokenizer))
text_encoder.resize_token_embeddings(len(tokenizer))
for token_id, embedding in token_ids_and_embeddings:
self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding
text_encoder.get_input_embeddings().weight.data[token_id] = embedding

# offload back
if is_model_cpu_offload:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
return noise_cfg


class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin):
class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin):
r"""
Pipeline for text-to-image generation using Stable Diffusion XL.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,9 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
return noise_cfg


class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin):
class StableDiffusionXLImg2ImgPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
):
r"""
Pipeline for text-to-image generation using Stable Diffusion XL.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,9 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool
return mask, masked_image


class StableDiffusionXLInpaintPipeline(DiffusionPipeline, LoraLoaderMixin, FromSingleFileMixin):
class StableDiffusionXLInpaintPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
):
r"""
Pipeline for text-to-image generation using Stable Diffusion XL.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
return noise_cfg


class StableDiffusionXLInstructPix2PixPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin):
class StableDiffusionXLInstructPix2PixPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
):
r"""
Pipeline for pixel-level image editing by following text instructions. Based on Stable Diffusion XL.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,9 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
return noise_cfg


class StableDiffusionXLAdapterPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin):
class StableDiffusionXLAdapterPipeline(
DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
):
r"""
Pipeline for text-to-image generation using Stable Diffusion augmented with T2I-Adapter
https://arxiv.org/abs/2302.08453
Expand Down

0 comments on commit bea99fc

Please sign in to comment.