Skip to content

Commit

Permalink
Merge branch 'main' into feat/ci-benchmarking
Browse files Browse the repository at this point in the history
  • Loading branch information
sayakpaul authored Dec 4, 2023
2 parents 93b491b + bf92e74 commit 131bfce
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -232,20 +232,26 @@ def parse_args(input_args=None):
help=(
"The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,"
" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
" or to a folder containing files that 🤗 Datasets can understand."
" or to a folder containing files that 🤗 Datasets can understand.To load the custom captions, the training set directory needs to follow the structure of a "
"datasets ImageFolder, containing both the images and the corresponding caption for each image. see: "
"https://huggingface.co/docs/datasets/image_dataset for more information"
),
)
parser.add_argument(
"--dataset_config_name",
type=str,
default=None,
help="The config of the Dataset, leave as None if there's only one config.",
help="The config of the Dataset. In some cases, a dataset may have more than one configuration (for example "
"if it contains different subsets of data within, and you only wish to load a specific subset - in that case specify the desired configuration using --dataset_config_name. Leave as "
"None if there's only one config.",
)
parser.add_argument(
"--instance_data_dir",
type=str,
default=None,
help=("A folder containing the training data. "),
help="A path to local folder containing the training data of instance images. Specify this arg instead of "
"--dataset_name if you wish to train using a local folder without custom captions. If you wish to train with custom captions please specify "
"--dataset_name instead.",
)

parser.add_argument(
Expand Down
13 changes: 11 additions & 2 deletions examples/community/stable_diffusion_tensorrt_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
save_engine,
)
from polygraphy.backend.trt import util as trt_util
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection

from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import (
Expand Down Expand Up @@ -709,6 +709,7 @@ def __init__(
scheduler: DDIMScheduler,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
image_encoder: CLIPVisionModelWithProjection = None,
requires_safety_checker: bool = True,
stages=["clip", "unet", "vae", "vae_encoder"],
image_height: int = 512,
Expand All @@ -724,7 +725,15 @@ def __init__(
timing_cache: str = "timing_cache",
):
super().__init__(
vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
vae,
text_encoder,
tokenizer,
unet,
scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
requires_safety_checker=requires_safety_checker,
)

self.vae.forward = self.vae.decode
Expand Down
13 changes: 11 additions & 2 deletions examples/community/stable_diffusion_tensorrt_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
save_engine,
)
from polygraphy.backend.trt import util as trt_util
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection

from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import (
Expand Down Expand Up @@ -710,6 +710,7 @@ def __init__(
scheduler: DDIMScheduler,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
image_encoder: CLIPVisionModelWithProjection = None,
requires_safety_checker: bool = True,
stages=["clip", "unet", "vae", "vae_encoder"],
image_height: int = 512,
Expand All @@ -725,7 +726,15 @@ def __init__(
timing_cache: str = "timing_cache",
):
super().__init__(
vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
vae,
text_encoder,
tokenizer,
unet,
scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
requires_safety_checker=requires_safety_checker,
)

self.vae.forward = self.vae.decode
Expand Down
13 changes: 11 additions & 2 deletions examples/community/stable_diffusion_tensorrt_txt2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
save_engine,
)
from polygraphy.backend.trt import util as trt_util
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection

from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import (
Expand Down Expand Up @@ -624,6 +624,7 @@ def __init__(
scheduler: DDIMScheduler,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
image_encoder: CLIPVisionModelWithProjection = None,
requires_safety_checker: bool = True,
stages=["clip", "unet", "vae"],
image_height: int = 768,
Expand All @@ -639,7 +640,15 @@ def __init__(
timing_cache: str = "timing_cache",
):
super().__init__(
vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
vae,
text_encoder,
tokenizer,
unet,
scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
requires_safety_checker=requires_safety_checker,
)

self.vae.forward = self.vae.decode
Expand Down

0 comments on commit 131bfce

Please sign in to comment.