Skip to content

Commit

Permalink
model_cache: add ability to load a diffusers model pipeline
Browse files Browse the repository at this point in the history
and update associated things in Generate & Generator to not instantly fail when that happens
  • Loading branch information
keturn committed Nov 10, 2022
1 parent 9f5e496 commit b39d04d
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 16 deletions.
53 changes: 51 additions & 2 deletions ldm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import hashlib
import cv2
import skimage
from diffusers import DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, \
EulerAncestralDiscreteScheduler

from omegaconf import OmegaConf
from ldm.invoke.generator.base import downsampling
Expand Down Expand Up @@ -386,7 +388,10 @@ def process_image(image,seed):
width = width or self.width
height = height or self.height

configure_model_padding(model, seamless, seamless_axes)
if isinstance(model, DiffusionPipeline):
configure_model_padding(model.unet, seamless, seamless_axes)
else:
configure_model_padding(model, seamless, seamless_axes)

assert cfg_scale > 1.0, 'CFG_Scale (-C) must be >1.0'
assert threshold >= 0.0, '--threshold must be >=0.0'
Expand Down Expand Up @@ -930,9 +935,15 @@ def sample_to_image(self, samples):
def sample_to_lowres_estimated_image(self, samples):
return self._make_base().sample_to_lowres_estimated_image(samples)

def _set_sampler(self):
if isinstance(self.model, DiffusionPipeline):
return self._set_scheduler()
else:
return self._set_sampler_legacy()

# very repetitive code - can this be simplified? The KSampler names are
# consistent, at least
def _set_sampler(self):
def _set_sampler_legacy(self):
msg = f'>> Setting Sampler to {self.sampler_name}'
if self.sampler_name == 'plms':
self.sampler = PLMSSampler(self.model, device=self.device)
Expand All @@ -956,6 +967,44 @@ def _set_sampler(self):

print(msg)

def _set_scheduler(self):
msg = f'>> Setting Sampler to {self.sampler_name}'
default = self.model.scheduler
# TODO: Test me! Not all schedulers take the same args.
scheduler_args = dict(
num_train_timesteps=default.num_train_timesteps,
beta_start=default.beta_start,
beta_end=default.beta_end,
beta_schedule=default.beta_schedule,
)
trained_betas = getattr(self.model.scheduler, 'trained_betas')
if trained_betas is not None:
scheduler_args.update(trained_betas=trained_betas)
if self.sampler_name == 'plms':
raise NotImplementedError("What's the diffusers implementation of PLMS?")
elif self.sampler_name == 'ddim':
self.sampler = DDIMScheduler(**scheduler_args)
elif self.sampler_name == 'k_dpm_2_a':
raise NotImplementedError("no diffusers implementation of dpm_2 samplers")
elif self.sampler_name == 'k_dpm_2':
raise NotImplementedError("no diffusers implementation of dpm_2 samplers")
elif self.sampler_name == 'k_euler_a':
self.sampler = EulerAncestralDiscreteScheduler(**scheduler_args)
elif self.sampler_name == 'k_euler':
self.sampler = EulerDiscreteScheduler(**scheduler_args)
elif self.sampler_name == 'k_heun':
raise NotImplementedError("no diffusers implementation of Heun's sampler")
elif self.sampler_name == 'k_lms':
self.sampler = LMSDiscreteScheduler(**scheduler_args)
else:
msg = f'>> Unsupported Sampler: {self.sampler_name}, Defaulting to {default}'

print(msg)

if not hasattr(self.sampler, 'uses_inpainting_model'):
# FIXME: terrible kludge!
self.sampler.uses_inpainting_model = lambda: False

def _load_img(self, img)->Image:
if isinstance(img, Image.Image):
image = img
Expand Down
5 changes: 3 additions & 2 deletions ldm/invoke/generator/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import numpy as np
import torch
from PIL import Image, ImageFilter
from diffusers import DiffusionPipeline
from einops import rearrange
from pytorch_lightning import seed_everything
from tqdm import trange
Expand All @@ -24,9 +25,9 @@ class Generator:
downsampling_factor: int
latent_channels: int
precision: str
model: DiffusionWrapper
model: DiffusionWrapper | DiffusionPipeline

def __init__(self, model: DiffusionWrapper, precision: str):
def __init__(self, model: DiffusionWrapper | DiffusionPipeline, precision: str):
self.model = model
self.precision = precision
self.seed = None
Expand Down
28 changes: 28 additions & 0 deletions ldm/invoke/generator/diffusers_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import secrets
import warnings
from dataclasses import dataclass
from typing import List, Optional, Union, Callable

Expand Down Expand Up @@ -309,6 +310,28 @@ def get_text_embeddings(self,
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
return text_embeddings

def get_learned_conditioning(self, c: List[List[str]], return_tokens=True,
fragment_weights=None, **kwargs):
"""
Compatibility function for ldm.models.diffusion.ddpm.LatentDiffusion.
"""
assert return_tokens == True
if fragment_weights:
weights = fragment_weights[0]
if any(weight != 1.0 for weight in weights):
warnings.warn(f"fragment weights not implemented yet {fragment_weights}", stacklevel=2)

if kwargs:
warnings.warn(f"unsupported args {kwargs}", stacklevel=2)

text_fragments = c[0]
text_input = self._tokenize(text_fragments)

with torch.inference_mode():
token_ids = text_input.input_ids.to(self.text_encoder.device)
text_embeddings = self.text_encoder(token_ids)[0]
return text_embeddings, text_input.input_ids

@torch.inference_mode()
def _tokenize(self, prompt: Union[str, List[str]]):
return self.tokenizer(
Expand All @@ -319,6 +342,11 @@ def _tokenize(self, prompt: Union[str, List[str]]):
return_tensors="pt",
)

@property
def channels(self) -> int:
"""Compatible with DiffusionWrapper"""
return self.unet.in_channels

def prepare_latents(self, latents, batch_size, height, width, generator, dtype):
# get the initial random noise unless the user supplied it
# Unlike in other pipelines, latents need to be generated in the target device
Expand Down
13 changes: 2 additions & 11 deletions ldm/invoke/generator/txt2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,8 @@ def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
self.perlin = perlin
uc, c, extra_conditioning_info = conditioning

# FIXME: this should probably be either passed in to __init__ instead of model & precision,
# or be constructed in __init__ from those inputs.
pipeline = StableDiffusionGeneratorPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
revision="fp16", torch_dtype=torch.float16,
safety_checker=None, # TODO
# scheduler=sampler + ddim_eta, # TODO
# TODO: local_files_only=True
)
pipeline.unet.to("cuda")
pipeline.vae.to("cuda")
pipeline = self.model
# TODO: customize a new pipeline for the given sampler (Scheduler)

def make_image(x_T) -> PIL.Image.Image:
# FIXME: restore free_gpu_mem functionality
Expand Down
41 changes: 40 additions & 1 deletion ldm/invoke/model_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
below a preset minimum, the least recently used model will be
cleared and loaded from disk when next needed.
'''
from pathlib import Path

import torch
import os
Expand All @@ -18,6 +19,8 @@
from sys import getrefcount
from omegaconf import OmegaConf
from omegaconf.errors import ConfigAttributeError

from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline
from ldm.util import instantiate_from_config

DEFAULT_MAX_MODELS=2
Expand Down Expand Up @@ -268,7 +271,43 @@ def _load_ckpt_model(self, mconfig):
return model, width, height, model_hash

def _load_diffusers_model(self, mconfig):
raise NotImplementedError() # return pipeline, width, height, model_hash
pipeline_args = {}

if 'repo_name' in mconfig:
name_or_path = mconfig['repo_name']
model_hash = "FIXME"
# model_hash = huggingface_hub.get_hf_file_metadata(url).commit_hash
elif 'path' in mconfig:
name_or_path = Path(mconfig['path'])
# FIXME: What should the model_hash be? A hash of the unet weights? Of all files of all
# the submodels hashed together? The commit ID from the repo?
model_hash = "FIXME TOO"
else:
raise ValueError("Model config must specify either repo_name or path.")

print(f'>> Loading diffusers model from {name_or_path}')

if self.precision == 'float16':
print(' | Using faster float16 precision')
pipeline_args.update(revision="fp16", torch_dtype=torch.float16)
else:
# TODO: more accurately, "using the model's default precision."
# How do we find out what that is?
print(' | Using more accurate float32 precision')

pipeline = StableDiffusionGeneratorPipeline.from_pretrained(
name_or_path,
safety_checker=None, # TODO
# TODO: alternate VAE
# TODO: local_files_only=True
**pipeline_args
)
pipeline.to(self.device)

width = pipeline.vae.sample_size
height = pipeline.vae.sample_size

return pipeline, width, height, model_hash

def offload_model(self, model_name:str):
'''
Expand Down

0 comments on commit b39d04d

Please sign in to comment.