Skip to content

Commit

Permalink
mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
axl1313 committed Aug 12, 2024
1 parent 9cd1647 commit 1570a61
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 19 deletions.
30 changes: 15 additions & 15 deletions cleanlab_studio/internal/api/beta_api.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from typing import Any, Dict, List
from typing import List, cast

import requests

from .api import API_BASE_URL, construct_headers
from .api_helper import JSONDict, UploadParts, handle_api_error
from cleanlab_studio.internal.types import JSONDict

from .api import API_BASE_URL
from .api_helper import UploadParts, construct_headers, handle_api_error

experimental_jobs_base_url = f"{API_BASE_URL}/v0/experimental_jobs"


def initialize_upload(
api_key: str, filename: str, file_type: str, file_size: int
) -> Dict[str, Any]:
def initialize_upload(api_key: str, filename: str, file_type: str, file_size: int) -> JSONDict:
url = f"{experimental_jobs_base_url}/upload/initialize"
headers = construct_headers(api_key)
data = {
Expand All @@ -20,7 +20,7 @@ def initialize_upload(
}
resp = requests.post(url, headers=headers, json=data)
resp.raise_for_status()
return resp.json()
return cast(JSONDict, resp.json())


def complete_upload(api_key: str, dataset_id: str, upload_parts: UploadParts) -> JSONDict:
Expand All @@ -32,15 +32,15 @@ def complete_upload(api_key: str, dataset_id: str, upload_parts: UploadParts) ->
}
resp = requests.post(url, headers=headers, json=data)
handle_api_error(resp)
return resp.json()
return cast(JSONDict, resp.json())


def get_dataset(api_key: str, dataset_id: str) -> JSONDict:
url = f"{experimental_jobs_base_url}/datasets/{dataset_id}"
headers = construct_headers(api_key)
resp = requests.get(url, headers=headers)
handle_api_error(resp)
return resp.json()
return cast(JSONDict, resp.json())


def run_job(api_key: str, dataset_id: str, job_definition_name: str) -> JSONDict:
Expand All @@ -52,44 +52,44 @@ def run_job(api_key: str, dataset_id: str, job_definition_name: str) -> JSONDict
}
resp = requests.post(url, headers=headers, json=data)
handle_api_error(resp)
return resp.json()
return cast(JSONDict, resp.json())


def get_job(api_key: str, job_id: str) -> JSONDict:
url = f"{experimental_jobs_base_url}/{job_id}"
headers = construct_headers(api_key)
resp = requests.get(url, headers=headers)
handle_api_error(resp)
return resp.json()
return cast(JSONDict, resp.json())


def get_job_status(api_key: str, job_id: str) -> JSONDict:
url = f"{experimental_jobs_base_url}/{job_id}/status"
headers = construct_headers(api_key)
resp = requests.get(url, headers=headers)
resp.raise_for_status()
return resp.json()
return cast(JSONDict, resp.json())


def get_results(api_key: str, job_id: str) -> JSONDict:
url = f"{experimental_jobs_base_url}/{job_id}/results"
headers = construct_headers(api_key)
resp = requests.get(url, headers=headers)
resp.raise_for_status()
return resp.json()
return cast(JSONDict, resp.json())


def list_datasets(api_key: str) -> List[JSONDict]:
url = f"{experimental_jobs_base_url}/datasets"
headers = construct_headers(api_key)
resp = requests.get(url, headers=headers)
handle_api_error(resp)
return resp.json()["datasets"]
return cast(List[JSONDict], resp.json()["datasets"])


def list_jobs(api_key: str) -> List[JSONDict]:
url = f"{experimental_jobs_base_url}/jobs"
headers = construct_headers(api_key)
resp = requests.get(url, headers=headers)
handle_api_error(resp)
return resp.json()["jobs"]
return cast(List[JSONDict], resp.json()["jobs"])
2 changes: 1 addition & 1 deletion cleanlab_studio/internal/studio_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from aiohttp_retry import Optional
from typing import Optional

from cleanlab_studio.errors import MissingAPIKeyError, VersionError
from cleanlab_studio.internal.api import api
Expand Down
5 changes: 3 additions & 2 deletions cleanlab_studio/studio_beta/beta_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,13 @@ def wait_until_ready(self, timeout: Optional[int] = None) -> None:

def download_results(self, output_filepath: str) -> None:
output_path = pathlib.Path(output_filepath)
if self.status != JobStatus.READY:
raise BetaJobError("Job must be ready to download results")

if self.status == JobStatus.FAILED:
raise BetaJobError("Job failed, cannot download results")

if self.status != JobStatus.READY:
raise BetaJobError("Job must be ready to download results")

results = get_results(self._api_key, self.id)
if output_path.suffix != results["result_file_type"]:
raise DownloadResultsError(
Expand Down
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
python_requires=">=3.8",
install_requires=[
"aiohttp>=3.8.1",
"aiohttp-retry>=2.4.0",
"Click>=8.1.0,<=8.1.3",
"colorama>=0.4.4",
"nest_asyncio>=1.5.0",
Expand Down

0 comments on commit 1570a61

Please sign in to comment.