From 4ad5bde4bf2dffc2ed83678864d26fdd97db7383 Mon Sep 17 00:00:00 2001 From: Mekhla Kapoor <54870020+mekhlakapoor@users.noreply.github.com> Date: Tue, 15 Oct 2024 16:24:06 -0700 Subject: [PATCH] serialization context --- src/aind_slims_api/core.py | 5 +-- src/aind_slims_api/models/base.py | 39 +++++++++---------- .../operations/ecephys_session.py | 7 ++-- tests/test_slimsmodel.py | 22 ++++++++++- 4 files changed, 44 insertions(+), 29 deletions(-) diff --git a/src/aind_slims_api/core.py b/src/aind_slims_api/core.py index 7e44ffd..70c6c83 100644 --- a/src/aind_slims_api/core.py +++ b/src/aind_slims_api/core.py @@ -353,6 +353,7 @@ def add_model( exclude=fields_to_exclude, **kwargs, by_alias=True, + context="slims_post", ), ) return type(model).model_validate(rtn) @@ -377,9 +378,7 @@ def update_model(self, model: SlimsBaseModel, *args, **kwargs): model._slims_table, model.pk, model.model_dump( - include=fields_to_include, - by_alias=True, - **kwargs, + include=fields_to_include, by_alias=True, **kwargs, context="slims_post" ), ) return type(model).model_validate(rtn) diff --git a/src/aind_slims_api/models/base.py b/src/aind_slims_api/models/base.py index 3064b86..61e6571 100644 --- a/src/aind_slims_api/models/base.py +++ b/src/aind_slims_api/models/base.py @@ -3,8 +3,14 @@ import logging from datetime import datetime -from typing import ClassVar, Optional, Dict, Any -from pydantic import BaseModel, ValidationInfo, field_serializer, field_validator +from typing import ClassVar, Optional +from pydantic import ( + BaseModel, + ValidationInfo, + field_serializer, + field_validator, + SerializationInfo, +) from slims.internal import Column as SlimsColumn # type: ignore from aind_slims_api.models.utils import _find_unit_spec @@ -62,31 +68,22 @@ def _validate(cls, value, info: ValidationInfo): return value @field_serializer("*") - def _serialize(self, field, info): - """Serialize a field, accounts for Quantities and datetime""" + def _serialize(self, field, info: SerializationInfo): + """Serialize a field, accounts for Quantities and datetime.""" unit_spec = _find_unit_spec(self.model_fields[info.field_name]) if unit_spec and field is not None: - quantity = { - "amount": field, - "unit_display": unit_spec.preferred_unit, - } - return quantity + if info.context == "slims_post": + quantity = { + "amount": field, + "unit_display": unit_spec.preferred_unit, + } + return quantity + else: + return field elif isinstance(field, datetime): return int(field.timestamp() * 10**3) else: return field - def model_dump(self, serialize_quantity=True, *args, **kwargs) -> Dict[str, Any]: - """Override model_dump to handle UnitSpec serialization.""" - data = super().model_dump(*args, **kwargs) - - # Update serialized fields with UnitSpec information - if not serialize_quantity: - for key, value in data.items(): - if isinstance(value, dict) and "amount" in value: - # Extract the amount - data[key] = value["amount"] - return data - # TODO: Add links - need Record.json_entity['links']['self'] # TODO: Add Table - need Record.json_entity['tableName'] diff --git a/src/aind_slims_api/operations/ecephys_session.py b/src/aind_slims_api/operations/ecephys_session.py index 144b7b8..e35c208 100644 --- a/src/aind_slims_api/operations/ecephys_session.py +++ b/src/aind_slims_api/operations/ecephys_session.py @@ -36,12 +36,14 @@ class SlimsRewardDeliveryInfo(BaseModel): class SlimsStreamModule(SlimsDomeModuleRdrc): """DomeModule Wrapper to add linked brain structure models""" + primary_targeted_structure: Optional[SlimsBrainStructureRdrc] = None secondary_targeted_structures: Optional[list[SlimsBrainStructureRdrc]] = None class SlimsStream(SlimsStreamsResult): """Streams wrapper to add linked stream modules""" + stream_modules: Optional[List[SlimsStreamModule]] @@ -94,7 +96,7 @@ def fetch_stream_modules( ] stream_module_model = SlimsStreamModule( - **stream_module.model_dump(serialize_quantity=False), + **stream_module.model_dump(), primary_targeted_structure=primary_structure, secondary_targeted_structures=secondary_structures, ) @@ -108,7 +110,7 @@ def fetch_streams(self, session_pk: int) -> List[SlimsStream]: ) complete_streams = [ SlimsStream( - **stream.model_dump(serialize_quantity=False), + **stream.model_dump(), stream_modules=( self.fetch_stream_modules(stream.stream_modules_pk) if stream.stream_modules_pk @@ -236,7 +238,6 @@ def fetch_ecephys_sessions( ) if group_run_step and session_run_steps: esb = EcephysSessionBuilder(client=client) - print("processing session steps") ecephys_sessions = esb.process_session_steps( group_run_step=group_run_step, session_run_steps=session_run_steps, diff --git a/tests/test_slimsmodel.py b/tests/test_slimsmodel.py index cb1e919..6e5e280 100644 --- a/tests/test_slimsmodel.py +++ b/tests/test_slimsmodel.py @@ -34,7 +34,7 @@ def test_string_field(self): self.assertEqual(obj.stringfield, "value") - def test_quantity_field(self): + def test_quantity_field_context(self): """Test validation/serialization of a quantity type, with unit""" obj = self.TestModel() obj.quantfield = Column( @@ -48,11 +48,29 @@ def test_quantity_field(self): self.assertEqual(obj.quantfield, 28.28) - serialized = obj.model_dump()["quantfield"] + serialized = obj.model_dump(context="slims_post")["quantfield"] expected = {"amount": 28.28, "unit_display": "um"} self.assertEqual(serialized, expected) + def test_quantity_field_no_context(self): + """Test validation/serialization of a quantity type without unit""" + obj = self.TestModel() + obj.quantfield = Column( + { + "datatype": "QUANTITY", + "name": "quantfield", + "value": 28.28, + "unit": "um", + } + ) + + self.assertEqual(obj.quantfield, 28.28) + + serialized = obj.model_dump()["quantfield"] + + self.assertEqual(serialized, 28.28) + def test_quantity_wrong_unit(self): """Ensure you get an error with an unexpected unit""" obj = self.TestModel()