Skip to content

Initial release

Compare
Choose a tag to compare
@vballoli vballoli released this 05 Oct 16:43
· 28 commits to main since this release
5459e26

Vision Transformer in Flax

This repository implements Vision Trasnformer(ViT) in Flax, introduced in an ICLR paper 2021 submission, with further explanation by Yannic Kilcher. This repository is heavily inspired from lucidrain's implementation.

Install

pip install vit-flax

Usage

import jax
from jax import numpy as jnp
from flax import nn
from vit_flax import ViT

rng = jax.random.PRNGKey(0)
module = ViT.partial(patch_size=32, dim=1024, depth=6, num_heads=8, dense_dims=(2048, 2048), img_size=256, num_classes=10)
_, initial_params = module.init_by_shape(
  rng, [((1, 256, 256, 3), jnp.float32)]
)
model = nn.Model(module, initial_params)

img = jax.random.uniform(rng, (1,256,256,3))
output = model(img)

Note: This repository is still in initial stages. Feel free to Contact me or raise issues/PR for suggestions, improvements or bugs.