diff --git a/ldm/invoke/generator/diffusers_pipeline.py b/ldm/invoke/generator/diffusers_pipeline.py index 20caecada0b..6846ff84567 100644 --- a/ldm/invoke/generator/diffusers_pipeline.py +++ b/ldm/invoke/generator/diffusers_pipeline.py @@ -1,5 +1,4 @@ import secrets -import warnings from dataclasses import dataclass from typing import List, Optional, Union, Callable @@ -11,6 +10,8 @@ from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from ldm.modules.encoders.modules import WeightedFrozenCLIPEmbedder + @dataclass class PipelineIntermediateState: @@ -76,6 +77,11 @@ def __init__( safety_checker=safety_checker, feature_extractor=feature_extractor, ) + # InvokeAI's interface for text embeddings and whatnot + self.clip_embedder = WeightedFrozenCLIPEmbedder( + tokenizer=self.tokenizer, + transformer=self.text_encoder + ) def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): r""" @@ -312,27 +318,12 @@ def get_text_embeddings(self, text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) return text_embeddings - def get_learned_conditioning(self, c: List[List[str]], return_tokens=True, - fragment_weights=None, **kwargs): + @torch.inference_mode() + def get_learned_conditioning(self, c: List[List[str]], *, return_tokens=True, fragment_weights=None): """ Compatibility function for ldm.models.diffusion.ddpm.LatentDiffusion. """ - assert return_tokens == True - if fragment_weights: - weights = fragment_weights[0] - if any(weight != 1.0 for weight in weights): - warnings.warn(f"fragment weights not implemented yet {fragment_weights}", stacklevel=2) - - if kwargs: - warnings.warn(f"unsupported args {kwargs}", stacklevel=2) - - text_fragments = c[0] - text_input = self._tokenize(text_fragments) - - with torch.inference_mode(): - token_ids = text_input.input_ids.to(self.text_encoder.device) - text_embeddings = self.text_encoder(token_ids)[0] - return text_embeddings, text_input.input_ids + return self.clip_embedder.encode(c, return_tokens=return_tokens, fragment_weights=fragment_weights) @torch.inference_mode() def _tokenize(self, prompt: Union[str, List[str]]): diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py index cf8644a7fb2..263d00bdb60 100644 --- a/ldm/modules/encoders/modules.py +++ b/ldm/modules/encoders/modules.py @@ -240,17 +240,17 @@ class FrozenCLIPEmbedder(AbstractEncoder): def __init__( self, version='openai/clip-vit-large-patch14', - device=choose_torch_device(), max_length=77, + tokenizer=None, + transformer=None, ): super().__init__() - self.tokenizer = CLIPTokenizer.from_pretrained( + self.tokenizer = tokenizer or CLIPTokenizer.from_pretrained( version, local_files_only=True ) - self.transformer = CLIPTextModel.from_pretrained( + self.transformer = transformer or CLIPTextModel.from_pretrained( version, local_files_only=True ) - self.device = device self.max_length = max_length self.freeze() @@ -456,6 +456,10 @@ def forward(self, text, **kwargs): def encode(self, text, **kwargs): return self(text, **kwargs) + @property + def device(self): + return self.transformer.device + class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder): fragment_weights_key = "fragment_weights"