Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add deploy cli #258

Open
wants to merge 10 commits into
base: hmi-release
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added nucleus/deploy/cli/__init__.py
Empty file.
27 changes: 27 additions & 0 deletions nucleus/deploy/cli/bin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import click

from nucleus.deploy.cli.bundles import bundles
from nucleus.deploy.cli.endpoints import endpoints


@click.group("cli")
def entry_point():
"""Launch CLI

\b
██╗ █████╗ ██╗ ██╗███╗ ██╗ ██████╗██╗ ██╗
██║ ██╔══██╗██║ ██║████╗ ██║██╔════╝██║ ██║
██║ ███████║██║ ██║██╔██╗ ██║██║ ███████║
██║ ██╔══██║██║ ██║██║╚██╗██║██║ ██╔══██║
███████╗██║ ██║╚██████╔╝██║ ╚████║╚██████╗██║ ██║
╚══════╝╚═╝ ╚═╝ ╚═════╝ ╚═╝ ╚═══╝ ╚═════╝╚═╝ ╚═╝

`scale-launch` is a command line interface to interact with Scale Launch
"""


entry_point.add_command(bundles) # type: ignore
entry_point.add_command(endpoints) # type: ignore

if __name__ == "__main__":
entry_point()
72 changes: 72 additions & 0 deletions nucleus/deploy/cli/bundles.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import click
from rich.console import Console
from rich.syntax import Syntax
from rich.table import Column, Table

from nucleus.deploy.cli.client import init_client


@click.group("bundles")
def bundles():
"""Bundles is a wrapper around model bundles in Scale Launch"""


@bundles.command("list")
def list_bundles():
"""List all of your Bundles"""
client = init_client()

table = Table(
Column("Bundle Id", overflow="fold", min_width=24),
"Bundle name",
"Location",
"Packaging type",
title="Bundles",
title_justify="left",
)

for model_bundle in client.list_model_bundles():
table.add_row(
model_bundle.bundle_id,
model_bundle.bundle_name,
model_bundle.location,
model_bundle.packaging_type,
)
console = Console()
console.print(table)


@bundles.command("get")
@click.argument("bundle_name")
def get_bundle(bundle_name):
"""Print bundle info"""
client = init_client()

model_bundle = client.get_model_bundle(bundle_name)

console = Console()
console.print(f"bundle_id: {model_bundle.bundle_id}")
console.print(f"bundle_name: {model_bundle.bundle_name}")
console.print(f"location: {model_bundle.location}")
console.print(f"packaging_type: {model_bundle.packaging_type}")
console.print(f"env_params: {model_bundle.env_params}")
console.print(f"requirements: {model_bundle.requirements}")

console.print("metadata:")
for meta_name, meta_value in model_bundle.metadata.items():
# TODO print non-code metadata differently
console.print(f"{meta_name}:", style="yellow")
syntax = Syntax(meta_value, "python")
console.print(syntax)


@bundles.command("delete")
@click.argument("bundle_name")
def delete_bundle(bundle_name):
"""Delete a model bundle"""
client = init_client()

console = Console()
model_bundle = client.get_model_bundle(bundle_name)
res = client.delete_model_bundle(model_bundle)
console.print(res)
14 changes: 14 additions & 0 deletions nucleus/deploy/cli/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import functools
import os

import nucleus


@functools.lru_cache()
def init_client():
api_key = os.environ.get("LAUNCH_API_KEY", None)
if api_key:
client = nucleus.deploy.DeployClient(api_key)
else:
raise RuntimeError("No LAUNCH_API_KEY set")
return client
48 changes: 48 additions & 0 deletions nucleus/deploy/cli/endpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import click
from rich.console import Console
from rich.table import Table

from nucleus.deploy.cli.client import init_client
from nucleus.deploy.model_endpoint import AsyncModelEndpoint, Endpoint


@click.group("endpoints")
def endpoints():
"""Endpoints is a wrapper around model bundles in Scale Launch"""


@endpoints.command("list")
def list_endpoints():
"""List all of your Bundles"""
client = init_client()

table = Table(
"Endpoint name",
"Metadata",
"Endpoint type",
title="Endpoints",
title_justify="left",
)

for endpoint_sync_async in client.list_model_endpoints():
endpoint = endpoint_sync_async.endpoint
table.add_row(
endpoint.name,
endpoint.metadata,
endpoint.endpoint_type,
)
console = Console()
console.print(table)


@endpoints.command("delete")
@click.argument("endpoint_name")
def delete_bundle(endpoint_name):
"""Delete a model bundle"""
client = init_client()

console = Console()
endpoint = Endpoint(name=endpoint_name)
dummy_endpoint = AsyncModelEndpoint(endpoint=endpoint, client=client)
res = client.delete_model_endpoint(dummy_endpoint)
console.print(res)
40 changes: 31 additions & 9 deletions nucleus/deploy/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@
get_imports,
)
from nucleus.deploy.model_bundle import ModelBundle
from nucleus.deploy.model_endpoint import AsyncModelEndpoint, SyncModelEndpoint
from nucleus.deploy.model_endpoint import (
AsyncModelEndpoint,
Endpoint,
SyncModelEndpoint,
)
from nucleus.deploy.request_validation import validate_task_request

DEFAULT_NETWORK_TIMEOUT_SEC = 120
Expand Down Expand Up @@ -384,7 +388,7 @@ def create_model_endpoint(
"""
payload = dict(
endpoint_name=endpoint_name,
bundle_name=model_bundle.name,
bundle_name=model_bundle.bundle_name,
cpus=cpus,
memory=memory,
gpus=gpus,
Expand All @@ -411,10 +415,11 @@ def create_model_endpoint(
logger.info(
"Endpoint creation task id is %s", endpoint_creation_task_id
)
endpoint = Endpoint(name=endpoint_name)
if endpoint_type == "async":
return AsyncModelEndpoint(endpoint_id=endpoint_name, client=self)
return AsyncModelEndpoint(endpoint=endpoint, client=self)
elif endpoint_type == "sync":
return SyncModelEndpoint(endpoint_id=endpoint_name, client=self)
return SyncModelEndpoint(endpoint=endpoint, client=self)
else:
raise ValueError(
"Endpoint should be one of the types 'sync' or 'async'"
Expand All @@ -431,10 +436,23 @@ def list_model_bundles(self) -> List[ModelBundle]:
"""
resp = self.connection.get("model_bundle")
model_bundles = [
ModelBundle(name=item["bundle_name"]) for item in resp["bundles"]
ModelBundle.from_dict(item) for item in resp["bundles"] # type: ignore
]
return model_bundles

def get_model_bundle(self, bundle_name: str) -> ModelBundle:
"""
Returns a Model Bundle object specified by `bundle_name`.

Returns:
A ModelBundle object
"""
resp = self.connection.get(f"model_bundle/{bundle_name}")
assert (
len(resp["bundles"]) == 1
), f"Bundle with name `{bundle_name}` not found"
return ModelBundle.from_dict(resp["bundles"][0]) # type: ignore

def list_model_endpoints(
self,
) -> List[Union[AsyncModelEndpoint, SyncModelEndpoint]]:
Expand All @@ -447,12 +465,16 @@ def list_model_endpoints(
"""
resp = self.connection.get(ENDPOINT_PATH)
async_endpoints: List[Union[AsyncModelEndpoint, SyncModelEndpoint]] = [
AsyncModelEndpoint(endpoint_id=endpoint["name"], client=self)
AsyncModelEndpoint(
endpoint=Endpoint.from_dict(endpoint), client=self # type: ignore
)
for endpoint in resp["endpoints"]
if endpoint["endpoint_type"] == "async"
]
sync_endpoints: List[Union[AsyncModelEndpoint, SyncModelEndpoint]] = [
SyncModelEndpoint(endpoint_id=endpoint["name"], client=self)
SyncModelEndpoint(
endpoint=Endpoint.from_dict(endpoint), client=self # type: ignore
)
for endpoint in resp["endpoints"]
if endpoint["endpoint_type"] == "sync"
]
Expand All @@ -462,7 +484,7 @@ def delete_model_bundle(self, model_bundle: ModelBundle):
"""
Deletes the model bundle on the server.
"""
route = f"model_bundle/{model_bundle.name}"
route = f"model_bundle/{model_bundle.bundle_name}"
resp = self.connection.delete(route)
return resp["deleted"]

Expand All @@ -472,7 +494,7 @@ def delete_model_endpoint(
"""
Deletes a model endpoint.
"""
route = f"{ENDPOINT_PATH}/{model_endpoint.endpoint_id}"
route = f"{ENDPOINT_PATH}/{model_endpoint.endpoint.name}"
resp = self.connection.delete(route)
return resp["deleted"]

Expand Down
20 changes: 16 additions & 4 deletions nucleus/deploy/model_bundle.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,23 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional

from dataclasses_json import Undefined, dataclass_json


@dataclass_json(undefined=Undefined.EXCLUDE)
@dataclass
class ModelBundle:
"""
Represents a ModelBundle.
TODO fill this out with more than just a name potentially.
"""

def __init__(self, name):
self.name = name
bundle_name: str
bundle_id: Optional[str] = None
env_params: Optional[Dict[str, str]] = None
location: Optional[str] = None
metadata: Optional[Dict[Any, Any]] = None
packaging_type: Optional[str] = None
requirements: Optional[List[str]] = None

def __str__(self):
return f"ModelBundle(name={self.name})"
return f"ModelBundle(bundle_name={self.bundle_name})"
36 changes: 27 additions & 9 deletions nucleus/deploy/model_endpoint.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,33 @@
import concurrent.futures
import uuid
from collections import Counter
from dataclasses import dataclass
from typing import Dict, Optional, Sequence

from dataclasses_json import Undefined, dataclass_json

from nucleus.deploy.request_validation import validate_task_request

TASK_PENDING_STATE = "PENDING"
TASK_SUCCESS_STATE = "SUCCESS"
TASK_FAILURE_STATE = "FAILURE"


@dataclass_json(undefined=Undefined.EXCLUDE)
@dataclass
class Endpoint:
"""
Represents an Endpoint from the database.
"""

name: str
metadata: Optional[Dict] = None
endpoint_type: Optional[str] = None

def __str__(self):
return f"Endpoint(name={self.name})"


class EndpointRequest:
"""
Represents a single request to either a SyncModelEndpoint or AsyncModelEndpoint.
Expand Down Expand Up @@ -62,16 +80,16 @@ def __str__(self):


class SyncModelEndpoint:
def __init__(self, endpoint_id: str, client):
self.endpoint_id = endpoint_id
def __init__(self, endpoint: Endpoint, client):
self.endpoint = endpoint
self.client = client

def __str__(self):
return f"SyncModelEndpoint <endpoint_id:{self.endpoint_id}>"
return f"SyncModelEndpoint <endpoint_name:{self.endpoint.name}>"

def predict(self, request: EndpointRequest) -> EndpointResponse:
raw_response = self.client.sync_request(
self.endpoint_id,
self.endpoint.name,
url=request.url,
args=request.args,
return_pickled=request.return_pickled,
Expand All @@ -92,17 +110,17 @@ class AsyncModelEndpoint:
A higher level abstraction for a Model Endpoint.
"""

def __init__(self, endpoint_id: str, client):
def __init__(self, endpoint: Endpoint, client):
"""
Parameters:
endpoint_id: The unique name of the ModelEndpoint
endpoint: Endpoint object.
client: A DeployClient object
"""
self.endpoint_id = endpoint_id
self.endpoint = endpoint
self.client = client

def __str__(self):
return f"AsyncModelEndpoint <endpoint_id:{self.endpoint_id}>"
return f"AsyncModelEndpoint <endpoint_name:{self.endpoint.name}>"

def predict_batch(
self, requests: Sequence[EndpointRequest]
Expand All @@ -129,7 +147,7 @@ def single_request(request):
# request has keys url and args

inner_inference_request = self.client.async_request(
endpoint_id=self.endpoint_id,
endpoint_id=self.endpoint.name,
url=request.url,
args=request.args,
return_pickled=request.return_pickled,
Expand Down
Loading