diff --git a/pkg/anthropic/client/bedrock/client.go b/pkg/anthropic/client/bedrock/client.go index 5be580f..d4be286 100644 --- a/pkg/anthropic/client/bedrock/client.go +++ b/pkg/anthropic/client/bedrock/client.go @@ -21,20 +21,30 @@ const ( BedrockModelClaude3Sonnet = "anthropic.claude-3-sonnet-20240229-v1:0" BedrockModelClaude3Haiku = "anthropic.claude-3-haiku-20240307-v1:0" BedrockModelClaudeV2_1 = "anthropic.claude-v2:1" + + // Cross-region top-level region code + CRUS = "us" + CREU = "eu" ) type Client struct { - brCli *bedrockruntime.Client + brCli *bedrockruntime.Client + crInferenceRegion string } type Config struct { - Region string - AccessKeyID string - SecretAccessKey string - SessionToken string + Region string + AccessKeyID string + SecretAccessKey string + SessionToken string + CrossRegionInference bool } func MakeClient(ctx context.Context, cfg Config) (*Client, error) { + if cfg.Region == "" { + return nil, fmt.Errorf("Region is requried for establishing anthropic bedrock client") + } + awsCfg, err := config.LoadDefaultConfig( ctx, config.WithRegion(cfg.Region), @@ -54,30 +64,54 @@ func MakeClient(ctx context.Context, cfg Config) (*Client, error) { return nil, err } + regionPrefix := "" + if cfg.CrossRegionInference { + // extract the first 2 letters from the region + regionPrefix = cfg.Region[:2] + if regionPrefix != CRUS && regionPrefix != CREU { + return nil, fmt.Errorf( + "Cross region inference is only supported for: '%s', '%s'; Region prefix: '%s' is not supported", + CRUS, + CREU, + regionPrefix, + ) + } + } + return &Client{ - brCli: bedrockruntime.NewFromConfig(awsCfg), + brCli: bedrockruntime.NewFromConfig(awsCfg), + crInferenceRegion: regionPrefix, }, nil } // adaptModelForMessage takes the model as defined in anthropic.Model and adapts it to the model Bedrock expects -func adaptModelForMessage(model anthropic.Model) (string, error) { - if model == anthropic.Claude35Sonnet { - return BedrockModelClaude35Sonnet, nil +func (c *Client) adaptModelForMessage(model anthropic.Model) (string, error) { + adaptedModel := "" + + switch model { + case anthropic.Claude35Sonnet: + adaptedModel = BedrockModelClaude35Sonnet + case anthropic.Claude3Opus: + adaptedModel = BedrockModelClaude3Opus + case anthropic.Claude3Sonnet: + adaptedModel = BedrockModelClaude3Sonnet + case anthropic.Claude3Haiku: + adaptedModel = BedrockModelClaude3Haiku + case anthropic.ClaudeV2_1: + adaptedModel = BedrockModelClaudeV2_1 + default: + return "", fmt.Errorf("model %s is not compatible with the bedrock message endpoint", model) } - if model == anthropic.Claude3Opus { - return BedrockModelClaude3Opus, nil - } - if model == anthropic.Claude3Sonnet { - return BedrockModelClaude3Sonnet, nil - } - if model == anthropic.Claude3Haiku { - return BedrockModelClaude3Haiku, nil + + if c.crInferenceRegion == "" { + return adaptedModel, nil } - if model == anthropic.ClaudeV2_1 { - return BedrockModelClaudeV2_1, nil + + if adaptedModel == BedrockModelClaudeV2_1 { + return "", fmt.Errorf("Bedrock model %s is not compatible with cross-region inference", adaptedModel) } - return "", fmt.Errorf("model %s is not compatible with the bedrock message endpoint", model) + return fmt.Sprintf("%s.%s", c.crInferenceRegion, adaptedModel), nil } // adaptModelForCompletion takes the model as defined in anthropic.Model and adapts it to the model Bedrock expects diff --git a/pkg/anthropic/client/bedrock/client_test.go b/pkg/anthropic/client/bedrock/client_test.go new file mode 100644 index 0000000..a17f4c7 --- /dev/null +++ b/pkg/anthropic/client/bedrock/client_test.go @@ -0,0 +1,212 @@ +package bedrock + +import ( + "context" + "fmt" + "strings" + "testing" + + "github.com/madebywelch/anthropic-go/v3/pkg/anthropic" +) + +func Test_Client_Success_RegionOnly(t *testing.T) { + client, err := MakeClient(context.Background(), Config{ + Region: "us-west-2", + }) + + assertSuccessClient(t, client, err, "") +} + +func Test_Client_Success_RegionWithCredentials(t *testing.T) { + client, err := MakeClient(context.Background(), Config{ + Region: "us-west-2", + AccessKeyID: "hello-there", + SecretAccessKey: "general-kenobi", + SessionToken: "order-66", + }) + + assertSuccessClient(t, client, err, "") +} + +func Test_Client_Success_RegionWithCrossRegionInference(t *testing.T) { + client, err := MakeClient(context.Background(), Config{ + Region: "us-west-2", + CrossRegionInference: true, + }) + + assertSuccessClient(t, client, err, "us") +} + +func Test_Client_Failure_MissingRegion(t *testing.T) { + client, err := MakeClient(context.Background(), Config{}) + if err == nil { + t.Error("Expected an error when region is not set") + } + + if client != nil { + t.Error("Unexpected value for client when region is not set") + } +} + +func Test_Client_Failure_UnsupportedCrossRegionInference(t *testing.T) { + client, err := MakeClient(context.Background(), Config{ + Region: "he-llothere", + CrossRegionInference: true, + }) + if err == nil { + t.Error("Expected an error when using an unsupported region for cross region inference") + } + + if !strings.Contains(err.Error(), "region inference is only supported") { + t.Errorf("Exepcted an error for unsupported region inference: %s", err.Error()) + } + + if client != nil { + t.Error("Unexpected value for client when region is not set") + } +} + +type modelTest struct { + modelInput anthropic.Model + expectedModelOutput string +} + +func Test_adaptModelForMessage_Success_NonCrossRegion(t *testing.T) { + client, err := MakeClient(context.Background(), Config{ + Region: "us-west-2", + }) + if err != nil { + t.Errorf("Unexpected error when establishing client %s", err.Error()) + } + + testCases := []*modelTest{ + { + modelInput: anthropic.Claude35Sonnet, + expectedModelOutput: BedrockModelClaude35Sonnet, + }, + { + modelInput: anthropic.Claude3Opus, + expectedModelOutput: BedrockModelClaude3Opus, + }, + { + modelInput: anthropic.Claude3Sonnet, + expectedModelOutput: BedrockModelClaude3Sonnet, + }, + { + modelInput: anthropic.Claude3Haiku, + expectedModelOutput: BedrockModelClaude3Haiku, + }, + { + modelInput: anthropic.ClaudeV2_1, + expectedModelOutput: BedrockModelClaudeV2_1, + }, + } + + result := "" + for _, testCase := range testCases { + result, err = client.adaptModelForMessage(testCase.modelInput) + if err != nil { + t.Errorf("Unexpected error when adapting model: %s", err.Error()) + } + + if result != testCase.expectedModelOutput { + t.Errorf("Error when adapting model. Expected: %s, Actual: %s", testCase.expectedModelOutput, result) + } + } +} + +func Test_adaptModelForMessage_Failure_UnsupportedModel(t *testing.T) { + client, err := MakeClient(context.Background(), Config{ + Region: "us-west-2", + }) + if err != nil { + t.Errorf("Unexpected error when establishing client %s", err.Error()) + } + + result, err := client.adaptModelForMessage("hello-there") + if err == nil { + t.Error("Expected an error when adapting unsupported model") + } + + if result != "" { + t.Errorf("Unexpected result for adaptModel: %s", result) + } +} + +func Test_adaptModelForMessage_Success_CrossRegion(t *testing.T) { + client, err := MakeClient(context.Background(), Config{ + Region: "eu-west-1", + CrossRegionInference: true, + }) + if err != nil { + t.Errorf("Unexpected error when establishing client %s", err.Error()) + } + + testCases := []*modelTest{ + { + modelInput: anthropic.Claude35Sonnet, + expectedModelOutput: fmt.Sprintf("%s.%s", client.crInferenceRegion, BedrockModelClaude35Sonnet), + }, + { + modelInput: anthropic.Claude3Opus, + expectedModelOutput: fmt.Sprintf("%s.%s", client.crInferenceRegion, BedrockModelClaude3Opus), + }, + { + modelInput: anthropic.Claude3Sonnet, + expectedModelOutput: fmt.Sprintf("%s.%s", client.crInferenceRegion, BedrockModelClaude3Sonnet), + }, + { + modelInput: anthropic.Claude3Haiku, + expectedModelOutput: fmt.Sprintf("%s.%s", client.crInferenceRegion, BedrockModelClaude3Haiku), + }, + } + + result := "" + for _, testCase := range testCases { + result, err = client.adaptModelForMessage(testCase.modelInput) + if err != nil { + t.Errorf("Unexpected error when adapting model: %s", err.Error()) + } + + if result != testCase.expectedModelOutput { + t.Errorf("Error when adapting model. Expected: %s, Actual: %s", testCase.expectedModelOutput, result) + } + } +} + +func Test_adaptModelForMessage_Failure_ClaudeV2_CrossRegionInference(t *testing.T) { + client, err := MakeClient(context.Background(), Config{ + Region: "eu-west-1", + CrossRegionInference: true, + }) + if err != nil { + t.Errorf("Unexpected error when establishing client %s", err.Error()) + } + + result, err := client.adaptModelForMessage(anthropic.ClaudeV2_1) + if err == nil { + t.Error("Expected an error when using cross region inference on claude v2.1") + } + + if !strings.Contains(err.Error(), "not compatible with cross-region") { + t.Error("Expected a 'not compatible with cross-region' error") + } + + if result != "" { + t.Errorf("Unexpected result for adaptModel: %s", result) + } +} + +func assertSuccessClient(t *testing.T, client *Client, err error, crRegionValue string) { + if err != nil { + t.Errorf("Unexpected error %s", err.Error()) + } + + if client.brCli == nil { + t.Error("Unexpected nil for brCli") + } + + if client.crInferenceRegion != crRegionValue { + t.Errorf("Unexpected value for inference region %s", client.crInferenceRegion) + } +} diff --git a/pkg/anthropic/client/bedrock/message.go b/pkg/anthropic/client/bedrock/message.go index 196a772..04d405f 100644 --- a/pkg/anthropic/client/bedrock/message.go +++ b/pkg/anthropic/client/bedrock/message.go @@ -21,7 +21,7 @@ func (c *Client) Message(ctx context.Context, req *anthropic.MessageRequest) (*a } func (c *Client) sendMessageRequest(ctx context.Context, req *anthropic.MessageRequest) (*anthropic.MessageResponse, error) { - adaptedModel, err := adaptModelForMessage(req.Model) + adaptedModel, err := c.adaptModelForMessage(req.Model) if err != nil { return nil, err } diff --git a/pkg/anthropic/client/bedrock/message_stream.go b/pkg/anthropic/client/bedrock/message_stream.go index 9de80f1..b798f4c 100644 --- a/pkg/anthropic/client/bedrock/message_stream.go +++ b/pkg/anthropic/client/bedrock/message_stream.go @@ -37,7 +37,7 @@ func (c *Client) handleMessageStreaming( defer close(msCh) defer close(errCh) - adaptedModel, err := adaptModelForMessage(req.Model) + adaptedModel, err := c.adaptModelForMessage(req.Model) if err != nil { errCh <- fmt.Errorf("error adapting model: %w", err) return