From b9ca4060d527cfd29b1a1a23cf905e79e5105cd6 Mon Sep 17 00:00:00 2001 From: Mishig Davaadorj Date: Mon, 19 Sep 2022 09:28:00 +0000 Subject: [PATCH 01/23] WIP: flax FlaxDiffusionPipeline & FlaxStableDiffusionPipeline --- src/diffusers/models/__init__.py | 1 + src/diffusers/pipeline_flax_utils.py | 428 ++++++++++++++++++ .../pipelines/stable_diffusion/__init__.py | 28 ++ .../pipeline_flax_stable_diffusion.py | 245 ++++++++++ 4 files changed, 702 insertions(+) create mode 100644 src/diffusers/pipeline_flax_utils.py create mode 100644 src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index e0ac5c8d548b..3c3656a572f0 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -14,4 +14,5 @@ from .unet_2d import UNet2DModel from .unet_2d_condition import UNet2DConditionModel +from .unet_2d_condition_flax import FlaxUNet2DConditionModel from .vae import AutoencoderKL, VQModel diff --git a/src/diffusers/pipeline_flax_utils.py b/src/diffusers/pipeline_flax_utils.py new file mode 100644 index 000000000000..ae5b72703548 --- /dev/null +++ b/src/diffusers/pipeline_flax_utils.py @@ -0,0 +1,428 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import inspect +import os +from typing import List, Optional, Union + +import numpy as np + +import diffusers +import flax +import jax.numpy as jnp +import PIL +from huggingface_hub import snapshot_download +from PIL import Image +from tqdm.auto import tqdm + +from .configuration_utils import ConfigMixin +from .modeling_flax_utils import FLAX_WEIGHTS_NAME +from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME +from .utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, logging + + +INDEX_FILE = "diffusion_flax_model.bin" + + +logger = logging.get_logger(__name__) + + +LOADABLE_CLASSES = { + "diffusers": { + "FlaxModelMixin": ["save_pretrained", "from_pretrained"], + "SchedulerMixin": ["save_config", "from_config"], + "FlaxDiffusionPipeline": ["save_pretrained", "from_pretrained"], + }, + "transformers": { + "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"], + "PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"], + "FlaxPreTrainedModel": ["save_pretrained", "from_pretrained"], + "FeatureExtractionMixin": ["save_pretrained", "from_pretrained"], + }, +} + +ALL_IMPORTABLE_CLASSES = {} +for library in LOADABLE_CLASSES: + ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library]) + + +@flax.structs.dataclass +class FlaxImagePipelineOutput(BaseOutput): + """ + Output class for image pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + + +class FlaxDiffusionPipeline(ConfigMixin): + r""" + Base class for all models. + + [`FlaxDiffusionPipeline`] takes care of storing all components (models, schedulers, processors) for diffusion + pipelines and handles methods for loading, downloading and saving models as well as a few methods common to all + pipelines to: + + - enabling/disabling the progress bar for the denoising iteration + + Class attributes: + + - **config_name** ([`str`]) -- name of the config file that will store the class and module names of all + components of the diffusion pipeline. + """ + config_name = "model_index.json" + + def register_modules(self, **kwargs): + # import it here to avoid circular import + from diffusers import pipelines + + for name, module in kwargs.items(): + # retrieve library + library = module.__module__.split(".")[0] + + # check if the module is a pipeline module + pipeline_dir = module.__module__.split(".")[-2] + path = module.__module__.split(".") + is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir) + + # if library is not in LOADABLE_CLASSES, then it is a custom module. + # Or if it's a pipeline module, then the module is inside the pipeline + # folder so we set the library to module name. + if library not in LOADABLE_CLASSES or is_pipeline_module: + library = pipeline_dir + + # retrieve class_name + class_name = module.__class__.__name__ + + register_dict = {name: (library, class_name)} + + # save model index config + self.register_to_config(**register_dict) + + # set models + setattr(self, name, module) + + def save_pretrained(self, save_directory: Union[str, os.PathLike]): + """ + Save all variables of the pipeline that can be saved and loaded as well as the pipelines configuration file to + a directory. A pipeline variable can be saved and loaded if its class implements both a save and loading + method. The pipeline can easily be re-loaded using the `[`~FlaxDiffusionPipeline.from_pretrained`]` class + method. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to which to save. Will be created if it doesn't exist. + """ + self.save_config(save_directory) + + model_index_dict = dict(self.config) + model_index_dict.pop("_class_name") + model_index_dict.pop("_diffusers_version") + model_index_dict.pop("_module", None) + + for pipeline_component_name in model_index_dict.keys(): + sub_model = getattr(self, pipeline_component_name) + model_cls = sub_model.__class__ + + save_method_name = None + # search for the model's base class in LOADABLE_CLASSES + for library_name, library_classes in LOADABLE_CLASSES.items(): + library = importlib.import_module(library_name) + for base_class, save_load_methods in library_classes.items(): + class_candidate = getattr(library, base_class) + if issubclass(model_cls, class_candidate): + # if we found a suitable base class in LOADABLE_CLASSES then grab its save method + save_method_name = save_load_methods[0] + break + if save_method_name is not None: + break + + save_method = getattr(sub_model, save_method_name) + save_method(os.path.join(save_directory, pipeline_component_name)) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): + r""" + Instantiate a PyTorch diffusion pipeline from pre-trained pipeline weights. + + The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). + + The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come + pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning + task. + + The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those + weights are discarded. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *repo id* of a pretrained pipeline hosted inside a model repo on + https://huggingface.co/ Valid repo ids have to be located under a user or organization name, like + `CompVis/ldm-text2im-large-256`. + - A path to a *directory* containing pipeline weights saved using + [`~FlaxDiffusionPipeline.save_pretrained`], e.g., `./my_pipeline_directory/`. + dtype (`str` or `jnp.dtype`, *optional*): + Override the default `jnp.dtype` and load the model under this dtype. If `"auto"` is passed the dtype + will be automatically derived from the model's weights. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `huggingface-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + mirror (`str`, *optional*): + Mirror source to accelerate downloads in China. If you are from China and have an accessibility + problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. + Please refer to the mirror site for more information. specify the folder name here. + + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the + specific pipeline class. The overritten components are then directly passed to the pipelines `__init__` + method. See example below for more information. + + + + Passing `use_auth_token=True`` is required when you want to use a private model, *e.g.* + `"CompVis/stable-diffusion-v1-4"` + + + + + + Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use + this method in a firewalled environment. + + + + Examples: + + ```py + >>> from diffusers import FlaxDiffusionPipeline + + >>> # Download pipeline from huggingface.co and cache. + >>> pipeline = FlaxDiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256") + + >>> # Download pipeline that requires an authorization token + >>> # For more information on access tokens, please refer to this section + >>> # of the documentation](https://huggingface.co/docs/hub/security-tokens) + >>> pipeline = FlaxDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True) + + >>> # Download pipeline, but overwrite scheduler + >>> from diffusers import LMSDiscreteScheduler + + >>> scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") + >>> pipeline = FlaxDiffusionPipeline.from_pretrained( + ... "CompVis/stable-diffusion-v1-4", scheduler=scheduler, use_auth_token=True + ... ) + ``` + """ + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", False) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + inference_state_dict = kwargs.pop("inference_state_dict", None) + dtype = kwargs.pop("dtype", None) + + # 1. Download the checkpoints and configs + # use snapshot download here to get it working from from_pretrained + if not os.path.isdir(pretrained_model_name_or_path): + config_dict = cls.get_config_dict( + pretrained_model_name_or_path, + cache_dir=cache_dir, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + ) + # make sure we only download sub-folders and `diffusers` filenames + folder_names = [k for k in config_dict.keys() if not k.startswith("_")] + allow_patterns = [os.path.join(k, "*") for k in folder_names] + allow_patterns += [FLAX_WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, cls.config_name] + + # download all allow_patterns + cached_folder = snapshot_download( + pretrained_model_name_or_path, + cache_dir=cache_dir, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + allow_patterns=allow_patterns, + ) + else: + cached_folder = pretrained_model_name_or_path + + config_dict = cls.get_config_dict(cached_folder) + + # 2. Load the pipeline class, if using custom module then load it from the hub + # if we load from explicit class, let's use it + if cls != FlaxDiffusionPipeline: + pipeline_class = cls + else: + diffusers_module = importlib.import_module(cls.__module__.split(".")[0]) + pipeline_class = getattr(diffusers_module, config_dict["_class_name"]) + + # some modules can be passed directly to the init + # in this case they are already instantiated in `kwargs` + # extract them here + expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys()) + passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} + + init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) + + init_kwargs = {} + + # import it here to avoid circular import + from diffusers import pipelines + + # 3. Load each module in the pipeline + for name, (library_name, class_name) in init_dict.items(): + is_pipeline_module = hasattr(pipelines, library_name) + loaded_sub_model = None + + # if the model is in a pipeline module, then we load it from the pipeline + if name in passed_class_obj: + # 1. check that passed_class_obj has correct parent class + if not is_pipeline_module: + library = importlib.import_module(library_name) + class_obj = getattr(library, class_name) + importable_classes = LOADABLE_CLASSES[library_name] + class_candidates = {c: getattr(library, c) for c in importable_classes.keys()} + + expected_class_obj = None + for class_name, class_candidate in class_candidates.items(): + if issubclass(class_obj, class_candidate): + expected_class_obj = class_candidate + + if not issubclass(passed_class_obj[name].__class__, expected_class_obj): + raise ValueError( + f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be" + f" {expected_class_obj}" + ) + else: + logger.warn( + f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it" + " has the correct type" + ) + + # set passed class object + loaded_sub_model = passed_class_obj[name] + elif is_pipeline_module: + pipeline_module = getattr(pipelines, library_name) + class_obj = getattr(pipeline_module, class_name) + importable_classes = ALL_IMPORTABLE_CLASSES + class_candidates = {c: class_obj for c in importable_classes.keys()} + else: + # else we just import it from the library. + library = importlib.import_module(library_name) + class_obj = getattr(library, class_name) + importable_classes = LOADABLE_CLASSES[library_name] + class_candidates = {c: getattr(library, c) for c in importable_classes.keys()} + + if loaded_sub_model is None: + load_method_name = None + for class_name, class_candidate in class_candidates.items(): + if issubclass(class_obj, class_candidate): + load_method_name = importable_classes[class_name][1] + + load_method = getattr(class_obj, load_method_name) + + loading_kwargs = {} + if issubclass(class_obj, flax.linen.Module): + loading_kwargs["dtype"] = dtype + + # check if the module is in a subdirectory + if os.path.isdir(os.path.join(cached_folder, name)): + if issubclass(class_obj, flax.linen.Module): + loaded_sub_model, loaded_params = load_method( + os.path.join(cached_folder, name), **loading_kwargs + ) + params_key = f"{name}_params" + if params_key not in inference_state_dict: + inference_state_dict[params_key] = loaded_params + else: + loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs) + else: + # else load from the root directory + if issubclass(class_obj, flax.linen.Module): + loaded_sub_model, loaded_params = load_method(cached_folder, **loading_kwargs) + params_key = f"{name}_params" + if params_key not in inference_state_dict: + inference_state_dict[params_key] = loaded_params + else: + loaded_sub_model = load_method(cached_folder, **loading_kwargs) + + init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...) + + # 4. Instantiate the pipeline + # TODO: fix hard-coded `StableDifusion.InferenceState`, it should be inferred as `{XYZ_Pipeline}.InferenceState` + from .pipelines.stable_diffusion import InferenceState + + inference_state = InferenceState(**inference_state_dict) + model = pipeline_class(**init_kwargs, dtype=dtype, inference_state=inference_state) + return model + + @staticmethod + def numpy_to_pil(images): + """ + Convert a numpy image or a batch of images to a PIL image. + """ + if images.ndim == 3: + images = images[None, ...] + images = (images * 255).round().astype("uint8") + pil_images = [Image.fromarray(image) for image in images] + + return pil_images + + # TODO: make it compatible with jax.lax + def progress_bar(self, iterable): + if not hasattr(self, "_progress_bar_config"): + self._progress_bar_config = {} + elif not isinstance(self._progress_bar_config, dict): + raise ValueError( + f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." + ) + + return tqdm(iterable, **self._progress_bar_config) + + def set_progress_bar_config(self, **kwargs): + self._progress_bar_config = kwargs diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index 5ffda93f1721..e8eeca5cae06 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -3,9 +3,11 @@ import numpy as np +import flax import PIL from PIL import Image +from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState from ...utils import BaseOutput, is_onnx_available, is_transformers_available @@ -27,6 +29,32 @@ class StableDiffusionPipelineOutput(BaseOutput): nsfw_content_detected: List[bool] +@flax.struct.dataclass +class FlaxStableDiffusionPipelineOutput(BaseOutput): + """ + Output class for Stable Diffusion pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + nsfw_content_detected (`List[bool]`) + List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + nsfw_content_detected: List[bool] + + +@flax.struct.dataclass +class InferenceState: + text_encoder_params: flax.core.FrozenDict + unet_params: flax.core.FrozenDict + vae_params: flax.core.FrozenDict + scheduler_state: PNDMSchedulerState + + if is_transformers_available(): from .pipeline_stable_diffusion import StableDiffusionPipeline from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py new file mode 100644 index 000000000000..b3679d33dd75 --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -0,0 +1,245 @@ +import inspect +import warnings +from typing import List, Optional, Union + +import jax +import jax.numpy as jnp +from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel + +from ...configuration_utils import FrozenDict +from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel +from ...pipeline_flax_utils import FlaxDiffusionPipeline +from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from . import FlaxStableDiffusionPipelineOutput, InferenceState +from .flax_safety_checker import FlaxStableDiffusionSafetyChecker + + +class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): + r""" + Pipeline for text-to-image generation using Stable Diffusion. + + This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`FlaxAutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`FlaxCLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.FlaxCLIPTextModel), + specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`FlaxStableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offsensive or harmful. + Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + def __init__( + self, + vae: FlaxAutoencoderKL, + text_encoder: FlaxCLIPTextModel, + tokenizer: CLIPTokenizer, + unet: FlaxUNet2DConditionModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + safety_checker: FlaxStableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + inference_state: InferenceState, + dtype: jnp.dtype = jnp.float32, + ): + super().__init__() + scheduler = scheduler.set_format("np") + self.dtype = dtype + self.inference_state = inference_state + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + warnings.warn( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 istead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file", + DeprecationWarning, + ) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + + def __call__( + self, + prompt: Union[str, List[str]], + prng_seed: jax.random.PRNGKey, + height: Optional[int] = 512, + width: Optional[int] = 512, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + eta: Optional[float] = 0.0, + latents: Optional[jnp.array] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + debug: bool = False, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`jnp.array`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of + a plain tuple. + + Returns: + [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a + `tuple. When returning a tuple, the first element is a list with the generated images, and the second + element is a list of `bool`s denoting whether the corresponding generated image likely represents + "not-safe-for-work" (nsfw) content, according to the `safety_checker`. + """ + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + inference_state = self.inference_state + + # get prompt text embeddings + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + text_embeddings = self.text_encoder(text_input.input_ids, params=inference_state.text_encoder_params)[0] + + # TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0` + # implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0` + max_length = text_input.input_ids.shape[-1] + uncond_input = self.tokenizer( + [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" + ) + uncond_embeddings = self.text_encoder(uncond_input.input_ids, params=inference_state.text_encoder_params)[0] + context = jnp.concatenate([uncond_embeddings, text_embeddings]) + + # TODO: check it because the shape is different from Pytorhc StableDiffusionPipeline + latents_shape = ( + text_input.input_ids.shape[0], + self.unet.sample_size, + self.unet.sample_size, + self.unet.in_channels, + ) + if latents is None: + latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=self.dtype) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + + def loop_body(step, args): + latents, scheduler_state = args + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + latents_input = jnp.concatenate([latents] * 2) + + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + timestep = jnp.broadcast_to(t, latents_input.shape[0]) + + # predict the noise residual + noise_pred = self.unet.apply( + {"params": inference_state.unet_params}, + jnp.array(latents_input), + jnp.array(timestep, dtype=jnp.int32), + encoder_hidden_states=context, + rngs={}, + ).sample + # perform guidance + noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0) + noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents) + latents = latents["prev_sample"] + return latents, scheduler_state + + scheduler_state = inference_state.scheduler_state + num_inference_steps = len(scheduler_state.timesteps) + if debug: + # run with python for loop + for i in range(num_inference_steps): + latents, scheduler_state = loop_body(i, (latents, scheduler_state)) + else: + latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state)) + + # scale and decode the image latents with vae + latents = 1 / 0.18215 * latents + # TODO: check when flax vae gets merged into main + image = self.vae.decode(latents, params=inference_state.vae_params).sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + + # run safety checker + # TODO: check when flax safety checker gets merged into main + safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np") + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_cheker_input.pixel_values, params=inference_state.safety_params + ) + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return FlaxStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) From 30abc633dca496d1827b8f80bcb2a5e8038cd5e0 Mon Sep 17 00:00:00 2001 From: Mishig Davaadorj Date: Mon, 19 Sep 2022 09:32:21 +0000 Subject: [PATCH 02/23] todo comment --- src/diffusers/pipeline_flax_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipeline_flax_utils.py b/src/diffusers/pipeline_flax_utils.py index ae5b72703548..37d4e1418937 100644 --- a/src/diffusers/pipeline_flax_utils.py +++ b/src/diffusers/pipeline_flax_utils.py @@ -122,6 +122,7 @@ def register_modules(self, **kwargs): setattr(self, name, module) def save_pretrained(self, save_directory: Union[str, os.PathLike]): + # TODO: handle inference_state """ Save all variables of the pipeline that can be saved and loaded as well as the pipelines configuration file to a directory. A pipeline variable can be saved and loaded if its class implements both a save and loading From 4b2becb89057aa0f7d21f7715cb0902fb9a08cdc Mon Sep 17 00:00:00 2001 From: Mishig Davaadorj Date: Mon, 19 Sep 2022 15:38:50 +0000 Subject: [PATCH 03/23] Fix imports --- src/diffusers/models/__init__.py | 1 + src/diffusers/pipeline_flax_utils.py | 2 +- .../stable_diffusion/pipeline_flax_stable_diffusion.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 3c3656a572f0..a6007d15b9db 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -16,3 +16,4 @@ from .unet_2d_condition import UNet2DConditionModel from .unet_2d_condition_flax import FlaxUNet2DConditionModel from .vae import AutoencoderKL, VQModel +from .vae_flax import FlaxAutoencoderKL \ No newline at end of file diff --git a/src/diffusers/pipeline_flax_utils.py b/src/diffusers/pipeline_flax_utils.py index 37d4e1418937..d61472949a73 100644 --- a/src/diffusers/pipeline_flax_utils.py +++ b/src/diffusers/pipeline_flax_utils.py @@ -60,7 +60,7 @@ ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library]) -@flax.structs.dataclass +@flax.struct.dataclass class FlaxImagePipelineOutput(BaseOutput): """ Output class for image pipelines. diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index b3679d33dd75..896ffabbeec3 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -11,7 +11,7 @@ from ...pipeline_flax_utils import FlaxDiffusionPipeline from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from . import FlaxStableDiffusionPipelineOutput, InferenceState -from .flax_safety_checker import FlaxStableDiffusionSafetyChecker +from .safety_checker_flax import FlaxStableDiffusionSafetyChecker class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): From 7f0e4297943ae4c01ccd8505801e2d963b1eecdc Mon Sep 17 00:00:00 2001 From: Mishig Davaadorj Date: Mon, 19 Sep 2022 15:45:33 +0000 Subject: [PATCH 04/23] Fix imports --- src/diffusers/__init__.py | 8 ++++++++ src/diffusers/pipelines/__init__.py | 5 ++++- src/diffusers/pipelines/stable_diffusion/__init__.py | 1 + 3 files changed, 13 insertions(+), 1 deletion(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 776551c7136d..9b71955a7107 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -74,5 +74,13 @@ FlaxPNDMScheduler, FlaxScoreSdeVeScheduler, ) + from .pipeline_flax_utils import FlaxDiffusionPipeline else: from .utils.dummy_flax_objects import * # noqa F403 + +if is_flax_available() and is_transformers_available(): + from .pipelines import FlaxStableDiffusionPipeline +else: + pass + # TODO: dummy_flax_and_transformers_objects + # from .utils.dummy_flax_and_transformers_objects import * # noqa F403 \ No newline at end of file diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 3e2aeb4fb2b7..0db153c864ed 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -1,4 +1,4 @@ -from ..utils import is_onnx_available, is_transformers_available +from ..utils import is_onnx_available, is_transformers_available, is_flax_available from .ddim import DDIMPipeline from .ddpm import DDPMPipeline from .latent_diffusion_uncond import LDMPipeline @@ -17,3 +17,6 @@ if is_transformers_available() and is_onnx_available(): from .stable_diffusion import StableDiffusionOnnxPipeline + +if is_flax_available(): + from .stable_diffusion import FlaxStableDiffusionPipeline \ No newline at end of file diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index eb2b51155a39..7c30ff0e3d20 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -66,3 +66,4 @@ class InferenceState: if is_transformers_available() and is_flax_available(): from .safety_checker_flax import FlaxStableDiffusionSafetyChecker + from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline From d9e2ae18623686ccb1974a9d1a42af4e515b41ba Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 19 Sep 2022 15:56:17 +0000 Subject: [PATCH 05/23] add dummies --- .../dummy_flax_and_transformers_objects.py | 11 +++++++++++ src/diffusers/utils/dummy_flax_objects.py | 18 ++++++++++++++++-- 2 files changed, 27 insertions(+), 2 deletions(-) create mode 100644 src/diffusers/utils/dummy_flax_and_transformers_objects.py diff --git a/src/diffusers/utils/dummy_flax_and_transformers_objects.py b/src/diffusers/utils/dummy_flax_and_transformers_objects.py new file mode 100644 index 000000000000..51ee3b184816 --- /dev/null +++ b/src/diffusers/utils/dummy_flax_and_transformers_objects.py @@ -0,0 +1,11 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +# flake8: noqa + +from ..utils import DummyObject, requires_backends + + +class FlaxStableDiffusionPipeline(metaclass=DummyObject): + _backends = ["flax", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax", "transformers"]) diff --git a/src/diffusers/utils/dummy_flax_objects.py b/src/diffusers/utils/dummy_flax_objects.py index 9615afb6f920..424e4f3bf6d8 100644 --- a/src/diffusers/utils/dummy_flax_objects.py +++ b/src/diffusers/utils/dummy_flax_objects.py @@ -11,6 +11,20 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) +class FlaxUNet2DConditionModel(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxAutoencoderKL(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + class FlaxDDIMScheduler(metaclass=DummyObject): _backends = ["flax"] @@ -46,14 +60,14 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) -class FlaxUNet2DConditionModel(metaclass=DummyObject): +class FlaxScoreSdeVeScheduler(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) -class FlaxScoreSdeVeScheduler(metaclass=DummyObject): +class FlaxDiffusionPipeline(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): From d51e8816edbae173c0ece6b8f1bf9495828c5e1d Mon Sep 17 00:00:00 2001 From: Mishig Davaadorj Date: Mon, 19 Sep 2022 15:59:49 +0000 Subject: [PATCH 06/23] Fix empty init --- src/diffusers/pipeline_flax_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipeline_flax_utils.py b/src/diffusers/pipeline_flax_utils.py index d61472949a73..0fb3a43e8d2e 100644 --- a/src/diffusers/pipeline_flax_utils.py +++ b/src/diffusers/pipeline_flax_utils.py @@ -258,7 +258,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P local_files_only = kwargs.pop("local_files_only", False) use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) - inference_state_dict = kwargs.pop("inference_state_dict", None) + inference_state_dict = kwargs.pop("inference_state_dict", dict()) dtype = kwargs.pop("dtype", None) # 1. Download the checkpoints and configs From 7aab68d6b18f77881ed5a3415aaf2174dd7329cf Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 19 Sep 2022 20:44:12 +0000 Subject: [PATCH 07/23] make pipeline work --- src/diffusers/__init__.py | 6 +- src/diffusers/models/__init__.py | 2 +- src/diffusers/pipeline_flax_utils.py | 61 +++++++------- src/diffusers/pipelines/__init__.py | 8 +- .../pipelines/stable_diffusion/__init__.py | 2 +- .../pipeline_flax_stable_diffusion.py | 83 +++++++++---------- .../schedulers/scheduling_ddim_flax.py | 61 +++++--------- .../schedulers/scheduling_pndm_flax.py | 3 +- 8 files changed, 100 insertions(+), 126 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 9b71955a7107..acdddaac4d26 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -66,6 +66,7 @@ from .modeling_flax_utils import FlaxModelMixin from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel from .models.vae_flax import FlaxAutoencoderKL + from .pipeline_flax_utils import FlaxDiffusionPipeline from .schedulers import ( FlaxDDIMScheduler, FlaxDDPMScheduler, @@ -74,13 +75,10 @@ FlaxPNDMScheduler, FlaxScoreSdeVeScheduler, ) - from .pipeline_flax_utils import FlaxDiffusionPipeline else: from .utils.dummy_flax_objects import * # noqa F403 if is_flax_available() and is_transformers_available(): from .pipelines import FlaxStableDiffusionPipeline else: - pass - # TODO: dummy_flax_and_transformers_objects - # from .utils.dummy_flax_and_transformers_objects import * # noqa F403 \ No newline at end of file + from .utils.dummy_flax_and_transformers_objects import * # noqa F403 diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index a6007d15b9db..b5fe089e05f0 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -16,4 +16,4 @@ from .unet_2d_condition import UNet2DConditionModel from .unet_2d_condition_flax import FlaxUNet2DConditionModel from .vae import AutoencoderKL, VQModel -from .vae_flax import FlaxAutoencoderKL \ No newline at end of file +from .vae_flax import FlaxAutoencoderKL diff --git a/src/diffusers/pipeline_flax_utils.py b/src/diffusers/pipeline_flax_utils.py index d61472949a73..8092b29bc582 100644 --- a/src/diffusers/pipeline_flax_utils.py +++ b/src/diffusers/pipeline_flax_utils.py @@ -17,24 +17,26 @@ import importlib import inspect import os -from typing import List, Optional, Union +from typing import Dict, List, Optional, Union import numpy as np -import diffusers import flax -import jax.numpy as jnp import PIL +from flax.core.frozen_dict import FrozenDict from huggingface_hub import snapshot_download from PIL import Image from tqdm.auto import tqdm from .configuration_utils import ConfigMixin -from .modeling_flax_utils import FLAX_WEIGHTS_NAME -from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME -from .utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, logging +from .modeling_flax_utils import FLAX_WEIGHTS_NAME, FlaxModelMixin +from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME, SchedulerMixin +from .utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, is_transformers_available, logging +if is_transformers_available(): + from transformers import FlaxPreTrainedModel + INDEX_FILE = "diffusion_flax_model.bin" @@ -121,7 +123,7 @@ def register_modules(self, **kwargs): # set models setattr(self, name, module) - def save_pretrained(self, save_directory: Union[str, os.PathLike]): + def save_pretrained(self, save_directory: Union[str, os.PathLike], params: Union[Dict, FrozenDict]): # TODO: handle inference_state """ Save all variables of the pipeline that can be saved and loaded as well as the pipelines configuration file to @@ -258,7 +260,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P local_files_only = kwargs.pop("local_files_only", False) use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) - inference_state_dict = kwargs.pop("inference_state_dict", None) dtype = kwargs.pop("dtype", None) # 1. Download the checkpoints and configs @@ -312,6 +313,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P init_kwargs = {} + # inference_params + params = {} + # import it here to avoid circular import from diffusers import pipelines @@ -373,34 +377,27 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # check if the module is in a subdirectory if os.path.isdir(os.path.join(cached_folder, name)): - if issubclass(class_obj, flax.linen.Module): - loaded_sub_model, loaded_params = load_method( - os.path.join(cached_folder, name), **loading_kwargs - ) - params_key = f"{name}_params" - if params_key not in inference_state_dict: - inference_state_dict[params_key] = loaded_params - else: - loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs) + loadable_folder = os.path.join(cached_folder, name) else: - # else load from the root directory - if issubclass(class_obj, flax.linen.Module): - loaded_sub_model, loaded_params = load_method(cached_folder, **loading_kwargs) - params_key = f"{name}_params" - if params_key not in inference_state_dict: - inference_state_dict[params_key] = loaded_params - else: - loaded_sub_model = load_method(cached_folder, **loading_kwargs) + loaded_sub_model = cached_folder + + if issubclass(class_obj, FlaxModelMixin): + loaded_sub_model, loaded_params = load_method(loadable_folder, **loading_kwargs) + params[name] = loaded_params + elif is_transformers_available() and issubclass(class_obj, FlaxPreTrainedModel): + # make sure we don't initialize the weights to save time + loaded_sub_model, loaded_params = load_method(loadable_folder, _do_init=False, **loading_kwargs) + params[name] = loaded_params + elif issubclass(class_obj, SchedulerMixin): + loaded_sub_model = load_method(loadable_folder, **loading_kwargs) + params[name] = loaded_sub_model.create_state() + else: + loaded_sub_model = load_method(loadable_folder, **loading_kwargs) init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...) - # 4. Instantiate the pipeline - # TODO: fix hard-coded `StableDifusion.InferenceState`, it should be inferred as `{XYZ_Pipeline}.InferenceState` - from .pipelines.stable_diffusion import InferenceState - - inference_state = InferenceState(**inference_state_dict) - model = pipeline_class(**init_kwargs, dtype=dtype, inference_state=inference_state) - return model + model = pipeline_class(**init_kwargs, dtype=dtype) + return model, params @staticmethod def numpy_to_pil(images): diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 0db153c864ed..8e3c8592a258 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -1,4 +1,4 @@ -from ..utils import is_onnx_available, is_transformers_available, is_flax_available +from ..utils import is_flax_available, is_onnx_available, is_torch_available, is_transformers_available from .ddim import DDIMPipeline from .ddpm import DDPMPipeline from .latent_diffusion_uncond import LDMPipeline @@ -7,7 +7,7 @@ from .stochastic_karras_ve import KarrasVePipeline -if is_transformers_available(): +if is_torch_available() and is_transformers_available(): from .latent_diffusion import LDMTextToImagePipeline from .stable_diffusion import ( StableDiffusionImg2ImgPipeline, @@ -18,5 +18,5 @@ if is_transformers_available() and is_onnx_available(): from .stable_diffusion import StableDiffusionOnnxPipeline -if is_flax_available(): - from .stable_diffusion import FlaxStableDiffusionPipeline \ No newline at end of file +if is_transformers_available() and is_flax_available(): + from .stable_diffusion import FlaxStableDiffusionPipeline diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index 7c30ff0e3d20..378dd8e9a99c 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -65,5 +65,5 @@ class InferenceState: from .pipeline_stable_diffusion_onnx import StableDiffusionOnnxPipeline if is_transformers_available() and is_flax_available(): - from .safety_checker_flax import FlaxStableDiffusionSafetyChecker from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline + from .safety_checker_flax import FlaxStableDiffusionSafetyChecker diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index 896ffabbeec3..98546e206e32 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -1,9 +1,12 @@ import inspect import warnings -from typing import List, Optional, Union +from typing import Dict, List, Optional, Union + +import numpy as np import jax import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel from ...configuration_utils import FrozenDict @@ -51,13 +54,11 @@ def __init__( scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], safety_checker: FlaxStableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, - inference_state: InferenceState, dtype: jnp.dtype = jnp.float32, ): super().__init__() scheduler = scheduler.set_format("np") self.dtype = dtype - self.inference_state = inference_state if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: warnings.warn( @@ -83,17 +84,29 @@ def __init__( feature_extractor=feature_extractor, ) + def prepare_prompts(self, prompt: Union[str, List[str]]): + if not isinstance(prompt, (str, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + return text_input.input_ids + def __call__( self, - prompt: Union[str, List[str]], + prompt_ids: jnp.array, + params: Union[Dict, FrozenDict], prng_seed: jax.random.PRNGKey, + num_inference_steps: Optional[int] = 50, height: Optional[int] = 512, width: Optional[int] = 512, - num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 7.5, - eta: Optional[float] = 0.0, latents: Optional[jnp.array] = None, - output_type: Optional[str] = "pil", return_dict: bool = True, debug: bool = False, **kwargs, @@ -117,9 +130,6 @@ def __call__( Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. @@ -141,40 +151,26 @@ def __call__( element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ - if isinstance(prompt, str): - batch_size = 1 - elif isinstance(prompt, list): - batch_size = len(prompt) - else: - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - inference_state = self.inference_state - # get prompt text embeddings - text_input = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="np", - ) - text_embeddings = self.text_encoder(text_input.input_ids, params=inference_state.text_encoder_params)[0] + text_embeddings = self.text_encoder(prompt_ids, params=params["text_encoder"])[0] # TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0` # implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0` - max_length = text_input.input_ids.shape[-1] + batch_size = prompt_ids.shape[0] + + max_length = prompt_ids.shape[-1] uncond_input = self.tokenizer( [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" ) - uncond_embeddings = self.text_encoder(uncond_input.input_ids, params=inference_state.text_encoder_params)[0] + uncond_embeddings = self.text_encoder(uncond_input.input_ids, params=params["text_encoder"])[0] context = jnp.concatenate([uncond_embeddings, text_embeddings]) # TODO: check it because the shape is different from Pytorhc StableDiffusionPipeline latents_shape = ( - text_input.input_ids.shape[0], + batch_size, self.unet.sample_size, self.unet.sample_size, self.unet.in_channels, @@ -197,7 +193,7 @@ def loop_body(step, args): # predict the noise residual noise_pred = self.unet.apply( - {"params": inference_state.unet_params}, + {"params": params["unet"]}, jnp.array(latents_input), jnp.array(timestep, dtype=jnp.int32), encoder_hidden_states=context, @@ -208,12 +204,11 @@ def loop_body(step, args): noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents) - latents = latents["prev_sample"] + latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() return latents, scheduler_state - scheduler_state = inference_state.scheduler_state - num_inference_steps = len(scheduler_state.timesteps) + scheduler_state = self.scheduler.set_timesteps(params["scheduler"], num_inference_steps=num_inference_steps) + if debug: # run with python for loop for i in range(num_inference_steps): @@ -224,20 +219,18 @@ def loop_body(step, args): # scale and decode the image latents with vae latents = 1 / 0.18215 * latents # TODO: check when flax vae gets merged into main - image = self.vae.decode(latents, params=inference_state.vae_params).sample + image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample - image = (image / 2 + 0.5).clamp(0, 1) - image = image.cpu().permute(0, 2, 3, 1).numpy() + image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) + # image = jnp.asarray(image).transpose(0, 2, 3, 1) # run safety checker # TODO: check when flax safety checker gets merged into main - safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np") - image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_cheker_input.pixel_values, params=inference_state.safety_params - ) - - if output_type == "pil": - image = self.numpy_to_pil(image) + # safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np") + # image, has_nsfw_concept = self.safety_checker( + # images=image, clip_input=safety_cheker_input.pixel_values, params=params["safety_params"] + # ) + has_nsfw_concept = False if not return_dict: return (image, has_nsfw_concept) diff --git a/src/diffusers/schedulers/scheduling_ddim_flax.py b/src/diffusers/schedulers/scheduling_ddim_flax.py index 015b79b2780d..dd5a87df654a 100644 --- a/src/diffusers/schedulers/scheduling_ddim_flax.py +++ b/src/diffusers/schedulers/scheduling_ddim_flax.py @@ -21,7 +21,6 @@ import flax import jax.numpy as jnp -from jax import random from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils import SchedulerMixin, SchedulerOutput @@ -60,11 +59,12 @@ def alpha_bar(time_step): class DDIMSchedulerState: # setable values timesteps: jnp.ndarray + alphas_cumprod: jnp.ndarray num_inference_steps: Optional[int] = None @classmethod - def create(cls, num_train_timesteps: int): - return cls(timesteps=jnp.arange(0, num_train_timesteps)[::-1]) + def create(cls, num_train_timesteps: int, alphas_cumprod: jnp.ndarray): + return cls(timesteps=jnp.arange(0, num_train_timesteps)[::-1], alphas_cumprod=alphas_cumprod) @dataclass @@ -112,13 +112,9 @@ def __init__( beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", - trained_betas: Optional[jnp.ndarray] = None, - clip_sample: bool = True, set_alpha_to_one: bool = True, steps_offset: int = 0, ): - if trained_betas is not None: - self.betas = jnp.asarray(trained_betas) if beta_schedule == "linear": self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32) elif beta_schedule == "scaled_linear": @@ -131,19 +127,24 @@ def __init__( raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas - self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0) + + # HACK for now - clean up later (PVP) + self._alphas_cumprod = jnp.cumprod(self.alphas, axis=0) # At every step in ddim, we are looking into the previous alphas_cumprod # For the final step, there is no previous alphas_cumprod because we are already at 0 # `set_alpha_to_one` decides whether we set this parameter simply to one or # whether we use the final alpha of the "non-previous" one. - self.final_alpha_cumprod = jnp.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + self.final_alpha_cumprod = jnp.array(1.0) if set_alpha_to_one else float(self._alphas_cumprod[0]) - self.state = DDIMSchedulerState.create(num_train_timesteps=num_train_timesteps) + def create_state(self): + return DDIMSchedulerState.create( + num_train_timesteps=self.config.num_train_timesteps, alphas_cumprod=self._alphas_cumprod + ) - def _get_variance(self, timestep, prev_timestep): - alpha_prod_t = self.alphas_cumprod[timestep] - alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + def _get_variance(self, timestep, prev_timestep, alphas_cumprod): + alpha_prod_t = alphas_cumprod[timestep] + alpha_prod_t_prev = jnp.where(prev_timestep >= 0, alphas_cumprod[prev_timestep], self.final_alpha_cumprod) beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev @@ -177,9 +178,6 @@ def step( model_output: jnp.ndarray, timestep: int, sample: jnp.ndarray, - key: random.KeyArray, - eta: float = 0.0, - use_clipped_model_output: bool = False, return_dict: bool = True, ) -> Union[FlaxSchedulerOutput, Tuple]: """ @@ -221,41 +219,28 @@ def step( # 1. get previous step value (=t-1) prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps + alphas_cumprod = state.alphas_cumprod + # 2. compute alphas, betas - alpha_prod_t = self.alphas_cumprod[timestep] - alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + alpha_prod_t = alphas_cumprod[timestep] + alpha_prod_t_prev = jnp.where(prev_timestep >= 0, alphas_cumprod[prev_timestep], self.final_alpha_cumprod) beta_prod_t = 1 - alpha_prod_t # 3. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) - # 4. Clip "predicted x_0" - if self.config.clip_sample: - pred_original_sample = jnp.clip(pred_original_sample, -1, 1) - - # 5. compute variance: "sigma_t(η)" -> see formula (16) + # 4. compute variance: "sigma_t(η)" -> see formula (16) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) - variance = self._get_variance(timestep, prev_timestep) - std_dev_t = eta * variance ** (0.5) + variance = self._get_variance(timestep, prev_timestep, alphas_cumprod) + std_dev_t = variance ** (0.5) - if use_clipped_model_output: - # the model_output is always re-derived from the clipped x_0 in Glide - model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) - - # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + # 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output - # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + # 6. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction - if eta > 0: - key = random.split(key, num=1) - noise = random.normal(key=key, shape=model_output.shape) - variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise - - prev_sample = prev_sample + variance - if not return_dict: return (prev_sample, state) diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py index efc3858ca75a..4c8c43810b6f 100644 --- a/src/diffusers/schedulers/scheduling_pndm_flax.py +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -148,7 +148,8 @@ def __init__( # mainly at formula (9), (12), (13) and the Algorithm 2. self.pndm_order = 4 - self.state = PNDMSchedulerState.create(num_train_timesteps=num_train_timesteps) + def create_state(self): + return PNDMSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps) def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int) -> PNDMSchedulerState: """ From 47d77393b7d06f5ce8f5cb3e8ce97a43e82cb8e6 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 19 Sep 2022 20:46:44 +0000 Subject: [PATCH 08/23] up --- .../stable_diffusion/pipeline_flax_stable_diffusion.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index 98546e206e32..c85f2c7faf04 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -1,19 +1,15 @@ -import inspect import warnings from typing import Dict, List, Optional, Union -import numpy as np - import jax import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel -from ...configuration_utils import FrozenDict from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel from ...pipeline_flax_utils import FlaxDiffusionPipeline from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler -from . import FlaxStableDiffusionPipelineOutput, InferenceState +from . import FlaxStableDiffusionPipelineOutput from .safety_checker_flax import FlaxStableDiffusionSafetyChecker From 4dfcf213dd156cfdb9589a4daf118a4f264da47d Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 20 Sep 2022 07:17:14 +0000 Subject: [PATCH 09/23] Allow dtype to be overridden on model load. This may be a temporary solution until #567 is addressed. --- src/diffusers/configuration_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index f5e5d36ffdc7..5792e17e8120 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -154,9 +154,12 @@ def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], ret """ config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs) - init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs) + # Allow dtype to be specified on initialization + if "dtype" in unused_kwargs: + init_dict["dtype"] = unused_kwargs.pop("dtype") + model = cls(**init_dict) if return_unused_kwargs: From d480534d0a185ff91736929a20e2a8a54e06090a Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 20 Sep 2022 07:19:03 +0000 Subject: [PATCH 10/23] Convert params to bfloat16 or fp16 after loading. This deals with the weights, not the model. --- src/diffusers/modeling_flax_utils.py | 5 +++++ .../stable_diffusion/pipeline_flax_stable_diffusion.py | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/diffusers/modeling_flax_utils.py b/src/diffusers/modeling_flax_utils.py index 505c2881cbab..90e9b73a3f87 100644 --- a/src/diffusers/modeling_flax_utils.py +++ b/src/diffusers/modeling_flax_utils.py @@ -483,6 +483,11 @@ def from_pretrained( "See [`~ModelMixin.to_fp32`] for further information on how to do this." ) + if dtype == jnp.bfloat16: + state = model.to_bf16(state) + if dtype == jnp.float16: + state = model.to_fp16(state) + return model, unflatten_dict(state) def save_pretrained( diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index c85f2c7faf04..62f9b23a6d2e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -172,7 +172,7 @@ def __call__( self.unet.in_channels, ) if latents is None: - latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=self.dtype) + latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32) else: if latents.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") From 0c2a868ec4e939a0fca3f590c39280a6eff1a686 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 20 Sep 2022 07:43:26 +0000 Subject: [PATCH 11/23] Use Flax schedulers (typing, docstring) --- .../stable_diffusion/pipeline_flax_stable_diffusion.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index c85f2c7faf04..92c135f872e3 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -8,7 +8,7 @@ from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel from ...pipeline_flax_utils import FlaxDiffusionPipeline -from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...schedulers import FlaxDDIMScheduler, FlaxLMSDiscreteScheduler, FlaxPNDMScheduler from . import FlaxStableDiffusionPipelineOutput from .safety_checker_flax import FlaxStableDiffusionSafetyChecker @@ -31,9 +31,9 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - scheduler ([`SchedulerMixin`]): + scheduler ([`FlaxSchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], or [`FlaxPNDMScheduler`]. safety_checker ([`FlaxStableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offsensive or harmful. Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. @@ -47,7 +47,7 @@ def __init__( text_encoder: FlaxCLIPTextModel, tokenizer: CLIPTokenizer, unet: FlaxUNet2DConditionModel, - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + scheduler: Union[FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler], safety_checker: FlaxStableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, dtype: jnp.dtype = jnp.float32, From aa3c010f46352975628a14a3092ba8207ceaa7f4 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 19 Sep 2022 17:29:53 +0000 Subject: [PATCH 12/23] PNDM: replace control flow with jax functions. Otherwise jitting/parallelization don't work properly as they don't know how to deal with traced objects. I temporarily removed `step_prk`. --- .../schedulers/scheduling_pndm_flax.py | 247 +++++++++++------- 1 file changed, 158 insertions(+), 89 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py index 4c8c43810b6f..ebaa90f4bac5 100644 --- a/src/diffusers/schedulers/scheduling_pndm_flax.py +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -19,6 +19,7 @@ from typing import Optional, Tuple, Union import flax +import jax import jax.numpy as jnp from ..configuration_utils import ConfigMixin, register_to_config @@ -151,7 +152,12 @@ def __init__( def create_state(self): return PNDMSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps) - def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int) -> PNDMSchedulerState: + def set_timesteps( + self, + state: PNDMSchedulerState, + shape: Tuple, + num_inference_steps: int + ) -> PNDMSchedulerState: """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -192,8 +198,11 @@ def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int) -> return state.replace( timesteps=jnp.concatenate([state.prk_timesteps, state.plms_timesteps]).astype(jnp.int64), - ets=jnp.array([]), counter=0, + # Will be zeros, not really empty + cur_model_output = jnp.empty(shape), + cur_sample = jnp.empty(shape), + ets = jnp.empty((4,) + shape), ) def step( @@ -223,73 +232,77 @@ def step( When returning a tuple, the first element is the sample tensor. """ - if state.counter < len(state.prk_timesteps) and not self.config.skip_prk_steps: - return self.step_prk( - state=state, model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict - ) - else: - return self.step_plms( - state=state, model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict - ) - - def step_prk( - self, - state: PNDMSchedulerState, - model_output: jnp.ndarray, - timestep: int, - sample: jnp.ndarray, - return_dict: bool = True, - ) -> Union[FlaxSchedulerOutput, Tuple]: - """ - Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the - solution to the differential equation. - - Args: - state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance. - model_output (`jnp.ndarray`): direct output from learned diffusion model. - timestep (`int`): current discrete timestep in the diffusion chain. - sample (`jnp.ndarray`): - current instance of sample being created by diffusion process. - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class - - Returns: - [`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. - When returning a tuple, the first element is the sample tensor. - - """ - if state.num_inference_steps is None: - raise ValueError( - "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" - ) - - diff_to_prev = 0 if state.counter % 2 else self.config.num_train_timesteps // state.num_inference_steps // 2 - prev_timestep = timestep - diff_to_prev - timestep = state.prk_timesteps[state.counter // 4 * 4] - - if state.counter % 4 == 0: - state = state.replace( - cur_model_output=state.cur_model_output + 1 / 6 * model_output, - ets=state.ets.append(model_output), - cur_sample=sample, - ) - elif (self.counter - 1) % 4 == 0: - state = state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output) - elif (self.counter - 2) % 4 == 0: - state = state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output) - elif (self.counter - 3) % 4 == 0: - model_output = state.cur_model_output + 1 / 6 * model_output - state = state.replace(cur_model_output=0) - - # cur_sample should not be `None` - cur_sample = state.cur_sample if state.cur_sample is not None else sample - - prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output) - state = state.replace(counter=state.counter + 1) - - if not return_dict: - return (prev_sample, state) + return self.step_plms( + state=state, model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict + ) - return FlaxSchedulerOutput(prev_sample=prev_sample, state=state) + # if state.counter < len(state.prk_timesteps) and not self.config.skip_prk_steps: + # return self.step_prk( + # state=state, model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict + # ) + # else: + # return self.step_plms( + # state=state, model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict + # ) + + # def step_prk( + # self, + # state: PNDMSchedulerState, + # model_output: jnp.ndarray, + # timestep: int, + # sample: jnp.ndarray, + # return_dict: bool = True, + # ) -> Union[FlaxSchedulerOutput, Tuple]: + # """ + # Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the + # solution to the differential equation. + + # Args: + # state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance. + # model_output (`jnp.ndarray`): direct output from learned diffusion model. + # timestep (`int`): current discrete timestep in the diffusion chain. + # sample (`jnp.ndarray`): + # current instance of sample being created by diffusion process. + # return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + # Returns: + # [`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. + # When returning a tuple, the first element is the sample tensor. + + # """ + # if state.num_inference_steps is None: + # raise ValueError( + # "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + # ) + + # diff_to_prev = 0 if state.counter % 2 else self.config.num_train_timesteps // state.num_inference_steps // 2 + # prev_timestep = timestep - diff_to_prev + # timestep = state.prk_timesteps[state.counter // 4 * 4] + + # if state.counter % 4 == 0: + # state = state.replace( + # cur_model_output=state.cur_model_output + 1 / 6 * model_output, + # ets=state.ets.append(model_output), + # cur_sample=sample, + # ) + # elif (self.counter - 1) % 4 == 0: + # state = state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output) + # elif (self.counter - 2) % 4 == 0: + # state = state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output) + # elif (self.counter - 3) % 4 == 0: + # model_output = state.cur_model_output + 1 / 6 * model_output + # state = state.replace(cur_model_output=0) + + # # cur_sample should not be `None` + # cur_sample = state.cur_sample if state.cur_sample is not None else sample + + # prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output) + # state = state.replace(counter=state.counter + 1) + + # if not return_dict: + # return (prev_sample, state) + + # return FlaxSchedulerOutput(prev_sample=prev_sample, state=state) def step_plms( self, @@ -330,29 +343,85 @@ def step_plms( ) prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps + prev_timestep = jnp.where(prev_timestep > 0, prev_timestep, 0) + + # Reference: + # if state.counter != 1: + # state.ets.append(model_output) + # else: + # prev_timestep = timestep + # timestep = timestep + self.config.num_train_timesteps // state.num_inference_steps + + prev_timestep = jnp.where(state.counter == 1, timestep, prev_timestep) + timestep = jnp.where(state.counter == 1, timestep + self.config.num_train_timesteps // state.num_inference_steps, timestep) + + # Reference: + # if len(state.ets) == 1 and state.counter == 0: + # model_output = model_output + # state.cur_sample = sample + # elif len(state.ets) == 1 and state.counter == 1: + # model_output = (model_output + state.ets[-1]) / 2 + # sample = state.cur_sample + # state.cur_sample = None + # elif len(state.ets) == 2: + # model_output = (3 * state.ets[-1] - state.ets[-2]) / 2 + # elif len(state.ets) == 3: + # model_output = (23 * state.ets[-1] - 16 * state.ets[-2] + 5 * state.ets[-3]) / 12 + # else: + # model_output = (1 / 24) * (55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4]) + + def counter_0(state: PNDMSchedulerState): + ets = state.ets.at[0].set(model_output) + return state.replace( + ets = ets, + cur_sample = sample, + cur_model_output = jnp.array(model_output, dtype=jnp.float32), + ) - if state.counter != 1: - state = state.replace(ets=state.ets.append(model_output)) - else: - prev_timestep = timestep - timestep = timestep + self.config.num_train_timesteps // state.num_inference_steps - - if len(state.ets) == 1 and state.counter == 0: - model_output = model_output - state = state.replace(cur_sample=sample) - elif len(state.ets) == 1 and state.counter == 1: - model_output = (model_output + state.ets[-1]) / 2 - sample = state.cur_sample - state = state.replace(cur_sample=None) - elif len(state.ets) == 2: - model_output = (3 * state.ets[-1] - state.ets[-2]) / 2 - elif len(state.ets) == 3: - model_output = (23 * state.ets[-1] - 16 * state.ets[-2] + 5 * state.ets[-3]) / 12 - else: - model_output = (1 / 24) * ( - 55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4] + def counter_1(state: PNDMSchedulerState): + return state.replace( + cur_model_output = (model_output + state.ets[0]) / 2, ) + def counter_2(state: PNDMSchedulerState): + ets = state.ets.at[1].set(model_output) + return state.replace( + ets = ets, + cur_model_output = (3 * ets[1] - ets[0]) / 2, + cur_sample = sample, + ) + + def counter_3(state: PNDMSchedulerState): + ets = state.ets.at[2].set(model_output) + return state.replace( + ets = ets, + cur_model_output = (23 * ets[2] - 16 * ets[1] + 5 * ets[0]) / 12, + cur_sample = sample, + ) + + def counter_other(state: PNDMSchedulerState): + ets = state.ets.at[3].set(model_output) + next_model_output = (1 / 24) * (55 * ets[3] - 59 * ets[2] + 37 * ets[1] - 9 * ets[0]) + + ets = ets.at[0].set(ets[1]) + ets = ets.at[1].set(ets[2]) + ets = ets.at[2].set(ets[3]) + + return state.replace( + ets = ets, + cur_model_output = next_model_output, + cur_sample = sample, + ) + + counter = jnp.clip(state.counter, 0, 4) + state = jax.lax.switch( + counter, + [counter_0, counter_1, counter_2, counter_3, counter_other], + state, + ) + + sample = state.cur_sample + model_output = state.cur_model_output prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output) state = state.replace(counter=state.counter + 1) @@ -375,7 +444,7 @@ def _get_prev_sample(self, sample, timestep, prev_timestep, model_output): # model_output -> e_θ(x_t, t) # prev_sample -> x_(t−δ) alpha_prod_t = self.alphas_cumprod[timestep] - alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + alpha_prod_t_prev = jnp.where(prev_timestep >= 0, self.alphas_cumprod[prev_timestep], self.final_alpha_cumprod) beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev From d6dbb89e8cd87095b5b7b1b4ced0977fbd58d3f6 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 20 Sep 2022 09:07:53 +0000 Subject: [PATCH 13/23] Pass latents shape to scheduler set_timesteps() PNDMScheduler uses it to reserve space, other schedulers will just ignore it. --- .../stable_diffusion/pipeline_flax_stable_diffusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index 19171817ace6..ae6295d40dbe 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -203,7 +203,7 @@ def loop_body(step, args): latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() return latents, scheduler_state - scheduler_state = self.scheduler.set_timesteps(params["scheduler"], num_inference_steps=num_inference_steps) + scheduler_state = self.scheduler.set_timesteps(params["scheduler"], num_inference_steps=num_inference_steps, shape=latents.shape) if debug: # run with python for loop From 69b1d7accd3dcc93e3d32b13dbeecde9fca212eb Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 20 Sep 2022 09:14:45 +0000 Subject: [PATCH 14/23] Wrap model imports inside availability checks. --- src/diffusers/models/__init__.py | 15 ++++++++++----- src/diffusers/utils/dummy_flax_objects.py | 14 +++++++------- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index b5fe089e05f0..d58e4d77ff73 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -12,8 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .unet_2d import UNet2DModel -from .unet_2d_condition import UNet2DConditionModel -from .unet_2d_condition_flax import FlaxUNet2DConditionModel -from .vae import AutoencoderKL, VQModel -from .vae_flax import FlaxAutoencoderKL +from ..utils import is_torch_available, is_flax_available + +if is_torch_available(): + from .unet_2d import UNet2DModel + from .unet_2d_condition import UNet2DConditionModel + from .vae import AutoencoderKL, VQModel + +if is_flax_available(): + from .unet_2d_condition_flax import FlaxUNet2DConditionModel + from .vae_flax import FlaxAutoencoderKL diff --git a/src/diffusers/utils/dummy_flax_objects.py b/src/diffusers/utils/dummy_flax_objects.py index 424e4f3bf6d8..1e3ac002a609 100644 --- a/src/diffusers/utils/dummy_flax_objects.py +++ b/src/diffusers/utils/dummy_flax_objects.py @@ -25,49 +25,49 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) -class FlaxDDIMScheduler(metaclass=DummyObject): +class FlaxDiffusionPipeline(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) -class FlaxDDPMScheduler(metaclass=DummyObject): +class FlaxDDIMScheduler(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) -class FlaxKarrasVeScheduler(metaclass=DummyObject): +class FlaxDDPMScheduler(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) -class FlaxLMSDiscreteScheduler(metaclass=DummyObject): +class FlaxKarrasVeScheduler(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) -class FlaxPNDMScheduler(metaclass=DummyObject): +class FlaxLMSDiscreteScheduler(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) -class FlaxScoreSdeVeScheduler(metaclass=DummyObject): +class FlaxPNDMScheduler(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) -class FlaxDiffusionPipeline(metaclass=DummyObject): +class FlaxScoreSdeVeScheduler(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): From 23f7d7385f5d8401feb697df82cb388b8ad4d91b Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 20 Sep 2022 09:56:21 +0000 Subject: [PATCH 15/23] Optionally return state in from_config. Useful for Flax schedulers. --- src/diffusers/configuration_utils.py | 11 +++++++++-- src/diffusers/schedulers/scheduling_ddim_flax.py | 1 + src/diffusers/schedulers/scheduling_pndm_flax.py | 3 ++- 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 5792e17e8120..e64bcb965114 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -160,12 +160,19 @@ def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], ret if "dtype" in unused_kwargs: init_dict["dtype"] = unused_kwargs.pop("dtype") + # Return model and optionally state and/or unused_kwargs model = cls(**init_dict) + return_tuple = (model,) + + # Some components (Flax schedulers) have a state. + if getattr(cls, "has_state", False): # Check for "create_state" in model instead? + state = model.create_state() + return_tuple += (state,) if return_unused_kwargs: - return model, unused_kwargs + return return_tuple + (unused_kwargs,) else: - return model + return return_tuple if len(return_tuple) > 1 else model @classmethod def get_config_dict( diff --git a/src/diffusers/schedulers/scheduling_ddim_flax.py b/src/diffusers/schedulers/scheduling_ddim_flax.py index dd5a87df654a..c9c14ad07f7a 100644 --- a/src/diffusers/schedulers/scheduling_ddim_flax.py +++ b/src/diffusers/schedulers/scheduling_ddim_flax.py @@ -104,6 +104,7 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin): `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. """ + has_state = True @register_to_config def __init__( diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py index ebaa90f4bac5..49d8f27f358c 100644 --- a/src/diffusers/schedulers/scheduling_pndm_flax.py +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -113,7 +113,8 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. """ - + has_state = True + @register_to_config def __init__( self, From 163df381baa6a5c70f91513f9aa7db73969fc968 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 20 Sep 2022 10:34:46 +0000 Subject: [PATCH 16/23] Do not convert model weights to dtype. --- src/diffusers/modeling_flax_utils.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/diffusers/modeling_flax_utils.py b/src/diffusers/modeling_flax_utils.py index 90e9b73a3f87..505c2881cbab 100644 --- a/src/diffusers/modeling_flax_utils.py +++ b/src/diffusers/modeling_flax_utils.py @@ -483,11 +483,6 @@ def from_pretrained( "See [`~ModelMixin.to_fp32`] for further information on how to do this." ) - if dtype == jnp.bfloat16: - state = model.to_bf16(state) - if dtype == jnp.float16: - state = model.to_fp16(state) - return model, unflatten_dict(state) def save_pretrained( From 8bc06b0098688381a0b073fbf6b741ccfa6950ef Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 20 Sep 2022 15:35:54 +0000 Subject: [PATCH 17/23] Re-enable PRK steps with functional implementation. Values returned still not verified for correctness. --- .../schedulers/scheduling_pndm_flax.py | 170 ++++++++++-------- 1 file changed, 98 insertions(+), 72 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py index 49d8f27f358c..ccdbb0ae5fc0 100644 --- a/src/diffusers/schedulers/scheduling_pndm_flax.py +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -233,77 +233,107 @@ def step( When returning a tuple, the first element is the sample tensor. """ - return self.step_plms( - state=state, model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict + if self.config.skip_prk_steps: + prev_sample, state = self.step_plms( + state=state, model_output=model_output, timestep=timestep, sample=sample + ) + else: + prev_sample, state = jax.lax.switch( + jnp.where(state.counter < len(state.prk_timesteps), 0, 1), + [self.step_prk, self.step_plms], + state, model_output, timestep, sample + ) + + if not return_dict: + return (prev_sample, state) + + return FlaxSchedulerOutput(prev_sample=prev_sample, state=state) + + + def step_prk( + self, + state: PNDMSchedulerState, + model_output: jnp.ndarray, + timestep: int, + sample: jnp.ndarray, + ) -> Union[FlaxSchedulerOutput, Tuple]: + """ + Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the + solution to the differential equation. + + Args: + state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance. + model_output (`jnp.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`jnp.ndarray`): + current instance of sample being created by diffusion process. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + [`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. + When returning a tuple, the first element is the sample tensor. + + """ + if state.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + diff_to_prev = jnp.where(state.counter % 2, 0, self.config.num_train_timesteps // state.num_inference_steps // 2) + prev_timestep = timestep - diff_to_prev + timestep = state.prk_timesteps[state.counter // 4 * 4] + + def remainder_0(state: PNDMSchedulerState, model_output: jnp.ndarray, ets_at: int): + return state.replace( + cur_model_output = state.cur_model_output + 1 / 6 * model_output, + ets = state.ets.at[ets_at].set(model_output), + cur_sample = sample, + ), model_output + + def remainder_1(state: PNDMSchedulerState, model_output: jnp.ndarray, ets_at: int): + return state.replace( + cur_model_output = state.cur_model_output + 1 / 3 * model_output + ), model_output + + def remainder_2(state: PNDMSchedulerState, model_output: jnp.ndarray, ets_at: int): + return state.replace( + cur_model_output = state.cur_model_output + 1 / 3 * model_output + ), model_output + + def remainder_3(state: PNDMSchedulerState, model_output: jnp.ndarray, ets_at: int): + model_output = state.cur_model_output + 1 / 6 * model_output + return state.replace( + cur_model_output = jnp.zeros_like(state.cur_model_output) + ), model_output + + state, model_output = jax.lax.switch( + state.counter % 4, + [remainder_0, remainder_1, remainder_2, remainder_3], + state, model_output, state.counter // 4 ) - # if state.counter < len(state.prk_timesteps) and not self.config.skip_prk_steps: - # return self.step_prk( - # state=state, model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict - # ) - # else: - # return self.step_plms( - # state=state, model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict + # if state.counter % 4 == 0: + # state = state.replace( + # cur_model_output=state.cur_model_output + 1 / 6 * model_output, + # ets=state.ets.append(model_output), + # cur_sample=sample, # ) + # elif (self.counter - 1) % 4 == 0: + # state = state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output) + # elif (self.counter - 2) % 4 == 0: + # state = state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output) + # elif (self.counter - 3) % 4 == 0: + # model_output = state.cur_model_output + 1 / 6 * model_output + # state = state.replace(cur_model_output=0) + + # cur_sample should not be `None` + # cur_sample = state.cur_sample if state.cur_sample is not None else sample + cur_sample = state.cur_sample + + prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output) + state = state.replace(counter=state.counter + 1) - # def step_prk( - # self, - # state: PNDMSchedulerState, - # model_output: jnp.ndarray, - # timestep: int, - # sample: jnp.ndarray, - # return_dict: bool = True, - # ) -> Union[FlaxSchedulerOutput, Tuple]: - # """ - # Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the - # solution to the differential equation. - - # Args: - # state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance. - # model_output (`jnp.ndarray`): direct output from learned diffusion model. - # timestep (`int`): current discrete timestep in the diffusion chain. - # sample (`jnp.ndarray`): - # current instance of sample being created by diffusion process. - # return_dict (`bool`): option for returning tuple rather than SchedulerOutput class - - # Returns: - # [`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. - # When returning a tuple, the first element is the sample tensor. - - # """ - # if state.num_inference_steps is None: - # raise ValueError( - # "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" - # ) - - # diff_to_prev = 0 if state.counter % 2 else self.config.num_train_timesteps // state.num_inference_steps // 2 - # prev_timestep = timestep - diff_to_prev - # timestep = state.prk_timesteps[state.counter // 4 * 4] - - # if state.counter % 4 == 0: - # state = state.replace( - # cur_model_output=state.cur_model_output + 1 / 6 * model_output, - # ets=state.ets.append(model_output), - # cur_sample=sample, - # ) - # elif (self.counter - 1) % 4 == 0: - # state = state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output) - # elif (self.counter - 2) % 4 == 0: - # state = state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output) - # elif (self.counter - 3) % 4 == 0: - # model_output = state.cur_model_output + 1 / 6 * model_output - # state = state.replace(cur_model_output=0) - - # # cur_sample should not be `None` - # cur_sample = state.cur_sample if state.cur_sample is not None else sample - - # prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output) - # state = state.replace(counter=state.counter + 1) - - # if not return_dict: - # return (prev_sample, state) - - # return FlaxSchedulerOutput(prev_sample=prev_sample, state=state) + return (prev_sample, state) def step_plms( self, @@ -311,7 +341,6 @@ def step_plms( model_output: jnp.ndarray, timestep: int, sample: jnp.ndarray, - return_dict: bool = True, ) -> Union[FlaxSchedulerOutput, Tuple]: """ Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple @@ -426,10 +455,7 @@ def counter_other(state: PNDMSchedulerState): prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output) state = state.replace(counter=state.counter + 1) - if not return_dict: - return (prev_sample, state) - - return FlaxSchedulerOutput(prev_sample=prev_sample, state=state) + return (prev_sample, state) def _get_prev_sample(self, sample, timestep, prev_timestep, model_output): # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf From 8a9ccf2bbe0cb64b33c36d1c88830adaba49cb30 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Wed, 21 Sep 2022 20:57:42 +0000 Subject: [PATCH 18/23] Remove left over has_state var. --- src/diffusers/schedulers/scheduling_ddim_flax.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_ddim_flax.py b/src/diffusers/schedulers/scheduling_ddim_flax.py index 46f8da34e81b..d81d66607147 100644 --- a/src/diffusers/schedulers/scheduling_ddim_flax.py +++ b/src/diffusers/schedulers/scheduling_ddim_flax.py @@ -104,7 +104,6 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin): `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. """ - has_state = True @property def has_state(self): From cf6cd7aafaef09e989b7a4c7f334e29104448a74 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Wed, 21 Sep 2022 21:04:11 +0000 Subject: [PATCH 19/23] make style --- .../pipelines/stable_diffusion/__init__.py | 2 +- .../pipeline_flax_stable_diffusion.py | 4 +- .../schedulers/scheduling_pndm_flax.py | 86 ++++++++++--------- 3 files changed, 48 insertions(+), 44 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index 03599d56b19a..e3b8e2f0f30c 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -56,6 +56,6 @@ class FlaxStableDiffusionPipelineOutput(BaseOutput): images: Union[List[PIL.Image.Image], np.ndarray] nsfw_content_detected: List[bool] + from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline from .safety_checker_flax import FlaxStableDiffusionSafetyChecker - from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index 47e1e05a68d7..974d77547e56 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -187,7 +187,9 @@ def loop_body(step, args): latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() return latents, scheduler_state - scheduler_state = self.scheduler.set_timesteps(params["scheduler"], num_inference_steps=num_inference_steps, shape=latents.shape) + scheduler_state = self.scheduler.set_timesteps( + params["scheduler"], num_inference_steps=num_inference_steps, shape=latents.shape + ) if debug: # run with python for loop diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py index b454fb0c5069..03e005140030 100644 --- a/src/diffusers/schedulers/scheduling_pndm_flax.py +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -156,12 +156,7 @@ def __init__( def create_state(self): return PNDMSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps) - def set_timesteps( - self, - state: PNDMSchedulerState, - shape: Tuple, - num_inference_steps: int - ) -> PNDMSchedulerState: + def set_timesteps(self, state: PNDMSchedulerState, shape: Tuple, num_inference_steps: int) -> PNDMSchedulerState: """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -204,9 +199,9 @@ def set_timesteps( timesteps=jnp.concatenate([state.prk_timesteps, state.plms_timesteps]).astype(jnp.int64), counter=0, # Will be zeros, not really empty - cur_model_output = jnp.empty(shape), - cur_sample = jnp.empty(shape), - ets = jnp.empty((4,) + shape), + cur_model_output=jnp.empty(shape), + cur_sample=jnp.empty(shape), + ets=jnp.empty((4,) + shape), ) def step( @@ -244,7 +239,11 @@ def step( prev_sample, state = jax.lax.switch( jnp.where(state.counter < len(state.prk_timesteps), 0, 1), [self.step_prk, self.step_plms], - state, model_output, timestep, sample + # Args to either branch + state, + model_output, + timestep, + sample, ) if not return_dict: @@ -252,7 +251,6 @@ def step( return FlaxSchedulerOutput(prev_sample=prev_sample, state=state) - def step_prk( self, state: PNDMSchedulerState, @@ -282,37 +280,39 @@ def step_prk( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) - diff_to_prev = jnp.where(state.counter % 2, 0, self.config.num_train_timesteps // state.num_inference_steps // 2) + diff_to_prev = jnp.where( + state.counter % 2, 0, self.config.num_train_timesteps // state.num_inference_steps // 2 + ) prev_timestep = timestep - diff_to_prev timestep = state.prk_timesteps[state.counter // 4 * 4] def remainder_0(state: PNDMSchedulerState, model_output: jnp.ndarray, ets_at: int): - return state.replace( - cur_model_output = state.cur_model_output + 1 / 6 * model_output, - ets = state.ets.at[ets_at].set(model_output), - cur_sample = sample, - ), model_output + return ( + state.replace( + cur_model_output=state.cur_model_output + 1 / 6 * model_output, + ets=state.ets.at[ets_at].set(model_output), + cur_sample=sample, + ), + model_output, + ) def remainder_1(state: PNDMSchedulerState, model_output: jnp.ndarray, ets_at: int): - return state.replace( - cur_model_output = state.cur_model_output + 1 / 3 * model_output - ), model_output + return state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output), model_output def remainder_2(state: PNDMSchedulerState, model_output: jnp.ndarray, ets_at: int): - return state.replace( - cur_model_output = state.cur_model_output + 1 / 3 * model_output - ), model_output + return state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output), model_output def remainder_3(state: PNDMSchedulerState, model_output: jnp.ndarray, ets_at: int): model_output = state.cur_model_output + 1 / 6 * model_output - return state.replace( - cur_model_output = jnp.zeros_like(state.cur_model_output) - ), model_output + return state.replace(cur_model_output=jnp.zeros_like(state.cur_model_output)), model_output state, model_output = jax.lax.switch( state.counter % 4, [remainder_0, remainder_1, remainder_2, remainder_3], - state, model_output, state.counter // 4 + # Args to either branch + state, + model_output, + state.counter // 4, ) # if state.counter % 4 == 0: @@ -386,7 +386,9 @@ def step_plms( # timestep = timestep + self.config.num_train_timesteps // state.num_inference_steps prev_timestep = jnp.where(state.counter == 1, timestep, prev_timestep) - timestep = jnp.where(state.counter == 1, timestep + self.config.num_train_timesteps // state.num_inference_steps, timestep) + timestep = jnp.where( + state.counter == 1, timestep + self.config.num_train_timesteps // state.num_inference_steps, timestep + ) # Reference: # if len(state.ets) == 1 and state.counter == 0: @@ -406,31 +408,31 @@ def step_plms( def counter_0(state: PNDMSchedulerState): ets = state.ets.at[0].set(model_output) return state.replace( - ets = ets, - cur_sample = sample, - cur_model_output = jnp.array(model_output, dtype=jnp.float32), + ets=ets, + cur_sample=sample, + cur_model_output=jnp.array(model_output, dtype=jnp.float32), ) def counter_1(state: PNDMSchedulerState): return state.replace( - cur_model_output = (model_output + state.ets[0]) / 2, + cur_model_output=(model_output + state.ets[0]) / 2, ) def counter_2(state: PNDMSchedulerState): ets = state.ets.at[1].set(model_output) return state.replace( - ets = ets, - cur_model_output = (3 * ets[1] - ets[0]) / 2, - cur_sample = sample, + ets=ets, + cur_model_output=(3 * ets[1] - ets[0]) / 2, + cur_sample=sample, ) def counter_3(state: PNDMSchedulerState): ets = state.ets.at[2].set(model_output) return state.replace( - ets = ets, - cur_model_output = (23 * ets[2] - 16 * ets[1] + 5 * ets[0]) / 12, - cur_sample = sample, - ) + ets=ets, + cur_model_output=(23 * ets[2] - 16 * ets[1] + 5 * ets[0]) / 12, + cur_sample=sample, + ) def counter_other(state: PNDMSchedulerState): ets = state.ets.at[3].set(model_output) @@ -441,9 +443,9 @@ def counter_other(state: PNDMSchedulerState): ets = ets.at[2].set(ets[3]) return state.replace( - ets = ets, - cur_model_output = next_model_output, - cur_sample = sample, + ets=ets, + cur_model_output=next_model_output, + cur_sample=sample, ) counter = jnp.clip(state.counter, 0, 4) From f974a41a420e711f2a288cd0aa1df0f26b299f7e Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Thu, 22 Sep 2022 13:32:29 +0200 Subject: [PATCH 20/23] Apply suggestion list -> tuple Co-authored-by: Suraj Patil --- src/diffusers/schedulers/scheduling_pndm_flax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py index 03e005140030..bd9c65f78076 100644 --- a/src/diffusers/schedulers/scheduling_pndm_flax.py +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -238,7 +238,7 @@ def step( else: prev_sample, state = jax.lax.switch( jnp.where(state.counter < len(state.prk_timesteps), 0, 1), - [self.step_prk, self.step_plms], + (self.step_prk, self.step_plms), # Args to either branch state, model_output, From ce0a327dc42b68d75e8ae46fcf87957e262dc937 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Thu, 22 Sep 2022 13:32:49 +0200 Subject: [PATCH 21/23] Apply suggestion list -> tuple Co-authored-by: Suraj Patil --- src/diffusers/schedulers/scheduling_pndm_flax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py index bd9c65f78076..5b1564c38705 100644 --- a/src/diffusers/schedulers/scheduling_pndm_flax.py +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -308,7 +308,7 @@ def remainder_3(state: PNDMSchedulerState, model_output: jnp.ndarray, ets_at: in state, model_output = jax.lax.switch( state.counter % 4, - [remainder_0, remainder_1, remainder_2, remainder_3], + (remainder_0, remainder_1, remainder_2, remainder_3), # Args to either branch state, model_output, From 7fcbc328cdc8afe8e84dcf5a1a0592a28c153c40 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Thu, 22 Sep 2022 11:42:30 +0000 Subject: [PATCH 22/23] Remove unused comments. --- .../schedulers/scheduling_pndm_flax.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py index 5b1564c38705..0a989397db23 100644 --- a/src/diffusers/schedulers/scheduling_pndm_flax.py +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -315,24 +315,7 @@ def remainder_3(state: PNDMSchedulerState, model_output: jnp.ndarray, ets_at: in state.counter // 4, ) - # if state.counter % 4 == 0: - # state = state.replace( - # cur_model_output=state.cur_model_output + 1 / 6 * model_output, - # ets=state.ets.append(model_output), - # cur_sample=sample, - # ) - # elif (self.counter - 1) % 4 == 0: - # state = state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output) - # elif (self.counter - 2) % 4 == 0: - # state = state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output) - # elif (self.counter - 3) % 4 == 0: - # model_output = state.cur_model_output + 1 / 6 * model_output - # state = state.replace(cur_model_output=0) - - # cur_sample should not be `None` - # cur_sample = state.cur_sample if state.cur_sample is not None else sample cur_sample = state.cur_sample - prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output) state = state.replace(counter=state.counter + 1) From cd17c560c1b3c343f8f4a2f30adf9a7ec94441e5 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Thu, 22 Sep 2022 11:42:47 +0000 Subject: [PATCH 23/23] Use zeros instead of empty. --- src/diffusers/schedulers/scheduling_pndm_flax.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py index 0a989397db23..4b4172213fa7 100644 --- a/src/diffusers/schedulers/scheduling_pndm_flax.py +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -198,10 +198,10 @@ def set_timesteps(self, state: PNDMSchedulerState, shape: Tuple, num_inference_s return state.replace( timesteps=jnp.concatenate([state.prk_timesteps, state.plms_timesteps]).astype(jnp.int64), counter=0, - # Will be zeros, not really empty - cur_model_output=jnp.empty(shape), - cur_sample=jnp.empty(shape), - ets=jnp.empty((4,) + shape), + # Reserve space for the state variables + cur_model_output=jnp.zeros(shape), + cur_sample=jnp.zeros(shape), + ets=jnp.zeros((4,) + shape), ) def step(