diff --git a/convert_pt_model_to_jax.py b/convert_pt_model_to_jax.py deleted file mode 100644 index af03fd4..0000000 --- a/convert_pt_model_to_jax.py +++ /dev/null @@ -1,110 +0,0 @@ -import re - -import jax.numpy as jnp -from flax.traverse_util import flatten_dict, unflatten_dict - -import torch - -from modeling_flax_vqgan import VQModel -from configuration_vqgan import VQGANConfig - - -regex = r"\w+[.]\d+" - - -def rename_key(key): - pats = re.findall(regex, key) - for pat in pats: - key = key.replace(pat, "_".join(pat.split("."))) - return key - - -# Adapted from https://github.com/huggingface/transformers/blob/ff5cdc086be1e0c3e2bbad8e3469b34cffb55a85/src/transformers/modeling_flax_pytorch_utils.py#L61 -def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model): - # convert pytorch tensor to numpy - pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()} - - random_flax_state_dict = flatten_dict(flax_model.params) - flax_state_dict = {} - - remove_base_model_prefix = (flax_model.base_model_prefix not in flax_model.params) and ( - flax_model.base_model_prefix in set([k.split(".")[0] for k in pt_state_dict.keys()]) - ) - add_base_model_prefix = (flax_model.base_model_prefix in flax_model.params) and ( - flax_model.base_model_prefix not in set([k.split(".")[0] for k in pt_state_dict.keys()]) - ) - - # Need to change some parameters name to match Flax names so that we don't have to fork any layer - for pt_key, pt_tensor in pt_state_dict.items(): - pt_tuple_key = tuple(pt_key.split(".")) - - has_base_model_prefix = pt_tuple_key[0] == flax_model.base_model_prefix - require_base_model_prefix = (flax_model.base_model_prefix,) + pt_tuple_key in random_flax_state_dict - - if remove_base_model_prefix and has_base_model_prefix: - pt_tuple_key = pt_tuple_key[1:] - elif add_base_model_prefix and require_base_model_prefix: - pt_tuple_key = (flax_model.base_model_prefix,) + pt_tuple_key - - # Correctly rename weight parameters - if ( - "norm" in pt_key - and (pt_tuple_key[-1] == "bias") - and (pt_tuple_key[:-1] + ("bias",) in random_flax_state_dict) - ): - pt_tensor = pt_tensor[None, None, None, :] - elif ( - "norm" in pt_key - and (pt_tuple_key[-1] == "bias") - and (pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict) - ): - pt_tuple_key = pt_tuple_key[:-1] + ("scale",) - pt_tensor = pt_tensor[None, None, None, :] - elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict: - pt_tuple_key = pt_tuple_key[:-1] + ("scale",) - pt_tensor = pt_tensor[None, None, None, :] - if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict: - pt_tuple_key = pt_tuple_key[:-1] + ("embedding",) - elif pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4 and pt_tuple_key not in random_flax_state_dict: - # conv layer - pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) - pt_tensor = pt_tensor.transpose(2, 3, 1, 0) - elif pt_tuple_key[-1] == "weight" and pt_tuple_key not in random_flax_state_dict: - # linear layer - pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) - pt_tensor = pt_tensor.T - elif pt_tuple_key[-1] == "gamma": - pt_tuple_key = pt_tuple_key[:-1] + ("weight",) - elif pt_tuple_key[-1] == "beta": - pt_tuple_key = pt_tuple_key[:-1] + ("bias",) - - if pt_tuple_key in random_flax_state_dict: - if pt_tensor.shape != random_flax_state_dict[pt_tuple_key].shape: - raise ValueError( - f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape " - f"{random_flax_state_dict[pt_tuple_key].shape}, but is {pt_tensor.shape}." - ) - - # also add unexpected weight so that warning is thrown - flax_state_dict[pt_tuple_key] = jnp.asarray(pt_tensor) - - return unflatten_dict(flax_state_dict) - - -def convert_model(config_path, pt_state_dict_path, save_path): - config = VQGANConfig.from_pretrained(config_path) - model = VQModel(config) - - state_dict = torch.load(pt_state_dict_path, map_location="cpu")["state_dict"] - keys = list(state_dict.keys()) - for key in keys: - if key.startswith("loss"): - state_dict.pop(key) - continue - renamed_key = rename_key(key) - state_dict[renamed_key] = state_dict.pop(key) - - state = convert_pytorch_state_dict_to_flax(state_dict, model) - model.params = state - model.save_pretrained(save_path) - return model diff --git a/modeling_flax_vqgan.py b/modeling_flax_vqgan.py deleted file mode 100644 index fc176ab..0000000 --- a/modeling_flax_vqgan.py +++ /dev/null @@ -1,609 +0,0 @@ -# JAX implementation of VQGAN from taming-transformers https://github.com/CompVis/taming-transformers - -from functools import partial -from typing import Tuple -import math - -import jax -import jax.numpy as jnp -import numpy as np -import flax.linen as nn -from flax.core.frozen_dict import FrozenDict - -from transformers.modeling_flax_utils import FlaxPreTrainedModel - -from configuration_vqgan import VQGANConfig - - -class Upsample(nn.Module): - in_channels: int - with_conv: bool - dtype: jnp.dtype = jnp.float32 - - def setup(self): - if self.with_conv: - self.conv = nn.Conv( - self.in_channels, - kernel_size=(3, 3), - strides=(1, 1), - padding=((1, 1), (1, 1)), - dtype=self.dtype, - ) - - def __call__(self, hidden_states): - batch, height, width, channels = hidden_states.shape - hidden_states = jax.image.resize( - hidden_states, - shape=(batch, height * 2, width * 2, channels), - method="nearest", - ) - if self.with_conv: - hidden_states = self.conv(hidden_states) - return hidden_states - - -class Downsample(nn.Module): - in_channels: int - with_conv: bool - dtype: jnp.dtype = jnp.float32 - - def setup(self): - if self.with_conv: - self.conv = nn.Conv( - self.in_channels, - kernel_size=(3, 3), - strides=(2, 2), - padding="VALID", - dtype=self.dtype, - ) - - def __call__(self, hidden_states): - if self.with_conv: - pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim - hidden_states = jnp.pad(hidden_states, pad_width=pad) - hidden_states = self.conv(hidden_states) - else: - hidden_states = nn.avg_pool(hidden_states, window_shape=(2, 2), strides=(2, 2), padding="VALID") - return hidden_states - - -class ResnetBlock(nn.Module): - in_channels: int - out_channels: int = None - use_conv_shortcut: bool = False - temb_channels: int = 512 - dropout_prob: float = 0.0 - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.out_channels_ = self.in_channels if self.out_channels is None else self.out_channels - - self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-6) - self.conv1 = nn.Conv( - self.out_channels_, - kernel_size=(3, 3), - strides=(1, 1), - padding=((1, 1), (1, 1)), - dtype=self.dtype, - ) - - if self.temb_channels: - self.temb_proj = nn.Dense(self.out_channels_, dtype=self.dtype) - - self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-6) - self.dropout = nn.Dropout(self.dropout_prob) - self.conv2 = nn.Conv( - self.out_channels_, - kernel_size=(3, 3), - strides=(1, 1), - padding=((1, 1), (1, 1)), - dtype=self.dtype, - ) - - if self.in_channels != self.out_channels_: - if self.use_conv_shortcut: - self.conv_shortcut = nn.Conv( - self.out_channels_, - kernel_size=(3, 3), - strides=(1, 1), - padding=((1, 1), (1, 1)), - dtype=self.dtype, - ) - else: - self.nin_shortcut = nn.Conv( - self.out_channels_, - kernel_size=(1, 1), - strides=(1, 1), - padding="VALID", - dtype=self.dtype, - ) - - def __call__(self, hidden_states, temb=None, deterministic: bool = True): - residual = hidden_states - hidden_states = self.norm1(hidden_states) - hidden_states = nn.swish(hidden_states) - hidden_states = self.conv1(hidden_states) - - if temb is not None: - hidden_states = hidden_states + self.temb_proj(nn.swish(temb))[:, :, None, None] # TODO: check shapes - - hidden_states = self.norm2(hidden_states) - hidden_states = nn.swish(hidden_states) - hidden_states = self.dropout(hidden_states, deterministic) - hidden_states = self.conv2(hidden_states) - - if self.in_channels != self.out_channels_: - if self.use_conv_shortcut: - residual = self.conv_shortcut(residual) - else: - residual = self.nin_shortcut(residual) - - return hidden_states + residual - - -class AttnBlock(nn.Module): - in_channels: int - dtype: jnp.dtype = jnp.float32 - - def setup(self): - conv = partial( - nn.Conv, self.in_channels, kernel_size=(1, 1), strides=(1, 1), padding="VALID", dtype=self.dtype - ) - - self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-6) - self.q, self.k, self.v = conv(), conv(), conv() - self.proj_out = conv() - - def __call__(self, hidden_states): - residual = hidden_states - hidden_states = self.norm(hidden_states) - - query = self.q(hidden_states) - key = self.k(hidden_states) - value = self.v(hidden_states) - - # compute attentions - batch, height, width, channels = query.shape - query = query.reshape((batch, height * width, channels)) - key = key.reshape((batch, height * width, channels)) - attn_weights = jnp.einsum("...qc,...kc->...qk", query, key) - attn_weights = attn_weights * (int(channels) ** -0.5) - attn_weights = nn.softmax(attn_weights, axis=2) - - ## attend to values - value = value.reshape((batch, height * width, channels)) - hidden_states = jnp.einsum("...kc,...qk->...qc", value, attn_weights) - hidden_states = hidden_states.reshape((batch, height, width, channels)) - - hidden_states = self.proj_out(hidden_states) - hidden_states = hidden_states + residual - return hidden_states - - -class UpsamplingBlock(nn.Module): - config: VQGANConfig - curr_res: int - block_idx: int - dtype: jnp.dtype = jnp.float32 - - def setup(self): - if self.block_idx == self.config.num_resolutions - 1: - block_in = self.config.ch * self.config.ch_mult[-1] - else: - block_in = self.config.ch * self.config.ch_mult[self.block_idx + 1] - - block_out = self.config.ch * self.config.ch_mult[self.block_idx] - self.temb_ch = 0 - - res_blocks = [] - attn_blocks = [] - for _ in range(self.config.num_res_blocks + 1): - res_blocks.append( - ResnetBlock( - block_in, block_out, temb_channels=self.temb_ch, dropout_prob=self.config.dropout, dtype=self.dtype - ) - ) - block_in = block_out - if self.curr_res in self.config.attn_resolutions: - attn_blocks.append(AttnBlock(block_in, dtype=self.dtype)) - - self.block = res_blocks - self.attn = attn_blocks - - self.upsample = None - if self.block_idx != 0: - self.upsample = Upsample(block_in, self.config.resamp_with_conv, dtype=self.dtype) - - def __call__(self, hidden_states, temb=None, deterministic: bool = True): - for res_block in self.block: - hidden_states = res_block(hidden_states, temb, deterministic=deterministic) - for attn_block in self.attn: - hidden_states = attn_block(hidden_states) - - if self.upsample is not None: - hidden_states = self.upsample(hidden_states) - - return hidden_states - - -class DownsamplingBlock(nn.Module): - config: VQGANConfig - curr_res: int - block_idx: int - dtype: jnp.dtype = jnp.float32 - - def setup(self): - in_ch_mult = (1,) + tuple(self.config.ch_mult) - block_in = self.config.ch * in_ch_mult[self.block_idx] - block_out = self.config.ch * self.config.ch_mult[self.block_idx] - self.temb_ch = 0 - - res_blocks = [] - attn_blocks = [] - for _ in range(self.config.num_res_blocks): - res_blocks.append( - ResnetBlock( - block_in, block_out, temb_channels=self.temb_ch, dropout_prob=self.config.dropout, dtype=self.dtype - ) - ) - block_in = block_out - if self.curr_res in self.config.attn_resolutions: - attn_blocks.append(AttnBlock(block_in, dtype=self.dtype)) - - self.block = res_blocks - self.attn = attn_blocks - - self.downsample = None - if self.block_idx != self.config.num_resolutions - 1: - self.downsample = Downsample(block_in, self.config.resamp_with_conv, dtype=self.dtype) - - def __call__(self, hidden_states, temb=None, deterministic: bool = True): - for res_block in self.block: - hidden_states = res_block(hidden_states, temb, deterministic=deterministic) - for attn_block in self.attn: - hidden_states = attn_block(hidden_states) - - if self.downsample is not None: - hidden_states = self.downsample(hidden_states) - - return hidden_states - - -class MidBlock(nn.Module): - in_channels: int - temb_channels: int - dropout: float - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.block_1 = ResnetBlock( - self.in_channels, - self.in_channels, - temb_channels=self.temb_channels, - dropout_prob=self.dropout, - dtype=self.dtype, - ) - self.attn_1 = AttnBlock(self.in_channels, dtype=self.dtype) - self.block_2 = ResnetBlock( - self.in_channels, - self.in_channels, - temb_channels=self.temb_channels, - dropout_prob=self.dropout, - dtype=self.dtype, - ) - - def __call__(self, hidden_states, temb=None, deterministic: bool = True): - hidden_states = self.block_1(hidden_states, temb, deterministic=deterministic) - hidden_states = self.attn_1(hidden_states) - hidden_states = self.block_2(hidden_states, temb, deterministic=deterministic) - return hidden_states - - -class Encoder(nn.Module): - config: VQGANConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.temb_ch = 0 - - # downsampling - self.conv_in = nn.Conv( - self.config.ch, - kernel_size=(3, 3), - strides=(1, 1), - padding=((1, 1), (1, 1)), - dtype=self.dtype, - ) - - curr_res = self.config.resolution - downsample_blocks = [] - for i_level in range(self.config.num_resolutions): - downsample_blocks.append(DownsamplingBlock(self.config, curr_res, block_idx=i_level, dtype=self.dtype)) - - if i_level != self.config.num_resolutions - 1: - curr_res = curr_res // 2 - self.down = downsample_blocks - - # middle - mid_channels = self.config.ch * self.config.ch_mult[-1] - self.mid = MidBlock(mid_channels, self.temb_ch, self.config.dropout, dtype=self.dtype) - - # end - self.norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-6) - self.conv_out = nn.Conv( - 2 * self.config.z_channels if self.config.double_z else self.config.z_channels, - kernel_size=(3, 3), - strides=(1, 1), - padding=((1, 1), (1, 1)), - dtype=self.dtype, - ) - - def __call__(self, pixel_values, deterministic: bool = True): - # timestep embedding - temb = None - - # downsampling - hidden_states = self.conv_in(pixel_values) - for block in self.down: - hidden_states = block(hidden_states, temb, deterministic=deterministic) - - # middle - hidden_states = self.mid(hidden_states, temb, deterministic=deterministic) - - # end - hidden_states = self.norm_out(hidden_states) - hidden_states = nn.swish(hidden_states) - hidden_states = self.conv_out(hidden_states) - - return hidden_states - - -class Decoder(nn.Module): - config: VQGANConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.temb_ch = 0 - - # compute in_ch_mult, block_in and curr_res at lowest res - block_in = self.config.ch * self.config.ch_mult[self.config.num_resolutions - 1] - curr_res = self.config.resolution // 2 ** (self.config.num_resolutions - 1) - self.z_shape = (1, self.config.z_channels, curr_res, curr_res) - print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) - - # z to block_in - self.conv_in = nn.Conv( - block_in, - kernel_size=(3, 3), - strides=(1, 1), - padding=((1, 1), (1, 1)), - dtype=self.dtype, - ) - - # middle - self.mid = MidBlock(block_in, self.temb_ch, self.config.dropout, dtype=self.dtype) - - # upsampling - upsample_blocks = [] - for i_level in reversed(range(self.config.num_resolutions)): - upsample_blocks.append(UpsamplingBlock(self.config, curr_res, block_idx=i_level, dtype=self.dtype)) - if i_level != 0: - curr_res = curr_res * 2 - self.up = list(reversed(upsample_blocks)) # reverse to get consistent order - - # end - self.norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-6) - self.conv_out = nn.Conv( - self.config.out_ch, - kernel_size=(3, 3), - strides=(1, 1), - padding=((1, 1), (1, 1)), - dtype=self.dtype, - ) - - def __call__(self, hidden_states, deterministic: bool = True): - # timestep embedding - temb = None - - # z to block_in - hidden_states = self.conv_in(hidden_states) - - # middle - hidden_states = self.mid(hidden_states, temb, deterministic=deterministic) - - # upsampling - for block in reversed(self.up): - hidden_states = block(hidden_states, temb, deterministic=deterministic) - - # end - if self.config.give_pre_end: - return hidden_states - - hidden_states = self.norm_out(hidden_states) - hidden_states = nn.swish(hidden_states) - hidden_states = self.conv_out(hidden_states) - - return hidden_states - - -class VectorQuantizer(nn.Module): - """ - see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py - ____________________________________________ - Discretization bottleneck part of the VQ-VAE. - Inputs: - - n_e : number of embeddings - - e_dim : dimension of embedding - - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2 - _____________________________________________ - """ - - config: VQGANConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.embedding = nn.Embed(self.config.n_embed, self.config.embed_dim, dtype=self.dtype) # TODO: init - - def __call__(self, hidden_states): - """ - Inputs the output of the encoder network z and maps it to a discrete - one-hot vector that is the index of the closest embedding vector e_j - z (continuous) -> z_q (discrete) - z.shape = (batch, channel, height, width) - quantization pipeline: - 1. get encoder input (B,C,H,W) - 2. flatten input to (B*H*W,C) - """ - # flatten - hidden_states_flattended = hidden_states.reshape((-1, self.config.embed_dim)) - - # dummy op to init the weights, so we can access them below - self.embedding(jnp.ones((1, 1), dtype="i4")) - - # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z - emb_weights = self.variables["params"]["embedding"]["embedding"] - distance = ( - jnp.sum(hidden_states_flattended ** 2, axis=1, keepdims=True) - + jnp.sum(emb_weights ** 2, axis=1) - - 2 * jnp.dot(hidden_states_flattended, emb_weights.T) - ) - - # get quantized latent vectors - min_encoding_indices = jnp.argmin(distance, axis=1) - z_q = self.embedding(min_encoding_indices).reshape(hidden_states.shape) - - # reshape to (batch, num_tokens) - min_encoding_indices = min_encoding_indices.reshape(hidden_states.shape[0], -1) - - # compute the codebook_loss (q_loss) outside the model - # here we return the embeddings and indices - return z_q, min_encoding_indices - - def get_codebook_entry(self, indices, shape=None): - # indices are expected to be of shape (batch, num_tokens) - # get quantized latent vectors - batch, num_tokens = indices.shape - z_q = self.embedding(indices) - z_q = z_q.reshape(batch, int(math.sqrt(num_tokens)), int(math.sqrt(num_tokens)), -1) - return z_q - - -class VQModule(nn.Module): - config: VQGANConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.encoder = Encoder(self.config, dtype=self.dtype) - self.decoder = Decoder(self.config, dtype=self.dtype) - self.quantize = VectorQuantizer(self.config, dtype=self.dtype) - self.quant_conv = nn.Conv( - self.config.embed_dim, - kernel_size=(1, 1), - strides=(1, 1), - padding="VALID", - dtype=self.dtype, - ) - self.post_quant_conv = nn.Conv( - self.config.z_channels, - kernel_size=(1, 1), - strides=(1, 1), - padding="VALID", - dtype=self.dtype, - ) - - def encode(self, pixel_values, deterministic: bool = True): - hidden_states = self.encoder(pixel_values, deterministic=deterministic) - hidden_states = self.quant_conv(hidden_states) - quant_states, indices = self.quantize(hidden_states) - return quant_states, indices - - def decode(self, hidden_states, deterministic: bool = True): - hidden_states = self.post_quant_conv(hidden_states) - hidden_states = self.decoder(hidden_states, deterministic=deterministic) - return hidden_states - - def decode_code(self, code_b): - hidden_states = self.quantize.get_codebook_entry(code_b) - hidden_states = self.decode(hidden_states) - return hidden_states - - def __call__(self, pixel_values, deterministic: bool = True): - quant_states, indices = self.encode(pixel_values, deterministic) - hidden_states = self.decode(quant_states, deterministic) - return hidden_states, indices - - -class VQGANPreTrainedModel(FlaxPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface - for downloading and loading pretrained models. - """ - - config_class = VQGANConfig - base_model_prefix = "model" - module_class: nn.Module = None - - def __init__( - self, - config: VQGANConfig, - input_shape: Tuple = (1, 256, 256, 3), - seed: int = 0, - dtype: jnp.dtype = jnp.float32, - **kwargs, - ): - module = self.module_class(config=config, dtype=dtype, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) - - def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: - # init input tensors - pixel_values = jnp.zeros(input_shape, dtype=jnp.float32) - params_rng, dropout_rng = jax.random.split(rng) - rngs = {"params": params_rng, "dropout": dropout_rng} - - return self.module.init(rngs, pixel_values)["params"] - - def encode(self, pixel_values, params: dict = None, dropout_rng: jax.random.PRNGKey = None, train: bool = False): - # Handle any PRNG if needed - rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} - - return self.module.apply( - {"params": params or self.params}, jnp.array(pixel_values), not train, rngs=rngs, method=self.module.encode - ) - - def decode(self, hidden_states, params: dict = None, dropout_rng: jax.random.PRNGKey = None, train: bool = False): - # Handle any PRNG if needed - rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} - - return self.module.apply( - {"params": params or self.params}, - jnp.array(hidden_states), - not train, - rngs=rngs, - method=self.module.decode, - ) - - def decode_code(self, indices, params: dict = None): - return self.module.apply( - {"params": params or self.params}, jnp.array(indices, dtype="i4"), method=self.module.decode_code - ) - - def __call__( - self, - pixel_values, - params: dict = None, - dropout_rng: jax.random.PRNGKey = None, - train: bool = False, - ): - # Handle any PRNG if needed - rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} - - return self.module.apply( - {"params": params or self.params}, - jnp.array(pixel_values), - not train, - rngs=rngs, - ) - - -class VQModel(VQGANPreTrainedModel): - module_class = VQModule diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..8383db7 --- /dev/null +++ b/setup.py @@ -0,0 +1,19 @@ +import setuptools +# To use a consistent encoding +from codecs import open +import os + +here = os.path.abspath(os.path.dirname(__file__)) + +with open(os.path.join(here, 'README.md'), encoding='utf-8') as f: + long_description = f.read() + +setuptools.setup( + name='vqgan-jax', + version='0.0.1', + description='JAX implementation of VQGAN', + long_description=long_description, + long_description_content_type='text/markdown', + packages=setuptools.find_packages(), + install_requires=['jax>=0.2.6', 'flax', 'transformers'], +) diff --git a/vqgan_jax/__init__.py b/vqgan_jax/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/configuration_vqgan.py b/vqgan_jax/configuration_vqgan.py similarity index 100% rename from configuration_vqgan.py rename to vqgan_jax/configuration_vqgan.py diff --git a/vqgan_jax/convert_pt_model_to_jax.py b/vqgan_jax/convert_pt_model_to_jax.py new file mode 100644 index 0000000..a48f434 --- /dev/null +++ b/vqgan_jax/convert_pt_model_to_jax.py @@ -0,0 +1,111 @@ +import re + +import jax.numpy as jnp +from flax.traverse_util import flatten_dict, unflatten_dict + +import torch + +from .modeling_flax_vqgan import VQModel +from .configuration_vqgan import VQGANConfig + +regex = r"\w+[.]\d+" + + +def rename_key(key): + pats = re.findall(regex, key) + for pat in pats: + key = key.replace(pat, "_".join(pat.split("."))) + return key + + +# Adapted from https://github.com/huggingface/transformers/blob/ff5cdc086be1e0c3e2bbad8e3469b34cffb55a85/src/transformers/modeling_flax_pytorch_utils.py#L61 +def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model): + # convert pytorch tensor to numpy + pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()} + + random_flax_state_dict = flatten_dict(flax_model.params) + flax_state_dict = {} + + remove_base_model_prefix = ( + flax_model.base_model_prefix + not in flax_model.params) and (flax_model.base_model_prefix in set( + [k.split(".")[0] for k in pt_state_dict.keys()])) + add_base_model_prefix = (flax_model.base_model_prefix in flax_model.params + ) and (flax_model.base_model_prefix not in set( + [k.split(".")[0] + for k in pt_state_dict.keys()])) + + # Need to change some parameters name to match Flax names so that we don't have to fork any layer + for pt_key, pt_tensor in pt_state_dict.items(): + pt_tuple_key = tuple(pt_key.split(".")) + + has_base_model_prefix = pt_tuple_key[0] == flax_model.base_model_prefix + require_base_model_prefix = (flax_model.base_model_prefix, + ) + pt_tuple_key in random_flax_state_dict + + if remove_base_model_prefix and has_base_model_prefix: + pt_tuple_key = pt_tuple_key[1:] + elif add_base_model_prefix and require_base_model_prefix: + pt_tuple_key = (flax_model.base_model_prefix, ) + pt_tuple_key + + # Correctly rename weight parameters + if ("norm" in pt_key and (pt_tuple_key[-1] == "bias") + and (pt_tuple_key[:-1] + ("bias", ) in random_flax_state_dict)): + pt_tensor = pt_tensor[None, None, None, :] + elif ("norm" in pt_key and (pt_tuple_key[-1] == "bias") + and (pt_tuple_key[:-1] + ("scale", ) in random_flax_state_dict)): + pt_tuple_key = pt_tuple_key[:-1] + ("scale", ) + pt_tensor = pt_tensor[None, None, None, :] + elif pt_tuple_key[-1] in [ + "weight", "gamma" + ] and pt_tuple_key[:-1] + ("scale", ) in random_flax_state_dict: + pt_tuple_key = pt_tuple_key[:-1] + ("scale", ) + pt_tensor = pt_tensor[None, None, None, :] + if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ( + "embedding", ) in random_flax_state_dict: + pt_tuple_key = pt_tuple_key[:-1] + ("embedding", ) + elif pt_tuple_key[ + -1] == "weight" and pt_tensor.ndim == 4 and pt_tuple_key not in random_flax_state_dict: + # conv layer + pt_tuple_key = pt_tuple_key[:-1] + ("kernel", ) + pt_tensor = pt_tensor.transpose(2, 3, 1, 0) + elif pt_tuple_key[ + -1] == "weight" and pt_tuple_key not in random_flax_state_dict: + # linear layer + pt_tuple_key = pt_tuple_key[:-1] + ("kernel", ) + pt_tensor = pt_tensor.T + elif pt_tuple_key[-1] == "gamma": + pt_tuple_key = pt_tuple_key[:-1] + ("weight", ) + elif pt_tuple_key[-1] == "beta": + pt_tuple_key = pt_tuple_key[:-1] + ("bias", ) + + if pt_tuple_key in random_flax_state_dict: + if pt_tensor.shape != random_flax_state_dict[pt_tuple_key].shape: + raise ValueError( + f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape " + f"{random_flax_state_dict[pt_tuple_key].shape}, but is {pt_tensor.shape}." + ) + + # also add unexpected weight so that warning is thrown + flax_state_dict[pt_tuple_key] = jnp.asarray(pt_tensor) + + return unflatten_dict(flax_state_dict) + + +def convert_model(config_path, pt_state_dict_path, save_path): + config = VQGANConfig.from_pretrained(config_path) + model = VQModel(config) + + state_dict = torch.load(pt_state_dict_path, map_location="cpu")["state_dict"] + keys = list(state_dict.keys()) + for key in keys: + if key.startswith("loss"): + state_dict.pop(key) + continue + renamed_key = rename_key(key) + state_dict[renamed_key] = state_dict.pop(key) + + state = convert_pytorch_state_dict_to_flax(state_dict, model) + model.params = state + model.save_pretrained(save_path) + return model diff --git a/vqgan_jax/modeling_flax_vqgan.py b/vqgan_jax/modeling_flax_vqgan.py new file mode 100644 index 0000000..3792960 --- /dev/null +++ b/vqgan_jax/modeling_flax_vqgan.py @@ -0,0 +1,666 @@ +# JAX implementation of VQGAN from taming-transformers https://github.com/CompVis/taming-transformers + +from functools import partial +from typing import Tuple +import math + +import jax +import jax.numpy as jnp +import numpy as np +import flax.linen as nn +from flax.core.frozen_dict import FrozenDict + +from transformers.modeling_flax_utils import FlaxPreTrainedModel + +from .configuration_vqgan import VQGANConfig + + +class Upsample(nn.Module): + in_channels: int + with_conv: bool + dtype: jnp.dtype = jnp.float32 + + def setup(self): + if self.with_conv: + self.conv = nn.Conv( + self.in_channels, + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + def __call__(self, hidden_states): + batch, height, width, channels = hidden_states.shape + hidden_states = jax.image.resize( + hidden_states, + shape=(batch, height * 2, width * 2, channels), + method="nearest", + ) + if self.with_conv: + hidden_states = self.conv(hidden_states) + return hidden_states + + +class Downsample(nn.Module): + in_channels: int + with_conv: bool + dtype: jnp.dtype = jnp.float32 + + def setup(self): + if self.with_conv: + self.conv = nn.Conv( + self.in_channels, + kernel_size=(3, 3), + strides=(2, 2), + padding="VALID", + dtype=self.dtype, + ) + + def __call__(self, hidden_states): + if self.with_conv: + pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim + hidden_states = jnp.pad(hidden_states, pad_width=pad) + hidden_states = self.conv(hidden_states) + else: + hidden_states = nn.avg_pool(hidden_states, + window_shape=(2, 2), + strides=(2, 2), + padding="VALID") + return hidden_states + + +class ResnetBlock(nn.Module): + in_channels: int + out_channels: int = None + use_conv_shortcut: bool = False + temb_channels: int = 512 + dropout_prob: float = 0.0 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.out_channels_ = self.in_channels if self.out_channels is None else self.out_channels + + self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-6) + self.conv1 = nn.Conv( + self.out_channels_, + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + if self.temb_channels: + self.temb_proj = nn.Dense(self.out_channels_, dtype=self.dtype) + + self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-6) + self.dropout = nn.Dropout(self.dropout_prob) + self.conv2 = nn.Conv( + self.out_channels_, + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + if self.in_channels != self.out_channels_: + if self.use_conv_shortcut: + self.conv_shortcut = nn.Conv( + self.out_channels_, + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + else: + self.nin_shortcut = nn.Conv( + self.out_channels_, + kernel_size=(1, 1), + strides=(1, 1), + padding="VALID", + dtype=self.dtype, + ) + + def __call__(self, hidden_states, temb=None, deterministic: bool = True): + residual = hidden_states + hidden_states = self.norm1(hidden_states) + hidden_states = nn.swish(hidden_states) + hidden_states = self.conv1(hidden_states) + + if temb is not None: + hidden_states = hidden_states + self.temb_proj( + nn.swish(temb))[:, :, None, None] # TODO: check shapes + + hidden_states = self.norm2(hidden_states) + hidden_states = nn.swish(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic) + hidden_states = self.conv2(hidden_states) + + if self.in_channels != self.out_channels_: + if self.use_conv_shortcut: + residual = self.conv_shortcut(residual) + else: + residual = self.nin_shortcut(residual) + + return hidden_states + residual + + +class AttnBlock(nn.Module): + in_channels: int + dtype: jnp.dtype = jnp.float32 + + def setup(self): + conv = partial(nn.Conv, + self.in_channels, + kernel_size=(1, 1), + strides=(1, 1), + padding="VALID", + dtype=self.dtype) + + self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-6) + self.q, self.k, self.v = conv(), conv(), conv() + self.proj_out = conv() + + def __call__(self, hidden_states): + residual = hidden_states + hidden_states = self.norm(hidden_states) + + query = self.q(hidden_states) + key = self.k(hidden_states) + value = self.v(hidden_states) + + # compute attentions + batch, height, width, channels = query.shape + query = query.reshape((batch, height * width, channels)) + key = key.reshape((batch, height * width, channels)) + attn_weights = jnp.einsum("...qc,...kc->...qk", query, key) + attn_weights = attn_weights * (int(channels)**-0.5) + attn_weights = nn.softmax(attn_weights, axis=2) + + ## attend to values + value = value.reshape((batch, height * width, channels)) + hidden_states = jnp.einsum("...kc,...qk->...qc", value, attn_weights) + hidden_states = hidden_states.reshape((batch, height, width, channels)) + + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states + residual + return hidden_states + + +class UpsamplingBlock(nn.Module): + config: VQGANConfig + curr_res: int + block_idx: int + dtype: jnp.dtype = jnp.float32 + + def setup(self): + if self.block_idx == self.config.num_resolutions - 1: + block_in = self.config.ch * self.config.ch_mult[-1] + else: + block_in = self.config.ch * self.config.ch_mult[self.block_idx + 1] + + block_out = self.config.ch * self.config.ch_mult[self.block_idx] + self.temb_ch = 0 + + res_blocks = [] + attn_blocks = [] + for _ in range(self.config.num_res_blocks + 1): + res_blocks.append( + ResnetBlock(block_in, + block_out, + temb_channels=self.temb_ch, + dropout_prob=self.config.dropout, + dtype=self.dtype)) + block_in = block_out + if self.curr_res in self.config.attn_resolutions: + attn_blocks.append(AttnBlock(block_in, dtype=self.dtype)) + + self.block = res_blocks + self.attn = attn_blocks + + self.upsample = None + if self.block_idx != 0: + self.upsample = Upsample(block_in, + self.config.resamp_with_conv, + dtype=self.dtype) + + def __call__(self, hidden_states, temb=None, deterministic: bool = True): + for res_block in self.block: + hidden_states = res_block(hidden_states, + temb, + deterministic=deterministic) + for attn_block in self.attn: + hidden_states = attn_block(hidden_states) + + if self.upsample is not None: + hidden_states = self.upsample(hidden_states) + + return hidden_states + + +class DownsamplingBlock(nn.Module): + config: VQGANConfig + curr_res: int + block_idx: int + dtype: jnp.dtype = jnp.float32 + + def setup(self): + in_ch_mult = (1, ) + tuple(self.config.ch_mult) + block_in = self.config.ch * in_ch_mult[self.block_idx] + block_out = self.config.ch * self.config.ch_mult[self.block_idx] + self.temb_ch = 0 + + res_blocks = [] + attn_blocks = [] + for _ in range(self.config.num_res_blocks): + res_blocks.append( + ResnetBlock(block_in, + block_out, + temb_channels=self.temb_ch, + dropout_prob=self.config.dropout, + dtype=self.dtype)) + block_in = block_out + if self.curr_res in self.config.attn_resolutions: + attn_blocks.append(AttnBlock(block_in, dtype=self.dtype)) + + self.block = res_blocks + self.attn = attn_blocks + + self.downsample = None + if self.block_idx != self.config.num_resolutions - 1: + self.downsample = Downsample(block_in, + self.config.resamp_with_conv, + dtype=self.dtype) + + def __call__(self, hidden_states, temb=None, deterministic: bool = True): + for res_block in self.block: + hidden_states = res_block(hidden_states, + temb, + deterministic=deterministic) + for attn_block in self.attn: + hidden_states = attn_block(hidden_states) + + if self.downsample is not None: + hidden_states = self.downsample(hidden_states) + + return hidden_states + + +class MidBlock(nn.Module): + in_channels: int + temb_channels: int + dropout: float + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.block_1 = ResnetBlock( + self.in_channels, + self.in_channels, + temb_channels=self.temb_channels, + dropout_prob=self.dropout, + dtype=self.dtype, + ) + self.attn_1 = AttnBlock(self.in_channels, dtype=self.dtype) + self.block_2 = ResnetBlock( + self.in_channels, + self.in_channels, + temb_channels=self.temb_channels, + dropout_prob=self.dropout, + dtype=self.dtype, + ) + + def __call__(self, hidden_states, temb=None, deterministic: bool = True): + hidden_states = self.block_1(hidden_states, + temb, + deterministic=deterministic) + hidden_states = self.attn_1(hidden_states) + hidden_states = self.block_2(hidden_states, + temb, + deterministic=deterministic) + return hidden_states + + +class Encoder(nn.Module): + config: VQGANConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.temb_ch = 0 + + # downsampling + self.conv_in = nn.Conv( + self.config.ch, + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + curr_res = self.config.resolution + downsample_blocks = [] + for i_level in range(self.config.num_resolutions): + downsample_blocks.append( + DownsamplingBlock(self.config, + curr_res, + block_idx=i_level, + dtype=self.dtype)) + + if i_level != self.config.num_resolutions - 1: + curr_res = curr_res // 2 + self.down = downsample_blocks + + # middle + mid_channels = self.config.ch * self.config.ch_mult[-1] + self.mid = MidBlock(mid_channels, + self.temb_ch, + self.config.dropout, + dtype=self.dtype) + + # end + self.norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-6) + self.conv_out = nn.Conv( + 2 * self.config.z_channels + if self.config.double_z else self.config.z_channels, + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + def __call__(self, pixel_values, deterministic: bool = True): + # timestep embedding + temb = None + + # downsampling + hidden_states = self.conv_in(pixel_values) + for block in self.down: + hidden_states = block(hidden_states, temb, deterministic=deterministic) + + # middle + hidden_states = self.mid(hidden_states, temb, deterministic=deterministic) + + # end + hidden_states = self.norm_out(hidden_states) + hidden_states = nn.swish(hidden_states) + hidden_states = self.conv_out(hidden_states) + + return hidden_states + + +class Decoder(nn.Module): + config: VQGANConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.temb_ch = 0 + + # compute in_ch_mult, block_in and curr_res at lowest res + block_in = self.config.ch * self.config.ch_mult[self.config.num_resolutions + - 1] + curr_res = self.config.resolution // 2**(self.config.num_resolutions - 1) + self.z_shape = (1, self.config.z_channels, curr_res, curr_res) + print("Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = nn.Conv( + block_in, + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + # middle + self.mid = MidBlock(block_in, + self.temb_ch, + self.config.dropout, + dtype=self.dtype) + + # upsampling + upsample_blocks = [] + for i_level in reversed(range(self.config.num_resolutions)): + upsample_blocks.append( + UpsamplingBlock(self.config, + curr_res, + block_idx=i_level, + dtype=self.dtype)) + if i_level != 0: + curr_res = curr_res * 2 + self.up = list( + reversed(upsample_blocks)) # reverse to get consistent order + + # end + self.norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-6) + self.conv_out = nn.Conv( + self.config.out_ch, + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + def __call__(self, hidden_states, deterministic: bool = True): + # timestep embedding + temb = None + + # z to block_in + hidden_states = self.conv_in(hidden_states) + + # middle + hidden_states = self.mid(hidden_states, temb, deterministic=deterministic) + + # upsampling + for block in reversed(self.up): + hidden_states = block(hidden_states, temb, deterministic=deterministic) + + # end + if self.config.give_pre_end: + return hidden_states + + hidden_states = self.norm_out(hidden_states) + hidden_states = nn.swish(hidden_states) + hidden_states = self.conv_out(hidden_states) + + return hidden_states + + +class VectorQuantizer(nn.Module): + """ + see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py + ____________________________________________ + Discretization bottleneck part of the VQ-VAE. + Inputs: + - n_e : number of embeddings + - e_dim : dimension of embedding + - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2 + _____________________________________________ + """ + + config: VQGANConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.embedding = nn.Embed(self.config.n_embed, + self.config.embed_dim, + dtype=self.dtype) # TODO: init + + def __call__(self, hidden_states): + """ + Inputs the output of the encoder network z and maps it to a discrete + one-hot vector that is the index of the closest embedding vector e_j + z (continuous) -> z_q (discrete) + z.shape = (batch, channel, height, width) + quantization pipeline: + 1. get encoder input (B,C,H,W) + 2. flatten input to (B*H*W,C) + """ + # flatten + hidden_states_flattended = hidden_states.reshape( + (-1, self.config.embed_dim)) + + # dummy op to init the weights, so we can access them below + self.embedding(jnp.ones((1, 1), dtype="i4")) + + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + emb_weights = self.variables["params"]["embedding"]["embedding"] + distance = (jnp.sum(hidden_states_flattended**2, axis=1, keepdims=True) + + jnp.sum(emb_weights**2, axis=1) - + 2 * jnp.dot(hidden_states_flattended, emb_weights.T)) + + # get quantized latent vectors + min_encoding_indices = jnp.argmin(distance, axis=1) + z_q = self.embedding(min_encoding_indices).reshape(hidden_states.shape) + + # reshape to (batch, num_tokens) + min_encoding_indices = min_encoding_indices.reshape( + hidden_states.shape[0], -1) + + # compute the codebook_loss (q_loss) outside the model + # here we return the embeddings and indices + return z_q, min_encoding_indices + + def get_codebook_entry(self, indices, shape=None): + # indices are expected to be of shape (batch, num_tokens) + # get quantized latent vectors + batch, num_tokens = indices.shape + z_q = self.embedding(indices) + z_q = z_q.reshape(batch, int(math.sqrt(num_tokens)), + int(math.sqrt(num_tokens)), -1) + return z_q + + +class VQModule(nn.Module): + config: VQGANConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.encoder = Encoder(self.config, dtype=self.dtype) + self.decoder = Decoder(self.config, dtype=self.dtype) + self.quantize = VectorQuantizer(self.config, dtype=self.dtype) + self.quant_conv = nn.Conv( + self.config.embed_dim, + kernel_size=(1, 1), + strides=(1, 1), + padding="VALID", + dtype=self.dtype, + ) + self.post_quant_conv = nn.Conv( + self.config.z_channels, + kernel_size=(1, 1), + strides=(1, 1), + padding="VALID", + dtype=self.dtype, + ) + + def encode(self, pixel_values, deterministic: bool = True): + hidden_states = self.encoder(pixel_values, deterministic=deterministic) + hidden_states = self.quant_conv(hidden_states) + quant_states, indices = self.quantize(hidden_states) + return quant_states, indices + + def decode(self, hidden_states, deterministic: bool = True): + hidden_states = self.post_quant_conv(hidden_states) + hidden_states = self.decoder(hidden_states, deterministic=deterministic) + return hidden_states + + def decode_code(self, code_b): + hidden_states = self.quantize.get_codebook_entry(code_b) + hidden_states = self.decode(hidden_states) + return hidden_states + + def __call__(self, pixel_values, deterministic: bool = True): + quant_states, indices = self.encode(pixel_values, deterministic) + hidden_states = self.decode(quant_states, deterministic) + return hidden_states, indices + + +class VQGANPreTrainedModel(FlaxPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface + for downloading and loading pretrained models. + """ + + config_class = VQGANConfig + base_model_prefix = "model" + module_class: nn.Module = None + + def __init__( + self, + config: VQGANConfig, + input_shape: Tuple = (1, 256, 256, 3), + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + **kwargs, + ): + module = self.module_class(config=config, dtype=dtype, **kwargs) + super().__init__(config, + module, + input_shape=input_shape, + seed=seed, + dtype=dtype) + + def init_weights(self, rng: jax.random.PRNGKey, + input_shape: Tuple) -> FrozenDict: + # init input tensors + pixel_values = jnp.zeros(input_shape, dtype=jnp.float32) + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + return self.module.init(rngs, pixel_values)["params"] + + def encode(self, + pixel_values, + params: dict = None, + dropout_rng: jax.random.PRNGKey = None, + train: bool = False): + # Handle any PRNG if needed + rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} + + return self.module.apply({"params": params or self.params}, + jnp.array(pixel_values), + not train, + rngs=rngs, + method=self.module.encode) + + def decode(self, + hidden_states, + params: dict = None, + dropout_rng: jax.random.PRNGKey = None, + train: bool = False): + # Handle any PRNG if needed + rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} + + return self.module.apply( + {"params": params or self.params}, + jnp.array(hidden_states), + not train, + rngs=rngs, + method=self.module.decode, + ) + + def decode_code(self, indices, params: dict = None): + return self.module.apply({"params": params or self.params}, + jnp.array(indices, dtype="i4"), + method=self.module.decode_code) + + def __call__( + self, + pixel_values, + params: dict = None, + dropout_rng: jax.random.PRNGKey = None, + train: bool = False, + ): + # Handle any PRNG if needed + rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} + + return self.module.apply( + {"params": params or self.params}, + jnp.array(pixel_values), + not train, + rngs=rngs, + ) + + +class VQModel(VQGANPreTrainedModel): + module_class = VQModule