From 2d03c54c479d257c81633b182ff87e11164c3233 Mon Sep 17 00:00:00 2001 From: Gunnar Atli Thoroddsen Date: Thu, 11 Aug 2022 13:49:00 +0200 Subject: [PATCH] Speed up slices and add client.slices for all customer slices (#339) * Speed up slices and add client.slices for all customer slices * Speed up unit test listing * Refactor slow __post_init__ pattern away * Remove test_reprs * Update version and CHANGELOG * Try to increase resource class * Patch unavailable fromisoformat in Python 3.6 * Bump to 0.14.14 --- CHANGELOG.md | 9 ++++ cli/slices.py | 42 ++++++------------ nucleus/__init__.py | 6 +++ nucleus/slice.py | 44 ++++++++++++++++++- nucleus/validate/client.py | 17 ++++--- .../scenario_test_evaluations.py | 8 ---- nucleus/validate/scenario_test.py | 36 +++++++++------ nucleus/validate/scenario_test_evaluation.py | 42 +++++++++--------- pyproject.toml | 2 +- tests/test_slice.py | 9 ---- 10 files changed, 126 insertions(+), 89 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8ce7bc41..e856d01b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,11 +5,20 @@ All notable changes to the [Nucleus Python Client](https://github.com/scaleapi/n The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.14.14](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.14.14) - 2022-08-11 + +### Added +- client.slices to list all of users slices independent of dataset + +### Fixed +- Validate unit test listing and evaluation history listing. Now uses new bulk fetch endpoints for faster listing. + ## [0.14.13](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.14.13) - 2022-08-10 ### Fixed - Fix payload parsing for scene export + ## [0.14.12](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.14.12) - 2022-08-05 ### Added diff --git a/cli/slices.py b/cli/slices.py index 11dad560..25eecf9c 100644 --- a/cli/slices.py +++ b/cli/slices.py @@ -23,12 +23,9 @@ def slices(ctx, web): @slices.command("list") def list_slices(): """List all available Slices""" - with Live( - Spinner("dots4", text="Finding your Slices!"), - vertical_overflow="visible", - ) as live: - client = init_client() - datasets = client.datasets + client = init_client() + console = Console() + with console.status("Finding your Slices!", spinner="dots4"): table = Table( Column("id", overflow="fold", min_width=24), "name", @@ -37,26 +34,15 @@ def list_slices(): title=":cake: Slices", title_justify="left", ) - errors = {} - for ds in datasets: - try: - ds_slices = ds.slices - if ds_slices: - for slc_id in ds_slices: - slice_url = nucleus_url(f"{ds.id}/{slc_id}") - slice_info = client.get_slice(slc_id).info() - table.add_row( - slc_id, slice_info["name"], ds.name, slice_url - ) - live.update(table) - except NucleusAPIError as e: - errors[ds.id] = e - - error_tree = Tree( - ":x: Encountered the following errors while fetching information" - ) - for ds_id, error in errors.items(): - dataset_branch = error_tree.add(f"Dataset: {ds_id}") - dataset_branch.add(f"Error: {error}") + datasets = client.datasets + id_to_datasets = {d.id: d for d in datasets} + all_slices = client.slices + for s in all_slices: + table.add_row( + s.id, + s.name, + id_to_datasets[s.dataset_id].name, + nucleus_url(f"{s.dataset_id}/{s.id}"), + ) - Console().print(error_tree) + console.print(table) diff --git a/nucleus/__init__.py b/nucleus/__init__.py index f695ba18..aaaca51d 100644 --- a/nucleus/__init__.py +++ b/nucleus/__init__.py @@ -235,6 +235,12 @@ def jobs( """ return self.list_jobs() + @property + def slices(self) -> List[Slice]: + response = self.make_request({}, "slice/", requests.get) + slices = [Slice.from_request(info, self) for info in response] + return slices + @deprecated(msg="Use the NucleusClient.models property in the future.") def list_models(self) -> List[Model]: return self.models diff --git a/nucleus/slice.py b/nucleus/slice.py index b25b145f..03070938 100644 --- a/nucleus/slice.py +++ b/nucleus/slice.py @@ -1,5 +1,6 @@ +import datetime import warnings -from typing import Dict, Iterable, List, Set, Tuple, Union +from typing import Dict, Iterable, List, Optional, Set, Tuple, Union import requests @@ -49,9 +50,11 @@ def __init__(self, slice_id: str, client): self._client = client self._name = None self._dataset_id = None + self._created_at = None + self._pending_job_count = None def __repr__(self): - return f"Slice(slice_id='{self.id}', client={self._client})" + return f"Slice(slice_id='{self.id}', name={self._name}, dataset_id={self._dataset_id})" def __eq__(self, other): if self.id == other.id: @@ -59,6 +62,43 @@ def __eq__(self, other): return True return False + @property + def created_at(self) -> Optional[datetime.datetime]: + """Timestamp of creation of the slice + + Returns: + datetime of creation or None if not created yet + """ + if self._created_at is None: + self._created_at = self.info().get("created_at", None) + return self._created_at + + @property + def pending_job_count(self) -> Optional[int]: + if self._pending_job_count is None: + self._pending_job_count = self.info().get( + "pending_job_count", None + ) + return self._pending_job_count + + @classmethod + def from_request(cls, request, client): + instance = cls(request["id"], client) + instance._name = request.get("name", None) + instance._dataset_id = request.get("dataset_id", None) + created_at_str = request.get("created_at").rstrip("Z") + if hasattr(datetime.datetime, "fromisoformat"): + instance._created_at = datetime.datetime.fromisoformat( + created_at_str + ) + else: + fmt_str = r"%Y-%m-%dT%H:%M:%S.%f" # replaces the fromisoformatm, not available in python 3.6 + instance._created_at = datetime.datetime.strptime( + created_at_str, fmt_str + ) + instance._pending_job_count = request.get("pending_job_count", None) + return instance + @property def slice_id(self): warnings.warn( diff --git a/nucleus/validate/client.py b/nucleus/validate/client.py index dc330faf..5d94173c 100644 --- a/nucleus/validate/client.py +++ b/nucleus/validate/client.py @@ -107,13 +107,17 @@ def create_scenario_test( ).dict(), "validate/scenario_test", ) - return ScenarioTest(response[SCENARIO_TEST_ID_KEY], self.connection) + return ScenarioTest.from_id( + response[SCENARIO_TEST_ID_KEY], self.connection + ) def get_scenario_test(self, scenario_test_id: str) -> ScenarioTest: response = self.connection.get( f"validate/scenario_test/{scenario_test_id}", ) - return ScenarioTest(response["unit_test"]["id"], self.connection) + return ScenarioTest.from_id( + response["unit_test"]["id"], self.connection + ) @property def scenario_tests(self) -> List[ScenarioTest]: @@ -131,12 +135,13 @@ def scenario_tests(self) -> List[ScenarioTest]: A list of ScenarioTest objects. """ response = self.connection.get( - "validate/scenario_test", + "validate/scenario_test/details", ) - return [ - ScenarioTest(test_id, self.connection) - for test_id in response["scenario_test_ids"] + tests = [ + ScenarioTest.from_response(payload, self.connection) + for payload in response ] + return tests def delete_scenario_test(self, scenario_test_id: str) -> bool: """Deletes a Scenario Test. :: diff --git a/nucleus/validate/data_transfer_objects/scenario_test_evaluations.py b/nucleus/validate/data_transfer_objects/scenario_test_evaluations.py index b0fa30a4..14205ba8 100644 --- a/nucleus/validate/data_transfer_objects/scenario_test_evaluations.py +++ b/nucleus/validate/data_transfer_objects/scenario_test_evaluations.py @@ -5,14 +5,6 @@ from nucleus.pydantic_base import ImmutableModel -class EvalDetail(ImmutableModel): - id: str - - -class GetEvalHistory(ImmutableModel): - evaluations: List[EvalDetail] - - class EvaluationResult(ImmutableModel): item_ref_id: str score: float diff --git a/nucleus/validate/scenario_test.py b/nucleus/validate/scenario_test.py index 189f5fa8..90bb97b4 100644 --- a/nucleus/validate/scenario_test.py +++ b/nucleus/validate/scenario_test.py @@ -18,10 +18,7 @@ THRESHOLD_KEY, ThresholdComparison, ) -from .data_transfer_objects.scenario_test_evaluations import ( - EvaluationResult, - GetEvalHistory, -) +from .data_transfer_objects.scenario_test_evaluations import EvaluationResult from .data_transfer_objects.scenario_test_metric import AddScenarioTestFunction from .eval_functions.available_eval_functions import ( EvalFunction, @@ -52,13 +49,24 @@ class ScenarioTest: slice_id: str = field(init=False) baseline_model_id: Optional[str] = None - def __post_init__(self): + @classmethod + def from_id(cls, unit_test_id: str, connection: Connection): # TODO(gunnar): Remove this pattern. It's too slow. We should get all the info required in one call - response = self.connection.get( - f"validate/scenario_test/{self.id}/info", + response = connection.get( + f"validate/scenario_test/{unit_test_id}/info", ) - self.name = response[NAME_KEY] - self.slice_id = response[SLICE_ID_KEY] + instance = cls(unit_test_id, connection) + instance.name = response[NAME_KEY] + instance.slice_id = response[SLICE_ID_KEY] + return instance + + @classmethod + def from_response(cls, response, connection: Connection): + instance = cls(response["id"], connection) + instance.name = response[NAME_KEY] + instance.slice_id = response[SLICE_ID_KEY] + instance.baseline_model_id = response.get("baseline_model_id", None) + return instance def add_eval_function( self, eval_function: EvalFunction @@ -148,13 +156,13 @@ def get_eval_history(self) -> List[ScenarioTestEvaluation]: A list of :class:`ScenarioTestEvaluation` objects. """ response = self.connection.get( - f"validate/scenario_test/{self.id}/eval_history", + f"validate/scenario_test/{self.id}/eval_history/details", ) - eval_history = GetEvalHistory.parse_obj(response) - return [ - ScenarioTestEvaluation(evaluation.id, self.connection) - for evaluation in eval_history.evaluations + evaluations = [ + ScenarioTestEvaluation.from_request(eval_payload, self.connection) + for eval_payload in response ] + return evaluations def get_items(self) -> List[DatasetItem]: response = self.connection.get( diff --git a/nucleus/validate/scenario_test_evaluation.py b/nucleus/validate/scenario_test_evaluation.py index 0fafc141..f1ccb462 100644 --- a/nucleus/validate/scenario_test_evaluation.py +++ b/nucleus/validate/scenario_test_evaluation.py @@ -1,5 +1,5 @@ """Data types for Scenario Test Evaluation results.""" -from dataclasses import InitVar, dataclass, field +from dataclasses import dataclass, field from enum import Enum from typing import List, Optional @@ -77,31 +77,30 @@ class ScenarioTestEvaluation: status: ScenarioTestEvaluationStatus = field(init=False) result: Optional[float] = field(init=False) passed: bool = field(init=False) - item_evals: List[ScenarioTestItemEvaluation] = field(init=False) - connection: InitVar[Connection] - - def __post_init__(self, connection: Connection): - # TODO(gunnar): Having the function call /info on every construction is too slow. The original - # endpoint should rather return the necessary human-readable information - response = connection.make_request( + connection: Connection = field(init=False, repr=False) + + @classmethod + def from_request(cls, response, connection): + instance = cls(response["id"]) + instance.connection = connection + + instance.scenario_test_id = response[SCENARIO_TEST_ID_KEY] + instance.eval_function_id = response[EVAL_FUNCTION_ID_KEY] + instance.model_id = response[MODEL_ID_KEY] + instance.status = ScenarioTestEvaluationStatus(response[STATUS_KEY]) + instance.result = try_convert_float(response[RESULT_KEY]) + instance.passed = bool(response[PASS_KEY]) + return instance + + @property + def item_evals(self) -> List[ScenarioTestItemEvaluation]: + response = self.connection.make_request( {}, f"validate/eval/{self.id}/info", requests_command=requests.get, ) - eval_response = response[SCENARIO_TEST_EVAL_KEY] items_response = response[ITEM_EVAL_KEY] - - self.scenario_test_id: str = eval_response[SCENARIO_TEST_ID_KEY] - self.eval_function_id: str = eval_response[EVAL_FUNCTION_ID_KEY] - self.model_id: str = eval_response[MODEL_ID_KEY] - self.status: ScenarioTestEvaluationStatus = ( - ScenarioTestEvaluationStatus(eval_response[STATUS_KEY]) - ) - self.result: Optional[float] = try_convert_float( - eval_response[RESULT_KEY] - ) - self.passed: bool = bool(eval_response[PASS_KEY]) - self.item_evals: List[ScenarioTestItemEvaluation] = [ + items = [ ScenarioTestItemEvaluation( evaluation_id=res[EVALUATION_ID_KEY], scenario_test_id=res[SCENARIO_TEST_ID_KEY], @@ -112,3 +111,4 @@ def __post_init__(self, connection: Connection): ) for res in items_response ] + return items diff --git a/pyproject.toml b/pyproject.toml index df6a8395..6ccff52b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ exclude = ''' [tool.poetry] name = "scale-nucleus" -version = "0.14.13" +version = "0.14.14" description = "The official Python client library for Nucleus, the Data Platform for AI" license = "MIT" authors = ["Scale AI Nucleus Team "] diff --git a/tests/test_slice.py b/tests/test_slice.py index 184db6aa..e9e19a09 100644 --- a/tests/test_slice.py +++ b/tests/test_slice.py @@ -35,15 +35,6 @@ def slc(CLIENT, dataset): CLIENT.delete_slice(slc.id) -def test_reprs(): - # Have to define here in order to have access to all relevant objects - def test_repr(test_object: any): - assert eval(str(test_object)) == test_object - - client = NucleusClient(api_key="fake_key") - test_repr(Slice(slice_id="fake_slice_id", client=client)) - - def test_slice_create_and_delete_and_list(dataset: Dataset): ds_items = dataset.items