From caafeee7a1869b6e599f8c464111b710a015e1ac Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Sat, 9 Jul 2022 21:16:47 +0000 Subject: [PATCH] =?UTF-8?q?fix:=E2=80=AFencoder=20attn?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- vqgan_jax/modeling_flax_vqgan.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vqgan_jax/modeling_flax_vqgan.py b/vqgan_jax/modeling_flax_vqgan.py index 3647639..1c49553 100644 --- a/vqgan_jax/modeling_flax_vqgan.py +++ b/vqgan_jax/modeling_flax_vqgan.py @@ -273,12 +273,12 @@ def setup(self): dtype=self.dtype) def __call__(self, hidden_states, temb=None, deterministic: bool = True): - for res_block in self.block: + for i, res_block in enumerate(self.block): hidden_states = res_block(hidden_states, temb, deterministic=deterministic) - for attn_block in self.attn: - hidden_states = attn_block(hidden_states) + if self.attn: + hidden_states = self.attn[i](hidden_states) if self.downsample is not None: hidden_states = self.downsample(hidden_states)