diff --git a/nucleus/dataset.py b/nucleus/dataset.py index 16092b1f..6bb5d773 100644 --- a/nucleus/dataset.py +++ b/nucleus/dataset.py @@ -1,9 +1,9 @@ -from typing import List, Dict, Any, Optional, Union +from typing import List, Dict, Any, Optional + +from nucleus.utils import format_dataset_item_response from .dataset_item import DatasetItem from .annotation import ( Annotation, - BoxAnnotation, - PolygonAnnotation, ) from .constants import ( DATASET_NAME_KEY, @@ -13,10 +13,7 @@ DATASET_ITEM_IDS_KEY, REFERENCE_IDS_KEY, NAME_KEY, - ITEM_KEY, DEFAULT_ANNOTATION_UPDATE_MODE, - ANNOTATIONS_KEY, - ANNOTATION_TYPES, ) from .payload_constructor import construct_model_run_creation_payload @@ -109,7 +106,7 @@ def create_model_run( def annotate( self, - annotations: List[Union[BoxAnnotation, PolygonAnnotation]], + annotations: List[Annotation], update: Optional[bool] = DEFAULT_ANNOTATION_UPDATE_MODE, batch_size: int = 5000, ) -> dict: @@ -179,7 +176,7 @@ def iloc(self, i: int) -> dict: } """ response = self._client.dataitem_iloc(self.id, i) - return self._format_dataset_item_response(response) + return format_dataset_item_response(response) def refloc(self, reference_id: str) -> dict: """ @@ -192,7 +189,7 @@ def refloc(self, reference_id: str) -> dict: } """ response = self._client.dataitem_ref_id(self.id, reference_id) - return self._format_dataset_item_response(response) + return format_dataset_item_response(response) def loc(self, dataset_item_id: str) -> dict: """ @@ -205,7 +202,7 @@ def loc(self, dataset_item_id: str) -> dict: } """ response = self._client.dataitem_loc(self.id, dataset_item_id) - return self._format_dataset_item_response(response) + return format_dataset_item_response(response) def create_slice( self, @@ -247,25 +244,6 @@ def delete_item(self, item_id: str = None, reference_id: str = None): def list_autotags(self): return self._client.list_autotags(self.id) - def _format_dataset_item_response(self, response: dict) -> dict: - item = response.get(ITEM_KEY, None) - annotation_payload = response.get(ANNOTATIONS_KEY, {}) - if not item or not annotation_payload: - # An error occured - return response - - annotation_response = {} - for annotation_type in ANNOTATION_TYPES: - if annotation_type in annotation_payload: - annotation_response[annotation_type] = [ - Annotation.from_json(ann) - for ann in annotation_payload[annotation_type] - ] - return { - ITEM_KEY: DatasetItem.from_json(item), - ANNOTATIONS_KEY: annotation_response, - } - def create_custom_index(self, embeddings_url: str): return self._client.create_custom_index(self.id, embeddings_url) diff --git a/nucleus/slice.py b/nucleus/slice.py index 267df929..fa004905 100644 --- a/nucleus/slice.py +++ b/nucleus/slice.py @@ -1,4 +1,9 @@ -from typing import List +from typing import Dict, List, Iterable, Set, Tuple, Optional, Union +from nucleus.dataset_item import DatasetItem +from nucleus.annotation import Annotation +from nucleus.utils import format_dataset_item_response + +from .constants import DEFAULT_ANNOTATION_UPDATE_MODE class Slice: @@ -9,6 +14,7 @@ class Slice: def __init__(self, slice_id: str, client): self.slice_id = slice_id self._client = client + self._dataset_id = None def __repr__(self): return f"Slice(slice_id='{self.slice_id}', client={self._client})" @@ -19,6 +25,13 @@ def __eq__(self, other): return True return False + @property + def dataset_id(self): + """The id of the dataset this slice belongs to.""" + if self._dataset_id is None: + self.info() + return self._dataset_id + def info(self) -> dict: """ This endpoint provides information about specified slice. @@ -30,7 +43,9 @@ def info(self) -> dict: "dataset_items", } """ - return self._client.slice_info(self.slice_id) + info = self._client.slice_info(self.slice_id) + self._dataset_id = info["dataset_id"] + return info def append( self, @@ -57,3 +72,118 @@ def append( reference_ids=reference_ids, ) return response + + def items_and_annotation_generator( + self, + ) -> Iterable[Dict[str, Union[DatasetItem, Dict[str, List[Annotation]]]]]: + """Returns an iterable of all DatasetItems and Annotations in this slice. + + Returns: + An iterable, where each item is a dict with two keys representing a row + in the dataset. + * One value in the dict is the DatasetItem, containing a reference to the + item that was annotated, for example an image_url. + * The other value is a dictionary containing all the annotations for this + dataset item, sorted by annotation type. + """ + info = self.info() + for item_metadata in info["dataset_items"]: + yield format_dataset_item_response( + self._client.dataitem_loc( + dataset_id=info["dataset_id"], + dataset_item_id=item_metadata["id"], + ) + ) + + def items_and_annotations( + self, + ) -> List[Dict[str, Union[DatasetItem, Dict[str, List[Annotation]]]]]: + """Returns a list of all DatasetItems and Annotations in this slice. + + Returns: + A list, where each item is a dict with two keys representing a row + in the dataset. + * One value in the dict is the DatasetItem, containing a reference to the + item that was annotated. + * The other value is a dictionary containing all the annotations for this + dataset item, sorted by annotation type. + """ + return list(self.items_and_annotation_generator()) + + def annotate( + self, + annotations: List[Annotation], + update: Optional[bool] = DEFAULT_ANNOTATION_UPDATE_MODE, + batch_size: int = 5000, + strict=True, + ): + """Update annotations within this slice. + + Args: + annotations: List of annotations to upload + batch_size: How many annotations to send per request. + strict: Whether to first check that the annotations belong to this slice. + Set to false to avoid this check and speed up upload. + """ + if strict: + ( + annotations_are_in_slice, + item_ids_not_found_in_slice, + reference_ids_not_found_in_slice, + ) = check_annotations_are_in_slice(annotations, self) + if not annotations_are_in_slice: + message = "Not all annotations are in this slice.\n" + if item_ids_not_found_in_slice: + message += f"Item ids not found in slice: {item_ids_not_found_in_slice} \n" + if reference_ids_not_found_in_slice: + message += f"Reference ids not found in slice: {reference_ids_not_found_in_slice}" + raise ValueError(message) + self._client.annotate_dataset( + dataset_id=self.dataset_id, + annotations=annotations, + update=update, + batch_size=batch_size, + ) + + +def check_annotations_are_in_slice( + annotations: List[Annotation], slice_to_check: Slice +) -> Tuple[bool, Set[str], Set[str]]: + """Check membership of the annotation targets within this slice. + + annotations: Annnotations with ids referring to targets. + slice: The slice to check against. + + + Returns: + A tuple, where the first element is true/false whether the annotations are all + in the slice. + The second element is the list of item_ids not in the slice. + The third element is the list of ref_ids not in the slice. + """ + info = slice_to_check.info() + + item_ids_not_found_in_slice = { + annotation.item_id + for annotation in annotations + if annotation.item_id is not None + }.difference( + {item_metadata["id"] for item_metadata in info["dataset_items"]} + ) + reference_ids_not_found_in_slice = { + annotation.reference_id + for annotation in annotations + if annotation.reference_id is not None + }.difference( + {item_metadata["ref_id"] for item_metadata in info["dataset_items"]} + ) + if item_ids_not_found_in_slice or reference_ids_not_found_in_slice: + annotations_are_in_slice = False + else: + annotations_are_in_slice = True + + return ( + annotations_are_in_slice, + item_ids_not_found_in_slice, + reference_ids_not_found_in_slice, + ) diff --git a/nucleus/utils.py b/nucleus/utils.py index 5320b8d7..c74b0007 100644 --- a/nucleus/utils.py +++ b/nucleus/utils.py @@ -1,8 +1,18 @@ +"""Shared stateless utility function library""" + + from typing import List, Union, Dict +from nucleus.annotation import Annotation from .dataset_item import DatasetItem from .prediction import BoxPrediction, PolygonPrediction +from .constants import ( + ITEM_KEY, + ANNOTATIONS_KEY, + ANNOTATION_TYPES, +) + def _get_all_field_values(metadata_list: List[dict], key: str): return {metadata[key] for metadata in metadata_list if key in metadata} @@ -34,3 +44,29 @@ def suggest_metadata_schema( entry["type"] = "text" schema[key] = entry return schema + + +def format_dataset_item_response(response: dict) -> dict: + """Format the raw client response into api objects.""" + if ANNOTATIONS_KEY not in response: + raise ValueError( + f"Server response was missing the annotation key: {response}" + ) + if ITEM_KEY not in response: + raise ValueError( + f"Server response was missing the item key: {response}" + ) + item = response[ITEM_KEY] + annotation_payload = response[ANNOTATIONS_KEY] + + annotation_response = {} + for annotation_type in ANNOTATION_TYPES: + if annotation_type in annotation_payload: + annotation_response[annotation_type] = [ + Annotation.from_json(ann) + for ann in annotation_payload[annotation_type] + ] + return { + ITEM_KEY: DatasetItem.from_json(item), + ANNOTATIONS_KEY: annotation_response, + } diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 570d6915..04108e41 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -127,69 +127,3 @@ def test_dataset_list_autotags(CLIENT, dataset): # List of Autotags should be empty autotag_response = CLIENT.list_autotags(dataset.id) assert autotag_response == [] - - -def test_slice_create_and_delete_and_list(dataset): - # Dataset upload - ds_items = [] - for url in TEST_IMG_URLS: - ds_items.append( - DatasetItem( - image_location=url, - reference_id=reference_id_from_url(url), - ) - ) - response = dataset.append(ds_items) - assert ERROR_PAYLOAD not in response.json() - - # Slice creation - slc = dataset.create_slice( - name=TEST_SLICE_NAME, - reference_ids=[item.reference_id for item in ds_items[:2]], - ) - - dataset_slices = dataset.slices - assert len(dataset_slices) == 1 - assert slc.slice_id == dataset_slices[0] - - response = slc.info() - assert response["name"] == TEST_SLICE_NAME - assert response["dataset_id"] == dataset.id - assert len(response["dataset_items"]) == 2 - for item in ds_items[:2]: - assert ( - item.reference_id == response["dataset_items"][0]["ref_id"] - or item.reference_id == response["dataset_items"][1]["ref_id"] - ) - - -def test_slice_append(dataset): - # Dataset upload - ds_items = [] - for url in TEST_IMG_URLS: - ds_items.append( - DatasetItem( - image_location=url, - reference_id=reference_id_from_url(url), - ) - ) - response = dataset.append(ds_items) - assert ERROR_PAYLOAD not in response.json() - - # Slice creation - slc = dataset.create_slice( - name=TEST_SLICE_NAME, - reference_ids=[ds_items[0].reference_id], - ) - - # Insert duplicate first item - slc.append(reference_ids=[item.reference_id for item in ds_items[:3]]) - - response = slc.info() - assert len(response["dataset_items"]) == 3 - for item in ds_items[:3]: - assert ( - item.reference_id == response["dataset_items"][0]["ref_id"] - or item.reference_id == response["dataset_items"][1]["ref_id"] - or item.reference_id == response["dataset_items"][2]["ref_id"] - ) diff --git a/tests/test_slice.py b/tests/test_slice.py index 2e76ad3e..6826008c 100644 --- a/tests/test_slice.py +++ b/tests/test_slice.py @@ -1,5 +1,22 @@ import pytest -from nucleus import Slice, NucleusClient +from nucleus import Slice, NucleusClient, DatasetItem, BoxAnnotation +from nucleus.constants import ERROR_PAYLOAD, ITEM_KEY +from helpers import ( + TEST_DATASET_NAME, + TEST_IMG_URLS, + TEST_SLICE_NAME, + TEST_BOX_ANNOTATIONS, + reference_id_from_url, +) + + +@pytest.fixture() +def dataset(CLIENT): + ds = CLIENT.create_dataset(TEST_DATASET_NAME) + yield ds + + response = CLIENT.delete_dataset(ds.id) + assert response == {} def test_reprs(): @@ -9,3 +26,118 @@ def test_repr(test_object: any): client = NucleusClient(api_key="fake_key") test_repr(Slice(slice_id="fake_slice_id", client=client)) + + +def test_slice_create_and_delete_and_list(dataset): + # Dataset upload + ds_items = [] + for url in TEST_IMG_URLS: + ds_items.append( + DatasetItem( + image_location=url, + reference_id=reference_id_from_url(url), + ) + ) + response = dataset.append(ds_items) + assert ERROR_PAYLOAD not in response.json() + + # Slice creation + slc = dataset.create_slice( + name=TEST_SLICE_NAME, + reference_ids=[item.reference_id for item in ds_items[:2]], + ) + + dataset_slices = dataset.slices + assert len(dataset_slices) == 1 + assert slc.slice_id == dataset_slices[0] + + response = slc.info() + assert response["name"] == TEST_SLICE_NAME + assert response["dataset_id"] == dataset.id + assert len(response["dataset_items"]) == 2 + for item in ds_items[:2]: + assert ( + item.reference_id == response["dataset_items"][0]["ref_id"] + or item.reference_id == response["dataset_items"][1]["ref_id"] + ) + + +def test_slice_create_and_annotate(dataset): + # Dataset upload + url = TEST_IMG_URLS[0] + annotation_in_slice = BoxAnnotation(**TEST_BOX_ANNOTATIONS[0]) + annotation_not_in_slice = BoxAnnotation(**TEST_BOX_ANNOTATIONS[1]) + + ds_items = [] + ds_items.append( + DatasetItem( + image_location=url, + reference_id=reference_id_from_url(url), + ) + ) + response = dataset.append(ds_items) + assert ERROR_PAYLOAD not in response.json() + + # Slice creation + slc = dataset.create_slice( + name=TEST_SLICE_NAME, + reference_ids=[item.reference_id for item in ds_items[:2]], + ) + + slc.annotate(annotations=[annotation_in_slice]) + with pytest.raises(ValueError) as not_in_slice_error: + slc.annotate(annotations=[annotation_not_in_slice]) + + assert ( + annotation_not_in_slice.reference_id + in not_in_slice_error.value.args[0] + ) + + slc.annotate(annotations=[annotation_not_in_slice], strict=False) + + +def test_slice_append(dataset): + # Dataset upload + ds_items = [] + for url in TEST_IMG_URLS: + ds_items.append( + DatasetItem( + image_location=url, + reference_id=reference_id_from_url(url), + ) + ) + response = dataset.append(ds_items) + assert ERROR_PAYLOAD not in response.json() + + # Slice creation + slc = dataset.create_slice( + name=TEST_SLICE_NAME, + reference_ids=[ds_items[0].reference_id], + ) + + # Insert duplicate first item + slc.append(reference_ids=[item.reference_id for item in ds_items[:3]]) + + response = slc.info() + assert len(response["dataset_items"]) == 3 + for item in ds_items[:3]: + assert ( + item.reference_id == response["dataset_items"][0]["ref_id"] + or item.reference_id == response["dataset_items"][1]["ref_id"] + or item.reference_id == response["dataset_items"][2]["ref_id"] + ) + + all_stored_items = [_[ITEM_KEY] for _ in slc.items_and_annotations()] + + def sort_by_reference_id(items): + # Remove the generated item_ids and standardize + # empty metadata so we can do an equality check. + for item in items: + item.item_id = None + if item.metadata == {}: + item.metadata = None + return sorted(items, key=lambda x: x.reference_id) + + assert sort_by_reference_id(all_stored_items) == sort_by_reference_id( + ds_items[:3] + )