Skip to content

Commit

Permalink
Speed up slices and add client.slices for all customer slices (#339)
Browse files Browse the repository at this point in the history
* Speed up slices and add client.slices for all customer slices

* Speed up unit test listing

* Refactor slow __post_init__ pattern away

* Remove test_reprs

* Update version and CHANGELOG

* Try to increase resource class

* Patch unavailable fromisoformat in Python 3.6

* Bump to 0.14.14
  • Loading branch information
gatli authored Aug 11, 2022
1 parent e90b2c7 commit 2d03c54
Show file tree
Hide file tree
Showing 10 changed files with 126 additions and 89 deletions.
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,20 @@ All notable changes to the [Nucleus Python Client](https://github.com/scaleapi/n
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [0.14.14](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.14.14) - 2022-08-11

### Added
- client.slices to list all of users slices independent of dataset

### Fixed
- Validate unit test listing and evaluation history listing. Now uses new bulk fetch endpoints for faster listing.

## [0.14.13](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.14.13) - 2022-08-10

### Fixed
- Fix payload parsing for scene export


## [0.14.12](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.14.12) - 2022-08-05

### Added
Expand Down
42 changes: 14 additions & 28 deletions cli/slices.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,9 @@ def slices(ctx, web):
@slices.command("list")
def list_slices():
"""List all available Slices"""
with Live(
Spinner("dots4", text="Finding your Slices!"),
vertical_overflow="visible",
) as live:
client = init_client()
datasets = client.datasets
client = init_client()
console = Console()
with console.status("Finding your Slices!", spinner="dots4"):
table = Table(
Column("id", overflow="fold", min_width=24),
"name",
Expand All @@ -37,26 +34,15 @@ def list_slices():
title=":cake: Slices",
title_justify="left",
)
errors = {}
for ds in datasets:
try:
ds_slices = ds.slices
if ds_slices:
for slc_id in ds_slices:
slice_url = nucleus_url(f"{ds.id}/{slc_id}")
slice_info = client.get_slice(slc_id).info()
table.add_row(
slc_id, slice_info["name"], ds.name, slice_url
)
live.update(table)
except NucleusAPIError as e:
errors[ds.id] = e

error_tree = Tree(
":x: Encountered the following errors while fetching information"
)
for ds_id, error in errors.items():
dataset_branch = error_tree.add(f"Dataset: {ds_id}")
dataset_branch.add(f"Error: {error}")
datasets = client.datasets
id_to_datasets = {d.id: d for d in datasets}
all_slices = client.slices
for s in all_slices:
table.add_row(
s.id,
s.name,
id_to_datasets[s.dataset_id].name,
nucleus_url(f"{s.dataset_id}/{s.id}"),
)

Console().print(error_tree)
console.print(table)
6 changes: 6 additions & 0 deletions nucleus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,12 @@ def jobs(
"""
return self.list_jobs()

@property
def slices(self) -> List[Slice]:
response = self.make_request({}, "slice/", requests.get)
slices = [Slice.from_request(info, self) for info in response]
return slices

@deprecated(msg="Use the NucleusClient.models property in the future.")
def list_models(self) -> List[Model]:
return self.models
Expand Down
44 changes: 42 additions & 2 deletions nucleus/slice.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime
import warnings
from typing import Dict, Iterable, List, Set, Tuple, Union
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union

import requests

Expand Down Expand Up @@ -49,16 +50,55 @@ def __init__(self, slice_id: str, client):
self._client = client
self._name = None
self._dataset_id = None
self._created_at = None
self._pending_job_count = None

def __repr__(self):
return f"Slice(slice_id='{self.id}', client={self._client})"
return f"Slice(slice_id='{self.id}', name={self._name}, dataset_id={self._dataset_id})"

def __eq__(self, other):
if self.id == other.id:
if self._client == other._client:
return True
return False

@property
def created_at(self) -> Optional[datetime.datetime]:
"""Timestamp of creation of the slice
Returns:
datetime of creation or None if not created yet
"""
if self._created_at is None:
self._created_at = self.info().get("created_at", None)
return self._created_at

@property
def pending_job_count(self) -> Optional[int]:
if self._pending_job_count is None:
self._pending_job_count = self.info().get(
"pending_job_count", None
)
return self._pending_job_count

@classmethod
def from_request(cls, request, client):
instance = cls(request["id"], client)
instance._name = request.get("name", None)
instance._dataset_id = request.get("dataset_id", None)
created_at_str = request.get("created_at").rstrip("Z")
if hasattr(datetime.datetime, "fromisoformat"):
instance._created_at = datetime.datetime.fromisoformat(
created_at_str
)
else:
fmt_str = r"%Y-%m-%dT%H:%M:%S.%f" # replaces the fromisoformatm, not available in python 3.6
instance._created_at = datetime.datetime.strptime(
created_at_str, fmt_str
)
instance._pending_job_count = request.get("pending_job_count", None)
return instance

@property
def slice_id(self):
warnings.warn(
Expand Down
17 changes: 11 additions & 6 deletions nucleus/validate/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,17 @@ def create_scenario_test(
).dict(),
"validate/scenario_test",
)
return ScenarioTest(response[SCENARIO_TEST_ID_KEY], self.connection)
return ScenarioTest.from_id(
response[SCENARIO_TEST_ID_KEY], self.connection
)

def get_scenario_test(self, scenario_test_id: str) -> ScenarioTest:
response = self.connection.get(
f"validate/scenario_test/{scenario_test_id}",
)
return ScenarioTest(response["unit_test"]["id"], self.connection)
return ScenarioTest.from_id(
response["unit_test"]["id"], self.connection
)

@property
def scenario_tests(self) -> List[ScenarioTest]:
Expand All @@ -131,12 +135,13 @@ def scenario_tests(self) -> List[ScenarioTest]:
A list of ScenarioTest objects.
"""
response = self.connection.get(
"validate/scenario_test",
"validate/scenario_test/details",
)
return [
ScenarioTest(test_id, self.connection)
for test_id in response["scenario_test_ids"]
tests = [
ScenarioTest.from_response(payload, self.connection)
for payload in response
]
return tests

def delete_scenario_test(self, scenario_test_id: str) -> bool:
"""Deletes a Scenario Test. ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,6 @@
from nucleus.pydantic_base import ImmutableModel


class EvalDetail(ImmutableModel):
id: str


class GetEvalHistory(ImmutableModel):
evaluations: List[EvalDetail]


class EvaluationResult(ImmutableModel):
item_ref_id: str
score: float
Expand Down
36 changes: 22 additions & 14 deletions nucleus/validate/scenario_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,7 @@
THRESHOLD_KEY,
ThresholdComparison,
)
from .data_transfer_objects.scenario_test_evaluations import (
EvaluationResult,
GetEvalHistory,
)
from .data_transfer_objects.scenario_test_evaluations import EvaluationResult
from .data_transfer_objects.scenario_test_metric import AddScenarioTestFunction
from .eval_functions.available_eval_functions import (
EvalFunction,
Expand Down Expand Up @@ -52,13 +49,24 @@ class ScenarioTest:
slice_id: str = field(init=False)
baseline_model_id: Optional[str] = None

def __post_init__(self):
@classmethod
def from_id(cls, unit_test_id: str, connection: Connection):
# TODO(gunnar): Remove this pattern. It's too slow. We should get all the info required in one call
response = self.connection.get(
f"validate/scenario_test/{self.id}/info",
response = connection.get(
f"validate/scenario_test/{unit_test_id}/info",
)
self.name = response[NAME_KEY]
self.slice_id = response[SLICE_ID_KEY]
instance = cls(unit_test_id, connection)
instance.name = response[NAME_KEY]
instance.slice_id = response[SLICE_ID_KEY]
return instance

@classmethod
def from_response(cls, response, connection: Connection):
instance = cls(response["id"], connection)
instance.name = response[NAME_KEY]
instance.slice_id = response[SLICE_ID_KEY]
instance.baseline_model_id = response.get("baseline_model_id", None)
return instance

def add_eval_function(
self, eval_function: EvalFunction
Expand Down Expand Up @@ -148,13 +156,13 @@ def get_eval_history(self) -> List[ScenarioTestEvaluation]:
A list of :class:`ScenarioTestEvaluation` objects.
"""
response = self.connection.get(
f"validate/scenario_test/{self.id}/eval_history",
f"validate/scenario_test/{self.id}/eval_history/details",
)
eval_history = GetEvalHistory.parse_obj(response)
return [
ScenarioTestEvaluation(evaluation.id, self.connection)
for evaluation in eval_history.evaluations
evaluations = [
ScenarioTestEvaluation.from_request(eval_payload, self.connection)
for eval_payload in response
]
return evaluations

def get_items(self) -> List[DatasetItem]:
response = self.connection.get(
Expand Down
42 changes: 21 additions & 21 deletions nucleus/validate/scenario_test_evaluation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Data types for Scenario Test Evaluation results."""
from dataclasses import InitVar, dataclass, field
from dataclasses import dataclass, field
from enum import Enum
from typing import List, Optional

Expand Down Expand Up @@ -77,31 +77,30 @@ class ScenarioTestEvaluation:
status: ScenarioTestEvaluationStatus = field(init=False)
result: Optional[float] = field(init=False)
passed: bool = field(init=False)
item_evals: List[ScenarioTestItemEvaluation] = field(init=False)
connection: InitVar[Connection]

def __post_init__(self, connection: Connection):
# TODO(gunnar): Having the function call /info on every construction is too slow. The original
# endpoint should rather return the necessary human-readable information
response = connection.make_request(
connection: Connection = field(init=False, repr=False)

@classmethod
def from_request(cls, response, connection):
instance = cls(response["id"])
instance.connection = connection

instance.scenario_test_id = response[SCENARIO_TEST_ID_KEY]
instance.eval_function_id = response[EVAL_FUNCTION_ID_KEY]
instance.model_id = response[MODEL_ID_KEY]
instance.status = ScenarioTestEvaluationStatus(response[STATUS_KEY])
instance.result = try_convert_float(response[RESULT_KEY])
instance.passed = bool(response[PASS_KEY])
return instance

@property
def item_evals(self) -> List[ScenarioTestItemEvaluation]:
response = self.connection.make_request(
{},
f"validate/eval/{self.id}/info",
requests_command=requests.get,
)
eval_response = response[SCENARIO_TEST_EVAL_KEY]
items_response = response[ITEM_EVAL_KEY]

self.scenario_test_id: str = eval_response[SCENARIO_TEST_ID_KEY]
self.eval_function_id: str = eval_response[EVAL_FUNCTION_ID_KEY]
self.model_id: str = eval_response[MODEL_ID_KEY]
self.status: ScenarioTestEvaluationStatus = (
ScenarioTestEvaluationStatus(eval_response[STATUS_KEY])
)
self.result: Optional[float] = try_convert_float(
eval_response[RESULT_KEY]
)
self.passed: bool = bool(eval_response[PASS_KEY])
self.item_evals: List[ScenarioTestItemEvaluation] = [
items = [
ScenarioTestItemEvaluation(
evaluation_id=res[EVALUATION_ID_KEY],
scenario_test_id=res[SCENARIO_TEST_ID_KEY],
Expand All @@ -112,3 +111,4 @@ def __post_init__(self, connection: Connection):
)
for res in items_response
]
return items
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ exclude = '''

[tool.poetry]
name = "scale-nucleus"
version = "0.14.13"
version = "0.14.14"
description = "The official Python client library for Nucleus, the Data Platform for AI"
license = "MIT"
authors = ["Scale AI Nucleus Team <[email protected]>"]
Expand Down
9 changes: 0 additions & 9 deletions tests/test_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,6 @@ def slc(CLIENT, dataset):
CLIENT.delete_slice(slc.id)


def test_reprs():
# Have to define here in order to have access to all relevant objects
def test_repr(test_object: any):
assert eval(str(test_object)) == test_object

client = NucleusClient(api_key="fake_key")
test_repr(Slice(slice_id="fake_slice_id", client=client))


def test_slice_create_and_delete_and_list(dataset: Dataset):
ds_items = dataset.items

Expand Down

0 comments on commit 2d03c54

Please sign in to comment.