diff --git a/src/aind_data_access_api/helpers/data_schema.py b/src/aind_data_access_api/helpers/data_schema.py index 7488324..875178c 100644 --- a/src/aind_data_access_api/helpers/data_schema.py +++ b/src/aind_data_access_api/helpers/data_schema.py @@ -2,7 +2,7 @@ from aind_data_access_api.document_db import MetadataDbClient from aind_data_access_api.helpers.docdb import ( - get_field_from_docdb, + get_field_by_id, get_id_from_name, ) from aind_data_schema.core.quality_control import QualityControl @@ -26,9 +26,7 @@ def get_quality_control_by_id( allow_invalid : bool, optional return invalid QualityControl as dict if True, by default False """ - record = get_field_from_docdb( - client, _id=_id, field="quality_control" - ) + record = get_field_by_id(client, _id=_id, field="quality_control") if not record: raise ValueError(f"No record found with id {_id}") diff --git a/src/aind_data_access_api/helpers/docdb.py b/src/aind_data_access_api/helpers/docdb.py index a660f39..25a75ff 100644 --- a/src/aind_data_access_api/helpers/docdb.py +++ b/src/aind_data_access_api/helpers/docdb.py @@ -4,37 +4,30 @@ from aind_data_access_api.document_db import MetadataDbClient -def get_record_from_docdb( +def get_record_by_id( client: MetadataDbClient, _id: str, ) -> Optional[dict]: - """ - Download a record from docdb using the record _id. + """Download a record from docdb using the record _id. Parameters ---------- - docdb_client : MongoClient - db_name : str - collection_name : str + client : MetadataDbClient _id : str Returns ------- Optional[dict] - None if record does not exist. Otherwise, it will return the record as - a dict. - + _description_ """ - records = client.retrieve_docdb_records( - filter_query={"_id": _id}, limit=1 - ) + records = client.retrieve_docdb_records(filter_query={"_id": _id}, limit=1) if len(records) > 0: return records[0] else: return None -def get_projected_record_from_docdb( +def get_projection_by_id( client: MetadataDbClient, _id: str, projection: dict, @@ -42,11 +35,11 @@ def get_projected_record_from_docdb( """ Download a record from docdb using the record _id and a projection. + Projections return fields set to 1 {"field": 1} + Parameters ---------- - docdb_client : MongoClient - db_name : str - collection_name : str + client : MetadataDbClient _id : str projection : dict @@ -65,12 +58,12 @@ def get_projected_record_from_docdb( return None -def get_field_from_docdb( +def get_field_by_id( client: MetadataDbClient, _id: str, field: str, ) -> Optional[dict]: - """Download a single field from a docdb record + """Download a single field from docdb using the record _id Parameters ---------- @@ -83,9 +76,7 @@ def get_field_from_docdb( Optional[dict] None if a record does not exist. Otherwise returns the field in a dict. """ - return get_projected_record_from_docdb( - client, _id=_id, projection={field: 1} - ) + return get_projection_by_id(client, _id=_id, projection={field: 1}) def get_id_from_name( @@ -93,13 +84,11 @@ def get_id_from_name( name: str, ) -> Optional[str]: """ - Get the _id of a record in DocDb using the name field. + Get the _id of a record in DocDb from its name field. Parameters ---------- - docdb_client : MongoClient - db_name : str - collection_name : str + client : MetadataDbClient name : str Returns @@ -109,8 +98,14 @@ def get_id_from_name( the record. """ records = client.retrieve_docdb_records( - filter_query={"name": name}, projection={"_id": 1}, limit=1 + filter_query={"name": name}, projection={"_id": 1}, limit=0 ) + + if len(records) > 1: + print( + "Warning: multiple records share the name {name}, only the first record will be returned." + ) + if len(records) > 0: return records[0]["_id"] else: diff --git a/tests/test_util_docdb.py b/tests/test_util_docdb.py index 736648f..9d78209 100644 --- a/tests/test_util_docdb.py +++ b/tests/test_util_docdb.py @@ -4,10 +4,10 @@ from unittest.mock import MagicMock from aind_data_access_api.helpers.docdb import ( - get_record_from_docdb, + get_record_by_id, get_id_from_name, - get_projected_record_from_docdb, - get_field_from_docdb, + get_projection_by_id, + get_field_by_id, ) @@ -26,16 +26,21 @@ def test_get_record_from_docdb(self): """Tests get_record_from_docdb""" client = MagicMock() client.retrieve_docdb_records.return_value = [{"_id": "abcd"}] - record = get_record_from_docdb(client, _id="abcd") + record = get_record_by_id(client, _id="abcd") self.assertEqual({"_id": "abcd"}, record) + # test the empty case + client.retrieve_docdb_records.return_value = [] + record = get_record_by_id(client, _id="abcd") + self.assertIsNone(record) + def test_get_projected_record_from_docdb(self): """Tests get_projected_record_from_docdb""" client = MagicMock() client.retrieve_docdb_records.return_value = [ {"quality_control": {"a": 1}} ] - record = get_projected_record_from_docdb( + record = get_projection_by_id( client, _id="abcd", projection={"quality_control": 1} ) self.assertEqual({"quality_control": {"a": 1}}, record) @@ -46,7 +51,5 @@ def test_get_field_from_docdb(self): client.retrieve_docdb_records.return_value = [ {"quality_control": {"a": 1}} ] - field = get_field_from_docdb( - client, _id="abcd", field="quality_control" - ) + field = get_field_by_id(client, _id="abcd", field="quality_control") self.assertEqual({"quality_control": {"a": 1}}, field)