From ddffb62ff1ede61f30116f0f9fd0141508095b62 Mon Sep 17 00:00:00 2001 From: Martin Schubert Date: Thu, 23 May 2024 13:46:16 -0700 Subject: [PATCH] Fix rotational symmetry with batch --- src/totypes/symmetry.py | 9 +++++++-- tests/test_symmetry.py | 2 +- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/totypes/symmetry.py b/src/totypes/symmetry.py index d2c2795..d44e693 100644 --- a/src/totypes/symmetry.py +++ b/src/totypes/symmetry.py @@ -57,13 +57,18 @@ def _reflection_e_w(array: jnp.ndarray) -> jnp.ndarray: def _rotation_180(array: jnp.ndarray) -> jnp.ndarray: """Transform `array` to have 180-degree rotational symmetry.""" - return (array + jnp.rot90(array, 2)) / 2 + return (array + jnp.rot90(array, 2, axes=(-2, -1))) / 2 def _rotation_90(array: jnp.ndarray) -> jnp.ndarray: """Transform `array` to have 90-degree rotational symmetry.""" assert array.shape[-2] == array.shape[-1] - return (array + jnp.rot90(array, 1) + jnp.rot90(array, 2) + jnp.rot90(array, 3)) / 4 + return ( + array + + jnp.rot90(array, 1, axes=(-2, -1)) + + jnp.rot90(array, 2, axes=(-2, -1)) + + jnp.rot90(array, 3, axes=(-2, -1)) + ) / 4 SYMMETRY_FNS = { diff --git a/tests/test_symmetry.py b/tests/test_symmetry.py index 83bb4ef..17263a0 100644 --- a/tests/test_symmetry.py +++ b/tests/test_symmetry.py @@ -159,5 +159,5 @@ def test_multiple_symmetry(self, obj): ] ) def test_with_batch(self, sym): - arr = jnp.ones((1, 10, 13, 13)) + arr = jnp.ones((8, 13, 13)) symmetry.symmetrize(arr, (sym,))