diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 0000000..efac170 --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,32 @@ +name: Run tests + +on: + push: + +jobs: + run-test: + strategy: + matrix: + python-version: [ 3.12 ] + os: [ ubuntu-latest ] + fail-fast: false + + runs-on: ${{ matrix.os }} + steps: + - name: Checkout code + uses: actions/checkout@v2 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install -r ./tests/requirements.txt + + - name: Test with pytest + run: | + python -m pip install . + python -m pytest \ No newline at end of file diff --git a/.gitignore b/.gitignore index 5b3ed6f..6a03f8f 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,5 @@ __pycache__/ imgs/ exps/ _fisher.py -_set_transformer.py \ No newline at end of file +_set_transformer.py +.pytest_cacche/ \ No newline at end of file diff --git a/sbgm/models/_mixer.py b/sbgm/models/_mixer.py index 63545b5..9981b99 100644 --- a/sbgm/models/_mixer.py +++ b/sbgm/models/_mixer.py @@ -147,9 +147,11 @@ def __init__( num_patches = (height // patch_size) * (width // patch_size) inkey, outkey, *bkeys = jr.split(key, 2 + num_blocks) + _input_size = input_size + q_dim if q_dim is not None else input_size + _context_dim = embedding_dim + a_dim if a_dim is not None else embedding_dim + self.conv_in = eqx.nn.Conv2d( - # input_size + 1, # Time is tiled along with time as channels - input_size, # Time is tiled along with time as channels + _input_size, hidden_size, patch_size, stride=patch_size, @@ -168,7 +170,7 @@ def __init__( hidden_size, mix_patch_size, mix_hidden_size, - context_dim=a_dim + embedding_dim, + context_dim=_context_dim, key=bkey ) for bkey in bkeys diff --git a/sbgm/models/_mlp.py b/sbgm/models/_mlp.py index 78e2cb2..3304695 100644 --- a/sbgm/models/_mlp.py +++ b/sbgm/models/_mlp.py @@ -78,8 +78,8 @@ def __call__( self, t: Union[float, Array], x: Array, - q: Array, - a: Array, + q: Optional[Array], + a: Optional[Array], *, key: Key = None ) -> Array: diff --git a/sbgm/models/_unet.py b/sbgm/models/_unet.py index 669bde4..e73cfe3 100644 --- a/sbgm/models/_unet.py +++ b/sbgm/models/_unet.py @@ -5,7 +5,8 @@ import jax.numpy as jnp import jax.random as jr import equinox as eqx -from jaxtyping import Key, Array +from jaxtyping import Key, Array, jaxtyped +from beartype import beartype as typechecker from einops import rearrange @@ -19,7 +20,7 @@ def __init__(self, dim: int): def __call__(self, x: Array) -> Array: emb = x * self.emb - emb = jnp.concatenate((jnp.sin(emb), jnp.cos(emb)), axis=-1) + emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=-1) return emb @@ -533,96 +534,4 @@ def __call__( for layer in self.final_conv_layers: h = layer(h) - return self.final_activation(h) if self.final_activation is not None else h - - -if __name__ == "__main__": - - key = jr.key(0) - - x = jnp.ones((1, 32, 32)) - t = jnp.ones((1,)) - - q = 1 - a = 2 - - unet = UNet( - x.shape, - is_biggan=False, - dim_mults=[1, 1, 1], - hidden_size=32, - heads=2, - dim_head=32, - dropout_rate=0.1, - num_res_blocks=2, - attn_resolutions=[8, 16, 32], - q_dim=q, - a_dim=a, - key=jr.key(0) - ) - - q = jnp.ones((1, 32, 32)) - a = jnp.ones((2,)) - print(unet(t, x, q=q, a=a, key=key).shape) - - q = None - a = None - - unet = UNet( - x.shape, - is_biggan=False, - dim_mults=[1, 1, 1], - hidden_size=32, - heads=2, - dim_head=32, - dropout_rate=0.1, - num_res_blocks=2, - attn_resolutions=[8, 16, 32], - q_dim=q, - a_dim=a, - key=jr.key(0) - ) - - print(unet(t, x, key=key).shape) - - q = 1 - a = None - - unet = UNet( - x.shape, - is_biggan=False, - dim_mults=[1, 1, 1], - hidden_size=32, - heads=2, - dim_head=32, - dropout_rate=0.1, - num_res_blocks=2, - attn_resolutions=[8, 16, 32], - q_dim=q, - a_dim=a, - key=jr.key(0) - ) - - q = jnp.ones((1, 32, 32)) - print(unet(t, x, q=q, key=key).shape) - - q = None - a = 2 - - unet = UNet( - x.shape, - is_biggan=False, - dim_mults=[1, 1, 1], - hidden_size=32, - heads=2, - dim_head=32, - dropout_rate=0.1, - num_res_blocks=2, - attn_resolutions=[8, 16, 32], - q_dim=q, - a_dim=a, - key=jr.key(0) - ) - - a = jnp.ones((2,)) - print(unet(t, x, a=a, key=key).shape) \ No newline at end of file + return self.final_activation(h) if self.final_activation is not None else h \ No newline at end of file diff --git a/tests/requirements.txt b/tests/requirements.txt new file mode 100644 index 0000000..55b033e --- /dev/null +++ b/tests/requirements.txt @@ -0,0 +1 @@ +pytest \ No newline at end of file diff --git a/tests/test.py b/tests/test.py new file mode 100644 index 0000000..9206193 --- /dev/null +++ b/tests/test.py @@ -0,0 +1,11 @@ +import jax.numpy as jnp + +""" + Testing testing + - pytest, pytest-cov in pyproject.toml requirements? + - test files: test_*.py + - test funcs in files: test_*(): +""" + +def test(): + assert jnp.square(2) == 4 \ No newline at end of file diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..145a6b8 --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,239 @@ +import jax +import jax.numpy as jnp +import jax.random as jr + +from sbgm.models import UNet, ResidualNetwork, Mixer2d + + +def test_resnet(): + + key = jr.key(0) + + x = jnp.ones((5,)) + t = jnp.ones((1,)) + a = jnp.ones((3,)) + q = None + + net = ResidualNetwork( + x.size, + width_size=32, + depth=2, + a_dim=a.size, + dropout_p=0.1, + activation=jax.nn.tanh, + key=key + ) + + out = net(t, x, q=q, a=a, key=key) + assert out.shape == x.shape + + net = ResidualNetwork( + x.size, + width_size=32, + depth=2, + dropout_p=0.1, + activation=jax.nn.tanh, + key=key + ) + + out = net(t, x, q=q, a=a, key=key) + assert out.shape == x.shape + + +def test_mixer(): + + key = jr.key(0) + + x = jnp.ones((1, 32, 32)) + t = jnp.ones((1,)) + + q = jnp.ones((1, 32, 32)) + a = jnp.ones((5,)) + + q_dim = 1 + a_dim = 5 + + net = Mixer2d( + x.shape, + patch_size=2, + hidden_size=256, + mix_patch_size=4, + mix_hidden_size=256, + num_blocks=3, + t1=1.0, + embedding_dim=8, + q_dim=q_dim, + a_dim=a_dim, + key=key + ) + + out = net(t, x, q=q, a=a, key=key) + assert out.shape == x.shape + + q = None + a = jnp.ones((5,)) + + q_dim = None + a_dim = 5 + + net = Mixer2d( + x.shape, + patch_size=2, + mix_patch_size=4, + hidden_size=256, + mix_hidden_size=256, + num_blocks=3, + t1=1.0, + embedding_dim=8, + q_dim=q_dim, + a_dim=a_dim, + key=key + ) + + out = net(t, x, q=q, a=a, key=key) + assert out.shape == x.shape + + q = jnp.ones((1, 32, 32)) + a = None + + q_dim = 1 + a_dim = None + + net = Mixer2d( + x.shape, + patch_size=2, + mix_patch_size=4, + hidden_size=256, + mix_hidden_size=256, + num_blocks=3, + t1=1.0, + embedding_dim=8, + q_dim=q_dim, + a_dim=a_dim, + key=key + ) + + out = net(t, x, q=q, a=a, key=key) + assert out.shape == x.shape + + q = None + a = None + + q_dim = None + a_dim = None + + net = Mixer2d( + x.shape, + patch_size=2, + mix_patch_size=4, + hidden_size=256, + mix_hidden_size=256, + num_blocks=3, + t1=1.0, + embedding_dim=8, + q_dim=q_dim, + a_dim=a_dim, + key=key + ) + + out = net(t, x, q=q, a=a, key=key) + assert out.shape == x.shape + + + +def test_unet(): + + key = jr.key(0) + + x = jnp.ones((1, 32, 32)) + t = jnp.ones((1,)) + + q_dim = 1 + a_dim = 2 + + unet = UNet( + x.shape, + is_biggan=False, + dim_mults=[1, 1, 1], + hidden_size=32, + heads=2, + dim_head=32, + dropout_rate=0.1, + num_res_blocks=2, + attn_resolutions=[8, 16, 32], + q_dim=q_dim, + a_dim=a_dim, + key=key + ) + + q = jnp.ones((1, 32, 32)) + a = jnp.ones((2,)) + + out = unet(t, x, q=q, a=a, key=key) + assert out.shape == x.shape + + q_dim = None + a_dim = None + + unet = UNet( + x.shape, + is_biggan=False, + dim_mults=[1, 1, 1], + hidden_size=32, + heads=2, + dim_head=32, + dropout_rate=0.1, + num_res_blocks=2, + attn_resolutions=[8, 16, 32], + q_dim=q_dim, + a_dim=a_dim, + key=key + ) + + out = unet(t, x, key=key) + assert out.shape == x.shape + + q_dim = 1 + a_dim = None + + unet = UNet( + x.shape, + is_biggan=False, + dim_mults=[1, 1, 1], + hidden_size=32, + heads=2, + dim_head=32, + dropout_rate=0.1, + num_res_blocks=2, + attn_resolutions=[8, 16, 32], + q_dim=q_dim, + a_dim=a_dim, + key=key + ) + + q = jnp.ones((1, 32, 32)) + + out = unet(t, x, q=q, key=key) + assert out.shape == x.shape + + q_dim = None + a_dim = 2 + + unet = UNet( + x.shape, + is_biggan=False, + dim_mults=[1, 1, 1], + hidden_size=32, + heads=2, + dim_head=32, + dropout_rate=0.1, + num_res_blocks=2, + attn_resolutions=[8, 16, 32], + q_dim=q_dim, + a_dim=a_dim, + key=key + ) + + a = jnp.ones((2,)) + out = unet(t, x, a=a, key=key) + assert out.shape == x.shape \ No newline at end of file