diff --git a/.bumpversion.toml b/.bumpversion.toml index 7ee8bec..0a15ce5 100644 --- a/.bumpversion.toml +++ b/.bumpversion.toml @@ -1,5 +1,5 @@ [tool.bumpversion] -current_version = "v0.4.0" +current_version = "v0.4.1" commit = true commit_args = "--no-verify" tag = true diff --git a/README.md b/README.md index 89c5dbd..07e86fd 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ # totypes - Custom types for topology optimization -`v0.4.0` +`v0.4.1` ## Overview diff --git a/pyproject.toml b/pyproject.toml index 4dab19f..8505e87 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "totypes" -version = "v0.4.0" +version = "v0.4.1" description = "Custom datatypes useful in a topology optimization context" keywords = ["topology", "optimization", "jax", "inverse design"] readme = "README.md" diff --git a/src/totypes/__init__.py b/src/totypes/__init__.py index a1572fe..0625a5b 100644 --- a/src/totypes/__init__.py +++ b/src/totypes/__init__.py @@ -3,7 +3,7 @@ Copyright (c) 2023 The INVRS-IO authors. """ -__version__ = "v0.4.0" +__version__ = "v0.4.1" __author__ = "Martin F. Schubert " __all__ = ["json_utils", "symmetry", "types"] diff --git a/src/totypes/json_utils.py b/src/totypes/json_utils.py index ee3b01d..e393153 100644 --- a/src/totypes/json_utils.py +++ b/src/totypes/json_utils.py @@ -214,7 +214,7 @@ def _convert_array(arr: Union[onp.ndarray, jnp.ndarray]) -> Tuple[str, Dict[str, def _asdict(x: Any) -> Dict[str, Any]: """Converts dataclasses or namedtuples to dictionaries.""" if dataclasses.is_dataclass(x): - return dataclasses.asdict(x) + return {field.name: getattr(x, field.name) for field in dataclasses.fields(x)} try: return x._asdict() # type: ignore[no-any-return] except AttributeError as exc: diff --git a/tests/test_json_utils.py b/tests/test_json_utils.py index 2855e9d..fa0fa5f 100644 --- a/tests/test_json_utils.py +++ b/tests/test_json_utils.py @@ -222,6 +222,49 @@ class CustomObject(NamedTuple): self.assertEqual(restored.y, obj.y) self.assertEqual(restored.z, obj.z) + def test_serialize_with_custom_namedtuple_having_internal_custom_type(self): + class CustomObject(NamedTuple): + x: types.Density2DArray + y: types.BoundedArray + z: str + + json_utils.register_custom_type(CustomObject) + + obj = CustomObject( + x=types.Density2DArray(array=onp.zeros((5, 5))), + y=types.BoundedArray(onp.ones((3,)), 0, 2), + z="test", + ) + + serialized = json_utils.json_from_pytree(obj) + restored = json_utils.pytree_from_json(serialized) + self.assertIsInstance(restored, CustomObject) + self.assertIsInstance(restored.x, types.Density2DArray) + self.assertIsInstance(restored.y, types.BoundedArray) + self.assertIsInstance(restored.z, str) + + def test_serialize_with_custom_dataclass_having_internal_custom_type(self): + @dataclasses.dataclass + class CustomObject: + x: types.Density2DArray + y: types.BoundedArray + z: str + + json_utils.register_custom_type(CustomObject) + + obj = CustomObject( + x=types.Density2DArray(array=onp.zeros((5, 5))), + y=types.BoundedArray(onp.ones((3,)), 0, 2), + z="test", + ) + + serialized = json_utils.json_from_pytree(obj) + restored = json_utils.pytree_from_json(serialized) + self.assertIsInstance(restored, CustomObject) + self.assertIsInstance(restored.x, types.Density2DArray) + self.assertIsInstance(restored.y, types.BoundedArray) + self.assertIsInstance(restored.z, str) + def test_serialize_with_registered_custom_dataclass(self): @dataclasses.dataclass class CustomObject: @@ -270,7 +313,6 @@ class MyClass123: ["MyClass123" in key for key in json_utils._CUSTOM_TYPE_REGISTRY.keys()] ) ) - print(type(MyClass123())) with self.assertRaisesRegex( ValueError, "`custom_type` must be a type, but got" ):