Skip to content

Commit

Permalink
Add ScicatClient.query_datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
jl-wynen committed May 27, 2024
1 parent 1b554e0 commit 5236160
Show file tree
Hide file tree
Showing 3 changed files with 327 additions and 3 deletions.
2 changes: 2 additions & 0 deletions docs/release-notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ Security
Features
~~~~~~~~

* Added experimental :meth:`client.ScicatClient.query_datasets` for querying datasets by field.

Breaking changes
~~~~~~~~~~~~~~~~

Expand Down
118 changes: 115 additions & 3 deletions src/scitacean/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import dataclasses
import datetime
import json
import re
import warnings
from collections.abc import Callable, Iterable, Iterator
Expand All @@ -15,6 +16,7 @@
from urllib.parse import quote_plus

import httpx
import pydantic

from . import model
from ._base_model import convert_download_to_user_model
Expand Down Expand Up @@ -708,6 +710,109 @@ def get_dataset_model(
**dset_json,
)

def query_datasets(
self,
fields: dict[str, Any],
*,
limit: int | None = None,
order: str | None = None,
strict_validation: bool = False,
) -> list[model.DownloadDataset]:
"""Query for datasets in SciCat.
Attention
---------
This function is experimental and may change or be removed in the future.
It is currently unclear how best to implement querying because SciCat
provides multiple, very different APIs and there are plans for supporting
queries via Mongo query language directly.
See `issue #177 <https://github.com/SciCatProject/scitacean/issues/177>`_
for a discussion.
Parameters
----------
fields:
Fields to query for.
Returned datasets must match all fields exactly.
See examples below.
limit:
Maximum number of results to return.
Requires ``order`` to be specified.
If not given, all matching datasets are returned.
order:
Specify order of results.
For example, ``"creationTime:asc"`` and ``"creationTime:desc"`` return
results in ascending or descending order in creation time, respectively.
strict_validation:
If ``True``, the datasets must pass validation.
If ``False``, datasets are still returned if validation fails.
Note that some dataset fields may have a bad value or type.
A warning will be logged if validation fails.
Returns
-------
:
A list of dataset models that match the query.
Examples
--------
Get all datasets belonging to proposal ``abc.123``:
.. code-block:: python
scicat_client.query_datasets({'proposalId': 'abc.123'})
Get all datasets that belong to proposal ``abc.123``
**and** have name ``"ds name"``: (The name and proposal must match exactly.)
.. code-block:: python
scicat_client.query_datasets({'proposalId': 'abc.123', 'name': 'ds name'})
Return only the newest 5 datasets for proposal ``bc.123``:
.. code-block:: python
scicat_client.query_datasets(
{'proposalId': 'bc.123'},
limit=5,
order="creationTime:desc",
)
"""
# Use a pydantic model to support serializing custom types to JSON.
params_model = pydantic.create_model(
"QueryParams", **{key: (type(field), ...) for key, field in fields.items()}
)
params = {"fields": params_model(**fields).model_dump_json()}

limits = {}
if order is not None:
limits["order"] = order
if limit is not None:
if order is None:
raise ValueError("`order` is required when `limit` is specified.")
limits["limit"] = limit
if limits:
params["limits"] = json.dumps(limits)

dsets_json = self._call_endpoint(
cmd="get",
url="datasets/fullquery",
params=params,
operation="query_datasets",
)
if not dsets_json:
return []
return [
model.construct(
model.DownloadDataset,
_strict_validation=strict_validation,
**dset_json,
)
for dset_json in dsets_json
]

def get_orig_datablocks(
self, pid: PID, strict_validation: bool = False
) -> list[model.DownloadOrigDatablock]:
Expand Down Expand Up @@ -1010,7 +1115,12 @@ def validate_dataset_model(
raise ValueError(f"Dataset {dset} did not pass validation in SciCat.")

def _send_to_scicat(
self, *, cmd: str, url: str, data: model.BaseModel | None = None
self,
*,
cmd: str,
url: str,
data: model.BaseModel | None = None,
params: dict[str, str] | None = None,
) -> httpx.Response:
if self._token is not None:
token = self._token.get_str()
Expand All @@ -1029,6 +1139,7 @@ def _send_to_scicat(
content=data.model_dump_json(exclude_none=True)
if data is not None
else None,
params=params,
headers=headers,
timeout=self._timeout.seconds,
)
Expand All @@ -1047,14 +1158,15 @@ def _call_endpoint(
*,
cmd: str,
url: str,
data: model.BaseModel | None = None,
operation: str,
data: model.BaseModel | None = None,
params: dict[str, str] | None = None,
) -> Any:
full_url = _url_concat(self._base_url, url)
logger = get_logger()
logger.info("Calling SciCat API at %s for operation '%s'", full_url, operation)

response = self._send_to_scicat(cmd=cmd, url=full_url, data=data)
response = self._send_to_scicat(cmd=cmd, url=full_url, data=data, params=params)
if not response.is_success:
logger.error(
"SciCat API call to %s failed: %s %s: %s",
Expand Down
210 changes: 210 additions & 0 deletions tests/client/query_client_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2024 SciCat Project (https://github.com/SciCatProject/scitacean)

import pytest
from dateutil.parser import parse as parse_datetime

from scitacean import Client, DatasetType, RemotePath, model
from scitacean.testing.backend import skip_if_not_backend
from scitacean.testing.backend.config import SciCatAccess

UPLOAD_DATASETS = {
"raw1": model.UploadRawDataset(
ownerGroup="PLACEHOLDER",
accessGroups=["uu", "faculty"],
contactEmail="[email protected]",
creationTime=parse_datetime("2004-06-13T01:45:28.100Z"),
datasetName="dataset 1",
numberOfFiles=0,
numberOfFilesArchived=0,
owner="PLACEHOLDER",
sourceFolder=RemotePath("/hex/raw1"),
type=DatasetType.RAW,
principalInvestigator="investigator 1",
creationLocation="UU",
proposalId="p0124",
),
"raw2": model.UploadRawDataset(
ownerGroup="PLACEHOLDER",
accessGroups=["uu", "faculty"],
contactEmail="[email protected]",
creationTime=parse_datetime("2004-06-14T14:00:30Z"),
datasetName="dataset 2",
numberOfFiles=0,
numberOfFilesArchived=0,
owner="PLACEHOLDER",
sourceFolder=RemotePath("/hex/raw2"),
type=DatasetType.RAW,
principalInvestigator="investigator 2",
creationLocation="UU",
proposalId="p0124",
),
"raw3": model.UploadRawDataset(
ownerGroup="PLACEHOLDER",
accessGroups=["uu", "faculty"],
contactEmail="[email protected]",
creationTime=parse_datetime("2004-06-10T00:13:13Z"),
datasetName="dataset 3",
numberOfFiles=0,
numberOfFilesArchived=0,
owner="PLACEHOLDER",
sourceFolder=RemotePath("/hex/raw3"),
type=DatasetType.RAW,
principalInvestigator="investigator 1",
creationLocation="UU",
proposalId="p0124",
),
"raw4": model.UploadRawDataset(
ownerGroup="PLACEHOLDER",
accessGroups=["uu", "faculty"],
contactEmail="[email protected]",
creationTime=parse_datetime("2005-11-03T21:56:02Z"),
datasetName="dataset 1",
numberOfFiles=0,
numberOfFilesArchived=0,
owner="PLACEHOLDER",
sourceFolder=RemotePath("/hex/raw4"),
type=DatasetType.RAW,
principalInvestigator="investigator X",
creationLocation="UU",
),
"derived1": model.UploadDerivedDataset(
ownerGroup="PLACEHOLDER",
accessGroups=["uu", "faculty"],
contactEmail="[email protected]",
creationTime=parse_datetime("2004-10-02T08:47:33Z"),
datasetName="dataset 1",
numberOfFiles=0,
numberOfFilesArchived=0,
owner="PLACEHOLDER",
sourceFolder=RemotePath("/hex/derived1"),
type=DatasetType.DERIVED,
investigator="investigator 1",
inputDatasets=[],
usedSoftware=["scitacean"],
),
"derived2": model.UploadDerivedDataset(
ownerGroup="PLACEHOLDER",
accessGroups=["uu", "faculty"],
contactEmail="[email protected]",
creationTime=parse_datetime("2004-10-14T09:18:58Z"),
datasetName="derived dataset 2",
numberOfFiles=0,
numberOfFilesArchived=0,
owner="PLACEHOLDER",
sourceFolder=RemotePath("/hex/derived2"),
type=DatasetType.DERIVED,
investigator="investigator 1",
inputDatasets=[],
usedSoftware=["scitacean"],
),
}
SEED = {}


@pytest.fixture(scope="module", autouse=True)
def seed_database(request: pytest.FixtureRequest, scicat_access: SciCatAccess) -> None:
skip_if_not_backend(request)

client = Client.from_credentials(
url=scicat_access.url,
**scicat_access.user.credentials, # type: ignore[arg-type]
)
for key, dset in UPLOAD_DATASETS.items():
dset.ownerGroup = scicat_access.user.group
dset.owner = scicat_access.user.username
SEED[key] = client.scicat.create_dataset_model(dset)


def test_query_dataset_multiple_by_single_field(real_client, seed_database):
datasets = real_client.scicat.query_datasets({"proposalId": "p0124"})
actual = {ds.pid: ds for ds in datasets}
expected = {SEED[key].pid: SEED[key] for key in ("raw1", "raw2", "raw3")}
assert actual == expected


def test_query_dataset_no_match(real_client, seed_database):
datasets = real_client.scicat.query_datasets({"owner": "librarian"})
assert not datasets


def test_query_dataset_multiple_by_multiple_fields(real_client, seed_database):
datasets = real_client.scicat.query_datasets(
{"proposalId": "p0124", "principalInvestigator": "investigator 1"},
)
actual = {ds.pid: ds for ds in datasets}
expected = {SEED[key].pid: SEED[key] for key in ("raw1", "raw3")}
assert actual == expected


def test_query_dataset_multiple_by_derived_field(real_client, seed_database):
datasets = real_client.scicat.query_datasets(
{"investigator": "investigator 1"},
)
actual = {ds.pid: ds for ds in datasets}
expected = {SEED[key].pid: SEED[key] for key in ("derived1", "derived2")}
assert actual == expected


def test_query_dataset_uses_conjunction_of_fields(real_client, seed_database):
datasets = real_client.scicat.query_datasets(
{"proposalId": "p0124", "investigator": "investigator X"},
)
assert not datasets


def test_query_dataset_can_use_custom_type(real_client, seed_database):
datasets = real_client.scicat.query_datasets(
{"sourceFolder": RemotePath("/hex/raw4")},
)
expected = [SEED["raw4"]]
assert datasets == expected


def test_query_dataset_set_order(real_client, seed_database):
datasets = real_client.scicat.query_datasets(
{"proposalId": "p0124"},
order="creationTime:desc",
)
# This test uses a list to check the order
expected = [SEED[key] for key in ("raw2", "raw1", "raw3")]
assert datasets == expected


def test_query_dataset_limit_ascending_creation_time(real_client, seed_database):
datasets = real_client.scicat.query_datasets(
{"proposalId": "p0124"},
limit=2,
order="creationTime:asc",
)
actual = {ds.pid: ds for ds in datasets}
expected = {SEED[key].pid: SEED[key] for key in ("raw1", "raw3")}
assert actual == expected


def test_query_dataset_limit_descending_creation_time(real_client, seed_database):
datasets = real_client.scicat.query_datasets(
{"proposalId": "p0124"},
limit=2,
order="creationTime:desc",
)
actual = {ds.pid: ds for ds in datasets}
expected = {SEED[key].pid: SEED[key] for key in ("raw1", "raw2")}
assert actual == expected


def test_query_dataset_limit_needs_order(real_client, seed_database):
with pytest.raises(ValueError, match="limit"):
real_client.scicat.query_datasets(
{"proposalId": "p0124"},
limit=2,
)


def test_query_dataset_all(real_client, seed_database):
datasets = real_client.scicat.query_datasets({})
actual = {ds.pid: ds for ds in datasets}
# We cannot test `datasets` directly because there are other datasets
# in the database from other tests.
for ds in SEED.values():
assert actual[ds.pid] == ds

0 comments on commit 5236160

Please sign in to comment.