Skip to content

Commit

Permalink
fix-copies
Browse files Browse the repository at this point in the history
  • Loading branch information
charchit7 committed Dec 8, 2023
1 parent 3c70b1e commit 6002126
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,11 +478,22 @@ def encode_image(self, image, device, num_images_per_prompt):
image = self.feature_extractor(image, return_tensors="pt").pixel_values

image = image.to(device=device, dtype=dtype)
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
if output_hidden_states:
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_enc_hidden_states = self.image_encoder(
torch.zeros_like(image), output_hidden_states=True
).hidden_states[-2]
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
num_images_per_prompt, dim=0
)
return image_enc_hidden_states, uncond_image_enc_hidden_states
else:
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)

uncond_image_embeds = torch.zeros_like(image_embeds)
return image_embeds, uncond_image_embeds
return image_embeds, uncond_image_embeds

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
Expand Down

0 comments on commit 6002126

Please sign in to comment.