Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cleanlab studio beta api #281

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions cleanlab_studio/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,11 @@ def __init__(self, filepath: Union[str, pathlib.Path] = "") -> None:
if isinstance(filepath, pathlib.Path):
filepath = str(filepath)
super().__init__(f"File could not be found at {filepath}. Please check the file path.")


class BetaJobError(HandledError):
pass


class DownloadResultsError(HandledError):
pass
121 changes: 46 additions & 75 deletions cleanlab_studio/internal/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,52 +40,23 @@
pyspark_exists = False

from cleanlab_studio.errors import NotInstalledError
from cleanlab_studio.internal.api.api_helper import check_uuid_well_formed
from cleanlab_studio.internal.api.api_helper import (
UploadParts,
check_uuid_well_formed,
construct_headers,
handle_api_error,
)
from cleanlab_studio.internal.types import JSONDict, SchemaOverride
from cleanlab_studio.version import __version__

base_url = os.environ.get("CLEANLAB_API_BASE_URL", "https://api.cleanlab.ai/api")
cli_base_url = f"{base_url}/cli/v0"
upload_base_url = f"{base_url}/upload/v1"
dataset_base_url = f"{base_url}/datasets"
project_base_url = f"{base_url}/projects"
cleanset_base_url = f"{base_url}/cleansets"
model_base_url = f"{base_url}/v1/deployment"
tlm_base_url = f"{base_url}/v0/trustworthy_llm"


def _construct_headers(
api_key: Optional[str], content_type: Optional[str] = "application/json"
) -> JSONDict:
retval = dict()
if api_key:
retval["Authorization"] = f"bearer {api_key}"
if content_type:
retval["Content-Type"] = content_type
retval["Client-Type"] = "python-api"
return retval


def handle_api_error(res: requests.Response) -> None:
handle_api_error_from_json(res.json(), res.status_code)


def handle_api_error_from_json(res_json: JSONDict, status_code: Optional[int] = None) -> None:
if "code" in res_json and "description" in res_json: # AuthError or UserQuotaError format
if res_json["code"] == "user_soft_quota_exceeded":
pass # soft quota limit is going away soon, so ignore it
else:
raise APIError(res_json["description"])

if res_json.get("error", None) is not None:
error = res_json["error"]
if (
status_code == 422
and isinstance(error, dict)
and error.get("code", None) == "UNSUPPORTED_PROJECT_CONFIGURATION"
):
raise InvalidProjectConfiguration(error["description"])
raise APIError(res_json["error"])
API_BASE_URL = os.environ.get("CLEANLAB_API_BASE_URL", "https://api.cleanlab.ai/api")
cli_base_url = f"{API_BASE_URL}/cli/v0"
upload_base_url = f"{API_BASE_URL}/upload/v1"
dataset_base_url = f"{API_BASE_URL}/datasets"
project_base_url = f"{API_BASE_URL}/projects"
cleanset_base_url = f"{API_BASE_URL}/cleansets"
model_base_url = f"{API_BASE_URL}/v1/deployment"
tlm_base_url = f"{API_BASE_URL}/v0/trustworthy_llm"


def handle_rate_limit_error_from_resp(resp: aiohttp.ClientResponse) -> None:
Expand Down Expand Up @@ -134,7 +105,7 @@ def validate_api_key(api_key: str) -> bool:
res = requests.get(
cli_base_url + "/validate",
json=dict(api_key=api_key),
headers=_construct_headers(api_key),
headers=construct_headers(api_key),
)
handle_api_error(res)
valid: bool = res.json()["valid"]
Expand All @@ -154,7 +125,7 @@ def initialize_upload(
res = requests.post(
f"{upload_base_url}/file/initialize",
json=dict(size_in_bytes=str(file_size), filename=filename, file_type=file_type),
headers=_construct_headers(api_key),
headers=construct_headers(api_key),
)
handle_api_error(res)
upload_id: str = res.json()["upload_id"]
Expand All @@ -163,13 +134,13 @@ def initialize_upload(
return upload_id, part_sizes, presigned_posts


def complete_file_upload(api_key: str, upload_id: str, upload_parts: List[JSONDict]) -> None:
def complete_file_upload(api_key: str, upload_id: str, upload_parts: UploadParts) -> None:
check_uuid_well_formed(upload_id, "upload ID")
request_json = dict(upload_id=upload_id, upload_parts=upload_parts)
res = requests.post(
f"{upload_base_url}/file/complete",
json=request_json,
headers=_construct_headers(api_key),
headers=construct_headers(api_key),
)
handle_api_error(res)

Expand All @@ -184,7 +155,7 @@ def confirm_upload(
res = requests.post(
f"{upload_base_url}/confirm",
json=request_json,
headers=_construct_headers(api_key),
headers=construct_headers(api_key),
)
handle_api_error(res)

Expand All @@ -199,7 +170,7 @@ def update_schema(
res = requests.patch(
f"{upload_base_url}/schema",
json=request_json,
headers=_construct_headers(api_key),
headers=construct_headers(api_key),
)
handle_api_error(res)

Expand All @@ -209,7 +180,7 @@ def get_ingestion_status(api_key: str, upload_id: str) -> JSONDict:
res = requests.get(
f"{upload_base_url}/total_progress",
params=dict(upload_id=upload_id),
headers=_construct_headers(api_key),
headers=construct_headers(api_key),
)
handle_api_error(res)
res_json: JSONDict = res.json()
Expand All @@ -221,7 +192,7 @@ def get_dataset_id(api_key: str, upload_id: str) -> JSONDict:
res = requests.get(
f"{upload_base_url}/dataset_id",
params=dict(upload_id=upload_id),
headers=_construct_headers(api_key),
headers=construct_headers(api_key),
)
handle_api_error(res)
res_json: JSONDict = res.json()
Expand All @@ -232,7 +203,7 @@ def get_project_of_cleanset(api_key: str, cleanset_id: str) -> str:
check_uuid_well_formed(cleanset_id, "cleanset ID")
res = requests.get(
cli_base_url + f"/cleansets/{cleanset_id}/project",
headers=_construct_headers(api_key),
headers=construct_headers(api_key),
)
handle_api_error(res)
project_id: str = res.json()["project_id"]
Expand All @@ -243,7 +214,7 @@ def get_label_column_of_project(api_key: str, project_id: str) -> str:
check_uuid_well_formed(project_id, "project ID")
res = requests.get(
cli_base_url + f"/projects/{project_id}/label_column",
headers=_construct_headers(api_key),
headers=construct_headers(api_key),
)
handle_api_error(res)
label_column: str = res.json()["label_column"]
Expand Down Expand Up @@ -274,7 +245,7 @@ def download_cleanlab_columns(
include_cleanlab_columns=include_cleanlab_columns,
include_project_details=include_project_details,
),
headers=_construct_headers(api_key),
headers=construct_headers(api_key),
)
handle_api_error(res)
id_col = get_id_column(api_key, cleanset_id)
Expand Down Expand Up @@ -306,7 +277,7 @@ def download_array(
check_uuid_well_formed(cleanset_id, "cleanset ID")
res = requests.get(
cli_base_url + f"/cleansets/{cleanset_id}/{name}",
headers=_construct_headers(api_key),
headers=construct_headers(api_key),
)
handle_api_error(res)
res_json: JSONDict = res.json()
Expand All @@ -323,7 +294,7 @@ def get_id_column(api_key: str, cleanset_id: str) -> str:
check_uuid_well_formed(cleanset_id, "cleanset ID")
res = requests.get(
cli_base_url + f"/cleansets/{cleanset_id}/id_column",
headers=_construct_headers(api_key),
headers=construct_headers(api_key),
)
handle_api_error(res)
id_column: str = res.json()["id_column"]
Expand All @@ -334,7 +305,7 @@ def get_dataset_of_project(api_key: str, project_id: str) -> str:
check_uuid_well_formed(project_id, "project ID")
res = requests.get(
cli_base_url + f"/projects/{project_id}/dataset",
headers=_construct_headers(api_key),
headers=construct_headers(api_key),
)
handle_api_error(res)
dataset_id: str = res.json()["dataset_id"]
Expand All @@ -345,7 +316,7 @@ def get_dataset_schema(api_key: str, dataset_id: str) -> JSONDict:
check_uuid_well_formed(dataset_id, "dataset ID")
res = requests.get(
cli_base_url + f"/datasets/{dataset_id}/schema",
headers=_construct_headers(api_key),
headers=construct_headers(api_key),
)
handle_api_error(res)
schema: JSONDict = res.json()["schema"]
Expand All @@ -357,7 +328,7 @@ def get_dataset_details(api_key: str, dataset_id: str, task_type: Optional[str])
res = requests.get(
project_base_url + f"/dataset_details/{dataset_id}",
params=dict(tasktype=task_type),
headers=_construct_headers(api_key),
headers=construct_headers(api_key),
)
handle_api_error(res)
dataset_details: JSONDict = res.json()
Expand All @@ -368,7 +339,7 @@ def check_column_diversity(api_key: str, dataset_id: str, column_name: str) -> J
check_uuid_well_formed(dataset_id, "dataset ID")
res = requests.get(
dataset_base_url + f"/diversity/{dataset_id}/{column_name}",
headers=_construct_headers(api_key),
headers=construct_headers(api_key),
)
handle_api_error(res)
column_diversity: JSONDict = res.json()
Expand All @@ -379,7 +350,7 @@ def is_valid_multilabel_column(api_key: str, dataset_id: str, column_name: str)
check_uuid_well_formed(dataset_id, "dataset ID")
res = requests.get(
dataset_base_url + f"/check_valid_multilabel/{dataset_id}/{column_name}",
headers=_construct_headers(api_key),
headers=construct_headers(api_key),
)
handle_api_error(res)
multilabel_column: JSONDict = res.json()
Expand Down Expand Up @@ -410,7 +381,7 @@ def clean_dataset(
)
res = requests.post(
project_base_url + f"/clean",
headers=_construct_headers(api_key),
headers=construct_headers(api_key),
json=request_json,
)
handle_api_error(res)
Expand All @@ -422,7 +393,7 @@ def get_latest_cleanset_id(api_key: str, project_id: str) -> str:
check_uuid_well_formed(project_id, "project ID")
res = requests.get(
cleanset_base_url + f"/project/{project_id}/latest_cleanset_id",
headers=_construct_headers(api_key),
headers=construct_headers(api_key),
)
handle_api_error(res)
cleanset_id = res.json()["cleanset_id"]
Expand All @@ -448,7 +419,7 @@ def get_dataset_id_for_name(
res = requests.get(
dataset_base_url + f"/dataset_id_for_name",
params=dict(dataset_name=dataset_name),
headers=_construct_headers(api_key),
headers=construct_headers(api_key),
)
handle_api_error(res)
return cast(Optional[str], res.json().get("dataset_id", None))
Expand All @@ -458,7 +429,7 @@ def get_cleanset_status(api_key: str, cleanset_id: str) -> JSONDict:
check_uuid_well_formed(cleanset_id, "cleanset ID")
res = requests.get(
cleanset_base_url + f"/{cleanset_id}/status",
headers=_construct_headers(api_key),
headers=construct_headers(api_key),
)
handle_api_error(res)
status: JSONDict = res.json()
Expand All @@ -467,13 +438,13 @@ def get_cleanset_status(api_key: str, cleanset_id: str) -> JSONDict:

def delete_dataset(api_key: str, dataset_id: str) -> None:
check_uuid_well_formed(dataset_id, "dataset ID")
res = requests.delete(dataset_base_url + f"/{dataset_id}", headers=_construct_headers(api_key))
res = requests.delete(dataset_base_url + f"/{dataset_id}", headers=construct_headers(api_key))
handle_api_error(res)


def delete_project(api_key: str, project_id: str) -> None:
check_uuid_well_formed(project_id, "project ID")
res = requests.delete(project_base_url + f"/{project_id}", headers=_construct_headers(api_key))
res = requests.delete(project_base_url + f"/{project_id}", headers=construct_headers(api_key))
handle_api_error(res)


Expand Down Expand Up @@ -528,7 +499,7 @@ def deploy_model(api_key: str, cleanset_id: str, model_name: str) -> str:
check_uuid_well_formed(cleanset_id, "cleanset ID")
res = requests.post(
model_base_url,
headers=_construct_headers(api_key),
headers=construct_headers(api_key),
json=dict(cleanset_id=cleanset_id, deployment_name=model_name),
)

Expand All @@ -542,7 +513,7 @@ def get_deployment_status(api_key: str, model_id: str) -> str:
check_uuid_well_formed(model_id, "model ID")
res = requests.get(
f"{model_base_url}/{model_id}",
headers=_construct_headers(api_key),
headers=construct_headers(api_key),
)
handle_api_error(res)
deployment: JSONDict = res.json()
Expand All @@ -555,7 +526,7 @@ def upload_predict_batch(api_key: str, model_id: str, batch: io.StringIO) -> str
url = f"{model_base_url}/{model_id}/upload"
res = requests.post(
url,
headers=_construct_headers(api_key),
headers=construct_headers(api_key),
)

handle_api_error(res)
Expand All @@ -573,7 +544,7 @@ def start_prediction(api_key: str, model_id: str, query_id: str) -> None:
check_uuid_well_formed(query_id, "query ID")
res = requests.post(
f"{model_base_url}/{model_id}/predict/{query_id}",
headers=_construct_headers(api_key),
headers=construct_headers(api_key),
)

handle_api_error(res)
Expand All @@ -584,7 +555,7 @@ def get_prediction_status(api_key: str, query_id: str) -> Dict[str, str]:
check_uuid_well_formed(query_id, "query ID")
res = requests.get(
f"{model_base_url}/predict/{query_id}",
headers=_construct_headers(api_key),
headers=construct_headers(api_key),
)
handle_api_error(res)

Expand All @@ -596,7 +567,7 @@ def get_deployed_model_info(api_key: str, model_id: str) -> Dict[str, str]:
check_uuid_well_formed(model_id, "model ID")
res = requests.get(
f"{model_base_url}/{model_id}",
headers=_construct_headers(api_key),
headers=construct_headers(api_key),
)
handle_api_error(res)

Expand Down Expand Up @@ -672,7 +643,7 @@ async def tlm_prompt(
res = await client_session.post(
f"{tlm_base_url}/prompt",
json=dict(prompt=prompt, quality=quality_preset, options=options or {}),
headers=_construct_headers(api_key),
headers=construct_headers(api_key),
)

res_json = await res.json()
Expand Down Expand Up @@ -733,7 +704,7 @@ async def tlm_get_confidence_score(
quality=quality_preset,
options=options or {},
),
headers=_construct_headers(api_key),
headers=construct_headers(api_key),
)

res_json = await res.json()
Expand Down
Loading
Loading