Skip to content

Commit

Permalink
initial tests (models) and workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
homerjed committed Sep 29, 2024
1 parent 06aa45e commit 40268eb
Show file tree
Hide file tree
Showing 8 changed files with 296 additions and 101 deletions.
32 changes: 32 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
name: Run tests

on:
push:

jobs:
run-test:
strategy:
matrix:
python-version: [ 3.12 ]
os: [ ubuntu-latest ]
fail-fast: false

runs-on: ${{ matrix.os }}
steps:
- name: Checkout code
uses: actions/checkout@v2

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}

- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install -r ./tests/requirements.txt
- name: Test with pytest
run: |
python -m pip install .
python -m pytest
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ __pycache__/
imgs/
exps/
_fisher.py
_set_transformer.py
_set_transformer.py
.pytest_cacche/
8 changes: 5 additions & 3 deletions sbgm/models/_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,11 @@ def __init__(
num_patches = (height // patch_size) * (width // patch_size)
inkey, outkey, *bkeys = jr.split(key, 2 + num_blocks)

_input_size = input_size + q_dim if q_dim is not None else input_size
_context_dim = embedding_dim + a_dim if a_dim is not None else embedding_dim

self.conv_in = eqx.nn.Conv2d(
# input_size + 1, # Time is tiled along with time as channels
input_size, # Time is tiled along with time as channels
_input_size,
hidden_size,
patch_size,
stride=patch_size,
Expand All @@ -168,7 +170,7 @@ def __init__(
hidden_size,
mix_patch_size,
mix_hidden_size,
context_dim=a_dim + embedding_dim,
context_dim=_context_dim,
key=bkey
)
for bkey in bkeys
Expand Down
4 changes: 2 additions & 2 deletions sbgm/models/_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ def __call__(
self,
t: Union[float, Array],
x: Array,
q: Array,
a: Array,
q: Optional[Array],
a: Optional[Array],
*,
key: Key = None
) -> Array:
Expand Down
99 changes: 4 additions & 95 deletions sbgm/models/_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import jax.numpy as jnp
import jax.random as jr
import equinox as eqx
from jaxtyping import Key, Array
from jaxtyping import Key, Array, jaxtyped
from beartype import beartype as typechecker
from einops import rearrange


Expand All @@ -19,7 +20,7 @@ def __init__(self, dim: int):

def __call__(self, x: Array) -> Array:
emb = x * self.emb
emb = jnp.concatenate((jnp.sin(emb), jnp.cos(emb)), axis=-1)
emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=-1)
return emb


Expand Down Expand Up @@ -533,96 +534,4 @@ def __call__(

for layer in self.final_conv_layers:
h = layer(h)
return self.final_activation(h) if self.final_activation is not None else h


if __name__ == "__main__":

key = jr.key(0)

x = jnp.ones((1, 32, 32))
t = jnp.ones((1,))

q = 1
a = 2

unet = UNet(
x.shape,
is_biggan=False,
dim_mults=[1, 1, 1],
hidden_size=32,
heads=2,
dim_head=32,
dropout_rate=0.1,
num_res_blocks=2,
attn_resolutions=[8, 16, 32],
q_dim=q,
a_dim=a,
key=jr.key(0)
)

q = jnp.ones((1, 32, 32))
a = jnp.ones((2,))
print(unet(t, x, q=q, a=a, key=key).shape)

q = None
a = None

unet = UNet(
x.shape,
is_biggan=False,
dim_mults=[1, 1, 1],
hidden_size=32,
heads=2,
dim_head=32,
dropout_rate=0.1,
num_res_blocks=2,
attn_resolutions=[8, 16, 32],
q_dim=q,
a_dim=a,
key=jr.key(0)
)

print(unet(t, x, key=key).shape)

q = 1
a = None

unet = UNet(
x.shape,
is_biggan=False,
dim_mults=[1, 1, 1],
hidden_size=32,
heads=2,
dim_head=32,
dropout_rate=0.1,
num_res_blocks=2,
attn_resolutions=[8, 16, 32],
q_dim=q,
a_dim=a,
key=jr.key(0)
)

q = jnp.ones((1, 32, 32))
print(unet(t, x, q=q, key=key).shape)

q = None
a = 2

unet = UNet(
x.shape,
is_biggan=False,
dim_mults=[1, 1, 1],
hidden_size=32,
heads=2,
dim_head=32,
dropout_rate=0.1,
num_res_blocks=2,
attn_resolutions=[8, 16, 32],
q_dim=q,
a_dim=a,
key=jr.key(0)
)

a = jnp.ones((2,))
print(unet(t, x, a=a, key=key).shape)
return self.final_activation(h) if self.final_activation is not None else h
1 change: 1 addition & 0 deletions tests/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pytest
11 changes: 11 additions & 0 deletions tests/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import jax.numpy as jnp

"""
Testing testing
- pytest, pytest-cov in pyproject.toml requirements?
- test files: test_*.py
- test funcs in files: test_*():
"""

def test():
assert jnp.square(2) == 4
Loading

0 comments on commit 40268eb

Please sign in to comment.