From 975ce958fde50ae82a046148285628b9e809feba Mon Sep 17 00:00:00 2001 From: Stainless Bot <107565488+stainless-bot@users.noreply.github.com> Date: Wed, 31 Jan 2024 21:15:55 +0100 Subject: [PATCH] feat(bedrock): include bedrock SDK (#328) The standalone package is being deprecated in favour of `anthropic[bedrock]` --- README.md | 25 ++ examples/bedrock.py | 39 +++ pyproject.toml | 3 +- requirements-dev.lock | 10 +- requirements.lock | 8 +- src/anthropic/__init__.py | 1 + src/anthropic/lib/bedrock/__init__.py | 1 + src/anthropic/lib/bedrock/_auth.py | 42 +++ src/anthropic/lib/bedrock/_client.py | 279 +++++++++++++++++++ src/anthropic/lib/bedrock/_stream.py | 53 ++++ src/anthropic/lib/bedrock/_stream_decoder.py | 64 +++++ 11 files changed, 522 insertions(+), 3 deletions(-) create mode 100644 examples/bedrock.py create mode 100644 src/anthropic/lib/bedrock/__init__.py create mode 100644 src/anthropic/lib/bedrock/_auth.py create mode 100644 src/anthropic/lib/bedrock/_client.py create mode 100644 src/anthropic/lib/bedrock/_stream.py create mode 100644 src/anthropic/lib/bedrock/_stream_decoder.py diff --git a/README.md b/README.md index cc50bf78..789f53f5 100644 --- a/README.md +++ b/README.md @@ -141,6 +141,31 @@ Streaming with `client.beta.messages.stream(...)` exposes [various helpers for y Alternatively, you can use `client.beta.messages.create(..., stream=True)` which only returns an async iterable of the events in the stream and thus uses less memory (it does not build up a final message object for you). +## AWS Bedrock + +This library also provides support for the [Anthropic Bedrock API](https://aws.amazon.com/bedrock/claude/) if you install this library with the `bedrock` extra, e.g. `pip install -U anthropic[bedrock]`. + +You can then import and instantiate a separate `AnthropicBedrock` class, the rest of the API is the same. + +```py +from anthropic import AI_PROMPT, HUMAN_PROMPT, AnthropicBedrock + +client = AnthropicBedrock() + +completion = client.completions.create( + model="anthropic.claude-instant-v1", + prompt=f"{HUMAN_PROMPT} hey!{AI_PROMPT}", + stop_sequences=[HUMAN_PROMPT], + max_tokens_to_sample=500, + temperature=0.5, + top_k=250, + top_p=0.5, +) +print(completion.completion) +``` + +For a more fully fledged example see [`examples/bedrock.py`](https://github.com/anthropics/anthropic-sdk-python/blob/main/examples/bedrock.py). + ## Token counting You can estimate billing for a given request with the `client.count_tokens()` method, eg: diff --git a/examples/bedrock.py b/examples/bedrock.py new file mode 100644 index 00000000..5161ffcb --- /dev/null +++ b/examples/bedrock.py @@ -0,0 +1,39 @@ +#!/usr/bin/env -S poetry run python + +# Note: you must have installed `anthropic` with the `bedrock` extra +# e.g. `pip install -U anthropic[bedrock]` + +from anthropic import AI_PROMPT, HUMAN_PROMPT, AnthropicBedrock + +# Note: this assumes you have AWS credentials configured. +# +# https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html +client = AnthropicBedrock() + +print("------ standard response ------") +completion = client.completions.create( + model="anthropic.claude-instant-v1", + prompt=f"{HUMAN_PROMPT} hey!{AI_PROMPT}", + stop_sequences=[HUMAN_PROMPT], + max_tokens_to_sample=500, + temperature=0.5, + top_k=250, + top_p=0.5, +) +print(completion.completion) + + +question = """ +Hey Claude! How can I recursively list all files in a directory in Python? +""" + +print("------ streamed response ------") +stream = client.completions.create( + model="anthropic.claude-instant-v1", + prompt=f"{HUMAN_PROMPT} {question}{AI_PROMPT}", + max_tokens_to_sample=500, + stream=True, +) +for item in stream: + print(item.completion, end="") +print() diff --git a/pyproject.toml b/pyproject.toml index 4ff23511..d26ba7c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ classifiers = [ [project.optional-dependencies] vertex = ["google-auth >=2, <3"] +bedrock = ["boto3 >= 1.28.57", "botocore >= 1.31.57"] [project.urls] Homepage = "https://github.com/anthropics/anthropic-sdk-python" @@ -59,7 +60,7 @@ dev-dependencies = [ "nox", "dirty-equals>=0.6.0", "importlib-metadata>=6.7.0", - + "boto3-stubs >= 1" ] [tool.rye.scripts] diff --git a/requirements-dev.lock b/requirements-dev.lock index 07c3326f..8a9aeacd 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -11,6 +11,10 @@ annotated-types==0.6.0 anyio==4.1.0 argcomplete==3.1.2 attrs==23.1.0 +boto3==1.28.58 +boto3-stubs==1.28.41 +botocore==1.31.58 +botocore-stubs==1.34.31 cachetools==5.3.2 certifi==2023.7.22 charset-normalizer==3.3.2 @@ -29,6 +33,7 @@ huggingface-hub==0.16.4 idna==3.4 importlib-metadata==7.0.0 iniconfig==2.0.0 +jmespath==1.0.1 mypy==1.7.1 mypy-extensions==1.0.0 nodeenv==1.8.0 @@ -51,14 +56,17 @@ requests==2.31.0 respx==0.20.2 rsa==4.9 ruff==0.1.9 +s3transfer==0.7.0 six==1.16.0 sniffio==1.3.0 time-machine==2.9.0 tokenizers==0.14.0 tomli==2.0.1 tqdm==4.66.1 +types-awscrt==0.20.3 +types-s3transfer==0.10.0 typing-extensions==4.8.0 -urllib3==2.1.0 +urllib3==1.26.18 virtualenv==20.24.5 zipp==3.17.0 # The following packages are considered to be unsafe in a requirements file: diff --git a/requirements.lock b/requirements.lock index e4692c10..0882deb2 100644 --- a/requirements.lock +++ b/requirements.lock @@ -9,6 +9,8 @@ -e file:. annotated-types==0.6.0 anyio==4.1.0 +boto3==1.34.31 +botocore==1.34.31 cachetools==5.3.2 certifi==2023.7.22 charset-normalizer==3.3.2 @@ -22,16 +24,20 @@ httpcore==1.0.2 httpx==0.25.2 huggingface-hub==0.16.4 idna==3.4 +jmespath==1.0.1 packaging==23.2 pyasn1==0.5.1 pyasn1-modules==0.3.0 pydantic==2.4.2 pydantic-core==2.10.1 +python-dateutil==2.8.2 pyyaml==6.0.1 requests==2.31.0 rsa==4.9 +s3transfer==0.10.0 +six==1.16.0 sniffio==1.3.0 tokenizers==0.14.0 tqdm==4.66.1 typing-extensions==4.8.0 -urllib3==2.1.0 +urllib3==1.26.18 diff --git a/src/anthropic/__init__.py b/src/anthropic/__init__.py index e4febc89..67737ea2 100644 --- a/src/anthropic/__init__.py +++ b/src/anthropic/__init__.py @@ -72,6 +72,7 @@ ] from .lib.vertex import * +from .lib.bedrock import * from .lib.streaming import * _setup_logging() diff --git a/src/anthropic/lib/bedrock/__init__.py b/src/anthropic/lib/bedrock/__init__.py new file mode 100644 index 00000000..69440c76 --- /dev/null +++ b/src/anthropic/lib/bedrock/__init__.py @@ -0,0 +1 @@ +from ._client import AnthropicBedrock as AnthropicBedrock, AsyncAnthropicBedrock as AsyncAnthropicBedrock diff --git a/src/anthropic/lib/bedrock/_auth.py b/src/anthropic/lib/bedrock/_auth.py new file mode 100644 index 00000000..503094ea --- /dev/null +++ b/src/anthropic/lib/bedrock/_auth.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +import httpx + + +def get_auth_headers( + *, + method: str, + url: str, + headers: httpx.Headers, + aws_access_key: str | None, + aws_secret_key: str | None, + aws_session_token: str | None, + region: str | None, + data: str | None, +) -> dict[str, str]: + import boto3 + from botocore.auth import SigV4Auth + from botocore.awsrequest import AWSRequest + + session = boto3.Session( + region_name=region, + aws_access_key_id=aws_access_key, + aws_secret_access_key=aws_secret_key, + aws_session_token=aws_session_token, + ) + + # The connection header may be stripped by a proxy somewhere, so the receiver + # of this message may not see this header, so we remove it from the set of headers + # that are signed. + headers = headers.copy() + del headers["connection"] + + request = AWSRequest(method=method.upper(), url=url, headers=headers, data=data) + credentials = session.get_credentials() + + signer = SigV4Auth(credentials, "bedrock", session.region_name) + signer.add_auth(request) + + prepped = request.prepare() + + return dict(prepped.headers) diff --git a/src/anthropic/lib/bedrock/_client.py b/src/anthropic/lib/bedrock/_client.py new file mode 100644 index 00000000..9e4e8318 --- /dev/null +++ b/src/anthropic/lib/bedrock/_client.py @@ -0,0 +1,279 @@ +from __future__ import annotations + +import os +from typing import Any, Union, Mapping, TypeVar +from typing_extensions import override, get_origin + +import httpx + +from ... import _exceptions +from ._stream import BedrockStream, AsyncBedrockStream +from ..._types import NOT_GIVEN, NotGiven, ResponseT +from ..._utils import is_dict +from ..._version import __version__ +from ..._response import extract_stream_chunk_type +from ..._streaming import Stream, AsyncStream +from ..._exceptions import APIStatusError +from ..._base_client import DEFAULT_MAX_RETRIES, BaseClient, SyncAPIClient, AsyncAPIClient, FinalRequestOptions +from ...resources.completions import Completions, AsyncCompletions + +DEFAULT_VERSION = "bedrock-2023-05-31" + +_HttpxClientT = TypeVar("_HttpxClientT", bound=Union[httpx.Client, httpx.AsyncClient]) +_DefaultStreamT = TypeVar("_DefaultStreamT", bound=Union[Stream[Any], AsyncStream[Any]]) + + +class BaseBedrockClient(BaseClient[_HttpxClientT, _DefaultStreamT]): + @override + def _build_request( + self, + options: FinalRequestOptions, + ) -> httpx.Request: + if is_dict(options.json_data): + options.json_data.setdefault("anthropic_version", DEFAULT_VERSION) + + if options.url == "/v1/complete" and options.method == "post": + if not is_dict(options.json_data): + raise RuntimeError("Expected dictionary json_data for post /completions endpoint") + + model = options.json_data.pop("model", None) + stream = options.json_data.pop("stream", False) + if stream: + options.url = f"/model/{model}/invoke-with-response-stream" + else: + options.url = f"/model/{model}/invoke" + + return super()._build_request(options) + + @override + def _make_status_error( + self, + err_msg: str, + *, + body: object, + response: httpx.Response, + ) -> APIStatusError: + if response.status_code == 400: + return _exceptions.BadRequestError(err_msg, response=response, body=body) + + if response.status_code == 401: + return _exceptions.AuthenticationError(err_msg, response=response, body=body) + + if response.status_code == 403: + return _exceptions.PermissionDeniedError(err_msg, response=response, body=body) + + if response.status_code == 404: + return _exceptions.NotFoundError(err_msg, response=response, body=body) + + if response.status_code == 409: + return _exceptions.ConflictError(err_msg, response=response, body=body) + + if response.status_code == 422: + return _exceptions.UnprocessableEntityError(err_msg, response=response, body=body) + + if response.status_code == 429: + return _exceptions.RateLimitError(err_msg, response=response, body=body) + + if response.status_code >= 500: + return _exceptions.InternalServerError(err_msg, response=response, body=body) + return APIStatusError(err_msg, response=response, body=body) + + +class AnthropicBedrock(BaseBedrockClient[httpx.Client, Stream[Any]], SyncAPIClient): + completions: Completions + + def __init__( + self, + aws_secret_key: str | None = None, + aws_access_key: str | None = None, + aws_region: str | None = None, + aws_session_token: str | None = None, + base_url: str | httpx.URL | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + max_retries: int = DEFAULT_MAX_RETRIES, + default_headers: Mapping[str, str] | None = None, + default_query: Mapping[str, object] | None = None, + # Configure a custom httpx client. See the [httpx documentation](https://www.python-httpx.org/api/#client) for more details. + http_client: httpx.Client | None = None, + # Enable or disable schema validation for data returned by the API. + # When enabled an error APIResponseValidationError is raised + # if the API responds with invalid data for the expected schema. + # + # This parameter may be removed or changed in the future. + # If you rely on this feature, please open a GitHub issue + # outlining your use-case to help us decide if it should be + # part of our public interface in the future. + _strict_response_validation: bool = False, + ) -> None: + self.aws_secret_key = aws_secret_key + + self.aws_access_key = aws_access_key + + if aws_region is None: + aws_region = os.environ.get("AWS_REGION") or "us-east-1" + self.aws_region = aws_region + + self.aws_session_token = aws_session_token + + if base_url is None: + base_url = os.environ.get("ANTHROPIC_BEDROCK_BASE_URL") + if base_url is None: + base_url = f"https://bedrock-runtime.{self.aws_region}.amazonaws.com" + + super().__init__( + version=__version__, + base_url=base_url, + timeout=timeout, + max_retries=max_retries, + custom_headers=default_headers, + custom_query=default_query, + http_client=http_client, + _strict_response_validation=_strict_response_validation, + ) + + self._default_stream_cls = BedrockStream + + self.completions = Completions(self) + + @override + def _prepare_request(self, request: httpx.Request) -> None: + from ._auth import get_auth_headers + + data = request.read().decode() + + headers = get_auth_headers( + method=request.method, + url=str(request.url), + headers=request.headers, + aws_access_key=self.aws_access_key, + aws_secret_key=self.aws_secret_key, + aws_session_token=self.aws_session_token, + region=self.aws_region or "us-east-1", + data=data, + ) + request.headers.update(headers) + + @override + def _process_response( + self, + *, + cast_to: type[ResponseT], + options: FinalRequestOptions, + response: httpx.Response, + stream: bool, + stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None, + ) -> ResponseT: + if stream_cls is not None and get_origin(stream_cls) == Stream: + chunk_type = extract_stream_chunk_type(stream_cls) + + # the type: ignore is required as mypy doesn't like us + # dynamically created a concrete type like this + stream_cls = BedrockStream[chunk_type] # type: ignore + + return super()._process_response( + cast_to=cast_to, + options=options, + response=response, + stream=stream, + stream_cls=stream_cls, + ) + + +class AsyncAnthropicBedrock(BaseBedrockClient[httpx.AsyncClient, AsyncStream[Any]], AsyncAPIClient): + completions: AsyncCompletions + + def __init__( + self, + aws_secret_key: str | None = None, + aws_access_key: str | None = None, + aws_region: str | None = None, + aws_session_token: str | None = None, + base_url: str | httpx.URL | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + max_retries: int = DEFAULT_MAX_RETRIES, + default_headers: Mapping[str, str] | None = None, + default_query: Mapping[str, object] | None = None, + # Configure a custom httpx client. See the [httpx documentation](https://www.python-httpx.org/api/#client) for more details. + http_client: httpx.AsyncClient | None = None, + # Enable or disable schema validation for data returned by the API. + # When enabled an error APIResponseValidationError is raised + # if the API responds with invalid data for the expected schema. + # + # This parameter may be removed or changed in the future. + # If you rely on this feature, please open a GitHub issue + # outlining your use-case to help us decide if it should be + # part of our public interface in the future. + _strict_response_validation: bool = False, + ) -> None: + self.aws_secret_key = aws_secret_key + + self.aws_access_key = aws_access_key + + if aws_region is None: + aws_region = os.environ.get("AWS_REGION") or "us-east-1" + self.aws_region = aws_region + + self.aws_session_token = aws_session_token + + if base_url is None: + base_url = os.environ.get("ANTHROPIC_BEDROCK_BASE_URL") + if base_url is None: + base_url = f"https://bedrock-runtime.{self.aws_region}.amazonaws.com" + + super().__init__( + version=__version__, + base_url=base_url, + timeout=timeout, + max_retries=max_retries, + custom_headers=default_headers, + custom_query=default_query, + http_client=http_client, + _strict_response_validation=_strict_response_validation, + ) + + self._default_stream_cls = AsyncBedrockStream + + self.completions = AsyncCompletions(self) + + @override + async def _prepare_request(self, request: httpx.Request) -> None: + from ._auth import get_auth_headers + + data = request.read().decode() + + headers = get_auth_headers( + method=request.method, + url=str(request.url), + headers=request.headers, + aws_access_key=self.aws_access_key, + aws_secret_key=self.aws_secret_key, + aws_session_token=self.aws_session_token, + region=self.aws_region or "us-east-1", + data=data, + ) + request.headers.update(headers) + + @override + async def _process_response( + self, + *, + cast_to: type[ResponseT], + options: FinalRequestOptions, + response: httpx.Response, + stream: bool, + stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None, + ) -> ResponseT: + if stream_cls is not None and get_origin(stream_cls) == AsyncStream: + chunk_type = extract_stream_chunk_type(stream_cls) + + # the type: ignore is required as mypy doesn't like us + # dynamically created a concrete type like this + stream_cls = AsyncBedrockStream[chunk_type] # type: ignore + + return await super()._process_response( + cast_to=cast_to, + options=options, + response=response, + stream=stream, + stream_cls=stream_cls, + ) diff --git a/src/anthropic/lib/bedrock/_stream.py b/src/anthropic/lib/bedrock/_stream.py new file mode 100644 index 00000000..2140acdc --- /dev/null +++ b/src/anthropic/lib/bedrock/_stream.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +from typing import TypeVar, Iterator +from typing_extensions import AsyncIterator, override + +import httpx + +from ..._client import Anthropic, AsyncAnthropic +from ..._streaming import Stream, AsyncStream, ServerSentEvent +from ._stream_decoder import AWSEventStreamDecoder + +_T = TypeVar("_T") + + +class BedrockStream(Stream[_T]): + # the AWS decoder expects `bytes` instead of `str` + _decoder: AWSEventStreamDecoder # type: ignore + + def __init__( + self, + *, + cast_to: type[_T], + response: httpx.Response, + client: Anthropic, + ) -> None: + super().__init__(cast_to=cast_to, response=response, client=client) + + self._decoder = AWSEventStreamDecoder() + + @override + def _iter_events(self) -> Iterator[ServerSentEvent]: + yield from self._decoder.iter(self.response.iter_bytes()) + + +class AsyncBedrockStream(AsyncStream[_T]): + # the AWS decoder expects `bytes` instead of `str` + _decoder: AWSEventStreamDecoder # type: ignore + + def __init__( + self, + *, + cast_to: type[_T], + response: httpx.Response, + client: AsyncAnthropic, + ) -> None: + super().__init__(cast_to=cast_to, response=response, client=client) + + self._decoder = AWSEventStreamDecoder() + + @override + async def _iter_events(self) -> AsyncIterator[ServerSentEvent]: + async for sse in self._decoder.aiter(self.response.aiter_bytes()): + yield sse diff --git a/src/anthropic/lib/bedrock/_stream_decoder.py b/src/anthropic/lib/bedrock/_stream_decoder.py new file mode 100644 index 00000000..a99647d7 --- /dev/null +++ b/src/anthropic/lib/bedrock/_stream_decoder.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Iterator, AsyncIterator +from functools import lru_cache + +from ..._streaming import ServerSentEvent + +if TYPE_CHECKING: + from botocore.model import Shape + from botocore.eventstream import EventStreamMessage + + +@lru_cache(maxsize=None) +def get_response_stream_shape() -> Shape: + from botocore.model import ServiceModel + from botocore.loaders import Loader + + loader = Loader() + bedrock_service_dict = loader.load_service_model("bedrock-runtime", "service-2") + bedrock_service_model = ServiceModel(bedrock_service_dict) + return bedrock_service_model.shape_for("ResponseStream") + + +class AWSEventStreamDecoder: + def __init__(self) -> None: + from botocore.parsers import EventStreamJSONParser + + self.parser = EventStreamJSONParser() + + def iter(self, iterator: Iterator[bytes]) -> Iterator[ServerSentEvent]: + """Given an iterator that yields lines, iterate over it & yield every event encountered""" + from botocore.eventstream import EventStreamBuffer + + event_stream_buffer = EventStreamBuffer() + for chunk in iterator: + event_stream_buffer.add_data(chunk) + for event in event_stream_buffer: + message = self._parse_message_from_event(event) + if message: + yield ServerSentEvent(data=message, event="completion") + + async def aiter(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[ServerSentEvent]: + """Given an async iterator that yields lines, iterate over it & yield every event encountered""" + from botocore.eventstream import EventStreamBuffer + + event_stream_buffer = EventStreamBuffer() + async for chunk in iterator: + event_stream_buffer.add_data(chunk) + for event in event_stream_buffer: + message = self._parse_message_from_event(event) + if message: + yield ServerSentEvent(data=message, event="completion") + + def _parse_message_from_event(self, event: EventStreamMessage) -> str | None: + response_dict = event.to_response_dict() + parsed_response = self.parser.parse(response_dict, get_response_stream_shape()) + if response_dict["status_code"] != 200: + raise ValueError(f"Bad response code, expected 200: {response_dict}") + + chunk = parsed_response.get("chunk") + if not chunk: + return None + + return chunk.get("bytes").decode() # type: ignore[no-any-return]