From 5329b548793fb4ff225239894798ec56b5d0e0cf Mon Sep 17 00:00:00 2001 From: Martin Schubert Date: Thu, 30 Nov 2023 15:12:17 -0800 Subject: [PATCH] Fix nested custom dataclasses --- src/totypes/json_utils.py | 4 +++- tests/test_json_utils.py | 44 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/src/totypes/json_utils.py b/src/totypes/json_utils.py index ee3b01d..49fac31 100644 --- a/src/totypes/json_utils.py +++ b/src/totypes/json_utils.py @@ -214,7 +214,9 @@ def _convert_array(arr: Union[onp.ndarray, jnp.ndarray]) -> Tuple[str, Dict[str, def _asdict(x: Any) -> Dict[str, Any]: """Converts dataclasses or namedtuples to dictionaries.""" if dataclasses.is_dataclass(x): - return dataclasses.asdict(x) + return dict( + [(field.name, getattr(x, field.name)) for field in dataclasses.fields(x)] + ) try: return x._asdict() # type: ignore[no-any-return] except AttributeError as exc: diff --git a/tests/test_json_utils.py b/tests/test_json_utils.py index 2855e9d..3b668a9 100644 --- a/tests/test_json_utils.py +++ b/tests/test_json_utils.py @@ -222,6 +222,50 @@ class CustomObject(NamedTuple): self.assertEqual(restored.y, obj.y) self.assertEqual(restored.z, obj.z) + def test_serialize_with_custom_namedtuple_having_internal_custom_type(self): + class CustomObject(NamedTuple): + x: types.Density2DArray + y: types.BoundedArray + z: str + + json_utils.register_custom_type(CustomObject) + + obj = CustomObject( + x=types.Density2DArray(array=onp.zeros((5, 5))), + y=types.BoundedArray(onp.ones((3,)), 0, 2), + z="test", + ) + + serialized = json_utils.json_from_pytree(obj) + restored = json_utils.pytree_from_json(serialized) + self.assertIsInstance(restored, CustomObject) + self.assertIsInstance(restored.x, types.Density2DArray) + self.assertIsInstance(restored.y, types.BoundedArray) + self.assertIsInstance(restored.z, str) + + def test_serialize_with_custom_dataclass_having_internal_custom_type(self): + @dataclasses.dataclass + class CustomObject: + x: types.Density2DArray + y: types.BoundedArray + z: str + + json_utils.register_custom_type(CustomObject) + + obj = CustomObject( + x=types.Density2DArray(array=onp.zeros((5, 5))), + y=types.BoundedArray(onp.ones((3,)), 0, 2), + z="test", + ) + + serialized = json_utils.json_from_pytree(obj) + print(serialized) + restored = json_utils.pytree_from_json(serialized) + self.assertIsInstance(restored, CustomObject) + self.assertIsInstance(restored.x, types.Density2DArray) + self.assertIsInstance(restored.y, types.BoundedArray) + self.assertIsInstance(restored.z, str) + def test_serialize_with_registered_custom_dataclass(self): @dataclasses.dataclass class CustomObject: