Skip to content

Commit

Permalink
Merge pull request #57 from scaleapi/da/slice-support
Browse files Browse the repository at this point in the history
Da/slice support
  • Loading branch information
ardila authored Apr 21, 2021
2 parents 4d09675 + 342d12f commit 2e515a4
Show file tree
Hide file tree
Showing 5 changed files with 308 additions and 98 deletions.
36 changes: 7 additions & 29 deletions nucleus/dataset.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import List, Dict, Any, Optional, Union
from typing import List, Dict, Any, Optional

from nucleus.utils import format_dataset_item_response
from .dataset_item import DatasetItem
from .annotation import (
Annotation,
BoxAnnotation,
PolygonAnnotation,
)
from .constants import (
DATASET_NAME_KEY,
Expand All @@ -13,10 +13,7 @@
DATASET_ITEM_IDS_KEY,
REFERENCE_IDS_KEY,
NAME_KEY,
ITEM_KEY,
DEFAULT_ANNOTATION_UPDATE_MODE,
ANNOTATIONS_KEY,
ANNOTATION_TYPES,
)
from .payload_constructor import construct_model_run_creation_payload

Expand Down Expand Up @@ -109,7 +106,7 @@ def create_model_run(

def annotate(
self,
annotations: List[Union[BoxAnnotation, PolygonAnnotation]],
annotations: List[Annotation],
update: Optional[bool] = DEFAULT_ANNOTATION_UPDATE_MODE,
batch_size: int = 5000,
) -> dict:
Expand Down Expand Up @@ -179,7 +176,7 @@ def iloc(self, i: int) -> dict:
}
"""
response = self._client.dataitem_iloc(self.id, i)
return self._format_dataset_item_response(response)
return format_dataset_item_response(response)

def refloc(self, reference_id: str) -> dict:
"""
Expand All @@ -192,7 +189,7 @@ def refloc(self, reference_id: str) -> dict:
}
"""
response = self._client.dataitem_ref_id(self.id, reference_id)
return self._format_dataset_item_response(response)
return format_dataset_item_response(response)

def loc(self, dataset_item_id: str) -> dict:
"""
Expand All @@ -205,7 +202,7 @@ def loc(self, dataset_item_id: str) -> dict:
}
"""
response = self._client.dataitem_loc(self.id, dataset_item_id)
return self._format_dataset_item_response(response)
return format_dataset_item_response(response)

def create_slice(
self,
Expand Down Expand Up @@ -247,25 +244,6 @@ def delete_item(self, item_id: str = None, reference_id: str = None):
def list_autotags(self):
return self._client.list_autotags(self.id)

def _format_dataset_item_response(self, response: dict) -> dict:
item = response.get(ITEM_KEY, None)
annotation_payload = response.get(ANNOTATIONS_KEY, {})
if not item or not annotation_payload:
# An error occured
return response

annotation_response = {}
for annotation_type in ANNOTATION_TYPES:
if annotation_type in annotation_payload:
annotation_response[annotation_type] = [
Annotation.from_json(ann)
for ann in annotation_payload[annotation_type]
]
return {
ITEM_KEY: DatasetItem.from_json(item),
ANNOTATIONS_KEY: annotation_response,
}

def create_custom_index(self, embeddings_url: str):
return self._client.create_custom_index(self.id, embeddings_url)

Expand Down
134 changes: 132 additions & 2 deletions nucleus/slice.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from typing import List
from typing import Dict, List, Iterable, Set, Tuple, Optional, Union
from nucleus.dataset_item import DatasetItem
from nucleus.annotation import Annotation
from nucleus.utils import format_dataset_item_response

from .constants import DEFAULT_ANNOTATION_UPDATE_MODE


class Slice:
Expand All @@ -9,6 +14,7 @@ class Slice:
def __init__(self, slice_id: str, client):
self.slice_id = slice_id
self._client = client
self._dataset_id = None

def __repr__(self):
return f"Slice(slice_id='{self.slice_id}', client={self._client})"
Expand All @@ -19,6 +25,13 @@ def __eq__(self, other):
return True
return False

@property
def dataset_id(self):
"""The id of the dataset this slice belongs to."""
if self._dataset_id is None:
self.info()
return self._dataset_id

def info(self) -> dict:
"""
This endpoint provides information about specified slice.
Expand All @@ -30,7 +43,9 @@ def info(self) -> dict:
"dataset_items",
}
"""
return self._client.slice_info(self.slice_id)
info = self._client.slice_info(self.slice_id)
self._dataset_id = info["dataset_id"]
return info

def append(
self,
Expand All @@ -57,3 +72,118 @@ def append(
reference_ids=reference_ids,
)
return response

def items_and_annotation_generator(
self,
) -> Iterable[Dict[str, Union[DatasetItem, Dict[str, List[Annotation]]]]]:
"""Returns an iterable of all DatasetItems and Annotations in this slice.
Returns:
An iterable, where each item is a dict with two keys representing a row
in the dataset.
* One value in the dict is the DatasetItem, containing a reference to the
item that was annotated, for example an image_url.
* The other value is a dictionary containing all the annotations for this
dataset item, sorted by annotation type.
"""
info = self.info()
for item_metadata in info["dataset_items"]:
yield format_dataset_item_response(
self._client.dataitem_loc(
dataset_id=info["dataset_id"],
dataset_item_id=item_metadata["id"],
)
)

def items_and_annotations(
self,
) -> List[Dict[str, Union[DatasetItem, Dict[str, List[Annotation]]]]]:
"""Returns a list of all DatasetItems and Annotations in this slice.
Returns:
A list, where each item is a dict with two keys representing a row
in the dataset.
* One value in the dict is the DatasetItem, containing a reference to the
item that was annotated.
* The other value is a dictionary containing all the annotations for this
dataset item, sorted by annotation type.
"""
return list(self.items_and_annotation_generator())

def annotate(
self,
annotations: List[Annotation],
update: Optional[bool] = DEFAULT_ANNOTATION_UPDATE_MODE,
batch_size: int = 5000,
strict=True,
):
"""Update annotations within this slice.
Args:
annotations: List of annotations to upload
batch_size: How many annotations to send per request.
strict: Whether to first check that the annotations belong to this slice.
Set to false to avoid this check and speed up upload.
"""
if strict:
(
annotations_are_in_slice,
item_ids_not_found_in_slice,
reference_ids_not_found_in_slice,
) = check_annotations_are_in_slice(annotations, self)
if not annotations_are_in_slice:
message = "Not all annotations are in this slice.\n"
if item_ids_not_found_in_slice:
message += f"Item ids not found in slice: {item_ids_not_found_in_slice} \n"
if reference_ids_not_found_in_slice:
message += f"Reference ids not found in slice: {reference_ids_not_found_in_slice}"
raise ValueError(message)
self._client.annotate_dataset(
dataset_id=self.dataset_id,
annotations=annotations,
update=update,
batch_size=batch_size,
)


def check_annotations_are_in_slice(
annotations: List[Annotation], slice_to_check: Slice
) -> Tuple[bool, Set[str], Set[str]]:
"""Check membership of the annotation targets within this slice.
annotations: Annnotations with ids referring to targets.
slice: The slice to check against.
Returns:
A tuple, where the first element is true/false whether the annotations are all
in the slice.
The second element is the list of item_ids not in the slice.
The third element is the list of ref_ids not in the slice.
"""
info = slice_to_check.info()

item_ids_not_found_in_slice = {
annotation.item_id
for annotation in annotations
if annotation.item_id is not None
}.difference(
{item_metadata["id"] for item_metadata in info["dataset_items"]}
)
reference_ids_not_found_in_slice = {
annotation.reference_id
for annotation in annotations
if annotation.reference_id is not None
}.difference(
{item_metadata["ref_id"] for item_metadata in info["dataset_items"]}
)
if item_ids_not_found_in_slice or reference_ids_not_found_in_slice:
annotations_are_in_slice = False
else:
annotations_are_in_slice = True

return (
annotations_are_in_slice,
item_ids_not_found_in_slice,
reference_ids_not_found_in_slice,
)
36 changes: 36 additions & 0 deletions nucleus/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,18 @@
"""Shared stateless utility function library"""


from typing import List, Union, Dict

from nucleus.annotation import Annotation
from .dataset_item import DatasetItem
from .prediction import BoxPrediction, PolygonPrediction

from .constants import (
ITEM_KEY,
ANNOTATIONS_KEY,
ANNOTATION_TYPES,
)


def _get_all_field_values(metadata_list: List[dict], key: str):
return {metadata[key] for metadata in metadata_list if key in metadata}
Expand Down Expand Up @@ -34,3 +44,29 @@ def suggest_metadata_schema(
entry["type"] = "text"
schema[key] = entry
return schema


def format_dataset_item_response(response: dict) -> dict:
"""Format the raw client response into api objects."""
if ANNOTATIONS_KEY not in response:
raise ValueError(
f"Server response was missing the annotation key: {response}"
)
if ITEM_KEY not in response:
raise ValueError(
f"Server response was missing the item key: {response}"
)
item = response[ITEM_KEY]
annotation_payload = response[ANNOTATIONS_KEY]

annotation_response = {}
for annotation_type in ANNOTATION_TYPES:
if annotation_type in annotation_payload:
annotation_response[annotation_type] = [
Annotation.from_json(ann)
for ann in annotation_payload[annotation_type]
]
return {
ITEM_KEY: DatasetItem.from_json(item),
ANNOTATIONS_KEY: annotation_response,
}
66 changes: 0 additions & 66 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,69 +127,3 @@ def test_dataset_list_autotags(CLIENT, dataset):
# List of Autotags should be empty
autotag_response = CLIENT.list_autotags(dataset.id)
assert autotag_response == []


def test_slice_create_and_delete_and_list(dataset):
# Dataset upload
ds_items = []
for url in TEST_IMG_URLS:
ds_items.append(
DatasetItem(
image_location=url,
reference_id=reference_id_from_url(url),
)
)
response = dataset.append(ds_items)
assert ERROR_PAYLOAD not in response.json()

# Slice creation
slc = dataset.create_slice(
name=TEST_SLICE_NAME,
reference_ids=[item.reference_id for item in ds_items[:2]],
)

dataset_slices = dataset.slices
assert len(dataset_slices) == 1
assert slc.slice_id == dataset_slices[0]

response = slc.info()
assert response["name"] == TEST_SLICE_NAME
assert response["dataset_id"] == dataset.id
assert len(response["dataset_items"]) == 2
for item in ds_items[:2]:
assert (
item.reference_id == response["dataset_items"][0]["ref_id"]
or item.reference_id == response["dataset_items"][1]["ref_id"]
)


def test_slice_append(dataset):
# Dataset upload
ds_items = []
for url in TEST_IMG_URLS:
ds_items.append(
DatasetItem(
image_location=url,
reference_id=reference_id_from_url(url),
)
)
response = dataset.append(ds_items)
assert ERROR_PAYLOAD not in response.json()

# Slice creation
slc = dataset.create_slice(
name=TEST_SLICE_NAME,
reference_ids=[ds_items[0].reference_id],
)

# Insert duplicate first item
slc.append(reference_ids=[item.reference_id for item in ds_items[:3]])

response = slc.info()
assert len(response["dataset_items"]) == 3
for item in ds_items[:3]:
assert (
item.reference_id == response["dataset_items"][0]["ref_id"]
or item.reference_id == response["dataset_items"][1]["ref_id"]
or item.reference_id == response["dataset_items"][2]["ref_id"]
)
Loading

0 comments on commit 2e515a4

Please sign in to comment.