diff --git a/src/aind_slims_api/core.py b/src/aind_slims_api/core.py index 7e44ffd..70c6c83 100644 --- a/src/aind_slims_api/core.py +++ b/src/aind_slims_api/core.py @@ -353,6 +353,7 @@ def add_model( exclude=fields_to_exclude, **kwargs, by_alias=True, + context="slims_post", ), ) return type(model).model_validate(rtn) @@ -377,9 +378,7 @@ def update_model(self, model: SlimsBaseModel, *args, **kwargs): model._slims_table, model.pk, model.model_dump( - include=fields_to_include, - by_alias=True, - **kwargs, + include=fields_to_include, by_alias=True, **kwargs, context="slims_post" ), ) return type(model).model_validate(rtn) diff --git a/src/aind_slims_api/models/base.py b/src/aind_slims_api/models/base.py index b893013..61e6571 100644 --- a/src/aind_slims_api/models/base.py +++ b/src/aind_slims_api/models/base.py @@ -4,8 +4,13 @@ import logging from datetime import datetime from typing import ClassVar, Optional - -from pydantic import BaseModel, ValidationInfo, field_serializer, field_validator +from pydantic import ( + BaseModel, + ValidationInfo, + field_serializer, + field_validator, + SerializationInfo, +) from slims.internal import Column as SlimsColumn # type: ignore from aind_slims_api.models.utils import _find_unit_spec @@ -63,15 +68,18 @@ def _validate(cls, value, info: ValidationInfo): return value @field_serializer("*") - def _serialize(self, field, info): - """Serialize a field, accounts for Quantities and datetime""" + def _serialize(self, field, info: SerializationInfo): + """Serialize a field, accounts for Quantities and datetime.""" unit_spec = _find_unit_spec(self.model_fields[info.field_name]) if unit_spec and field is not None: - quantity = { - "amount": field, - "unit_display": unit_spec.preferred_unit, - } - return quantity + if info.context == "slims_post": + quantity = { + "amount": field, + "unit_display": unit_spec.preferred_unit, + } + return quantity + else: + return field elif isinstance(field, datetime): return int(field.timestamp() * 10**3) else: diff --git a/src/aind_slims_api/models/ecephys_session.py b/src/aind_slims_api/models/ecephys_session.py index ce9eefe..e264220 100644 --- a/src/aind_slims_api/models/ecephys_session.py +++ b/src/aind_slims_api/models/ecephys_session.py @@ -72,6 +72,11 @@ class SlimsGroupOfSessionsRunStep(SlimsExperimentRunStep): serialization_alias="xprs_cf_activeMousePlatform", validation_alias="xprs_cf_activeMousePlatform", ) + instrument_pk: Optional[int] = Field( + default=None, + serialization_alias="xprs_cf_fk_instrumentJson", + validation_alias="xprs_cf_fk_instrumentJson", + ) # TODO: add device calibrations once we have an example # device_calibrations_attachment: Optional[str] = Field( # default=None, @@ -373,6 +378,26 @@ class SlimsEphysInsertionResult(SlimsBaseModel): } +class SlimsInstrumentRdrc(SlimsBaseModel): + """Model for Instrument Rdrc""" + + pk: Optional[int] = Field( + default=None, serialization_alias="rdrc_pk", validation_alias="rdrc_pk" + ) + name: Optional[str] = Field( + default=None, serialization_alias="rdrc_name", validation_alias="rdrc_name" + ) + created_on: Optional[datetime] = Field( + default=None, + serialization_alias="rdrc_createdOn", + validation_alias="rdrc_createdOn", + ) + _slims_table = "ReferenceDataRecord" + _base_fetch_filters: ClassVar[dict[str, str]] = { + "rdty_name": "AIND Instruments", + } + + class SlimsDomeModuleRdrc(SlimsBaseModel): """Model for Dome Module Reference Data""" @@ -392,15 +417,15 @@ class SlimsDomeModuleRdrc(SlimsBaseModel): serialization_alias="rdrc_cf_probeName", validation_alias="rdrc_cf_probeName", ) - primary_targeted_structure: Optional[str] = Field( + primary_targeted_structure_pk: Optional[int] = Field( default=None, - serialization_alias="rdrc_cf_fk_primaryTargetedStructure_display", - validation_alias="rdrc_cf_fk_primaryTargetedStructure_display", + serialization_alias="rdrc_cf_fk_primaryTargetedStructure", + validation_alias="rdrc_cf_fk_primaryTargetedStructure", ) - secondary_targeted_structures: Optional[List] = Field( + secondary_targeted_structures_pk: Optional[List] = Field( default=None, - serialization_alias="rdrc_cf_fk_secondaryTargetedStructures2_display", - validation_alias="rdrc_cf_fk_secondaryTargetedStructures2_display", + serialization_alias="rdrc_cf_fk_secondaryTargetedStructures", + validation_alias="rdrc_cf_fk_secondaryTargetedStructures", ) arc_angle: Annotated[float | None, UnitSpec("degree", "°")] = Field( default=None, @@ -504,6 +529,26 @@ class SlimsDomeModuleRdrc(SlimsBaseModel): } +class SlimsBrainStructureRdrc(SlimsBaseModel): + """Model for Brain Structure Reference Data""" + + pk: Optional[int] = Field( + default=None, serialization_alias="rdrc_pk", validation_alias="rdrc_pk" + ) + name: Optional[str] = Field( + default=None, serialization_alias="rdrc_name", validation_alias="rdrc_name" + ) + created_on: Optional[datetime] = Field( + default=None, + serialization_alias="rdrc_createdOn", + validation_alias="rdrc_createdOn", + ) + _slims_table = "ReferenceDataRecord" + _base_fetch_filters: ClassVar[dict[str, str]] = { + "rdty_name": "CCF brain structures", + } + + class SlimsFiberConnectionsRdrc(SlimsBaseModel): """Model for Fiber Connections Reference Data""" @@ -539,7 +584,7 @@ class SlimsRewardDeliveryRdrc(SlimsBaseModel): pk: Optional[int] = Field(serialization_alias="rdrc_pk", validation_alias="rdrc_pk") reward_spouts_pk: Optional[int] = Field( - default=[], + default=None, serialization_alias="rdrc_cf_fk_rewardSpouts", validation_alias="rdrc_cf_fk_rewardSpouts", ) diff --git a/src/aind_slims_api/operations/ecephys_session.py b/src/aind_slims_api/operations/ecephys_session.py index d75c6bb..e35c208 100644 --- a/src/aind_slims_api/operations/ecephys_session.py +++ b/src/aind_slims_api/operations/ecephys_session.py @@ -5,6 +5,7 @@ import logging from typing import List, Optional from pydantic import BaseModel + from aind_slims_api import SlimsClient from aind_slims_api.exceptions import SlimsRecordNotFound from aind_slims_api.models.mouse import SlimsMouseContent @@ -19,97 +20,174 @@ SlimsRewardSpoutsRdrc, SlimsGroupOfSessionsRunStep, SlimsMouseSessionRunStep, + SlimsBrainStructureRdrc, + SlimsInstrumentRdrc, ) logger = logging.getLogger(__name__) +class SlimsRewardDeliveryInfo(BaseModel): + """Pydantic Model to store Reward Delivery Info""" + + reward_delivery: Optional[SlimsRewardDeliveryRdrc] = [] + reward_spouts: Optional[SlimsRewardSpoutsRdrc] = [] + + +class SlimsStreamModule(SlimsDomeModuleRdrc): + """DomeModule Wrapper to add linked brain structure models""" + + primary_targeted_structure: Optional[SlimsBrainStructureRdrc] = None + secondary_targeted_structures: Optional[list[SlimsBrainStructureRdrc]] = None + + +class SlimsStream(SlimsStreamsResult): + """Streams wrapper to add linked stream modules""" + + stream_modules: Optional[List[SlimsStreamModule]] + + class EcephysSession(BaseModel): """ Pydantic model encapsulating all session-related responses. """ session_group: SlimsExperimentRunStep - session_result: Optional[SlimsMouseSessionResult] - streams: Optional[List[SlimsStreamsResult]] = [] - stream_modules: Optional[List[SlimsDomeModuleRdrc]] = [] - reward_delivery: Optional[SlimsRewardDeliveryRdrc] = None - reward_spouts: Optional[SlimsRewardSpoutsRdrc] = None + session_instrument: Optional[SlimsInstrumentRdrc] = None + session_result: Optional[SlimsMouseSessionResult] = None + streams: Optional[List[SlimsStream]] = [] + reward_delivery: Optional[SlimsRewardDeliveryInfo] = None stimulus_epochs: Optional[List[SlimsStimulusEpochsResult]] = [] -def _process_session_steps( - client: SlimsClient, - group_run_step: SlimsGroupOfSessionsRunStep, - session_run_steps: List[SlimsMouseSessionRunStep], -) -> List[EcephysSession]: - """ - Process session run steps and encapsulate related data into EcephysSession objects. - Iterates through each run step in the provided session run steps, - gathers the necessary data, and creates a list of EcephysSession objects. - - Parameters - ---------- - client : SlimsClient - An instance of SlimsClient used to retrieve additional session data. - group_run_step : SlimsGroupOfSessionsRunStep - The group run step containing session metadata and run information. - session_run_steps : List[SlimsMouseSessionRunStep] - A list of individual session run steps to be processed and encapsulated. +class EcephysSessionBuilder: + """Class to build EcephysSession objects from session run steps.""" + + def __init__(self, client: SlimsClient): + """Initialize Session Builder""" + self.client = client + + def fetch_stream_modules( + self, stream_modules_pk: list[int] + ) -> List[SlimsStreamModule]: + """Fetches stream modules and processes structure names.""" + stream_modules = ( + [ + self.client.fetch_model(SlimsDomeModuleRdrc, pk=pk) + for pk in stream_modules_pk + ] + if stream_modules_pk + else [] + ) - Returns - ------- - List[EcephysSession] - A list of EcephysSession objects containing the processed session data. + complete_stream_modules = [] + for stream_module in stream_modules: + primary_structure, secondary_structures = None, [] - """ - ecephys_sessions = [] + if stream_module.primary_targeted_structure_pk: + primary_structure = self.client.fetch_model( + SlimsBrainStructureRdrc, + pk=stream_module.primary_targeted_structure_pk, + ) + if stream_module.secondary_targeted_structures_pk: + secondary_structures = [ + self.client.fetch_model(SlimsBrainStructureRdrc, pk=pk) + for pk in stream_module.secondary_targeted_structures_pk + ] + + stream_module_model = SlimsStreamModule( + **stream_module.model_dump(), + primary_targeted_structure=primary_structure, + secondary_targeted_structures=secondary_structures, + ) + complete_stream_modules.append(stream_module_model) + return complete_stream_modules - for step in session_run_steps: - # retrieve session, streams, and epochs from Results table - session = client.fetch_model( - SlimsMouseSessionResult, experiment_run_step_pk=step.pk - ) - streams = client.fetch_models(SlimsStreamsResult, mouse_session_pk=session.pk) - stimulus_epochs = client.fetch_models( - SlimsStimulusEpochsResult, mouse_session_pk=session.pk + def fetch_streams(self, session_pk: int) -> List[SlimsStream]: + """Fetches and completes stream information with modules.""" + streams = self.client.fetch_models( + SlimsStreamsResult, mouse_session_pk=session_pk ) - - # retrieve modules and reward info from ReferenceDataRecord table - stream_modules = [ - client.fetch_model(SlimsDomeModuleRdrc, pk=stream_module_pk) + complete_streams = [ + SlimsStream( + **stream.model_dump(), + stream_modules=( + self.fetch_stream_modules(stream.stream_modules_pk) + if stream.stream_modules_pk + else [] + ), + ) for stream in streams - if stream.stream_modules_pk - for stream_module_pk in stream.stream_modules_pk ] + return complete_streams - reward_delivery = ( - client.fetch_model(SlimsRewardDeliveryRdrc, pk=session.reward_delivery_pk) - if session.reward_delivery_pk - else None + def fetch_reward_data(self, reward_delivery_pk: int) -> SlimsRewardDeliveryInfo: + """Fetches reward delivery and spouts data.""" + reward_delivery = self.client.fetch_model( + SlimsRewardDeliveryRdrc, pk=reward_delivery_pk ) - reward_spouts = ( - client.fetch_model( + self.client.fetch_model( SlimsRewardSpoutsRdrc, pk=reward_delivery.reward_spouts_pk ) if reward_delivery and reward_delivery.reward_spouts_pk else None ) + return SlimsRewardDeliveryInfo( + reward_delivery=reward_delivery, reward_spouts=reward_spouts + ) - # encapsulate all info for a single session - ecephys_session = EcephysSession( + def _process_single_step(self, group_run_step, session_run_step) -> EcephysSession: + """Process a single session run step into an EcephysSession.""" + session = self.client.fetch_model( + SlimsMouseSessionResult, experiment_run_step_pk=session_run_step.pk + ) + session_instrument = self.client.fetch_model( + SlimsInstrumentRdrc, pk=group_run_step.instrument_pk + ) + stimulus_epochs = self.client.fetch_models( + SlimsStimulusEpochsResult, mouse_session_pk=session.pk + ) + + streams = self.fetch_streams(session.pk) + reward_delivery = ( + self.fetch_reward_data(session.reward_delivery_pk) + if session.reward_delivery_pk + else None + ) + + return EcephysSession( session_group=group_run_step, + session_instrument=session_instrument or None, session_result=session, streams=streams or None, - stream_modules=stream_modules or None, reward_delivery=reward_delivery, - reward_spouts=reward_spouts, stimulus_epochs=stimulus_epochs or None, ) - ecephys_sessions.append(ecephys_session) - return ecephys_sessions + def process_session_steps( + self, + group_run_step: SlimsGroupOfSessionsRunStep, + session_run_steps: List[SlimsMouseSessionRunStep], + ) -> List[EcephysSession]: + """ + Processes all session run steps into EcephysSession objects. + Parameters + ---------- + group_run_step : SlimsGroupOfSessionsRunStep + The group run step containing session metadata and run information. + session_run_steps : List[SlimsMouseSessionRunStep] + A list of individual session run steps to be processed and encapsulated. + + Returns + ------- + List[EcephysSession] + A list of EcephysSession objects containing the processed session data. + """ + return [ + self._process_single_step(group_run_step, step) + for step in session_run_steps + ] def fetch_ecephys_sessions( @@ -150,7 +228,7 @@ def fetch_ecephys_sessions( ) # retrieve group and mouse sessions in the experiment run - group_run_step = client.fetch_models( + group_run_step = client.fetch_model( SlimsGroupOfSessionsRunStep, experimentrun_pk=content_run_step.experimentrun_pk, ) @@ -159,9 +237,9 @@ def fetch_ecephys_sessions( experimentrun_pk=content_run_step.experimentrun_pk, ) if group_run_step and session_run_steps: - ecephys_sessions = _process_session_steps( - client=client, - group_run_step=group_run_step[0], + esb = EcephysSessionBuilder(client=client) + ecephys_sessions = esb.process_session_steps( + group_run_step=group_run_step, session_run_steps=session_run_steps, ) ecephys_sessions_list.extend(ecephys_sessions) diff --git a/tests/test_operations/test_ecephys_session.py b/tests/test_operations/test_ecephys_session.py index e46e9ac..9efbeae 100644 --- a/tests/test_operations/test_ecephys_session.py +++ b/tests/test_operations/test_ecephys_session.py @@ -18,8 +18,11 @@ SlimsMouseSessionRunStep, SlimsExperimentRunStepContent, SlimsExperimentRunStep, + SlimsBrainStructureRdrc, + SlimsInstrumentRdrc, ) from aind_slims_api.operations import EcephysSession, fetch_ecephys_sessions +from aind_slims_api.operations.ecephys_session import EcephysSessionBuilder RESOURCES_DIR = Path(os.path.dirname(os.path.realpath(__file__))) / ".." / "resources" @@ -39,20 +42,72 @@ def setUp(cls, mock_client): for r in json.load(f) ] cls.example_fetch_ecephys_session_result = response + cls.operator = EcephysSessionBuilder(client=cls.mock_client) + + def test_fetch_streams(self): + """Tests streams and modules are fetched successfully""" + example_stream = [ + SlimsStreamsResult( + pk=12, + mouse_session_pk=2329, + camera_names=["camera1", "camera2"], + stream_modalities=["Ecephys", "Behavior Videos"], + stream_modules_pk=[123, 456], + ) + ] + self.mock_client.fetch_models.side_effect = [example_stream] + example_module_1 = SlimsDomeModuleRdrc(pk=123) + example_module_2 = SlimsDomeModuleRdrc(pk=456) + self.mock_client.fetch_model.side_effect = [example_module_1, example_module_2] + streams = self.operator.fetch_streams(session_pk=2329) + self.assertEqual(len(streams), 1) + self.assertEqual(len(streams[0].stream_modules), 2) + + def test_fetch_stream_modules(self): + """Tests that stream modules and structures are fetched successfully""" + example_module_1 = SlimsDomeModuleRdrc( + pk=123, + probe_name="ProbeA", + primary_targeted_structure_pk=789, + secondary_targeted_structures_pk=[789], + ) + example_structure = SlimsBrainStructureRdrc( + pk=789, + name="Brain Structure A", + ) + self.mock_client.fetch_model.side_effect = [ + example_module_1, + example_structure, + example_structure, + ] + stream_modules = self.operator.fetch_stream_modules(stream_modules_pk=[123]) + self.assertEqual( + stream_modules[0].primary_targeted_structure, example_structure + ) + self.assertEqual(len(stream_modules[0].secondary_targeted_structures), 1) + + def test_fetch_reward_data(self): + """Tests that reward info is fetched successfully""" + self.mock_client.fetch_model.side_effect = [ + SlimsRewardDeliveryRdrc( + pk=1011, reward_solution="Water", reward_spouts_pk="1213" + ), + SlimsRewardSpoutsRdrc( + pk=1213, + spout_side="Right", + variable_position=True, + ), + ] + reward_info = self.operator.fetch_reward_data(reward_delivery_pk=1011) + self.assertEqual(reward_info.reward_delivery.reward_solution, "Water") + self.assertEqual(reward_info.reward_spouts.spout_side, "Right") def test_fetch_ecephys_sessions_success(self): """Tests session info is fetched successfully""" self.mock_client.fetch_models.side_effect = [ [SlimsExperimentRunStepContent(pk=1, runstep_pk=3, mouse_pk=12345)], - [ - SlimsGroupOfSessionsRunStep( - pk=6, - session_type="OptoTagging", - mouse_platform_name="Platform1", - experimentrun_pk=101, - ) - ], [SlimsMouseSessionRunStep(pk=7, experimentrun_pk=101)], + None, [ SlimsStreamsResult( pk=8, @@ -67,7 +122,15 @@ def test_fetch_ecephys_sessions_success(self): self.mock_client.fetch_model.side_effect = [ SlimsMouseContent.model_construct(pk=12345), SlimsExperimentRunStep(pk=3, experimentrun_pk=101), + SlimsGroupOfSessionsRunStep( + pk=6, + session_type="OptoTagging", + mouse_platform_name="Platform1", + experimentrun_pk=101, + instrument_pk=18, + ), SlimsMouseSessionResult(pk=12, reward_delivery_pk=14), + SlimsInstrumentRdrc(pk=18, name="323InstrumentA"), SlimsDomeModuleRdrc(pk=9, probe_name="Probe1", arc_angle=20), SlimsDomeModuleRdrc(pk=10, probe_name="Probe1", arc_angle=20), SlimsRewardDeliveryRdrc( @@ -88,7 +151,7 @@ def test_fetch_ecephys_sessions_success(self): self.assertEqual(ecephys_session.session_group.session_type, "OptoTagging") self.assertEqual(len(ecephys_session.streams), 1) self.assertEqual(ecephys_session.streams[0].daq_names, ["DAQ1", "DAQ2"]) - self.assertEqual(len(ecephys_session.stream_modules), 2) + self.assertEqual(len(ecephys_session.streams[0].stream_modules), 2) self.assertIsNone(ecephys_session.stimulus_epochs) def test_fetch_ecephys_sessions_handle_exception(self): diff --git a/tests/test_slimsmodel.py b/tests/test_slimsmodel.py index cb1e919..6e5e280 100644 --- a/tests/test_slimsmodel.py +++ b/tests/test_slimsmodel.py @@ -34,7 +34,7 @@ def test_string_field(self): self.assertEqual(obj.stringfield, "value") - def test_quantity_field(self): + def test_quantity_field_context(self): """Test validation/serialization of a quantity type, with unit""" obj = self.TestModel() obj.quantfield = Column( @@ -48,11 +48,29 @@ def test_quantity_field(self): self.assertEqual(obj.quantfield, 28.28) - serialized = obj.model_dump()["quantfield"] + serialized = obj.model_dump(context="slims_post")["quantfield"] expected = {"amount": 28.28, "unit_display": "um"} self.assertEqual(serialized, expected) + def test_quantity_field_no_context(self): + """Test validation/serialization of a quantity type without unit""" + obj = self.TestModel() + obj.quantfield = Column( + { + "datatype": "QUANTITY", + "name": "quantfield", + "value": 28.28, + "unit": "um", + } + ) + + self.assertEqual(obj.quantfield, 28.28) + + serialized = obj.model_dump()["quantfield"] + + self.assertEqual(serialized, 28.28) + def test_quantity_wrong_unit(self): """Ensure you get an error with an unexpected unit""" obj = self.TestModel()