Skip to content

Commit

Permalink
Add support for bedrock cross-region inference (#39)
Browse files Browse the repository at this point in the history
  • Loading branch information
pigeonlaser authored Sep 3, 2024
1 parent 061e224 commit 159b128
Show file tree
Hide file tree
Showing 4 changed files with 268 additions and 22 deletions.
74 changes: 54 additions & 20 deletions pkg/anthropic/client/bedrock/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,30 @@ const (
BedrockModelClaude3Sonnet = "anthropic.claude-3-sonnet-20240229-v1:0"
BedrockModelClaude3Haiku = "anthropic.claude-3-haiku-20240307-v1:0"
BedrockModelClaudeV2_1 = "anthropic.claude-v2:1"

// Cross-region top-level region code
CRUS = "us"
CREU = "eu"
)

type Client struct {
brCli *bedrockruntime.Client
brCli *bedrockruntime.Client
crInferenceRegion string
}

type Config struct {
Region string
AccessKeyID string
SecretAccessKey string
SessionToken string
Region string
AccessKeyID string
SecretAccessKey string
SessionToken string
CrossRegionInference bool
}

func MakeClient(ctx context.Context, cfg Config) (*Client, error) {
if cfg.Region == "" {
return nil, fmt.Errorf("Region is requried for establishing anthropic bedrock client")
}

awsCfg, err := config.LoadDefaultConfig(
ctx,
config.WithRegion(cfg.Region),
Expand All @@ -54,30 +64,54 @@ func MakeClient(ctx context.Context, cfg Config) (*Client, error) {
return nil, err
}

regionPrefix := ""
if cfg.CrossRegionInference {
// extract the first 2 letters from the region
regionPrefix = cfg.Region[:2]
if regionPrefix != CRUS && regionPrefix != CREU {
return nil, fmt.Errorf(
"Cross region inference is only supported for: '%s', '%s'; Region prefix: '%s' is not supported",
CRUS,
CREU,
regionPrefix,
)
}
}

return &Client{
brCli: bedrockruntime.NewFromConfig(awsCfg),
brCli: bedrockruntime.NewFromConfig(awsCfg),
crInferenceRegion: regionPrefix,
}, nil
}

// adaptModelForMessage takes the model as defined in anthropic.Model and adapts it to the model Bedrock expects
func adaptModelForMessage(model anthropic.Model) (string, error) {
if model == anthropic.Claude35Sonnet {
return BedrockModelClaude35Sonnet, nil
func (c *Client) adaptModelForMessage(model anthropic.Model) (string, error) {
adaptedModel := ""

switch model {
case anthropic.Claude35Sonnet:
adaptedModel = BedrockModelClaude35Sonnet
case anthropic.Claude3Opus:
adaptedModel = BedrockModelClaude3Opus
case anthropic.Claude3Sonnet:
adaptedModel = BedrockModelClaude3Sonnet
case anthropic.Claude3Haiku:
adaptedModel = BedrockModelClaude3Haiku
case anthropic.ClaudeV2_1:
adaptedModel = BedrockModelClaudeV2_1
default:
return "", fmt.Errorf("model %s is not compatible with the bedrock message endpoint", model)
}
if model == anthropic.Claude3Opus {
return BedrockModelClaude3Opus, nil
}
if model == anthropic.Claude3Sonnet {
return BedrockModelClaude3Sonnet, nil
}
if model == anthropic.Claude3Haiku {
return BedrockModelClaude3Haiku, nil

if c.crInferenceRegion == "" {
return adaptedModel, nil
}
if model == anthropic.ClaudeV2_1 {
return BedrockModelClaudeV2_1, nil

if adaptedModel == BedrockModelClaudeV2_1 {
return "", fmt.Errorf("Bedrock model %s is not compatible with cross-region inference", adaptedModel)
}

return "", fmt.Errorf("model %s is not compatible with the bedrock message endpoint", model)
return fmt.Sprintf("%s.%s", c.crInferenceRegion, adaptedModel), nil
}

// adaptModelForCompletion takes the model as defined in anthropic.Model and adapts it to the model Bedrock expects
Expand Down
212 changes: 212 additions & 0 deletions pkg/anthropic/client/bedrock/client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
package bedrock

import (
"context"
"fmt"
"strings"
"testing"

"github.com/madebywelch/anthropic-go/v3/pkg/anthropic"
)

func Test_Client_Success_RegionOnly(t *testing.T) {
client, err := MakeClient(context.Background(), Config{
Region: "us-west-2",
})

assertSuccessClient(t, client, err, "")
}

func Test_Client_Success_RegionWithCredentials(t *testing.T) {
client, err := MakeClient(context.Background(), Config{
Region: "us-west-2",
AccessKeyID: "hello-there",
SecretAccessKey: "general-kenobi",
SessionToken: "order-66",
})

assertSuccessClient(t, client, err, "")
}

func Test_Client_Success_RegionWithCrossRegionInference(t *testing.T) {
client, err := MakeClient(context.Background(), Config{
Region: "us-west-2",
CrossRegionInference: true,
})

assertSuccessClient(t, client, err, "us")
}

func Test_Client_Failure_MissingRegion(t *testing.T) {
client, err := MakeClient(context.Background(), Config{})
if err == nil {
t.Error("Expected an error when region is not set")
}

if client != nil {
t.Error("Unexpected value for client when region is not set")
}
}

func Test_Client_Failure_UnsupportedCrossRegionInference(t *testing.T) {
client, err := MakeClient(context.Background(), Config{
Region: "he-llothere",
CrossRegionInference: true,
})
if err == nil {
t.Error("Expected an error when using an unsupported region for cross region inference")
}

if !strings.Contains(err.Error(), "region inference is only supported") {
t.Errorf("Exepcted an error for unsupported region inference: %s", err.Error())
}

if client != nil {
t.Error("Unexpected value for client when region is not set")
}
}

type modelTest struct {
modelInput anthropic.Model
expectedModelOutput string
}

func Test_adaptModelForMessage_Success_NonCrossRegion(t *testing.T) {
client, err := MakeClient(context.Background(), Config{
Region: "us-west-2",
})
if err != nil {
t.Errorf("Unexpected error when establishing client %s", err.Error())
}

testCases := []*modelTest{
{
modelInput: anthropic.Claude35Sonnet,
expectedModelOutput: BedrockModelClaude35Sonnet,
},
{
modelInput: anthropic.Claude3Opus,
expectedModelOutput: BedrockModelClaude3Opus,
},
{
modelInput: anthropic.Claude3Sonnet,
expectedModelOutput: BedrockModelClaude3Sonnet,
},
{
modelInput: anthropic.Claude3Haiku,
expectedModelOutput: BedrockModelClaude3Haiku,
},
{
modelInput: anthropic.ClaudeV2_1,
expectedModelOutput: BedrockModelClaudeV2_1,
},
}

result := ""
for _, testCase := range testCases {
result, err = client.adaptModelForMessage(testCase.modelInput)
if err != nil {
t.Errorf("Unexpected error when adapting model: %s", err.Error())
}

if result != testCase.expectedModelOutput {
t.Errorf("Error when adapting model. Expected: %s, Actual: %s", testCase.expectedModelOutput, result)
}
}
}

func Test_adaptModelForMessage_Failure_UnsupportedModel(t *testing.T) {
client, err := MakeClient(context.Background(), Config{
Region: "us-west-2",
})
if err != nil {
t.Errorf("Unexpected error when establishing client %s", err.Error())
}

result, err := client.adaptModelForMessage("hello-there")
if err == nil {
t.Error("Expected an error when adapting unsupported model")
}

if result != "" {
t.Errorf("Unexpected result for adaptModel: %s", result)
}
}

func Test_adaptModelForMessage_Success_CrossRegion(t *testing.T) {
client, err := MakeClient(context.Background(), Config{
Region: "eu-west-1",
CrossRegionInference: true,
})
if err != nil {
t.Errorf("Unexpected error when establishing client %s", err.Error())
}

testCases := []*modelTest{
{
modelInput: anthropic.Claude35Sonnet,
expectedModelOutput: fmt.Sprintf("%s.%s", client.crInferenceRegion, BedrockModelClaude35Sonnet),
},
{
modelInput: anthropic.Claude3Opus,
expectedModelOutput: fmt.Sprintf("%s.%s", client.crInferenceRegion, BedrockModelClaude3Opus),
},
{
modelInput: anthropic.Claude3Sonnet,
expectedModelOutput: fmt.Sprintf("%s.%s", client.crInferenceRegion, BedrockModelClaude3Sonnet),
},
{
modelInput: anthropic.Claude3Haiku,
expectedModelOutput: fmt.Sprintf("%s.%s", client.crInferenceRegion, BedrockModelClaude3Haiku),
},
}

result := ""
for _, testCase := range testCases {
result, err = client.adaptModelForMessage(testCase.modelInput)
if err != nil {
t.Errorf("Unexpected error when adapting model: %s", err.Error())
}

if result != testCase.expectedModelOutput {
t.Errorf("Error when adapting model. Expected: %s, Actual: %s", testCase.expectedModelOutput, result)
}
}
}

func Test_adaptModelForMessage_Failure_ClaudeV2_CrossRegionInference(t *testing.T) {
client, err := MakeClient(context.Background(), Config{
Region: "eu-west-1",
CrossRegionInference: true,
})
if err != nil {
t.Errorf("Unexpected error when establishing client %s", err.Error())
}

result, err := client.adaptModelForMessage(anthropic.ClaudeV2_1)
if err == nil {
t.Error("Expected an error when using cross region inference on claude v2.1")
}

if !strings.Contains(err.Error(), "not compatible with cross-region") {
t.Error("Expected a 'not compatible with cross-region' error")
}

if result != "" {
t.Errorf("Unexpected result for adaptModel: %s", result)
}
}

func assertSuccessClient(t *testing.T, client *Client, err error, crRegionValue string) {
if err != nil {
t.Errorf("Unexpected error %s", err.Error())
}

if client.brCli == nil {
t.Error("Unexpected nil for brCli")
}

if client.crInferenceRegion != crRegionValue {
t.Errorf("Unexpected value for inference region %s", client.crInferenceRegion)
}
}
2 changes: 1 addition & 1 deletion pkg/anthropic/client/bedrock/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func (c *Client) Message(ctx context.Context, req *anthropic.MessageRequest) (*a
}

func (c *Client) sendMessageRequest(ctx context.Context, req *anthropic.MessageRequest) (*anthropic.MessageResponse, error) {
adaptedModel, err := adaptModelForMessage(req.Model)
adaptedModel, err := c.adaptModelForMessage(req.Model)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/anthropic/client/bedrock/message_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func (c *Client) handleMessageStreaming(
defer close(msCh)
defer close(errCh)

adaptedModel, err := adaptModelForMessage(req.Model)
adaptedModel, err := c.adaptModelForMessage(req.Model)
if err != nil {
errCh <- fmt.Errorf("error adapting model: %w", err)
return
Expand Down

0 comments on commit 159b128

Please sign in to comment.