From c641abbd2524ada673ebc153cc5f4b48a19567ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20Nagy?= Date: Mon, 23 Oct 2023 21:23:10 +0200 Subject: [PATCH] Implement dataset.query_objects method (#402) * Implement dataset.query_objects method * Remove true_negative from enum * Fix confusion category true positive case * Rename IOUMatch to EvaluationMatch * Rename documentation * Add documentation to EvaluationMatch * Propagate model_run_id parameter * Bump sdk version --------- Co-authored-by: Gunnar Atli Thoroddsen --- CHANGELOG.md | 13 +++++ nucleus/async_job.py | 2 +- nucleus/constants.py | 6 +++ nucleus/dataset.py | 58 ++++++++++++++++++-- nucleus/evaluation_match.py | 102 ++++++++++++++++++++++++++++++++++++ pyproject.toml | 2 +- 6 files changed, 178 insertions(+), 5 deletions(-) create mode 100644 nucleus/evaluation_match.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 0beb94b9..fe0282d5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,19 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.16.4](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.16.4) - 2023-10-23 + +### Added +- Added a `query_objects` method on the Dataset class. +- Example +```shell +>>> ds = client.get_dataset('ds_id') +>>> objects = ds.query_objects('annotations.metadata.distance_to_device > 150', ObjectQueryType.GROUND_TRUTH_ONLY) +[CuboidAnnotation(label="", dimensions={}, ...), ...] +``` +- Added `EvaluationMatch` class to represent IOU Matches, False Positives and False Negatives retrieved through the `query_objects` method + + ## [0.16.3](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.16.3) - 2023-10-10 ### Added diff --git a/nucleus/async_job.py b/nucleus/async_job.py index ab04d161..e108e9aa 100644 --- a/nucleus/async_job.py +++ b/nucleus/async_job.py @@ -177,7 +177,7 @@ def result_urls(self, wait_for_completion=True) -> List[str]: Parameters: wait_for_completion: Defines whether the call shall wait for - the job to complete. Defaults to True + the job to complete. Defaults to True Returns: A list of signed Scale URLs which contain batches of embeddings. diff --git a/nucleus/constants.py b/nucleus/constants.py index 1c48eb63..31d1e710 100644 --- a/nucleus/constants.py +++ b/nucleus/constants.py @@ -59,6 +59,8 @@ FX_KEY = "fx" FY_KEY = "fy" GEOMETRY_KEY = "geometry" +GROUND_TRUTH_ANNOTATION_ID_KEY = "ground_truth_annotation_id" +GROUND_TRUTH_ANNOTATION_LABEL_KEY = "ground_truth_annotation_label" HEADING_KEY = "heading" HEIGHT_KEY = "height" ID_KEY = "id" @@ -68,6 +70,7 @@ IMAGE_URL_KEY = "image_url" INDEX_KEY = "index" INDEX_CONTINUOUS_ENABLE_KEY = "enable" +IOU_KEY = "iou" ITEMS_KEY = "items" ITEM_KEY = "item" ITEM_METADATA_SCHEMA_KEY = "item_metadata_schema" @@ -97,6 +100,8 @@ MODEL_TAGS_KEY = "tags" MODEL_ID_KEY = "model_id" MODEL_RUN_ID_KEY = "model_run_id" +MODEL_PREDICTION_ID_KEY = "model_prediction_id" +MODEL_PREDICTION_LABEL_KEY = "model_prediction_label" NAME_KEY = "name" NEW_ITEMS = "new_items" NUCLEUS_ENDPOINT = "https://api.scale.com/v1/nucleus" @@ -135,6 +140,7 @@ TRACK_REFERENCE_ID_KEY = "track_reference_id" TRACK_REFERENCE_IDS_KEY = "track_reference_ids" TRACKS_KEY = "tracks" +TRUE_POSITIVE_KEY = "true_positive" TYPE_KEY = "type" UPDATED_ITEMS = "updated_items" UPDATE_KEY = "update" diff --git a/nucleus/dataset.py b/nucleus/dataset.py index 38ed2964..2d5a64d5 100644 --- a/nucleus/dataset.py +++ b/nucleus/dataset.py @@ -1,5 +1,6 @@ import datetime import os +from enum import Enum from typing import ( TYPE_CHECKING, Any, @@ -16,7 +17,8 @@ from nucleus.annotation_uploader import AnnotationUploader, PredictionUploader from nucleus.async_job import AsyncJob, EmbeddingsExportJob -from nucleus.prediction import Prediction, from_json +from nucleus.evaluation_match import EvaluationMatch +from nucleus.prediction import from_json as prediction_from_json from nucleus.track import Track from nucleus.url_utils import sanitize_string_args from nucleus.utils import ( @@ -77,6 +79,7 @@ construct_model_run_creation_payload, construct_taxonomy_payload, ) +from .prediction import Prediction from .scene import LidarScene, Scene, VideoScene, check_all_scene_paths_remote from .slice import ( Slice, @@ -98,6 +101,14 @@ WARN_FOR_LARGE_SCENES_UPLOAD = 5 +class ObjectQueryType(str, Enum): + IOU = "iou" + FALSE_POSITIVE = "false_positive" + FALSE_NEGATIVE = "false_negative" + PREDICTIONS_ONLY = "predictions_only" + GROUND_TRUTH_ONLY = "ground_truth_only" + + class Dataset: """Datasets are collections of your data that can be associated with models. @@ -1681,7 +1692,7 @@ def upload_predictions( :class:`Category`, and :class:`Category` predictions. Cuboid predictions can only be uploaded to a :class:`pointcloud DatasetItem`. - When uploading an prediction, you need to specify which item you are + When uploading a prediction, you need to specify which item you are annotating via the reference_id you provided when uploading the image or pointcloud. @@ -1854,7 +1865,7 @@ def prediction_loc(self, model, reference_id, annotation_id): :class:`KeypointsPrediction` \ ]: Model prediction object with the specified annotation ID. """ - return from_json( + return prediction_from_json( self._client.make_request( payload=None, route=f"dataset/{self.id}/model/{model.id}/loc/{reference_id}/{annotation_id}", @@ -1999,6 +2010,47 @@ def query_scenes(self, query: str) -> Iterable[Scene]: for item_json in json_generator: yield Scene.from_json(item_json, None, True) + def query_objects( + self, + query: str, + query_type: ObjectQueryType, + model_run_id: Optional[str] = None, + ) -> Iterable[Union[Annotation, Prediction, EvaluationMatch]]: + """ + Fetches all objects in the dataset that pertain to a given structured query. + The results are either Predictions, Annotations, or Evaluation Matches, based on the objectType input parameter + + Args: + query: Structured query compatible with the `Nucleus query language `_. + objectType: Defines the type of the object to query + + Returns: + An iterable of either Predictions, Annotations, or Evaluation Matches + """ + json_generator = paginate_generator( + client=self._client, + endpoint=f"dataset/{self.id}/queryObjectsPage", + result_key=ITEMS_KEY, + page_size=MAX_ES_PAGE_SIZE, + query=query, + patch_mode=query_type, + model_run_id=model_run_id, + ) + + for item_json in json_generator: + if query_type == ObjectQueryType.GROUND_TRUTH_ONLY: + yield Annotation.from_json(item_json) + elif query_type == ObjectQueryType.PREDICTIONS_ONLY: + yield prediction_from_json(item_json) + elif query_type in [ + ObjectQueryType.IOU, + ObjectQueryType.FALSE_POSITIVE, + ObjectQueryType.FALSE_NEGATIVE, + ]: + yield EvaluationMatch.from_json(item_json) + else: + raise ValueError("Unknown object type", query_type) + @property def tracks(self) -> List[Track]: """Tracks unique to this dataset. diff --git a/nucleus/evaluation_match.py b/nucleus/evaluation_match.py new file mode 100644 index 00000000..a0a2deec --- /dev/null +++ b/nucleus/evaluation_match.py @@ -0,0 +1,102 @@ +from dataclasses import dataclass +from enum import Enum +from typing import Optional + +from .constants import ( + DATASET_ITEM_ID_KEY, + GROUND_TRUTH_ANNOTATION_ID_KEY, + GROUND_TRUTH_ANNOTATION_LABEL_KEY, + IOU_KEY, + MODEL_PREDICTION_ID_KEY, + MODEL_PREDICTION_LABEL_KEY, + MODEL_RUN_ID_KEY, + TRUE_POSITIVE_KEY, +) + + +class ConfusionCategory(Enum): + TRUE_POSITIVE = "true_positive" + FALSE_POSITIVE = "false_positive" + FALSE_NEGATIVE = "false_negative" + + +def infer_confusion_category( + true_positive: bool, + ground_truth_annotation_label: str, + model_prediction_label: str, +): + confusion_category = ConfusionCategory.FALSE_NEGATIVE + + if ( + true_positive + or model_prediction_label == ground_truth_annotation_label + ): + confusion_category = ConfusionCategory.TRUE_POSITIVE + elif model_prediction_label is not None: + confusion_category = ConfusionCategory.FALSE_POSITIVE + + return confusion_category + + +@dataclass +class EvaluationMatch: + """ + EvaluationMatch is a result from a model run evaluation. It can represent a true positive, false positive, + or false negative. + + The matching only matches the strongest prediction for each annotation, so if there are multiple predictions + that overlap a single annotation only the one with the highest overlap metric will be matched. + + The model prediction label and the ground truth annotation label can differ for true positives if there is configured + an allowed_label_mapping for the model run. + + NOTE: There is no iou thresholding applied to these matches, so it is possible to have a true positive with a low + iou score. If manually rejecting matches remember that a rejected match produces both a false positive and a false + negative otherwise you'll skew your aggregates. + + Attributes: + model_run_id (str): The ID of the model run that produced this match. + model_prediction_id (str): The ID of the model prediction that was matched. None if the match was a false negative. + ground_truth_annotation_id (str): The ID of the ground truth annotation that was matched. None if the match was a false positive. + iou (int): The intersection over union score of the match. + dataset_item_id (str): The ID of the dataset item that was matched. + confusion_category (ConfusionCategory): The confusion category of the match. + model_prediction_label (str): The label of the model prediction that was matched. None if the match was a false negative. + ground_truth_annotation_label (str): The label of the ground truth annotation that was matched. None if the match was a false positive. + """ + + model_run_id: str + model_prediction_id: Optional[str] # field is nullable + ground_truth_annotation_id: Optional[str] # field is nullable + iou: float + dataset_item_id: str + confusion_category: ConfusionCategory + model_prediction_label: Optional[str] # field is nullable + ground_truth_annotation_label: Optional[str] # field is nullable + + @classmethod + def from_json(cls, payload: dict): + is_true_positive = payload.get(TRUE_POSITIVE_KEY, False) + model_prediction_label = payload.get(MODEL_PREDICTION_LABEL_KEY, None) + ground_truth_annotation_label = payload.get( + GROUND_TRUTH_ANNOTATION_LABEL_KEY, None + ) + + confusion_category = infer_confusion_category( + true_positive=is_true_positive, + ground_truth_annotation_label=ground_truth_annotation_label, + model_prediction_label=model_prediction_label, + ) + + return cls( + model_run_id=payload[MODEL_RUN_ID_KEY], + model_prediction_id=payload.get(MODEL_PREDICTION_ID_KEY, None), + ground_truth_annotation_id=payload.get( + GROUND_TRUTH_ANNOTATION_ID_KEY, None + ), + iou=payload[IOU_KEY], + dataset_item_id=payload[DATASET_ITEM_ID_KEY], + confusion_category=confusion_category, + model_prediction_label=model_prediction_label, + ground_truth_annotation_label=ground_truth_annotation_label, + ) diff --git a/pyproject.toml b/pyproject.toml index eadd0ca9..889d87bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ ignore = ["E501", "E741", "E731", "F401"] # Easy ignore for getting it running [tool.poetry] name = "scale-nucleus" -version = "0.16.3" +version = "0.16.4" description = "The official Python client library for Nucleus, the Data Platform for AI" license = "MIT" authors = ["Scale AI Nucleus Team "]