Skip to content

Commit

Permalink
fix StableDiffusionTensorRT super args error (#6009)
Browse files Browse the repository at this point in the history
  • Loading branch information
gujingit authored Dec 4, 2023
1 parent b785a15 commit bf92e74
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 6 deletions.
13 changes: 11 additions & 2 deletions examples/community/stable_diffusion_tensorrt_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
save_engine,
)
from polygraphy.backend.trt import util as trt_util
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection

from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import (
Expand Down Expand Up @@ -709,6 +709,7 @@ def __init__(
scheduler: DDIMScheduler,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
image_encoder: CLIPVisionModelWithProjection = None,
requires_safety_checker: bool = True,
stages=["clip", "unet", "vae", "vae_encoder"],
image_height: int = 512,
Expand All @@ -724,7 +725,15 @@ def __init__(
timing_cache: str = "timing_cache",
):
super().__init__(
vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
vae,
text_encoder,
tokenizer,
unet,
scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
requires_safety_checker=requires_safety_checker,
)

self.vae.forward = self.vae.decode
Expand Down
13 changes: 11 additions & 2 deletions examples/community/stable_diffusion_tensorrt_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
save_engine,
)
from polygraphy.backend.trt import util as trt_util
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection

from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import (
Expand Down Expand Up @@ -710,6 +710,7 @@ def __init__(
scheduler: DDIMScheduler,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
image_encoder: CLIPVisionModelWithProjection = None,
requires_safety_checker: bool = True,
stages=["clip", "unet", "vae", "vae_encoder"],
image_height: int = 512,
Expand All @@ -725,7 +726,15 @@ def __init__(
timing_cache: str = "timing_cache",
):
super().__init__(
vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
vae,
text_encoder,
tokenizer,
unet,
scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
requires_safety_checker=requires_safety_checker,
)

self.vae.forward = self.vae.decode
Expand Down
13 changes: 11 additions & 2 deletions examples/community/stable_diffusion_tensorrt_txt2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
save_engine,
)
from polygraphy.backend.trt import util as trt_util
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection

from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import (
Expand Down Expand Up @@ -624,6 +624,7 @@ def __init__(
scheduler: DDIMScheduler,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
image_encoder: CLIPVisionModelWithProjection = None,
requires_safety_checker: bool = True,
stages=["clip", "unet", "vae"],
image_height: int = 768,
Expand All @@ -639,7 +640,15 @@ def __init__(
timing_cache: str = "timing_cache",
):
super().__init__(
vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
vae,
text_encoder,
tokenizer,
unet,
scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
requires_safety_checker=requires_safety_checker,
)

self.vae.forward = self.vae.decode
Expand Down

0 comments on commit bf92e74

Please sign in to comment.