Skip to content

Commit

Permalink
Merge pull request #31 from invrs-io/sym
Browse files Browse the repository at this point in the history
Fix batch dims for symmetry
  • Loading branch information
mfschubert authored Mar 5, 2024
2 parents faa5c1f + 7aaba99 commit acf2737
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 6 deletions.
2 changes: 1 addition & 1 deletion .bumpversion.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "v0.6.3"
current_version = "v0.6.4"
commit = true
commit_args = "--no-verify"
tag = true
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# totypes - Custom types for topology optimization
`v0.6.3`
`v0.6.4`

## Overview

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]

name = "totypes"
version = "v0.6.3"
version = "v0.6.4"
description = "Custom datatypes useful in a topology optimization context"
keywords = ["topology", "optimization", "jax", "inverse design"]
readme = "README.md"
Expand Down
2 changes: 1 addition & 1 deletion src/totypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Copyright (c) 2023 The INVRS-IO authors.
"""

__version__ = "v0.6.3"
__version__ = "v0.6.4"
__author__ = "Martin F. Schubert <[email protected]>"

__all__ = ["json_utils", "symmetry", "types"]
Expand Down
4 changes: 2 additions & 2 deletions src/totypes/symmetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@ def symmetrize(
def _reflection_ne_sw(array: jnp.ndarray) -> jnp.ndarray:
"""Transform `array` to have reflection symmetry about the ne-sw axis."""
assert array.shape[-2] == array.shape[-1]
return (array + array[..., ::-1, ::-1].T) / 2
return (array + jnp.swapaxes(array[..., ::-1, ::-1], -2, -1)) / 2


def _reflection_nw_se(array: jnp.ndarray) -> jnp.ndarray:
"""Transform `array` to have reflection symmetry about the nw-se axis."""
assert array.shape[-2] == array.shape[-1]
return (array + array.T) / 2
return (array + jnp.swapaxes(array, -2, -1)) / 2


def _reflection_n_s(array: jnp.ndarray) -> jnp.ndarray:
Expand Down
14 changes: 14 additions & 0 deletions tests/test_symmetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,17 @@ def test_multiple_symmetry(self, obj):
]
)
_assert_array_equal(result, expected)

@parameterized.expand(
[
[symmetry.REFLECTION_E_W],
[symmetry.REFLECTION_N_S],
[symmetry.REFLECTION_NE_SW],
[symmetry.REFLECTION_NW_SE],
[symmetry.ROTATION_180],
[symmetry.ROTATION_90],
]
)
def test_with_batch(self, sym):
arr = jnp.ones((1, 10, 13, 13))
symmetry.symmetrize(arr, (sym,))

0 comments on commit acf2737

Please sign in to comment.