Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: do not convert docdb records to DataAssetRecord #67

Merged
merged 3 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/UserGuide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ REST API (Read-Only)
filter = {"subject.subject_id": "123456"}
limit = 1000
paginate_batch_size = 100
response = docdb_api_client.retrieve_data_asset_records(
response = docdb_api_client.retrieve_docdb_records(
filter_query=filter,
limit=limit,
paginate_batch_size=paginate_batch_size
Expand Down
221 changes: 220 additions & 1 deletion src/aind_data_access_api/document_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import json
import logging
import warnings
from functools import cached_property
from sys import getsizeof
from typing import List, Optional, Tuple
Expand All @@ -13,6 +14,7 @@
from requests import Response

from aind_data_access_api.models import DataAssetRecord
from aind_data_access_api.utils import is_dict_corrupt


class Client:
Expand Down Expand Up @@ -226,6 +228,96 @@ def _bulk_write(self, operations: List[dict]) -> Response:
class MetadataDbClient(Client):
"""Class to manage reading and writing to metadata db"""

def retrieve_docdb_records(
self,
filter_query: Optional[dict] = None,
projection: Optional[dict] = None,
sort: Optional[dict] = None,
limit: int = 0,
paginate: bool = True,
paginate_batch_size: int = 10,
paginate_max_iterations: int = 20000,
) -> List[dict]:
"""
Retrieve raw json records from DocDB API Gateway as a list of dicts.

Parameters
----------
filter_query : Optional[dict]
Filter to apply to the records being returned. Default is None.
projection : Optional[dict]
Subset of document fields to return. Default is None.
sort : Optional[dict]
Sort records when returned. Default is None.
limit : int
Return a smaller set of records. 0 for all records. Default is 0.
paginate : bool
If set to true, will batch the queries to the API Gateway. It may
be faster to set to false if the number of records expected to be
returned is small.
paginate_batch_size : int
Number of records to return at a time. Default is 10.
paginate_max_iterations : int
Max number of iterations to run to prevent indefinite calls to the
API Gateway. Default is 20000.

Returns
-------
List[dict]

"""
if paginate is False:
records = self._get_records(
filter_query=filter_query,
projection=projection,
sort=sort,
limit=limit,
)
else:
# Get record count
record_counts = self._count_records(filter_query)
total_record_count = record_counts["total_record_count"]
filtered_record_count = record_counts["filtered_record_count"]
if filtered_record_count <= paginate_batch_size:
records = self._get_records(
filter_query=filter_query, projection=projection, sort=sort
)
else:
records = []
errors = []
num_of_records_collected = 0
limit = filtered_record_count if limit == 0 else limit
skip = 0
iter_count = 0
while (
skip < total_record_count
and num_of_records_collected
< min(filtered_record_count, limit)
and iter_count < paginate_max_iterations
):
try:
batched_records = self._get_records(
filter_query=filter_query,
projection=projection,
sort=sort,
limit=paginate_batch_size,
skip=skip,
)
num_of_records_collected += len(batched_records)
records.extend(batched_records)
except Exception as e:
errors.append(repr(e))
skip = skip + paginate_batch_size
iter_count += 1
# TODO: Add optional progress bar?
records = records[0:limit]
if len(errors) > 0:
logging.error(
f"There were errors retrieving records. {errors}"
)
return records

# TODO: remove this method
def retrieve_data_asset_records(
self,
filter_query: Optional[dict] = None,
Expand All @@ -237,6 +329,9 @@ def retrieve_data_asset_records(
paginate_max_iterations: int = 20000,
) -> List[DataAssetRecord]:
"""
DEPRECATED: This method is deprecated. Use `retrieve_docdb_records`
instead.

Retrieve data asset records

Parameters
Expand Down Expand Up @@ -264,6 +359,13 @@ def retrieve_data_asset_records(
List[DataAssetRecord]

"""
warnings.warn(
"retrieve_data_asset_records is deprecated. "
"Use retrieve_docdb_records instead."
"",
DeprecationWarning,
stacklevel=2,
)
if paginate is False:
records = self._get_records(
filter_query=filter_query,
Expand Down Expand Up @@ -318,11 +420,35 @@ def retrieve_data_asset_records(
data_asset_records.append(DataAssetRecord(**record))
return data_asset_records

def upsert_one_docdb_record(self, record: dict) -> Response:
"""Upsert one record if the record is not corrupt"""
if record.get("_id") is None:
raise ValueError("Record does not have an _id field.")
if is_dict_corrupt(record):
raise ValueError("Record is corrupt and cannot be upserted.")
response = self._upsert_one_record(
record_filter={"_id": record["_id"]},
update={"$set": json.loads(json.dumps(record, default=str))},
)
return response

# TODO: remove this method
def upsert_one_record(
self, data_asset_record: DataAssetRecord
) -> Response:
"""Upsert one record"""
"""
DEPRECATED: This method is deprecated. Use `upsert_one_docdb_record`
instead.

Upsert one record
"""
warnings.warn(
"upsert_one_record is deprecated. "
"Use upsert_one_docdb_record instead."
"",
DeprecationWarning,
stacklevel=2,
)
response = self._upsert_one_record(
record_filter={"_id": data_asset_record.id},
update={
Expand Down Expand Up @@ -362,12 +488,98 @@ def _record_to_operation(record: str, record_id: str) -> dict:
}
}

def upsert_list_of_docdb_records(
self,
records: List[dict],
max_payload_size: int = 2e6,
) -> List[Response]:
"""
Upsert a list of records. There's a limit to the size of the
request that can be sent, so we chunk the requests.

Parameters
----------

records : List[dict]
List of records to upsert into the DocDB database
max_payload_size : int
Chunk requests into smaller lists no bigger than this value in bytes.
If a single record is larger than this value in bytes, an attempt
will be made to upsert the record but will most likely receive a 413
status code. The Default is 2e6 bytes. The max payload for the API
Gateway including headers is 10MB.

Returns
-------
List[Response]
A list of responses from the API Gateway.

"""
if len(records) == 0:
return []
else:
# check no record is corrupt or missing _id
for record in records:
if record.get("_id") is None:
raise ValueError("A record does not have an _id field.")
if is_dict_corrupt(record):
raise ValueError(
"A record is corrupt and cannot be upserted."
)
# chunk records
first_index = 0
end_index = len(records)
second_index = 1
responses = []
record_json = json.dumps(records[first_index], default=str)
total_size = getsizeof(record_json)
operations = [
self._record_to_operation(
record=record_json,
record_id=records[first_index].get("_id"),
)
]
while second_index < end_index + 1:
if second_index == end_index:
response = self._bulk_write(operations)
responses.append(response)
else:
record_json = json.dumps(
records[second_index], default=str
)
record_size = getsizeof(record_json)
if total_size + record_size > max_payload_size:
response = self._bulk_write(operations)
responses.append(response)
first_index = second_index
operations = [
self._record_to_operation(
record=record_json,
record_id=records[first_index].get("_id"),
)
]
total_size = record_size
else:
operations.append(
self._record_to_operation(
record=record_json,
record_id=records[second_index].get("_id"),
)
)
total_size += record_size
second_index = second_index + 1
return responses

# TODO: remove this method
def upsert_list_of_records(
self,
data_asset_records: List[DataAssetRecord],
max_payload_size: int = 2e6,
) -> List[Response]:
"""
DEPRECATED: This method is deprecated. Use
`upsert_list_of_docdb_records` instead.

Upsert a list of records. There's a limit to the size of the
request that can be sent, so we chunk the requests.

Expand All @@ -389,6 +601,13 @@ def upsert_list_of_records(
A list of responses from the API Gateway.

"""
warnings.warn(
"upsert_list_of_records is deprecated. "
"Use upsert_list_of_docdb_records instead."
"",
DeprecationWarning,
stacklevel=2,
)
if len(data_asset_records) == 0:
return []
else:
Expand Down
3 changes: 3 additions & 0 deletions src/aind_data_access_api/document_db_ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ def from_secrets_manager(
class DocumentDbSSHClient:
"""Class to establish a Document Store client with SSH tunneling."""

# TODO: add retrieve_docdb_records, upsert_one_docdb_record,
# and upsert_list_of_docdb_records methods

def __init__(self, credentials: DocumentDbSSHCredentials):
"""
Construct a client to interface with a Document Database.
Expand Down
2 changes: 2 additions & 0 deletions src/aind_data_access_api/document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from aind_data_access_api.models import DataAssetRecord


# TODO: deprecate this class
class DocumentStoreCredentials(CoreCredentials):
"""Document Store credentials"""

Expand All @@ -31,6 +32,7 @@ class DocumentStoreCredentials(CoreCredentials):
database: str = Field(...)


# TODO: deprecate this client
class Client:
"""Class to establish a document store client."""

Expand Down
1 change: 1 addition & 0 deletions src/aind_data_access_api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pydantic import BaseModel, Extra, Field


# TODO: remove this model
class DataAssetRecord(BaseModel):
"""The records in the Data Asset Collection needs to contain certain fields
to easily query and index the data."""
Expand Down
Loading
Loading