diff --git a/vqgan_jax/modeling_flax_vqgan.py b/vqgan_jax/modeling_flax_vqgan.py index 3647639..1c49553 100644 --- a/vqgan_jax/modeling_flax_vqgan.py +++ b/vqgan_jax/modeling_flax_vqgan.py @@ -273,12 +273,12 @@ def setup(self): dtype=self.dtype) def __call__(self, hidden_states, temb=None, deterministic: bool = True): - for res_block in self.block: + for i, res_block in enumerate(self.block): hidden_states = res_block(hidden_states, temb, deterministic=deterministic) - for attn_block in self.attn: - hidden_states = attn_block(hidden_states) + if self.attn: + hidden_states = self.attn[i](hidden_states) if self.downsample is not None: hidden_states = self.downsample(hidden_states)