Skip to content

Commit

Permalink
refactor/tests/docs: various fixes
Browse files Browse the repository at this point in the history
Improved naming conventions
Fixed some incorrect docstrings
Added missing test
  • Loading branch information
dbirman committed Oct 24, 2024
1 parent 5e75e34 commit 1c2b7e0
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 38 deletions.
6 changes: 2 additions & 4 deletions src/aind_data_access_api/helpers/data_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from aind_data_access_api.document_db import MetadataDbClient
from aind_data_access_api.helpers.docdb import (
get_field_from_docdb,
get_field_by_id,
get_id_from_name,
)
from aind_data_schema.core.quality_control import QualityControl
Expand All @@ -26,9 +26,7 @@ def get_quality_control_by_id(
allow_invalid : bool, optional
return invalid QualityControl as dict if True, by default False
"""
record = get_field_from_docdb(
client, _id=_id, field="quality_control"
)
record = get_field_by_id(client, _id=_id, field="quality_control")
if not record:
raise ValueError(f"No record found with id {_id}")

Expand Down
47 changes: 21 additions & 26 deletions src/aind_data_access_api/helpers/docdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,49 +4,42 @@
from aind_data_access_api.document_db import MetadataDbClient


def get_record_from_docdb(
def get_record_by_id(
client: MetadataDbClient,
_id: str,
) -> Optional[dict]:
"""
Download a record from docdb using the record _id.
"""Download a record from docdb using the record _id.
Parameters
----------
docdb_client : MongoClient
db_name : str
collection_name : str
client : MetadataDbClient
_id : str
Returns
-------
Optional[dict]
None if record does not exist. Otherwise, it will return the record as
a dict.
_description_
"""
records = client.retrieve_docdb_records(
filter_query={"_id": _id}, limit=1
)
records = client.retrieve_docdb_records(filter_query={"_id": _id}, limit=1)
if len(records) > 0:
return records[0]
else:
return None


def get_projected_record_from_docdb(
def get_projection_by_id(
client: MetadataDbClient,
_id: str,
projection: dict,
) -> Optional[dict]:
"""
Download a record from docdb using the record _id and a projection.
Projections return fields set to 1 {"field": 1}
Parameters
----------
docdb_client : MongoClient
db_name : str
collection_name : str
client : MetadataDbClient
_id : str
projection : dict
Expand All @@ -65,12 +58,12 @@ def get_projected_record_from_docdb(
return None


def get_field_from_docdb(
def get_field_by_id(
client: MetadataDbClient,
_id: str,
field: str,
) -> Optional[dict]:
"""Download a single field from a docdb record
"""Download a single field from docdb using the record _id
Parameters
----------
Expand All @@ -83,23 +76,19 @@ def get_field_from_docdb(
Optional[dict]
None if a record does not exist. Otherwise returns the field in a dict.
"""
return get_projected_record_from_docdb(
client, _id=_id, projection={field: 1}
)
return get_projection_by_id(client, _id=_id, projection={field: 1})


def get_id_from_name(
client: MetadataDbClient,
name: str,
) -> Optional[str]:
"""
Get the _id of a record in DocDb using the name field.
Get the _id of a record in DocDb from its name field.
Parameters
----------
docdb_client : MongoClient
db_name : str
collection_name : str
client : MetadataDbClient
name : str
Returns
Expand All @@ -109,8 +98,14 @@ def get_id_from_name(
the record.
"""
records = client.retrieve_docdb_records(
filter_query={"name": name}, projection={"_id": 1}, limit=1
filter_query={"name": name}, projection={"_id": 1}, limit=0
)

if len(records) > 1:
print(
"Warning: multiple records share the name {name}, only the first record will be returned."
)

if len(records) > 0:
return records[0]["_id"]
else:
Expand Down
19 changes: 11 additions & 8 deletions tests/test_util_docdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
from unittest.mock import MagicMock

from aind_data_access_api.helpers.docdb import (
get_record_from_docdb,
get_record_by_id,
get_id_from_name,
get_projected_record_from_docdb,
get_field_from_docdb,
get_projection_by_id,
get_field_by_id,
)


Expand All @@ -26,16 +26,21 @@ def test_get_record_from_docdb(self):
"""Tests get_record_from_docdb"""
client = MagicMock()
client.retrieve_docdb_records.return_value = [{"_id": "abcd"}]
record = get_record_from_docdb(client, _id="abcd")
record = get_record_by_id(client, _id="abcd")
self.assertEqual({"_id": "abcd"}, record)

# test the empty case
client.retrieve_docdb_records.return_value = []
record = get_record_by_id(client, _id="abcd")
self.assertIsNone(record)

def test_get_projected_record_from_docdb(self):
"""Tests get_projected_record_from_docdb"""
client = MagicMock()
client.retrieve_docdb_records.return_value = [
{"quality_control": {"a": 1}}
]
record = get_projected_record_from_docdb(
record = get_projection_by_id(
client, _id="abcd", projection={"quality_control": 1}
)
self.assertEqual({"quality_control": {"a": 1}}, record)
Expand All @@ -46,7 +51,5 @@ def test_get_field_from_docdb(self):
client.retrieve_docdb_records.return_value = [
{"quality_control": {"a": 1}}
]
field = get_field_from_docdb(
client, _id="abcd", field="quality_control"
)
field = get_field_by_id(client, _id="abcd", field="quality_control")
self.assertEqual({"quality_control": {"a": 1}}, field)

0 comments on commit 1c2b7e0

Please sign in to comment.