Skip to content

Commit

Permalink
serialization context
Browse files Browse the repository at this point in the history
  • Loading branch information
mekhlakapoor committed Oct 15, 2024
1 parent 779af8e commit 4ad5bde
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 29 deletions.
5 changes: 2 additions & 3 deletions src/aind_slims_api/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,7 @@ def add_model(
exclude=fields_to_exclude,
**kwargs,
by_alias=True,
context="slims_post",
),
)
return type(model).model_validate(rtn)
Expand All @@ -377,9 +378,7 @@ def update_model(self, model: SlimsBaseModel, *args, **kwargs):
model._slims_table,
model.pk,
model.model_dump(
include=fields_to_include,
by_alias=True,
**kwargs,
include=fields_to_include, by_alias=True, **kwargs, context="slims_post"
),
)
return type(model).model_validate(rtn)
Expand Down
39 changes: 18 additions & 21 deletions src/aind_slims_api/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,14 @@

import logging
from datetime import datetime
from typing import ClassVar, Optional, Dict, Any
from pydantic import BaseModel, ValidationInfo, field_serializer, field_validator
from typing import ClassVar, Optional
from pydantic import (
BaseModel,
ValidationInfo,
field_serializer,
field_validator,
SerializationInfo,
)
from slims.internal import Column as SlimsColumn # type: ignore

from aind_slims_api.models.utils import _find_unit_spec
Expand Down Expand Up @@ -62,31 +68,22 @@ def _validate(cls, value, info: ValidationInfo):
return value

@field_serializer("*")
def _serialize(self, field, info):
"""Serialize a field, accounts for Quantities and datetime"""
def _serialize(self, field, info: SerializationInfo):
"""Serialize a field, accounts for Quantities and datetime."""
unit_spec = _find_unit_spec(self.model_fields[info.field_name])
if unit_spec and field is not None:
quantity = {
"amount": field,
"unit_display": unit_spec.preferred_unit,
}
return quantity
if info.context == "slims_post":
quantity = {
"amount": field,
"unit_display": unit_spec.preferred_unit,
}
return quantity
else:
return field
elif isinstance(field, datetime):
return int(field.timestamp() * 10**3)
else:
return field

def model_dump(self, serialize_quantity=True, *args, **kwargs) -> Dict[str, Any]:
"""Override model_dump to handle UnitSpec serialization."""
data = super().model_dump(*args, **kwargs)

# Update serialized fields with UnitSpec information
if not serialize_quantity:
for key, value in data.items():
if isinstance(value, dict) and "amount" in value:
# Extract the amount
data[key] = value["amount"]
return data

# TODO: Add links - need Record.json_entity['links']['self']
# TODO: Add Table - need Record.json_entity['tableName']
7 changes: 4 additions & 3 deletions src/aind_slims_api/operations/ecephys_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,14 @@ class SlimsRewardDeliveryInfo(BaseModel):

class SlimsStreamModule(SlimsDomeModuleRdrc):
"""DomeModule Wrapper to add linked brain structure models"""

primary_targeted_structure: Optional[SlimsBrainStructureRdrc] = None
secondary_targeted_structures: Optional[list[SlimsBrainStructureRdrc]] = None


class SlimsStream(SlimsStreamsResult):
"""Streams wrapper to add linked stream modules"""

stream_modules: Optional[List[SlimsStreamModule]]


Expand Down Expand Up @@ -94,7 +96,7 @@ def fetch_stream_modules(
]

stream_module_model = SlimsStreamModule(
**stream_module.model_dump(serialize_quantity=False),
**stream_module.model_dump(),
primary_targeted_structure=primary_structure,
secondary_targeted_structures=secondary_structures,
)
Expand All @@ -108,7 +110,7 @@ def fetch_streams(self, session_pk: int) -> List[SlimsStream]:
)
complete_streams = [
SlimsStream(
**stream.model_dump(serialize_quantity=False),
**stream.model_dump(),
stream_modules=(
self.fetch_stream_modules(stream.stream_modules_pk)
if stream.stream_modules_pk
Expand Down Expand Up @@ -236,7 +238,6 @@ def fetch_ecephys_sessions(
)
if group_run_step and session_run_steps:
esb = EcephysSessionBuilder(client=client)
print("processing session steps")
ecephys_sessions = esb.process_session_steps(
group_run_step=group_run_step,
session_run_steps=session_run_steps,
Expand Down
22 changes: 20 additions & 2 deletions tests/test_slimsmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_string_field(self):

self.assertEqual(obj.stringfield, "value")

def test_quantity_field(self):
def test_quantity_field_context(self):
"""Test validation/serialization of a quantity type, with unit"""
obj = self.TestModel()
obj.quantfield = Column(
Expand All @@ -48,11 +48,29 @@ def test_quantity_field(self):

self.assertEqual(obj.quantfield, 28.28)

serialized = obj.model_dump()["quantfield"]
serialized = obj.model_dump(context="slims_post")["quantfield"]
expected = {"amount": 28.28, "unit_display": "um"}

self.assertEqual(serialized, expected)

def test_quantity_field_no_context(self):
"""Test validation/serialization of a quantity type without unit"""
obj = self.TestModel()
obj.quantfield = Column(
{
"datatype": "QUANTITY",
"name": "quantfield",
"value": 28.28,
"unit": "um",
}
)

self.assertEqual(obj.quantfield, 28.28)

serialized = obj.model_dump()["quantfield"]

self.assertEqual(serialized, 28.28)

def test_quantity_wrong_unit(self):
"""Ensure you get an error with an unexpected unit"""
obj = self.TestModel()
Expand Down

0 comments on commit 4ad5bde

Please sign in to comment.