Skip to content

Commit

Permalink
Enh/2/behavior sessions 2 (#16)
Browse files Browse the repository at this point in the history
* switch to typevar to avoid mypy errors for returning subclass

* use class to validate

* update column names

* move fetch_models to core

* add instrument fetching

* behavior session testing

* add tests and more finalized version of behavior_session and instrument
  • Loading branch information
mochic authored Jul 9, 2024
1 parent a531a01 commit 8e5d33b
Show file tree
Hide file tree
Showing 9 changed files with 1,140 additions and 3 deletions.
123 changes: 123 additions & 0 deletions src/aind_slims_api/behavior_session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""Contains a model for the behavior session content events, a method for
fetching it and writing it.
"""

import logging
from typing import Any
from datetime import datetime

from pydantic import Field

from aind_slims_api.core import SlimsBaseModel, SlimsClient, SLIMSTABLES

logger = logging.getLogger()


class SlimsBehaviorSessionContentEvent(SlimsBaseModel):
"""Model for an instance of the Behavior Session ContentEvent"""

pk: int | None = Field(default=None, alias="cnvn_pk")
mouse_pk: int | None = Field(
default=None, alias="cnvn_fk_content"
) # used as reference to mouse
notes: str | None = Field(default=None, alias="cnvn_cf_notes")
task_stage: str | None = Field(default=None, alias="cnvn_cf_taskStage")
instrument: int | None = Field(default=None, alias="cnvn_cf_fk_instrument")
trainers: list[int] = Field(default=[], alias="cnvn_cf_fk_trainer")
task: str | None = Field(default=None, alias="cnvn_cf_task")
is_curriculum_suggestion: bool | None = Field(
default=None, alias="cnvn_cf_stageIsOnCurriculum"
)
task_schema_version: str | None = Field(
default=None, alias="cnvn_cf_taskSchemaVersion"
)
software_version: str | None = Field(default=None, alias="cnvn_cf_softwareVersion")
date: datetime | None = Field(..., alias="cnvn_cf_scheduledDate")

cnvn_fk_contentEventType: int = 10 # pk of Behavior Session ContentEvent

_slims_table: SLIMSTABLES = "ContentEvent"


SlimsSingletonFetchReturn = SlimsBaseModel | dict[str, Any] | None


def _resolve_pk(
model: SlimsSingletonFetchReturn,
primary_key_name: str = "pk",
) -> int:
"""Utility function shared across read/write
Notes
-----
- TODO: Change return type of fetch_mouse_content to match pattern in
fetch_behavior_session_content_events, or the other way around?
- TODO: Move to core to have better centralized control of when references
are resolved
"""
if isinstance(model, dict):
logger.warning("Extracting primary key from unvalidated dict.")
return model[primary_key_name]
elif isinstance(model, SlimsBaseModel):
return getattr(model, primary_key_name)
elif model is None:
raise ValueError(f"Cannot resolve primary key from {model}")
else:
raise ValueError("Unexpected type for model: %s" % type(model))


def fetch_behavior_session_content_events(
client: SlimsClient,
mouse: SlimsSingletonFetchReturn,
) -> tuple[list[SlimsBehaviorSessionContentEvent], list[dict[str, Any]]]:
"""Fetches behavior sessions for a mouse with labtracks id {mouse_name}
Returns
-------
tuple:
list:
Validated SlimsBehaviorSessionContentEvent objects
list:
Dictionaries representations of objects that failed validation
"""
return client.fetch_models(
SlimsBehaviorSessionContentEvent,
cnvn_fk_content=_resolve_pk(mouse),
cnvt_name="Behavior Session",
sort=["cnvn_cf_scheduledDate"],
)


def write_behavior_session_content_events(
client: SlimsClient,
mouse: SlimsSingletonFetchReturn,
instrument: SlimsSingletonFetchReturn,
trainers: list[SlimsSingletonFetchReturn],
*behavior_sessions: SlimsBehaviorSessionContentEvent,
) -> list[SlimsBehaviorSessionContentEvent]:
"""Writes behavior sessions for a mouse with labtracks id {mouse_name}
Notes
-----
- All supplied `behavior_sessions` will have their `mouse_name` field set
to the value supplied as `mouse_name` to this function
"""
mouse_pk = _resolve_pk(mouse)
logger.debug(f"Mouse pk: {mouse_pk}")
instrument_pk = _resolve_pk(instrument)
logger.debug(f"Instrument pk: {instrument_pk}")
trainer_pks = [_resolve_pk(trainer) for trainer in trainers]
logger.debug(f"Trainer pks: {trainer_pks}")
added = []
for behavior_session in behavior_sessions:
updated = behavior_session.model_copy(
update={
"mouse_pk": mouse_pk,
"instrument": instrument_pk,
"trainers": trainer_pks,
},
)
logger.debug(f"Resolved behavior session: {updated}")
added.append(client.add_model(updated))

return added
49 changes: 47 additions & 2 deletions src/aind_slims_api/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
from pydantic import (
BaseModel,
ValidationInfo,
ValidationError,
field_serializer,
field_validator,
)
from pydantic.fields import FieldInfo
import logging
from typing import Literal, Optional
from typing import Any, Literal, Optional, Type, TypeVar

from slims.slims import Slims, _SlimsApiException
from slims.internal import (
Expand All @@ -41,6 +42,7 @@
"Test",
"User",
"Groups",
"Instrument",
]


Expand Down Expand Up @@ -132,6 +134,9 @@ def _serialize(self, field, info):
# TODO: Support attachments


SlimsBaseModelTypeVar = TypeVar("SlimsBaseModelTypeVar", bound=SlimsBaseModel)


class SlimsClient:
"""Wrapper around slims-python-api client with convenience methods"""

Expand Down Expand Up @@ -199,6 +204,44 @@ def fetch(

return records

def fetch_models(
self,
model: Type[SlimsBaseModelTypeVar],
*args,
sort: Optional[str | list[str]] = None,
start: Optional[int] = None,
end: Optional[int] = None,
**kwargs,
) -> tuple[list[SlimsBaseModelTypeVar], list[dict[str, Any]]]:
"""Fetch records from SLIMS and return them as SlimsBaseModel objects
Returns
-------
tuple:
list:
Validated SlimsBaseModel objects
list:
Dictionaries representations of objects that failed validation
"""
response = self.fetch(
model._slims_table.default, # TODO: consider changing fetch method
*args,
sort=sort,
start=start,
end=end,
**kwargs,
)
validated = []
unvalidated = []
for record in response:
try:
validated.append(model.model_validate(record))
except ValidationError as e:
logger.error(f"SLIMS data validation failed, {repr(e)}")
unvalidated.append(record.json_entity)

return validated, unvalidated

@lru_cache(maxsize=None)
def fetch_pk(self, table: SLIMSTABLES, *args, **kwargs) -> int | None:
"""SlimsClient.fetch but returns the pk of the first returned record"""
Expand Down Expand Up @@ -233,7 +276,9 @@ def rest_link(self, table: SLIMSTABLES, **kwargs):
queries = [f"?{k}={v}" for k, v in kwargs.items()]
return base_url + "".join(queries)

def add_model(self, model: SlimsBaseModel, *args, **kwargs) -> SlimsBaseModel:
def add_model(
self, model: SlimsBaseModelTypeVar, *args, **kwargs
) -> SlimsBaseModelTypeVar:
"""Given a SlimsBaseModel object, add it to SLIMS
Args
model (SlimsBaseModel): object to add
Expand Down
66 changes: 66 additions & 0 deletions src/aind_slims_api/instrument.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""Contains a model for the instrument content, and a method for fetching it"""

import logging
from typing import Any

from pydantic import Field

from aind_slims_api.core import SlimsBaseModel, SlimsClient, SLIMSTABLES

logger = logging.getLogger()


class SlimsInstrument(SlimsBaseModel):
"""Model for an instance of the Behavior Session ContentEvent"""

name: str = Field(..., alias="nstr_name")
pk: int = Field(..., alias="nstr_pk")
_slims_table: SLIMSTABLES = "Instrument"

# todo add more useful fields


def fetch_instrument_content(
client: SlimsClient,
instrument_name: str,
) -> SlimsInstrument | dict[str, Any] | None:
"""Fetches behavior sessions for a mouse with labtracks id {mouse_name}
Returns
-------
tuple:
list:
Validated SlimsInstrument objects
list:
Dictionaries representations of objects that failed validation
Notes
-----
- Todo: add partial name match or some other type of filtering
- TODO: reconsider this pattern, consider just returning all records or
having number returned be a parameter or setting
"""
validated, unvalidated = client.fetch_models(
SlimsInstrument,
nstr_name=instrument_name,
)
if len(validated) > 0:
validated_details = validated[0]
if len(validated) > 1:
logger.warning(
f"Warning, Multiple instruments in SLIMS with name {instrument_name}, "
f"using pk={validated_details.pk}"
)
return validated_details
else:
if len(unvalidated) > 0:
unvalidated_details = unvalidated[0]
if len(unvalidated) > 1:
logger.warning(
"Warning, Multiple instruments in SLIMS with name "
f"{instrument_name}, "
f"using pk={unvalidated_details['pk']}"
)
return unvalidated[0]

return None
2 changes: 1 addition & 1 deletion src/aind_slims_api/mouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def fetch_mouse_content(
)
else:
logger.warning("Warning, Mouse not in SLIMS")
return
return None

try:
mouse = SlimsMouseContent.model_validate(mouse_details)
Expand Down
Loading

0 comments on commit 8e5d33b

Please sign in to comment.