Skip to content

Commit

Permalink
diffusers: restore prompt weighting feature
Browse files Browse the repository at this point in the history
  • Loading branch information
keturn committed Nov 11, 2022
1 parent d121406 commit 8ec60b3
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 22 deletions.
28 changes: 10 additions & 18 deletions ldm/invoke/generator/diffusers_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer

from ldm.modules.encoders.modules import WeightedFrozenCLIPEmbedder


@dataclass
class PipelineIntermediateState:
Expand Down Expand Up @@ -76,6 +78,11 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
# InvokeAI's interface for text embeddings and whatnot
self.clip_embedder = WeightedFrozenCLIPEmbedder(
tokenizer=self.tokenizer,
transformer=self.text_encoder
)

def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Expand Down Expand Up @@ -312,27 +319,12 @@ def get_text_embeddings(self,
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
return text_embeddings

def get_learned_conditioning(self, c: List[List[str]], return_tokens=True,
fragment_weights=None, **kwargs):
@torch.inference_mode()
def get_learned_conditioning(self, c: List[List[str]], *, return_tokens=True, fragment_weights=None):
"""
Compatibility function for ldm.models.diffusion.ddpm.LatentDiffusion.
"""
assert return_tokens == True
if fragment_weights:
weights = fragment_weights[0]
if any(weight != 1.0 for weight in weights):
warnings.warn(f"fragment weights not implemented yet {fragment_weights}", stacklevel=2)

if kwargs:
warnings.warn(f"unsupported args {kwargs}", stacklevel=2)

text_fragments = c[0]
text_input = self._tokenize(text_fragments)

with torch.inference_mode():
token_ids = text_input.input_ids.to(self.text_encoder.device)
text_embeddings = self.text_encoder(token_ids)[0]
return text_embeddings, text_input.input_ids
return self.clip_embedder.encode(c, return_tokens=return_tokens, fragment_weights=fragment_weights)

@torch.inference_mode()
def _tokenize(self, prompt: Union[str, List[str]]):
Expand Down
12 changes: 8 additions & 4 deletions ldm/modules/encoders/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,17 +240,17 @@ class FrozenCLIPEmbedder(AbstractEncoder):
def __init__(
self,
version='openai/clip-vit-large-patch14',
device=choose_torch_device(),
max_length=77,
tokenizer=None,
transformer=None,
):
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(
self.tokenizer = tokenizer or CLIPTokenizer.from_pretrained(
version, local_files_only=True
)
self.transformer = CLIPTextModel.from_pretrained(
self.transformer = transformer or CLIPTextModel.from_pretrained(
version, local_files_only=True
)
self.device = device
self.max_length = max_length
self.freeze()

Expand Down Expand Up @@ -456,6 +456,10 @@ def forward(self, text, **kwargs):
def encode(self, text, **kwargs):
return self(text, **kwargs)

@property
def device(self):
return self.transformer.device

class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):

fragment_weights_key = "fragment_weights"
Expand Down

0 comments on commit 8ec60b3

Please sign in to comment.