diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index da08bc360942..0f849a66eaea 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -130,9 +130,9 @@ def __init__( self.gradient_checkpointing = False - def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: + def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor: r"""The forward method of the `Encoder` class.""" - sample = x + sample = self.conv_in(sample) if self.training and self.gradient_checkpointing: @@ -273,9 +273,11 @@ def __init__( self.gradient_checkpointing = False - def forward(self, z: torch.FloatTensor, latent_embeds: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: + def forward( + self, sample: torch.FloatTensor, latent_embeds: Optional[torch.FloatTensor] = None + ) -> torch.FloatTensor: r"""The forward method of the `Decoder` class.""" - sample = z + sample = self.conv_in(sample) upscale_dtype = next(iter(self.up_blocks.parameters())).dtype