Skip to content

Commit

Permalink
feat: added in retrieve_schema_records
Browse files Browse the repository at this point in the history
  • Loading branch information
micahwoodard committed Jul 25, 2024
1 parent 2e874be commit 151871c
Showing 1 changed file with 198 additions and 51 deletions.
249 changes: 198 additions & 51 deletions src/aind_data_access_api/document_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ class Client:
Gateway."""

def __init__(
self,
host: str,
database: str,
collection: str,
version: str = "v1",
boto_session=None,
self,
host: str,
database: str,
collection: str,
version: str = "v1",
boto_session=None,
):
"""Class constructor."""
self.host = host.strip("/")
Expand All @@ -39,11 +39,36 @@ def __init__(
@property
def _base_url(self):
"""Construct base url to interface with a collection in a database."""
# return helper function
return self._create_url()

def _create_url(self,
database: Optional[str] = None,
collection: Optional[str] = None):
"""
Create url based on input database and collection
----------
database : Optional[str]
Database of url. Default is None
collection : Optional[str]
Collection of url. Default is None
Returns
-------
str
String of url in https://{self.host}/{self.version}/{database}/{collection} format
"""

database = database if database is not None else self.database
collection = collection if collection is not None else self.collection

return (
f"https://{self.host}/{self.version}/{self.database}/"
f"{self.collection}"
f"https://{self.host}/{self.version}/{database}/"
f"{collection}"
)


@property
def _update_one_url(self):
"""Url to update one record"""
Expand Down Expand Up @@ -84,7 +109,7 @@ def __boto_session(self):
return self._boto_session

def _signed_request(
self, url: str, method: str, data: Optional[str] = None
self, url: str, method: str, data: Optional[str] = None
) -> AWSRequest:
"""Create a signed request to write to the document store.
Permissions are managed through AWS."""
Expand All @@ -101,11 +126,18 @@ def _signed_request(
).add_auth(aws_request)
return aws_request

def _count_records(self, filter_query: Optional[dict] = None):
def _count_records(self,
database: Optional[str] = None,
collection: Optional[str] = None,
filter_query: Optional[dict] = None):
"""
Methods to count the number of records in a collection.
Parameters
----------
database : Optional[str]
Database of the records being counted. Default is None
collection : Optional[str]
Collection of the records being counted. Default is None
filter_query : Optional[dict]
If passed, will return the number of records and number of records
returned by the filter query.
Expand All @@ -121,23 +153,30 @@ def _count_records(self, filter_query: Optional[dict] = None):
}
if filter_query is not None:
params["filter"] = json.dumps(filter_query)
response = requests.get(self._base_url, params=params)
url = self._create_url(database, collection)
response = requests.get(url, params=params)
response_json = response.json()
body = response_json.get("body")
return json.loads(body)

def _get_records(
self,
filter_query: Optional[dict] = None,
projection: Optional[dict] = None,
sort: Optional[List[Tuple[str, int]]] = None,
limit: int = 0,
skip: int = 0,
self,
database: Optional[str] = None,
collection: Optional[str] = None,
filter_query: Optional[dict] = None,
projection: Optional[dict] = None,
sort: Optional[List[Tuple[str, int]]] = None,
limit: int = 0,
skip: int = 0,
) -> List[dict]:
"""
Retrieve records from collection.
Parameters
----------
database : Optional[str]
Database of the records being returned. Default is None
collection : Optional[str]
Collection of the records being returned. Default is None
filter_query : Optional[dict]
Filter to apply to the records being returned. Default is None.
projection : Optional[dict]
Expand All @@ -163,7 +202,8 @@ def _get_records(
if sort is not None:
params["sort"] = str(sort)

response = requests.get(self._base_url, params=params)
url = self._create_url(database, collection)
response = requests.get(url, params=params)
response_json = response.json()
body = response_json.get("body")
if body is None:
Expand All @@ -172,7 +212,7 @@ def _get_records(
return json.loads(body)

def _upsert_one_record(
self, record_filter: dict, update: dict
self, record_filter: dict, update: dict
) -> Response:
"""Upsert a single record into the collection."""
data = json.dumps(
Expand Down Expand Up @@ -229,14 +269,14 @@ class MetadataDbClient(Client):
"""Class to manage reading and writing to metadata db"""

def retrieve_docdb_records(
self,
filter_query: Optional[dict] = None,
projection: Optional[dict] = None,
sort: Optional[dict] = None,
limit: int = 0,
paginate: bool = True,
paginate_batch_size: int = 10,
paginate_max_iterations: int = 20000,
self,
filter_query: Optional[dict] = None,
projection: Optional[dict] = None,
sort: Optional[dict] = None,
limit: int = 0,
paginate: bool = True,
paginate_batch_size: int = 10,
paginate_max_iterations: int = 20000,
) -> List[dict]:
"""
Retrieve raw json records from DocDB API Gateway as a list of dicts.
Expand Down Expand Up @@ -290,13 +330,117 @@ def retrieve_docdb_records(
skip = 0
iter_count = 0
while (
skip < total_record_count
and num_of_records_collected
< min(filtered_record_count, limit)
and iter_count < paginate_max_iterations
skip < total_record_count
and num_of_records_collected
< min(filtered_record_count, limit)
and iter_count < paginate_max_iterations
):
try:
batched_records = self._get_records(
filter_query=filter_query,
projection=projection,
sort=sort,
limit=paginate_batch_size,
skip=skip,
)
num_of_records_collected += len(batched_records)
records.extend(batched_records)
except Exception as e:
errors.append(repr(e))
skip = skip + paginate_batch_size
iter_count += 1
# TODO: Add optional progress bar?
records = records[0:limit]
if len(errors) > 0:
logging.error(
f"There were errors retrieving records. {errors}"
)
return records

def retrieve_schema_records(
self,
schema_type: Optional[str] = None,
filter_query: Optional[dict] = None,
projection: Optional[dict] = None,
sort: Optional[dict] = None,
limit: int = 0,
paginate: bool = True,
paginate_batch_size: int = 10,
paginate_max_iterations: int = 20000,
) -> List[dict]:
"""
Retrieve schemas records from DocDB API Gateway as a list of dicts.
Parameters
----------
schema_type : Optional[str]
Type of schema to retrieve. Default is None.
filter_query : Optional[dict]
Filter to apply to the records being returned. Default is None.
projection : Optional[dict]
Subset of document fields to return. Default is None.
sort : Optional[dict]
Sort records when returned. Default is None.
limit : int
Return a smaller set of records. 0 for all records. Default is 0.
paginate : bool
If set to true, will batch the queries to the API Gateway. It may
be faster to set to false if the number of records expected to be
returned is small.
paginate_batch_size : int
Number of records to return at a time. Default is 10.
paginate_max_iterations : int
Max number of iterations to run to prevent indefinite calls to the
API Gateway. Default is 20000.
Returns
-------
List[dict]
"""
if paginate is False:
records = self._get_records(
database='schemas',
collection=schema_type,
filter_query=filter_query,
projection=projection,
sort=sort,
limit=limit,
)
else:
# Get record count
record_counts = self._count_records(
database='schemas',
collection=schema_type,
filter_query=filter_query,
)
total_record_count = record_counts["total_record_count"]
filtered_record_count = record_counts["filtered_record_count"]
if filtered_record_count <= paginate_batch_size:
records = self._get_records(
database='schemas',
collection=schema_type,
filter_query=filter_query,
projection=projection,
sort=sort
)
else:
records = []
errors = []
num_of_records_collected = 0
limit = filtered_record_count if limit == 0 else limit
skip = 0
iter_count = 0
while (
skip < total_record_count
and num_of_records_collected
< min(filtered_record_count, limit)
and iter_count < paginate_max_iterations
):
try:
batched_records = self._get_records(
database='schemas',
collection=schema_type,
filter_query=filter_query,
projection=projection,
sort=sort,
Expand All @@ -319,14 +463,16 @@ def retrieve_docdb_records(

# TODO: remove this method
def retrieve_data_asset_records(
self,
filter_query: Optional[dict] = None,
projection: Optional[dict] = None,
sort: Optional[dict] = None,
limit: int = 0,
paginate: bool = True,
paginate_batch_size: int = 10,
paginate_max_iterations: int = 20000,
self,
# add schmema type, string
# add collection
filter_query: Optional[dict] = None,
projection: Optional[dict] = None,
sort: Optional[dict] = None,
limit: int = 0,
paginate: bool = True,
paginate_batch_size: int = 10,
paginate_max_iterations: int = 20000,
) -> List[DataAssetRecord]:
"""
DEPRECATED: This method is deprecated. Use `retrieve_docdb_records`
Expand Down Expand Up @@ -367,6 +513,7 @@ def retrieve_data_asset_records(
stacklevel=2,
)
if paginate is False:
# and add to count records
records = self._get_records(
filter_query=filter_query,
projection=projection,
Expand All @@ -390,10 +537,10 @@ def retrieve_data_asset_records(
skip = 0
iter_count = 0
while (
skip < total_record_count
and num_of_records_collected
< min(filtered_record_count, limit)
and iter_count < paginate_max_iterations
skip < total_record_count
and num_of_records_collected
< min(filtered_record_count, limit)
and iter_count < paginate_max_iterations
):
try:
batched_records = self._get_records(
Expand Down Expand Up @@ -434,7 +581,7 @@ def upsert_one_docdb_record(self, record: dict) -> Response:

# TODO: remove this method
def upsert_one_record(
self, data_asset_record: DataAssetRecord
self, data_asset_record: DataAssetRecord
) -> Response:
"""
DEPRECATED: This method is deprecated. Use `upsert_one_docdb_record`
Expand Down Expand Up @@ -468,7 +615,7 @@ def delete_one_record(self, data_asset_record_id: str) -> Response:
return response

def delete_many_records(
self, data_asset_record_ids: List[str]
self, data_asset_record_ids: List[str]
) -> Response:
"""Delete many records by their ids"""

Expand All @@ -489,9 +636,9 @@ def _record_to_operation(record: str, record_id: str) -> dict:
}

def upsert_list_of_docdb_records(
self,
records: List[dict],
max_payload_size: int = 2e6,
self,
records: List[dict],
max_payload_size: int = 2e6,
) -> List[Response]:
"""
Upsert a list of records. There's a limit to the size of the
Expand Down Expand Up @@ -572,9 +719,9 @@ def upsert_list_of_docdb_records(

# TODO: remove this method
def upsert_list_of_records(
self,
data_asset_records: List[DataAssetRecord],
max_payload_size: int = 2e6,
self,
data_asset_records: List[DataAssetRecord],
max_payload_size: int = 2e6,
) -> List[Response]:
"""
DEPRECATED: This method is deprecated. Use
Expand Down

0 comments on commit 151871c

Please sign in to comment.