Skip to content

Commit

Permalink
[ENH] Cloud client (#1462)
Browse files Browse the repository at this point in the history
## Description of changes

This PR adds the Chroma `CloudClient` as a client type. It takes an API
key (or reads it from the environment, or prompts the user for it if
it's not found) , and tries to connect to the Chroma cloud using the API
key.

Under the hood, this is a specialized HTTPClient with ssl always turned
on, and which uses token-based authentication with the API key as the
token.

This PR also moves the `AuthorizationError` from auth to the general set
of errors, so it can be handled elsewhere by the Chroma system.

## Test plan

Added a new test, `test_cloud_client`, which creates a mock 'cloud'
server and then tests connection with valid and invalid API keys.

## Documentation Changes

Cloud is not yet available, so `CloudClient` is for now not to be used
by external users of Chroma. Therefore, we won't land documentation
changes with this PR.

## TODOs
- [x] Tests
- [ ] JavaScript @jeffchuber to take a look
- [ ] ~Docs~
  • Loading branch information
atroyn authored Dec 8, 2023
1 parent a846cbe commit 4ae47cd
Show file tree
Hide file tree
Showing 7 changed files with 187 additions and 21 deletions.
51 changes: 51 additions & 0 deletions chromadb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
from chromadb.api.client import Client as ClientCreator
from chromadb.api.client import AdminClient as AdminClientCreator
from chromadb.auth.token import TokenTransportHeader
import chromadb.config
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings
from chromadb.api import AdminAPI, ClientAPI
Expand Down Expand Up @@ -181,6 +182,56 @@ def HttpClient(
return ClientCreator(tenant=tenant, database=database, settings=settings)


def CloudClient(
tenant: str,
database: str,
api_key: Optional[str] = None,
settings: Optional[Settings] = None,
*, # Following arguments are keyword-only, intended for testing only.
cloud_host: str = "api.trychroma.com",
cloud_port: str = "8000",
enable_ssl: bool = True,
) -> ClientAPI:
"""
Creates a client to connect to a tennant and database on the Chroma cloud.
Args:
tenant: The tenant to use for this client.
database: The database to use for this client.
api_key: The api key to use for this client.
"""

# If no API key is provided, try to load it from the environment variable
if api_key is None:
import os

api_key = os.environ.get("CHROMA_API_KEY")

# If the API key is still not provided, prompt the user
if api_key is None:
print(
"\033[93mDon't have an API key?\033[0m Get one at https://app.trychroma.com"
)
api_key = input("Please enter your Chroma API key: ")

if settings is None:
settings = Settings()

settings.chroma_api_impl = "chromadb.api.fastapi.FastAPI"
settings.chroma_server_host = cloud_host
settings.chroma_server_http_port = cloud_port
# Always use SSL for cloud
settings.chroma_server_ssl_enabled = enable_ssl

settings.chroma_client_auth_provider = "chromadb.auth.token.TokenAuthClientProvider"
settings.chroma_client_auth_credentials = api_key
settings.chroma_client_auth_token_transport_header = (
TokenTransportHeader.X_CHROMA_TOKEN.name
)

return ClientCreator(tenant=tenant, database=database, settings=settings)


def Client(
settings: Settings = __settings,
tenant: str = DEFAULT_TENANT,
Expand Down
10 changes: 7 additions & 3 deletions chromadb/api/client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import ClassVar, Dict, Optional, Sequence
from uuid import UUID
import uuid

from overrides import override
import requests
Expand All @@ -22,6 +23,7 @@
from chromadb.config import Settings, System
from chromadb.config import DEFAULT_TENANT, DEFAULT_DATABASE
from chromadb.api.models.Collection import Collection
from chromadb.errors import ChromaError
from chromadb.telemetry.product import ProductTelemetryClient
from chromadb.telemetry.product.events import ClientStartEvent
from chromadb.types import Database, Tenant, Where, WhereDocument
Expand Down Expand Up @@ -78,9 +80,8 @@ def _get_identifier_from_settings(settings: Settings) -> str:
"ephemeral" # TODO: support pathing and multiple ephemeral clients
)
elif api_impl == "chromadb.api.fastapi.FastAPI":
identifier = (
f"{settings.chroma_server_host}:{settings.chroma_server_http_port}"
)
# FastAPI clients can all use unique system identifiers since their configurations can be independent, e.g. different auth tokens
identifier = str(uuid.uuid4())
else:
raise ValueError(f"Unsupported Chroma API implementation {api_impl}")

Expand Down Expand Up @@ -429,6 +430,9 @@ def _validate_tenant_database(self, tenant: str, database: str) -> None:
raise ValueError(
"Could not connect to a Chroma server. Are you sure it is running?"
)
# Propagate ChromaErrors
except ChromaError as e:
raise e
except Exception:
raise ValueError(
f"Could not connect to tenant {tenant}. Are you sure it exists?"
Expand Down
11 changes: 0 additions & 11 deletions chromadb/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,14 +436,3 @@ def __init__(self, system: System) -> None:
@abstractmethod
def get_configuration(self) -> T:
pass


class AuthorizationError(ChromaError):
@override
def code(self) -> int:
return 403

@classmethod
@override
def name(cls) -> str:
return "AuthorizationError"
7 changes: 4 additions & 3 deletions chromadb/auth/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@
from overrides import override
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request
from starlette.responses import Response, JSONResponse
from starlette.responses import Response
from starlette.types import ASGIApp

from chromadb.config import DEFAULT_TENANT, System
from chromadb.auth import (
AuthorizationContext,
AuthorizationError,
AuthorizationRequestContext,
AuthzAction,
AuthzResource,
Expand All @@ -28,6 +27,7 @@
ServerAuthorizationProvider,
)
from chromadb.auth.registry import resolve_provider
from chromadb.errors import AuthorizationError
from chromadb.telemetry.opentelemetry import (
OpenTelemetryGranularity,
trace_method,
Expand Down Expand Up @@ -143,7 +143,8 @@ async def dispatch(
FastAPIServerAuthenticationRequest(request)
)
if not response or not response.success():
return JSONResponse({"error": "Unauthorized"}, status_code=401)
return AuthorizationError("Unauthorized").fastapi_json_response()

request.state.user_identity = response.get_user_identity()
return await call_next(request)

Expand Down
21 changes: 20 additions & 1 deletion chromadb/errors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import abstractmethod
from typing import Dict, Type
from overrides import overrides, EnforceOverrides
from fastapi.responses import JSONResponse


class ChromaError(Exception, EnforceOverrides):
Expand All @@ -13,10 +14,16 @@ def message(self) -> str:

@classmethod
@abstractmethod
def name(self) -> str:
def name(cls) -> str:
"""Return the error name"""
pass

def fastapi_json_response(self) -> JSONResponse:
return JSONResponse(
content={"error": self.name(), "message": self.message()},
status_code=self.code(),
)


class InvalidDimensionException(ChromaError):
@classmethod
Expand Down Expand Up @@ -64,11 +71,23 @@ def name(cls) -> str:
return "InvalidHTTPVersion"


class AuthorizationError(ChromaError):
@overrides
def code(self) -> int:
return 401

@classmethod
@overrides
def name(cls) -> str:
return "AuthorizationError"


error_types: Dict[str, Type[ChromaError]] = {
"InvalidDimension": InvalidDimensionException,
"InvalidCollection": InvalidCollectionException,
"IDAlreadyExists": IDAlreadyExistsError,
"DuplicateID": DuplicateIDError,
"InvalidUUID": InvalidUUIDError,
"InvalidHTTPVersion": InvalidHTTPVersion,
"AuthorizationError": AuthorizationError,
}
4 changes: 1 addition & 3 deletions chromadb/server/fastapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,7 @@ async def catch_exceptions_middleware(
try:
return await call_next(request)
except ChromaError as e:
return JSONResponse(
content={"error": e.name(), "message": e.message()}, status_code=e.code()
)
return e.fastapi_json_response()
except Exception as e:
logger.exception(e)
return JSONResponse(content={"error": repr(e)}, status_code=500)
Expand Down
104 changes: 104 additions & 0 deletions chromadb/test/client/test_cloud_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import multiprocessing
from typing import Any, Dict, Generator, Optional, Tuple
import pytest
from chromadb import CloudClient
from chromadb.api import ServerAPI
from chromadb.auth.token import TokenTransportHeader
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings, System
from chromadb.errors import AuthorizationError

from chromadb.test.conftest import _await_server, _run_server, find_free_port

TOKEN_TRANSPORT_HEADER = TokenTransportHeader.X_CHROMA_TOKEN.name
TEST_CLOUD_HOST = "localhost"


@pytest.fixture(scope="module")
def valid_token() -> str:
return "valid_token"


@pytest.fixture(scope="module")
def mock_cloud_server(valid_token: str) -> Generator[System, None, None]:
chroma_server_auth_provider: str = "chromadb.auth.token.TokenAuthServerProvider"
chroma_server_auth_credentials_provider: str = (
"chromadb.auth.token.TokenConfigServerAuthCredentialsProvider"
)
chroma_server_auth_credentials: str = valid_token
chroma_server_auth_token_transport_header: str = TOKEN_TRANSPORT_HEADER

port = find_free_port()

args: Tuple[
int,
bool,
Optional[str],
Optional[str],
Optional[str],
Optional[str],
Optional[str],
Optional[str],
Optional[str],
Optional[str],
Optional[Dict[str, Any]],
] = (
port,
False,
None,
chroma_server_auth_provider,
chroma_server_auth_credentials_provider,
None,
chroma_server_auth_credentials,
chroma_server_auth_token_transport_header,
None,
None,
None,
)
ctx = multiprocessing.get_context("spawn")
proc = ctx.Process(target=_run_server, args=args, daemon=True)
proc.start()

settings = Settings(
chroma_api_impl="chromadb.api.fastapi.FastAPI",
chroma_server_host=TEST_CLOUD_HOST,
chroma_server_http_port=str(port),
chroma_client_auth_provider="chromadb.auth.token.TokenAuthClientProvider",
chroma_client_auth_credentials=valid_token,
chroma_client_auth_token_transport_header=TOKEN_TRANSPORT_HEADER,
)

system = System(settings)
api = system.instance(ServerAPI)
system.start()
_await_server(api)
yield system
system.stop()
proc.kill()


def test_valid_key(mock_cloud_server: System, valid_token: str) -> None:
valid_client = CloudClient(
tenant=DEFAULT_TENANT,
database=DEFAULT_DATABASE,
api_key=valid_token,
cloud_host=TEST_CLOUD_HOST,
cloud_port=mock_cloud_server.settings.chroma_server_http_port, # type: ignore
enable_ssl=False,
)

assert valid_client.heartbeat()


def test_invalid_key(mock_cloud_server: System, valid_token: str) -> None:
# Try to connect to the default tenant and database with an invalid token
invalid_token = valid_token + "_invalid"
with pytest.raises(AuthorizationError):
client = CloudClient(
tenant=DEFAULT_TENANT,
database=DEFAULT_DATABASE,
api_key=invalid_token,
cloud_host=TEST_CLOUD_HOST,
cloud_port=mock_cloud_server.settings.chroma_server_http_port, # type: ignore
enable_ssl=False,
)
client.heartbeat()

0 comments on commit 4ae47cd

Please sign in to comment.