Skip to content

Commit

Permalink
Fix rotational symmetry with batch
Browse files Browse the repository at this point in the history
  • Loading branch information
mfschubert committed May 23, 2024
1 parent b773386 commit ddffb62
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
9 changes: 7 additions & 2 deletions src/totypes/symmetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,18 @@ def _reflection_e_w(array: jnp.ndarray) -> jnp.ndarray:

def _rotation_180(array: jnp.ndarray) -> jnp.ndarray:
"""Transform `array` to have 180-degree rotational symmetry."""
return (array + jnp.rot90(array, 2)) / 2
return (array + jnp.rot90(array, 2, axes=(-2, -1))) / 2


def _rotation_90(array: jnp.ndarray) -> jnp.ndarray:
"""Transform `array` to have 90-degree rotational symmetry."""
assert array.shape[-2] == array.shape[-1]
return (array + jnp.rot90(array, 1) + jnp.rot90(array, 2) + jnp.rot90(array, 3)) / 4
return (
array
+ jnp.rot90(array, 1, axes=(-2, -1))
+ jnp.rot90(array, 2, axes=(-2, -1))
+ jnp.rot90(array, 3, axes=(-2, -1))
) / 4


SYMMETRY_FNS = {
Expand Down
2 changes: 1 addition & 1 deletion tests/test_symmetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,5 +159,5 @@ def test_multiple_symmetry(self, obj):
]
)
def test_with_batch(self, sym):
arr = jnp.ones((1, 10, 13, 13))
arr = jnp.ones((8, 13, 13))
symmetry.symmetrize(arr, (sym,))

0 comments on commit ddffb62

Please sign in to comment.