Skip to content

Commit

Permalink
Fix nested custom dataclasses
Browse files Browse the repository at this point in the history
  • Loading branch information
mfschubert committed Nov 30, 2023
1 parent 7e53924 commit 5329b54
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/totypes/json_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,9 @@ def _convert_array(arr: Union[onp.ndarray, jnp.ndarray]) -> Tuple[str, Dict[str,
def _asdict(x: Any) -> Dict[str, Any]:
"""Converts dataclasses or namedtuples to dictionaries."""
if dataclasses.is_dataclass(x):
return dataclasses.asdict(x)
return dict(
[(field.name, getattr(x, field.name)) for field in dataclasses.fields(x)]
)
try:
return x._asdict() # type: ignore[no-any-return]
except AttributeError as exc:
Expand Down
44 changes: 44 additions & 0 deletions tests/test_json_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,50 @@ class CustomObject(NamedTuple):
self.assertEqual(restored.y, obj.y)
self.assertEqual(restored.z, obj.z)

def test_serialize_with_custom_namedtuple_having_internal_custom_type(self):
class CustomObject(NamedTuple):
x: types.Density2DArray
y: types.BoundedArray
z: str

json_utils.register_custom_type(CustomObject)

obj = CustomObject(
x=types.Density2DArray(array=onp.zeros((5, 5))),
y=types.BoundedArray(onp.ones((3,)), 0, 2),
z="test",
)

serialized = json_utils.json_from_pytree(obj)
restored = json_utils.pytree_from_json(serialized)
self.assertIsInstance(restored, CustomObject)
self.assertIsInstance(restored.x, types.Density2DArray)
self.assertIsInstance(restored.y, types.BoundedArray)
self.assertIsInstance(restored.z, str)

def test_serialize_with_custom_dataclass_having_internal_custom_type(self):
@dataclasses.dataclass
class CustomObject:
x: types.Density2DArray
y: types.BoundedArray
z: str

json_utils.register_custom_type(CustomObject)

obj = CustomObject(
x=types.Density2DArray(array=onp.zeros((5, 5))),
y=types.BoundedArray(onp.ones((3,)), 0, 2),
z="test",
)

serialized = json_utils.json_from_pytree(obj)
print(serialized)
restored = json_utils.pytree_from_json(serialized)
self.assertIsInstance(restored, CustomObject)
self.assertIsInstance(restored.x, types.Density2DArray)
self.assertIsInstance(restored.y, types.BoundedArray)
self.assertIsInstance(restored.z, str)

def test_serialize_with_registered_custom_dataclass(self):
@dataclasses.dataclass
class CustomObject:
Expand Down

0 comments on commit 5329b54

Please sign in to comment.