From 9a818838d0d0fd63760edffcc62889208188f2ff Mon Sep 17 00:00:00 2001 From: "Keto D. Zhang" Date: Fri, 24 May 2024 23:04:54 -0700 Subject: [PATCH] Add test if type with numpy array serializes back to ndarray --- tests/patterns/numpy_type_test.py | 48 +++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 tests/patterns/numpy_type_test.py diff --git a/tests/patterns/numpy_type_test.py b/tests/patterns/numpy_type_test.py new file mode 100644 index 0000000..ecce787 --- /dev/null +++ b/tests/patterns/numpy_type_test.py @@ -0,0 +1,48 @@ +import asdf +import numpy as np +from asdf.extension import Extension +from numpy.typing import NDArray + +from asdf_pydantic import AsdfPydanticConverter, AsdfPydanticModel + + +class ArrayContainer(AsdfPydanticModel): + _tag = "asdf://asdf-pydantic/examples/tags/array-container-1.0.0" + + array: NDArray # Equivalently np.ndarray + + +def setup_module(): + """Register the ArrayContainer model with the AsdfPydanticConverter. + + Pytest will run this function before the tests in this module. + """ + AsdfPydanticConverter.add_models(ArrayContainer) + + class TestExtension(Extension): + extension_uri = "asdf://asdf-pydantic/examples/extensions/test-1.0.0" + + converters = [AsdfPydanticConverter()] # type: ignore + tags = [*AsdfPydanticConverter().tags] # type: ignore + + asdf.get_config().add_extension(TestExtension()) + + +######################################################################################## +# Test Cases +######################################################################################## + + +def test_convert_ArrayContainer_to_asdf(tmp_path): + """When writing ArrayContainer to an ASDF file, the array field should be + serialized to the original numpy array. + """ + af = asdf.AsdfFile({"data": ArrayContainer(array=np.array([1, 2, 3]))}).write_to( + tmp_path / "test.asdf" + ) + + with asdf.open(tmp_path / "test.asdf") as af: + assert isinstance(af.tree["array"], np.ndarray), ( + f"Expected {type(np.ndarray)}, " f"got {type(af.tree['array'])}" + ) + assert np.all(af.tree["array"] == np.array([1, 2, 3]))