Skip to content

Commit

Permalink
unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mekhlakapoor committed Oct 11, 2024
1 parent 0a173b0 commit 779af8e
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 62 deletions.
4 changes: 2 additions & 2 deletions src/aind_slims_api/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ def model_dump(self, serialize_quantity=True, *args, **kwargs) -> Dict[str, Any]
# Update serialized fields with UnitSpec information
if not serialize_quantity:
for key, value in data.items():
if isinstance(value, dict) and 'amount' in value:
if isinstance(value, dict) and "amount" in value:
# Extract the amount
data[key] = value['amount']
data[key] = value["amount"]
return data

# TODO: Add links - need Record.json_entity['links']['self']
Expand Down
20 changes: 10 additions & 10 deletions src/aind_slims_api/models/ecephys_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@

from datetime import datetime
from typing import Annotated, List, Optional, ClassVar
from pydantic import Field, ConfigDict
from slims.slims import Slims
from pydantic import Field

from aind_slims_api.models.base import SlimsBaseModel
from aind_slims_api.models.utils import UnitSpec
Expand Down Expand Up @@ -76,7 +75,7 @@ class SlimsGroupOfSessionsRunStep(SlimsExperimentRunStep):
instrument_pk: Optional[int] = Field(
default=None,
serialization_alias="xprs_cf_fk_instrumentJson",
validation_alias="xprs_cf_fk_instrumentJson"
validation_alias="xprs_cf_fk_instrumentJson",
)
# TODO: add device calibrations once we have an example
# device_calibrations_attachment: Optional[str] = Field(
Expand Down Expand Up @@ -378,8 +377,10 @@ class SlimsEphysInsertionResult(SlimsBaseModel):
"test_name": "test_ephys_insertion",
}


class SlimsInstrumentRdrc(SlimsBaseModel):
"""Model for Instrument Rdrc"""

pk: Optional[int] = Field(
default=None, serialization_alias="rdrc_pk", validation_alias="rdrc_pk"
)
Expand All @@ -396,8 +397,10 @@ class SlimsInstrumentRdrc(SlimsBaseModel):
"rdty_name": "AIND Instruments",
}


class SlimsDomeModuleRdrc(SlimsBaseModel):
"""Model for Dome Module Reference Data"""

pk: Optional[int] = Field(
default=None, serialization_alias="rdrc_pk", validation_alias="rdrc_pk"
)
Expand Down Expand Up @@ -528,15 +531,12 @@ class SlimsDomeModuleRdrc(SlimsBaseModel):

class SlimsBrainStructureRdrc(SlimsBaseModel):
"""Model for Brain Structure Reference Data"""

pk: Optional[int] = Field(
default=None,
serialization_alias="rdrc_pk",
validation_alias="rdrc_pk"
default=None, serialization_alias="rdrc_pk", validation_alias="rdrc_pk"
)
name: Optional[str] = Field(
default=None,
serialization_alias="rdrc_name",
validation_alias="rdrc_name"
default=None, serialization_alias="rdrc_name", validation_alias="rdrc_name"
)
created_on: Optional[datetime] = Field(
default=None,
Expand Down Expand Up @@ -584,7 +584,7 @@ class SlimsRewardDeliveryRdrc(SlimsBaseModel):

pk: Optional[int] = Field(serialization_alias="rdrc_pk", validation_alias="rdrc_pk")
reward_spouts_pk: Optional[int] = Field(
default=[],
default=None,
serialization_alias="rdrc_cf_fk_rewardSpouts",
validation_alias="rdrc_cf_fk_rewardSpouts",
)
Expand Down
3 changes: 2 additions & 1 deletion src/aind_slims_api/models/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Utility functions and classes for working with slims models.
"""

from typing import Optional, Dict, Any
from typing import Optional

from pydantic.fields import FieldInfo

Expand All @@ -20,6 +20,7 @@ def __init__(self, *args, preferred_unit=None):
if preferred_unit is None:
self.preferred_unit = self.units[0]


def _find_unit_spec(field: FieldInfo) -> UnitSpec | None:
"""Given a Pydantic FieldInfo, find the UnitSpec in its metadata"""
metadata = field.metadata
Expand Down
88 changes: 61 additions & 27 deletions src/aind_slims_api/operations/ecephys_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,31 @@
SlimsRewardDeliveryRdrc,
SlimsRewardSpoutsRdrc,
SlimsGroupOfSessionsRunStep,
SlimsMouseSessionRunStep, SlimsBrainStructureRdrc, SlimsInstrumentRdrc,
SlimsMouseSessionRunStep,
SlimsBrainStructureRdrc,
SlimsInstrumentRdrc,
)

logger = logging.getLogger(__name__)


class SlimsRewardDeliveryInfo(BaseModel):
""""""
"""Pydantic Model to store Reward Delivery Info"""

reward_delivery: Optional[SlimsRewardDeliveryRdrc] = []
reward_spouts: Optional[SlimsRewardSpoutsRdrc] = []


class SlimsStreamModule(SlimsDomeModuleRdrc):
""""""
"""DomeModule Wrapper to add linked brain structure models"""
primary_targeted_structure: Optional[SlimsBrainStructureRdrc] = None
secondary_targeted_structures: Optional[list[SlimsBrainStructureRdrc]] = None


class SlimsStream(SlimsStreamsResult):
""""""
stream_modules: List[SlimsStreamModule]
"""Streams wrapper to add linked stream modules"""
stream_modules: Optional[List[SlimsStreamModule]]


class EcephysSession(BaseModel):
"""
Expand All @@ -56,13 +62,20 @@ class EcephysSessionBuilder:
"""Class to build EcephysSession objects from session run steps."""

def __init__(self, client: SlimsClient):
"""Initialize Session Builder"""
self.client = client

def fetch_stream_modules(self, stream) -> List[SlimsStreamModule]:
def fetch_stream_modules(
self, stream_modules_pk: list[int]
) -> List[SlimsStreamModule]:
"""Fetches stream modules and processes structure names."""
stream_modules = (
[self.client.fetch_model(SlimsDomeModuleRdrc, pk=pk)
for pk in stream.stream_modules_pk] if stream.stream_modules_pk else []
[
self.client.fetch_model(SlimsDomeModuleRdrc, pk=pk)
for pk in stream_modules_pk
]
if stream_modules_pk
else []
)

complete_stream_modules = []
Expand All @@ -72,7 +85,7 @@ def fetch_stream_modules(self, stream) -> List[SlimsStreamModule]:
if stream_module.primary_targeted_structure_pk:
primary_structure = self.client.fetch_model(
SlimsBrainStructureRdrc,
pk=stream_module.primary_targeted_structure_pk
pk=stream_module.primary_targeted_structure_pk,
)
if stream_module.secondary_targeted_structures_pk:
secondary_structures = [
Expand All @@ -83,45 +96,63 @@ def fetch_stream_modules(self, stream) -> List[SlimsStreamModule]:
stream_module_model = SlimsStreamModule(
**stream_module.model_dump(serialize_quantity=False),
primary_targeted_structure=primary_structure,
secondary_targeted_structures=secondary_structures
secondary_targeted_structures=secondary_structures,
)
complete_stream_modules.append(stream_module_model)
return complete_stream_modules

def fetch_streams(self, session_pk: int) -> List[SlimsStream]:
"""Fetches and completes stream information with modules."""
streams = self.client.fetch_models(SlimsStreamsResult, mouse_session_pk=session_pk)
streams = self.client.fetch_models(
SlimsStreamsResult, mouse_session_pk=session_pk
)
complete_streams = [
SlimsStream(
**stream.model_dump(serialize_quantity=False),
stream_modules=self.fetch_stream_modules(stream)
) for stream in streams
stream_modules=(
self.fetch_stream_modules(stream.stream_modules_pk)
if stream.stream_modules_pk
else []
),
)
for stream in streams
]
return complete_streams

def fetch_reward_data(self, session) -> SlimsRewardDeliveryInfo:
def fetch_reward_data(self, reward_delivery_pk: int) -> SlimsRewardDeliveryInfo:
"""Fetches reward delivery and spouts data."""
reward_delivery = (
self.client.fetch_model(SlimsRewardDeliveryRdrc, pk=session.reward_delivery_pk)
if session.reward_delivery_pk else None
reward_delivery = self.client.fetch_model(
SlimsRewardDeliveryRdrc, pk=reward_delivery_pk
)
reward_spouts = (
self.client.fetch_model(SlimsRewardSpoutsRdrc, pk=reward_delivery.reward_spouts_pk)
if reward_delivery and reward_delivery.reward_spouts_pk else None
self.client.fetch_model(
SlimsRewardSpoutsRdrc, pk=reward_delivery.reward_spouts_pk
)
if reward_delivery and reward_delivery.reward_spouts_pk
else None
)
return SlimsRewardDeliveryInfo(
reward_delivery=reward_delivery,
reward_spouts=reward_spouts
reward_delivery=reward_delivery, reward_spouts=reward_spouts
)

def _process_single_step(self, group_run_step, session_run_step) -> EcephysSession:
"""Process a single session run step into an EcephysSession."""
session = self.client.fetch_model(SlimsMouseSessionResult, experiment_run_step_pk=session_run_step.pk)
session_instrument = self.client.fetch_model(SlimsInstrumentRdrc, pk=group_run_step.instrument_pk)
stimulus_epochs = self.client.fetch_models(SlimsStimulusEpochsResult, mouse_session_pk=session.pk)
session = self.client.fetch_model(
SlimsMouseSessionResult, experiment_run_step_pk=session_run_step.pk
)
session_instrument = self.client.fetch_model(
SlimsInstrumentRdrc, pk=group_run_step.instrument_pk
)
stimulus_epochs = self.client.fetch_models(
SlimsStimulusEpochsResult, mouse_session_pk=session.pk
)

streams = self.fetch_streams(session.pk)
reward_delivery = self.fetch_reward_data(session)
reward_delivery = (
self.fetch_reward_data(session.reward_delivery_pk)
if session.reward_delivery_pk
else None
)

return EcephysSession(
session_group=group_run_step,
Expand Down Expand Up @@ -151,11 +182,14 @@ def process_session_steps(
List[EcephysSession]
A list of EcephysSession objects containing the processed session data.
"""
return [self._process_single_step(group_run_step, step) for step in session_run_steps]
return [
self._process_single_step(group_run_step, step)
for step in session_run_steps
]


def fetch_ecephys_sessions(
client: SlimsClient, subject_id: str
client: SlimsClient, subject_id: str
) -> List[EcephysSession]:
"""
Fetch and process all electrophysiology (ecephys) run steps for a given subject.
Expand Down
90 changes: 68 additions & 22 deletions tests/test_operations/test_ecephys_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
SlimsMouseSessionRunStep,
SlimsExperimentRunStepContent,
SlimsExperimentRunStep,
SlimsBrainStructureRdrc,
SlimsInstrumentRdrc,
)
from aind_slims_api.operations import EcephysSession, fetch_ecephys_sessions
from aind_slims_api.operations.ecephys_session import EcephysSessionBuilder
Expand All @@ -40,36 +42,72 @@ def setUp(cls, mock_client):
for r in json.load(f)
]
cls.example_fetch_ecephys_session_result = response
with open(
RESOURCES_DIR / "example_fetch_ecephys_streams_result.json", "r"
) as f:
response = [
Record(json_entity=r, slims_api=cls.mock_client.db.slims_api)
for r in json.load(f)
]
cls.example_fetch_ecephys_streams_result = response
cls.operator = EcephysSessionBuilder(client=cls.mock_client)

def test_fetch_streams(self):
self.mock_client.fetch_models.return_value = [SlimsStreamsResult(stream="Stream1"),
SlimsStreamsResult(stream="Stream2")]
streams = self.operator.fetch_streams(session_pk=1)
self.assertEqual(len(streams), 2)
self.assertEqual(streams[0].stream, "Stream1")
"""Tests streams and modules are fetched successfully"""
example_stream = [
SlimsStreamsResult(
pk=12,
mouse_session_pk=2329,
camera_names=["camera1", "camera2"],
stream_modalities=["Ecephys", "Behavior Videos"],
stream_modules_pk=[123, 456],
)
]
self.mock_client.fetch_models.side_effect = [example_stream]
example_module_1 = SlimsDomeModuleRdrc(pk=123)
example_module_2 = SlimsDomeModuleRdrc(pk=456)
self.mock_client.fetch_model.side_effect = [example_module_1, example_module_2]
streams = self.operator.fetch_streams(session_pk=2329)
self.assertEqual(len(streams), 1)
self.assertEqual(len(streams[0].stream_modules), 2)

def test_fetch_stream_modules(self):
"""Tests that stream modules and structures are fetched successfully"""
example_module_1 = SlimsDomeModuleRdrc(
pk=123,
probe_name="ProbeA",
primary_targeted_structure_pk=789,
secondary_targeted_structures_pk=[789],
)
example_structure = SlimsBrainStructureRdrc(
pk=789,
name="Brain Structure A",
)
self.mock_client.fetch_model.side_effect = [
example_module_1,
example_structure,
example_structure,
]
stream_modules = self.operator.fetch_stream_modules(stream_modules_pk=[123])
self.assertEqual(
stream_modules[0].primary_targeted_structure, example_structure
)
self.assertEqual(len(stream_modules[0].secondary_targeted_structures), 1)

def test_fetch_reward_data(self):
"""Tests that reward info is fetched successfully"""
self.mock_client.fetch_model.side_effect = [
SlimsRewardDeliveryRdrc(
pk=1011, reward_solution="Water", reward_spouts_pk="1213"
),
SlimsRewardSpoutsRdrc(
pk=1213,
spout_side="Right",
variable_position=True,
),
]
reward_info = self.operator.fetch_reward_data(reward_delivery_pk=1011)
self.assertEqual(reward_info.reward_delivery.reward_solution, "Water")
self.assertEqual(reward_info.reward_spouts.spout_side, "Right")

def test_fetch_ecephys_sessions_success(self):
"""Tests session info is fetched successfully"""
self.mock_client.fetch_models.side_effect = [
[SlimsExperimentRunStepContent(pk=1, runstep_pk=3, mouse_pk=12345)],
[
SlimsGroupOfSessionsRunStep(
pk=6,
session_type="OptoTagging",
mouse_platform_name="Platform1",
experimentrun_pk=101,
)
],
[SlimsMouseSessionRunStep(pk=7, experimentrun_pk=101)],
None,
[
SlimsStreamsResult(
pk=8,
Expand All @@ -84,7 +122,15 @@ def test_fetch_ecephys_sessions_success(self):
self.mock_client.fetch_model.side_effect = [
SlimsMouseContent.model_construct(pk=12345),
SlimsExperimentRunStep(pk=3, experimentrun_pk=101),
SlimsGroupOfSessionsRunStep(
pk=6,
session_type="OptoTagging",
mouse_platform_name="Platform1",
experimentrun_pk=101,
instrument_pk=18,
),
SlimsMouseSessionResult(pk=12, reward_delivery_pk=14),
SlimsInstrumentRdrc(pk=18, name="323InstrumentA"),
SlimsDomeModuleRdrc(pk=9, probe_name="Probe1", arc_angle=20),
SlimsDomeModuleRdrc(pk=10, probe_name="Probe1", arc_angle=20),
SlimsRewardDeliveryRdrc(
Expand All @@ -105,7 +151,7 @@ def test_fetch_ecephys_sessions_success(self):
self.assertEqual(ecephys_session.session_group.session_type, "OptoTagging")
self.assertEqual(len(ecephys_session.streams), 1)
self.assertEqual(ecephys_session.streams[0].daq_names, ["DAQ1", "DAQ2"])
self.assertEqual(len(ecephys_session.stream_modules), 2)
self.assertEqual(len(ecephys_session.streams[0].stream_modules), 2)
self.assertIsNone(ecephys_session.stimulus_epochs)

def test_fetch_ecephys_sessions_handle_exception(self):
Expand Down

0 comments on commit 779af8e

Please sign in to comment.