Skip to content

Commit

Permalink
Implement dataset.query_objects method (#402)
Browse files Browse the repository at this point in the history
* Implement dataset.query_objects method

* Remove true_negative from enum

* Fix confusion category true positive case

* Rename IOUMatch to EvaluationMatch

* Rename documentation

* Add documentation to EvaluationMatch

* Propagate model_run_id parameter

* Bump sdk version

---------

Co-authored-by: Gunnar Atli Thoroddsen <[email protected]>
  • Loading branch information
ntamas92 and gatli authored Oct 23, 2023
1 parent 27c7dfd commit c641abb
Show file tree
Hide file tree
Showing 6 changed files with 178 additions and 5 deletions.
13 changes: 13 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,19 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).


## [0.16.4](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.16.4) - 2023-10-23

### Added
- Added a `query_objects` method on the Dataset class.
- Example
```shell
>>> ds = client.get_dataset('ds_id')
>>> objects = ds.query_objects('annotations.metadata.distance_to_device > 150', ObjectQueryType.GROUND_TRUTH_ONLY)
[CuboidAnnotation(label="", dimensions={}, ...), ...]
```
- Added `EvaluationMatch` class to represent IOU Matches, False Positives and False Negatives retrieved through the `query_objects` method


## [0.16.3](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.16.3) - 2023-10-10

### Added
Expand Down
2 changes: 1 addition & 1 deletion nucleus/async_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def result_urls(self, wait_for_completion=True) -> List[str]:
Parameters:
wait_for_completion: Defines whether the call shall wait for
the job to complete. Defaults to True
the job to complete. Defaults to True
Returns:
A list of signed Scale URLs which contain batches of embeddings.
Expand Down
6 changes: 6 additions & 0 deletions nucleus/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@
FX_KEY = "fx"
FY_KEY = "fy"
GEOMETRY_KEY = "geometry"
GROUND_TRUTH_ANNOTATION_ID_KEY = "ground_truth_annotation_id"
GROUND_TRUTH_ANNOTATION_LABEL_KEY = "ground_truth_annotation_label"
HEADING_KEY = "heading"
HEIGHT_KEY = "height"
ID_KEY = "id"
Expand All @@ -68,6 +70,7 @@
IMAGE_URL_KEY = "image_url"
INDEX_KEY = "index"
INDEX_CONTINUOUS_ENABLE_KEY = "enable"
IOU_KEY = "iou"
ITEMS_KEY = "items"
ITEM_KEY = "item"
ITEM_METADATA_SCHEMA_KEY = "item_metadata_schema"
Expand Down Expand Up @@ -97,6 +100,8 @@
MODEL_TAGS_KEY = "tags"
MODEL_ID_KEY = "model_id"
MODEL_RUN_ID_KEY = "model_run_id"
MODEL_PREDICTION_ID_KEY = "model_prediction_id"
MODEL_PREDICTION_LABEL_KEY = "model_prediction_label"
NAME_KEY = "name"
NEW_ITEMS = "new_items"
NUCLEUS_ENDPOINT = "https://api.scale.com/v1/nucleus"
Expand Down Expand Up @@ -135,6 +140,7 @@
TRACK_REFERENCE_ID_KEY = "track_reference_id"
TRACK_REFERENCE_IDS_KEY = "track_reference_ids"
TRACKS_KEY = "tracks"
TRUE_POSITIVE_KEY = "true_positive"
TYPE_KEY = "type"
UPDATED_ITEMS = "updated_items"
UPDATE_KEY = "update"
Expand Down
58 changes: 55 additions & 3 deletions nucleus/dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime
import os
from enum import Enum
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -16,7 +17,8 @@

from nucleus.annotation_uploader import AnnotationUploader, PredictionUploader
from nucleus.async_job import AsyncJob, EmbeddingsExportJob
from nucleus.prediction import Prediction, from_json
from nucleus.evaluation_match import EvaluationMatch
from nucleus.prediction import from_json as prediction_from_json
from nucleus.track import Track
from nucleus.url_utils import sanitize_string_args
from nucleus.utils import (
Expand Down Expand Up @@ -77,6 +79,7 @@
construct_model_run_creation_payload,
construct_taxonomy_payload,
)
from .prediction import Prediction
from .scene import LidarScene, Scene, VideoScene, check_all_scene_paths_remote
from .slice import (
Slice,
Expand All @@ -98,6 +101,14 @@
WARN_FOR_LARGE_SCENES_UPLOAD = 5


class ObjectQueryType(str, Enum):
IOU = "iou"
FALSE_POSITIVE = "false_positive"
FALSE_NEGATIVE = "false_negative"
PREDICTIONS_ONLY = "predictions_only"
GROUND_TRUTH_ONLY = "ground_truth_only"


class Dataset:
"""Datasets are collections of your data that can be associated with models.
Expand Down Expand Up @@ -1681,7 +1692,7 @@ def upload_predictions(
:class:`Category<CategoryPrediction>`, and :class:`Category<SceneCategoryPrediction>` predictions. Cuboid predictions
can only be uploaded to a :class:`pointcloud DatasetItem<LidarScene>`.
When uploading an prediction, you need to specify which item you are
When uploading a prediction, you need to specify which item you are
annotating via the reference_id you provided when uploading the image
or pointcloud.
Expand Down Expand Up @@ -1854,7 +1865,7 @@ def prediction_loc(self, model, reference_id, annotation_id):
:class:`KeypointsPrediction` \
]: Model prediction object with the specified annotation ID.
"""
return from_json(
return prediction_from_json(
self._client.make_request(
payload=None,
route=f"dataset/{self.id}/model/{model.id}/loc/{reference_id}/{annotation_id}",
Expand Down Expand Up @@ -1999,6 +2010,47 @@ def query_scenes(self, query: str) -> Iterable[Scene]:
for item_json in json_generator:
yield Scene.from_json(item_json, None, True)

def query_objects(
self,
query: str,
query_type: ObjectQueryType,
model_run_id: Optional[str] = None,
) -> Iterable[Union[Annotation, Prediction, EvaluationMatch]]:
"""
Fetches all objects in the dataset that pertain to a given structured query.
The results are either Predictions, Annotations, or Evaluation Matches, based on the objectType input parameter
Args:
query: Structured query compatible with the `Nucleus query language <https://nucleus.scale.com/docs/query-language-reference>`_.
objectType: Defines the type of the object to query
Returns:
An iterable of either Predictions, Annotations, or Evaluation Matches
"""
json_generator = paginate_generator(
client=self._client,
endpoint=f"dataset/{self.id}/queryObjectsPage",
result_key=ITEMS_KEY,
page_size=MAX_ES_PAGE_SIZE,
query=query,
patch_mode=query_type,
model_run_id=model_run_id,
)

for item_json in json_generator:
if query_type == ObjectQueryType.GROUND_TRUTH_ONLY:
yield Annotation.from_json(item_json)
elif query_type == ObjectQueryType.PREDICTIONS_ONLY:
yield prediction_from_json(item_json)
elif query_type in [
ObjectQueryType.IOU,
ObjectQueryType.FALSE_POSITIVE,
ObjectQueryType.FALSE_NEGATIVE,
]:
yield EvaluationMatch.from_json(item_json)
else:
raise ValueError("Unknown object type", query_type)

@property
def tracks(self) -> List[Track]:
"""Tracks unique to this dataset.
Expand Down
102 changes: 102 additions & 0 deletions nucleus/evaluation_match.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from dataclasses import dataclass
from enum import Enum
from typing import Optional

from .constants import (
DATASET_ITEM_ID_KEY,
GROUND_TRUTH_ANNOTATION_ID_KEY,
GROUND_TRUTH_ANNOTATION_LABEL_KEY,
IOU_KEY,
MODEL_PREDICTION_ID_KEY,
MODEL_PREDICTION_LABEL_KEY,
MODEL_RUN_ID_KEY,
TRUE_POSITIVE_KEY,
)


class ConfusionCategory(Enum):
TRUE_POSITIVE = "true_positive"
FALSE_POSITIVE = "false_positive"
FALSE_NEGATIVE = "false_negative"


def infer_confusion_category(
true_positive: bool,
ground_truth_annotation_label: str,
model_prediction_label: str,
):
confusion_category = ConfusionCategory.FALSE_NEGATIVE

if (
true_positive
or model_prediction_label == ground_truth_annotation_label
):
confusion_category = ConfusionCategory.TRUE_POSITIVE
elif model_prediction_label is not None:
confusion_category = ConfusionCategory.FALSE_POSITIVE

return confusion_category


@dataclass
class EvaluationMatch:
"""
EvaluationMatch is a result from a model run evaluation. It can represent a true positive, false positive,
or false negative.
The matching only matches the strongest prediction for each annotation, so if there are multiple predictions
that overlap a single annotation only the one with the highest overlap metric will be matched.
The model prediction label and the ground truth annotation label can differ for true positives if there is configured
an allowed_label_mapping for the model run.
NOTE: There is no iou thresholding applied to these matches, so it is possible to have a true positive with a low
iou score. If manually rejecting matches remember that a rejected match produces both a false positive and a false
negative otherwise you'll skew your aggregates.
Attributes:
model_run_id (str): The ID of the model run that produced this match.
model_prediction_id (str): The ID of the model prediction that was matched. None if the match was a false negative.
ground_truth_annotation_id (str): The ID of the ground truth annotation that was matched. None if the match was a false positive.
iou (int): The intersection over union score of the match.
dataset_item_id (str): The ID of the dataset item that was matched.
confusion_category (ConfusionCategory): The confusion category of the match.
model_prediction_label (str): The label of the model prediction that was matched. None if the match was a false negative.
ground_truth_annotation_label (str): The label of the ground truth annotation that was matched. None if the match was a false positive.
"""

model_run_id: str
model_prediction_id: Optional[str] # field is nullable
ground_truth_annotation_id: Optional[str] # field is nullable
iou: float
dataset_item_id: str
confusion_category: ConfusionCategory
model_prediction_label: Optional[str] # field is nullable
ground_truth_annotation_label: Optional[str] # field is nullable

@classmethod
def from_json(cls, payload: dict):
is_true_positive = payload.get(TRUE_POSITIVE_KEY, False)
model_prediction_label = payload.get(MODEL_PREDICTION_LABEL_KEY, None)
ground_truth_annotation_label = payload.get(
GROUND_TRUTH_ANNOTATION_LABEL_KEY, None
)

confusion_category = infer_confusion_category(
true_positive=is_true_positive,
ground_truth_annotation_label=ground_truth_annotation_label,
model_prediction_label=model_prediction_label,
)

return cls(
model_run_id=payload[MODEL_RUN_ID_KEY],
model_prediction_id=payload.get(MODEL_PREDICTION_ID_KEY, None),
ground_truth_annotation_id=payload.get(
GROUND_TRUTH_ANNOTATION_ID_KEY, None
),
iou=payload[IOU_KEY],
dataset_item_id=payload[DATASET_ITEM_ID_KEY],
confusion_category=confusion_category,
model_prediction_label=model_prediction_label,
ground_truth_annotation_label=ground_truth_annotation_label,
)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ ignore = ["E501", "E741", "E731", "F401"] # Easy ignore for getting it running

[tool.poetry]
name = "scale-nucleus"
version = "0.16.3"
version = "0.16.4"
description = "The official Python client library for Nucleus, the Data Platform for AI"
license = "MIT"
authors = ["Scale AI Nucleus Team <[email protected]>"]
Expand Down

0 comments on commit c641abb

Please sign in to comment.