diff --git a/README.md b/README.md index e2c5fab..464d8fa 100644 --- a/README.md +++ b/README.md @@ -1,150 +1,55 @@ -# Unofficial Anthropic SDK in Go +# Unofficial Anthropic SDK for Go -This project provides an unofficial Go SDK for Anthropic, a A next-generation AI assistant for your tasks, no matter the scale. The SDK makes it easy to interact with the Anthropic API in Go applications. For more information about Anthropic, including API documentation, visit the official [Anthropic documentation.](https://console.anthropic.com/docs) +This project provides an unofficial Go SDK for Anthropic, a next-generation AI assistant platform. The SDK simplifies interactions with the Anthropic API in Go applications. For more information about Anthropic and its API, visit the [official Anthropic documentation](https://console.anthropic.com/docs). -[![GoDoc](https://godoc.org/github.com/madebywelch/anthropic-go?status.svg)](https://pkg.go.dev/github.com/madebywelch/anthropic-go/v2) +[![GoDoc](https://godoc.org/github.com/madebywelch/anthropic-go?status.svg)](https://pkg.go.dev/github.com/madebywelch/anthropic-go/v3) ## Installation -You can install the Anthropic SDK in Go using go get: +Install the Anthropic SDK for Go using: ```go -go get github.com/madebywelch/anthropic-go/v2 +go get github.com/madebywelch/anthropic-go/v3 ``` -## Usage - -To use the Anthropic SDK, you'll need to initialize a client and make requests to the Anthropic API. Here's an example of initializing a client and performing a regular and a streaming completion: +## Features -## Completion Example - -```go -package main +- Support for both native Anthropic API and AWS Bedrock +- Completion and streaming completion +- Message and streaming message support +- Tool usage capabilities -import ( - "fmt" +## Quick Start - "github.com/madebywelch/anthropic-go/v2/pkg/anthropic" - "github.com/madebywelch/anthropic-go/v2/pkg/anthropic/utils" -) - -func main() { - client, err := anthropic.NewClient("your-api-key") - if err != nil { - panic(err) - } - - prompt, err := utils.GetPrompt("Why is the sky blue?") - if err != nil { - panic(err) - } - - request := anthropic.NewCompletionRequest( - prompt, - anthropic.WithModel[anthropic.CompletionRequest](anthropic.ClaudeV2_1), - anthropic.WithMaxTokens[anthropic.CompletionRequest](100), - ) - - // Note: Only use client.Complete when streaming is disabled, otherwise use client.CompleteStream! - response, err := client.Complete(request) - if err != nil { - panic(err) - } - - fmt.Printf("Completion: %s\n", response.Completion) -} -``` - -### Completion Example Output - -``` -The sky appears blue to us due to the way the atmosphere scatters light from the sun -``` - -## Completion Streaming Example +Here's a basic example of using the SDK: ```go package main import ( + "context" "fmt" - "github.com/madebywelch/anthropic-go/v2/pkg/anthropic" - "github.com/madebywelch/anthropic-go/v2/pkg/anthropic/utils" + "github.com/madebywelch/anthropic-go/v3/pkg/anthropic" + "github.com/madebywelch/anthropic-go/v3/pkg/anthropic/client/native" ) func main() { - client, err := anthropic.NewClient("your-api-key") + ctx := context.Background() + client, err := native.MakeClient(native.Config{ + APIKey: "your-api-key", + }) if err != nil { panic(err) } - prompt, err := utils.GetPrompt("Why is the sky blue?") - if err != nil { - panic(err) - } - - request := anthropic.NewCompletionRequest( - prompt, - anthropic.WithModel[anthropic.CompletionRequest](anthropic.ClaudeV2_1), - anthropic.WithMaxTokens[anthropic.CompletionRequest](100), - anthropic.WithStreaming[anthropic.CompletionRequest](true), - ) - - // Note: Only use client.CompleteStream when streaming is enabled, otherwise use client.Complete! - resps, errs := client.CompleteStream(request) - - for { - select { - case resp := <-resps: - fmt.Printf("Completion: %s\n", resp.Completion) - case err := <-errs: - panic(err) - } - } -} -``` - -### Completion Streaming Example Output - -``` -There -are -a -few -reasons -why -the -sky -appears -``` - -## Messages Example - -```go -package main - -import ( - "fmt" - - "github.com/madebywelch/anthropic-go/v2/pkg/anthropic" -) - -func main() { - client, err := anthropic.NewClient("your-api-key") - if err != nil { - panic(err) - } - - // Prepare a message request request := anthropic.NewMessageRequest( []anthropic.MessagePartRequest{{Role: "user", Content: []anthropic.ContentBlock{anthropic.NewTextContentBlock("Hello, world!")}}}, - anthropic.WithModel[anthropic.MessageRequest](anthropic.ClaudeV2_1), + anthropic.WithModel[anthropic.MessageRequest](anthropic.Claude35Sonnet), anthropic.WithMaxTokens[anthropic.MessageRequest](20), ) - // Call the Message method - response, err := client.Message(request) + response, err := client.Message(ctx, request) if err != nil { panic(err) } @@ -153,152 +58,20 @@ func main() { } ``` -### Messages Example Output - -``` -{ID:msg_01W3bZkuMrS3h1ehqTdF84vv Type:message Model:claude-2.1 Role:assistant Content:[{Type:text Text:Hello!}] StopReason:end_turn Stop: StopSequence:} -``` - -## Messages Streaming Example - -```go -package main - -import ( - "fmt" - "os" - - "github.com/madebywelch/anthropic-go/v2/pkg/anthropic" -) - -func main() { - apiKey, ok := os.LookupEnv("ANTHROPIC_API_KEY") - if !ok { - fmt.Printf("missing ANTHROPIC_API_KEY environment variable") - } - client, err := anthropic.NewClient(apiKey) - if err != nil { - panic(err) - } - - // Prepare a message request - request := anthropic.NewMessageRequest( - []anthropic.MessagePartRequest{{Role: "user", Content: "Hello, Good Morning!"}}, - anthropic.WithModel[anthropic.MessageRequest](anthropic.ClaudeV2_1), - anthropic.WithMaxTokens[anthropic.MessageRequest](20), - anthropic.WithStreaming[anthropic.MessageRequest](true), - ) - - // Call the Message method - resps, errors := client.MessageStream(request) - - for { - select { - case response := <-resps: - if response.Type == "content_block_delta" { - fmt.Println(response.Delta.Text) - } - if response.Type == "message_stop" { - fmt.Println("Message stop") - return - } - case err := <-errors: - fmt.Println(err) - return - } - } -} -``` - -### Messages Streaming Example Output - -``` -Good - morning -! - As - an - AI - language - model -, - I - don -'t - have - feelings - or - a - physical - state -, - but -``` - -## Messages Tools Example - -```go -package main - -import ( - "github.com/madebywelch/anthropic-go/v2/pkg/anthropic" -) - -func main() { - client, err := anthropic.NewClient("your-api-key") - if err != nil { - panic(err) - } - - // Prepare a message request - request := &anthropic.MessageRequest{ - Model: anthropic.Claude3Opus, - MaxTokensToSample: 1024, - Tools: []anthropic.Tool{ - { - Name: "get_weather", - Description: "Get the weather", - InputSchema: anthropic.InputSchema{ - Type: "object", - Properties: map[string]anthropic.Property{ - "city": {Type: "string", Description: "city to get the weather for"}, - "unit": {Type: "string", Enum: []string{"celsius", "fahrenheit"}, Description: "temperature unit to return"}}, - Required: []string{"city"}, - }, - }, - }, - Messages: []anthropic.MessagePartRequest{ - { - Role: "user", - Content: []anthropic.ContentBlock{ - anthropic.NewTextContentBlock("what's the weather in Charleston?"), - }, - }, - }, - } - - // Call the Message method - response, err := client.Message(request) - if err != nil { - panic(err) - } +## Usage - if response.StopReason == "tool_use" { - // Do something with the tool response - } -} -``` +For more detailed usage examples, including streaming, completion, and tool usage, please refer to the `pkg/internal/examples` directory in the repository. ## Contributing -Contributions to this project are welcome. To contribute, follow these steps: +Contributions are welcome! To contribute: -- Fork this repository -- Create a new branch (`git checkout -b feature/my-new-feature`) -- Commit your changes (`git commit -am 'Add some feature'`) -- Push the branch (`git push origin feature/my-new-feature`) -- Create a new pull request +1. Fork this repository +2. Create a new branch (`git checkout -b feature/my-new-feature`) +3. Commit your changes (`git commit -am 'Add some feature'`) +4. Push the branch (`git push origin feature/my-new-feature`) +5. Create a new pull request ## License -This project is licensed under the Apache License, Version 2.0 - see the [LICENSE](LICENSE) file for details. +This project is licensed under the Apache License, Version 2.0. See the [LICENSE](LICENSE) file for details. diff --git a/go.mod b/go.mod index ccc9e9e..5c49378 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,24 @@ -module github.com/madebywelch/anthropic-go/v2 +module github.com/madebywelch/anthropic-go/v3 go 1.20 + +require ( + github.com/aws/aws-sdk-go-v2 v1.30.1 + github.com/aws/aws-sdk-go-v2/config v1.27.23 + github.com/aws/aws-sdk-go-v2/credentials v1.17.23 + github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.12.1 +) + +require ( + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.3 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.9 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.13 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.13 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.3 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.15 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.22.1 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.1 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.30.1 // indirect + github.com/aws/smithy-go v1.20.3 // indirect +) diff --git a/go.sum b/go.sum index e69de29..5eda614 100644 --- a/go.sum +++ b/go.sum @@ -0,0 +1,30 @@ +github.com/aws/aws-sdk-go-v2 v1.30.1 h1:4y/5Dvfrhd1MxRDD77SrfsDaj8kUkkljU7XE83NPV+o= +github.com/aws/aws-sdk-go-v2 v1.30.1/go.mod h1:nIQjQVp5sfpQcTc9mPSr1B0PaWK5ByX9MOoDadSN4lc= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.3 h1:tW1/Rkad38LA15X4UQtjXZXNKsCgkshC3EbmcUmghTg= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.3/go.mod h1:UbnqO+zjqk3uIt9yCACHJ9IVNhyhOCnYk8yA19SAWrM= +github.com/aws/aws-sdk-go-v2/config v1.27.23 h1:Cr/gJEa9NAS7CDAjbnB7tHYb3aLZI2gVggfmSAasDac= +github.com/aws/aws-sdk-go-v2/config v1.27.23/go.mod h1:WMMYHqLCFu5LH05mFOF5tsq1PGEMfKbu083VKqLCd0o= +github.com/aws/aws-sdk-go-v2/credentials v1.17.23 h1:G1CfmLVoO2TdQ8z9dW+JBc/r8+MqyPQhXCafNZcXVZo= +github.com/aws/aws-sdk-go-v2/credentials v1.17.23/go.mod h1:V/DvSURn6kKgcuKEk4qwSwb/fZ2d++FFARtWSbXnLqY= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.9 h1:Aznqksmd6Rfv2HQN9cpqIV/lQRMaIpJkLLaJ1ZI76no= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.9/go.mod h1:WQr3MY7AxGNxaqAtsDWn+fBxmd4XvLkzeqQ8P1VM0/w= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.13 h1:5SAoZ4jYpGH4721ZNoS1znQrhOfZinOhc4XuTXx/nVc= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.13/go.mod h1:+rdA6ZLpaSeM7tSg/B0IEDinCIBJGmW8rKDFkYpP04g= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.13 h1:WIijqeaAO7TYFLbhsZmi2rgLEAtWOC1LhxCAVTJlSKw= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.13/go.mod h1:i+kbfa76PQbWw/ULoWnp51EYVWH4ENln76fLQE3lXT8= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 h1:hT8rVHwugYE2lEfdFE0QWVo81lF7jMrYJVDWI+f+VxU= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0/go.mod h1:8tu/lYfQfFe6IGnaOdrpVgEL2IrrDOf6/m9RQum4NkY= +github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.12.1 h1:3B45hjMYPuv9K3M8dBUhQiLaZz6QIOF3AYgCadMoUpQ= +github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.12.1/go.mod h1:jeJzYp86gwna3f1bV3q0A9pxOyrdK4D0thCZ84ru6L0= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.3 h1:dT3MqvGhSoaIhRseqw2I0yH81l7wiR2vjs57O51EAm8= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.3/go.mod h1:GlAeCkHwugxdHaueRr4nhPuY+WW+gR8UjlcqzPr1SPI= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.15 h1:I9zMeF107l0rJrpnHpjEiiTSCKYAIw8mALiXcPsGBiA= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.15/go.mod h1:9xWJ3Q/S6Ojusz1UIkfycgD1mGirJfLLKqq3LPT7WN8= +github.com/aws/aws-sdk-go-v2/service/sso v1.22.1 h1:p1GahKIjyMDZtiKoIn0/jAj/TkMzfzndDv5+zi2Mhgc= +github.com/aws/aws-sdk-go-v2/service/sso v1.22.1/go.mod h1:/vWdhoIoYA5hYoPZ6fm7Sv4d8701PiG5VKe8/pPJL60= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.1 h1:lCEv9f8f+zJ8kcFeAjRZsekLd/x5SAm96Cva+VbUdo8= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.1/go.mod h1:xyFHA4zGxgYkdD73VeezHt3vSKEG9EmFnGwoKlP00u4= +github.com/aws/aws-sdk-go-v2/service/sts v1.30.1 h1:+woJ607dllHJQtsnJLi52ycuqHMwlW+Wqm2Ppsfp4nQ= +github.com/aws/aws-sdk-go-v2/service/sts v1.30.1/go.mod h1:jiNR3JqT15Dm+QWq2SRgh0x0bCNSRP2L25+CqPNpJlQ= +github.com/aws/smithy-go v1.20.3 h1:ryHwveWzPV5BIof6fyDvor6V3iUL7nTfiTKXHiW05nE= +github.com/aws/smithy-go v1.20.3/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E= diff --git a/pkg/anthropic/client.go b/pkg/anthropic/client.go deleted file mode 100644 index 6cca571..0000000 --- a/pkg/anthropic/client.go +++ /dev/null @@ -1,30 +0,0 @@ -package anthropic - -import ( - "net/http" -) - -// Client represents the Anthropic API client and its configuration. -type Client struct { - httpClient *http.Client - apiKey string - baseURL string -} - -// NewClient initializes a new Anthropic API client with the required headers. -func NewClient(apiKey string, options ...GenericOption[Client]) (*Client, error) { - if apiKey == "" { - return nil, ErrAnthropicApiKeyRequired - } - - client := &Client{ - httpClient: &http.Client{}, - apiKey: apiKey, - baseURL: "https://api.anthropic.com", - } - for _, opt := range options { - opt(client) - } - - return client, nil -} diff --git a/pkg/anthropic/client/bedrock/client.go b/pkg/anthropic/client/bedrock/client.go new file mode 100644 index 0000000..1e7745d --- /dev/null +++ b/pkg/anthropic/client/bedrock/client.go @@ -0,0 +1,130 @@ +package bedrock + +import ( + "context" + "fmt" + "regexp" + "strconv" + + "github.com/madebywelch/anthropic-go/v3/pkg/anthropic" + + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" +) + +const ( + AnthropicVersion = "bedrock-2023-05-31" + + BedrockModelClaude3Opus = "anthropic.claude-3-opus-20240229-v1:0" + BedrockModelClaude3Sonnet = "anthropic.claude-3-sonnet-20240229-v1:0" + BedrockModelClaude3Haiku = "anthropic.claude-3-haiku-20240307-v1:0" + BedrockModelClaudeV2_1 = "anthropic.claude-v2:1" +) + +type Client struct { + brCli *bedrockruntime.Client +} + +type Config struct { + Region string + AccessKeyID string + SecretAccessKey string + SessionToken string +} + +func MakeClient(ctx context.Context, cfg Config) (*Client, error) { + awsCfg, err := config.LoadDefaultConfig( + ctx, + config.WithRegion(cfg.Region), + ) + + // override config load with static credentials if provided + if cfg.AccessKeyID != "" && cfg.SecretAccessKey != "" { + credsProvider := credentials.NewStaticCredentialsProvider(cfg.AccessKeyID, cfg.SecretAccessKey, cfg.SessionToken) + awsCfg, err = config.LoadDefaultConfig( + ctx, + config.WithRegion(cfg.Region), + config.WithCredentialsProvider(credsProvider), + ) + } + + if err != nil { + return nil, err + } + + return &Client{ + brCli: bedrockruntime.NewFromConfig(awsCfg), + }, nil +} + +// adaptModelForMessage takes the model as defined in anthropic.Model and adapts it to the model Bedrock expects +func adaptModelForMessage(model anthropic.Model) (string, error) { + if model == anthropic.Claude3Opus { + return BedrockModelClaude3Opus, nil + } + if model == anthropic.Claude3Sonnet { + return BedrockModelClaude3Sonnet, nil + } + if model == anthropic.Claude3Haiku { + return BedrockModelClaude3Haiku, nil + } + if model == anthropic.ClaudeV2_1 { + return BedrockModelClaudeV2_1, nil + } + + return "", fmt.Errorf("model %s is not compatible with the bedrock message endpoint", model) +} + +// adaptModelForCompletion takes the model as defined in anthropic.Model and adapts it to the model Bedrock expects +func adaptModelForCompletion(model anthropic.Model) (string, error) { + if model == anthropic.ClaudeV2_1 { + return BedrockModelClaudeV2_1, nil + } + + return "", fmt.Errorf("model %s is not compatible with the bedrock completion endpoint", model) +} + +// MessageRequest is an override for the default message request to adapt the request for the Bedrock API. +type MessageRequest struct { + anthropic.MessageRequest + AnthropicVersion string `json:"anthropic_version"` + Model bool `json:"model,omitempty"` // shadow for Model + Stream bool `json:"stream,omitempty"` // shadow for Stream +} + +func adaptMessageRequest(req *anthropic.MessageRequest) *MessageRequest { + return &MessageRequest{ + MessageRequest: *req, + AnthropicVersion: AnthropicVersion, + } +} + +type CompleteRequest struct { + anthropic.CompletionRequest + AnthropicVersion string `json:"anthropic_version"` + Model bool `json:"model,omitempty"` // shadow for Model + Stream bool `json:"stream,omitempty"` // shadow for Stream +} + +func adaptCompletionRequest(req *anthropic.CompletionRequest) *CompleteRequest { + return &CompleteRequest{ + CompletionRequest: *req, + AnthropicVersion: AnthropicVersion, + } +} + +func extractErrStatusCode(err error) int { + re := regexp.MustCompile(`StatusCode: (\d+)`) + match := re.FindStringSubmatch(err.Error()) + + if len(match) > 1 { + res, err := strconv.Atoi(match[1]) + if err != nil { + return 0 + } + return res + } + + return 0 +} diff --git a/pkg/anthropic/client/bedrock/complete.go b/pkg/anthropic/client/bedrock/complete.go new file mode 100644 index 0000000..da5137f --- /dev/null +++ b/pkg/anthropic/client/bedrock/complete.go @@ -0,0 +1,54 @@ +package bedrock + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" + "github.com/madebywelch/anthropic-go/v3/pkg/anthropic" +) + +func (c *Client) Complete(ctx context.Context, req *anthropic.CompletionRequest) (*anthropic.CompletionResponse, error) { + err := anthropic.ValidateCompleteRequest(req) + if err != nil { + return nil, err + } + + return c.sendCompleteRequest(ctx, req) +} + +func (c *Client) sendCompleteRequest(ctx context.Context, req *anthropic.CompletionRequest) (*anthropic.CompletionResponse, error) { + adaptedModel, err := adaptModelForCompletion(req.Model) + if err != nil { + return nil, err + } + + // Adapt the request to a Bedrock request + bedReq := adaptCompletionRequest(req) + + data, err := json.Marshal(bedReq) + if err != nil { + return nil, fmt.Errorf("error marshalling complete request: %w", err) + } + + response, err := c.brCli.InvokeModel(ctx, &bedrockruntime.InvokeModelInput{ + Body: data, + ModelId: aws.String(adaptedModel), + ContentType: aws.String("application/json"), + }) + + if err != nil { + errStatusCode := extractErrStatusCode(err) + return nil, anthropic.MapHTTPStatusCodeToError(errStatusCode) + } + + compResp := &anthropic.CompletionResponse{} + err = json.Unmarshal(response.Body, compResp) + if err != nil { + return nil, fmt.Errorf("error unmarshalling complete response: %w", err) + } + + return compResp, nil +} diff --git a/pkg/anthropic/client/bedrock/complete_stream.go b/pkg/anthropic/client/bedrock/complete_stream.go new file mode 100644 index 0000000..24902f5 --- /dev/null +++ b/pkg/anthropic/client/bedrock/complete_stream.go @@ -0,0 +1,89 @@ +package bedrock + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/madebywelch/anthropic-go/v3/pkg/anthropic" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types" +) + +func (c *Client) CompleteStream(ctx context.Context, req *anthropic.CompletionRequest) (<-chan *anthropic.StreamResponse, <-chan error) { + cCh := make(chan *anthropic.StreamResponse) + errCh := make(chan error, 1) + + err := anthropic.ValidateCompleteStreamRequest(req) + if err != nil { + errCh <- err + close(cCh) + close(errCh) + return cCh, errCh + } + + go c.handleCompleteStreaming(ctx, req, cCh, errCh) + return cCh, errCh +} + +func (c *Client) handleCompleteStreaming( + ctx context.Context, + req *anthropic.CompletionRequest, + cCh chan<- *anthropic.StreamResponse, + errCh chan<- error, +) { + defer close(cCh) + defer close(errCh) + + adaptedModel, err := adaptModelForCompletion(req.Model) + if err != nil { + errCh <- err + return + } + + // Adapt the request to a Bedrock request + bedReq := adaptCompletionRequest(req) + + data, err := json.Marshal(bedReq) + if err != nil { + errCh <- fmt.Errorf("error marshalling complete request: %w", err) + return + } + + response, err := c.brCli.InvokeModelWithResponseStream( + ctx, + &bedrockruntime.InvokeModelWithResponseStreamInput{ + Body: data, + ModelId: aws.String(adaptedModel), + ContentType: aws.String("application/json"), + }, + ) + if err != nil { + errStatusCode := extractErrStatusCode(err) + errCh <- anthropic.MapHTTPStatusCodeToError(errStatusCode) + return + } + + for event := range response.GetStream().Events() { + select { + case <-ctx.Done(): + return + default: + } + + if v, ok := event.(*types.ResponseStreamMemberChunk); ok { + streamResp := &anthropic.StreamResponse{} + err = json.Unmarshal(v.Value.Bytes, streamResp) + if err != nil { + errCh <- fmt.Errorf("error unmarshalling stream response: %w", err) + return + } + + fmt.Println(streamResp) + + cCh <- streamResp + } + } +} diff --git a/pkg/anthropic/client/bedrock/message.go b/pkg/anthropic/client/bedrock/message.go new file mode 100644 index 0000000..196a772 --- /dev/null +++ b/pkg/anthropic/client/bedrock/message.go @@ -0,0 +1,55 @@ +package bedrock + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/madebywelch/anthropic-go/v3/pkg/anthropic" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" +) + +func (c *Client) Message(ctx context.Context, req *anthropic.MessageRequest) (*anthropic.MessageResponse, error) { + err := anthropic.ValidateMessageRequest(req) + if err != nil { + return nil, err + } + + return c.sendMessageRequest(ctx, req) +} + +func (c *Client) sendMessageRequest(ctx context.Context, req *anthropic.MessageRequest) (*anthropic.MessageResponse, error) { + adaptedModel, err := adaptModelForMessage(req.Model) + if err != nil { + return nil, err + } + + // Adapt the request to a Bedrock request + bedReq := adaptMessageRequest(req) + + data, err := json.Marshal(bedReq) + if err != nil { + return nil, fmt.Errorf("error marshalling message request: %w", err) + } + + response, err := c.brCli.InvokeModel(ctx, &bedrockruntime.InvokeModelInput{ + Body: data, + ModelId: aws.String(adaptedModel), + ContentType: aws.String("application/json"), + }) + + if err != nil { + errStatusCode := extractErrStatusCode(err) + return nil, anthropic.MapHTTPStatusCodeToError(errStatusCode) + } + + msgResp := &anthropic.MessageResponse{} + err = json.Unmarshal(response.Body, msgResp) + if err != nil { + return nil, fmt.Errorf("error unmarshalling message response: %w", err) + } + + return msgResp, nil +} diff --git a/pkg/anthropic/client/bedrock/message_stream.go b/pkg/anthropic/client/bedrock/message_stream.go new file mode 100644 index 0000000..9de80f1 --- /dev/null +++ b/pkg/anthropic/client/bedrock/message_stream.go @@ -0,0 +1,99 @@ +package bedrock + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/madebywelch/anthropic-go/v3/pkg/anthropic" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types" +) + +func (c *Client) MessageStream(ctx context.Context, req *anthropic.MessageRequest) (<-chan *anthropic.MessageStreamResponse, <-chan error) { + msCh := make(chan *anthropic.MessageStreamResponse) + errCh := make(chan error, 1) + + err := anthropic.ValidateMessageStreamRequest(req) + if err != nil { + errCh <- err + close(msCh) + close(errCh) + return msCh, errCh + } + + go c.handleMessageStreaming(ctx, req, msCh, errCh) + return msCh, errCh +} + +func (c *Client) handleMessageStreaming( + ctx context.Context, + req *anthropic.MessageRequest, + msCh chan<- *anthropic.MessageStreamResponse, + errCh chan<- error, +) { + defer close(msCh) + defer close(errCh) + + adaptedModel, err := adaptModelForMessage(req.Model) + if err != nil { + errCh <- fmt.Errorf("error adapting model: %w", err) + return + } + + // Adapt the request to a Bedrock request + bedReq := adaptMessageRequest(req) + + data, err := json.Marshal(bedReq) + if err != nil { + errCh <- fmt.Errorf("error marshalling message request: %w", err) + return + } + + response, err := c.brCli.InvokeModelWithResponseStream( + ctx, + &bedrockruntime.InvokeModelWithResponseStreamInput{ + Body: data, + ModelId: aws.String(adaptedModel), + ContentType: aws.String("application/json"), + }, + ) + if err != nil { + errStatusCode := extractErrStatusCode(err) + errCh <- anthropic.MapHTTPStatusCodeToError(errStatusCode) + return + } + + for event := range response.GetStream().Events() { + select { + case <-ctx.Done(): + return + default: + } + + if v, ok := event.(*types.ResponseStreamMemberChunk); ok { + event := &anthropic.MessageEvent{} + err := json.Unmarshal(v.Value.Bytes, event) + if err != nil { + errCh <- fmt.Errorf("error decoding event data: %w", err) + return + } + msg, err := anthropic.ParseMessageEvent( + anthropic.MessageEventType(event.Type), + string(v.Value.Bytes), + ) + if err != nil { + if _, ok := err.(anthropic.UnsupportedEventType); ok { + // ignore unsupported event types + } else { + errCh <- fmt.Errorf("error processing message stream: %v", err) + return + } + } + + msCh <- msg + } + } +} diff --git a/pkg/anthropic/client/client.go b/pkg/anthropic/client/client.go new file mode 100644 index 0000000..8cd36e6 --- /dev/null +++ b/pkg/anthropic/client/client.go @@ -0,0 +1,31 @@ +package client + +import ( + "context" + "fmt" + + "github.com/madebywelch/anthropic-go/v3/pkg/anthropic" + + "github.com/madebywelch/anthropic-go/v3/pkg/anthropic/client/bedrock" + "github.com/madebywelch/anthropic-go/v3/pkg/anthropic/client/native" +) + +type ClientType string + +type Client interface { + Message(context.Context, *anthropic.MessageRequest) (*anthropic.MessageResponse, error) + MessageStream(context.Context, *anthropic.MessageRequest) (<-chan *anthropic.MessageStreamResponse, <-chan error) + Complete(context.Context, *anthropic.CompletionRequest) (*anthropic.CompletionResponse, error) + CompleteStream(context.Context, *anthropic.CompletionRequest) (<-chan *anthropic.StreamResponse, <-chan error) +} + +func MakeClient(ctx context.Context, config interface{}) (Client, error) { + switch cfg := config.(type) { + case bedrock.Config: + return bedrock.MakeClient(ctx, cfg) + case native.Config: + return native.MakeClient(cfg) + } + + return nil, fmt.Errorf("unknown client config") +} diff --git a/pkg/anthropic/client/client_test.go b/pkg/anthropic/client/client_test.go new file mode 100644 index 0000000..abb5658 --- /dev/null +++ b/pkg/anthropic/client/client_test.go @@ -0,0 +1,40 @@ +package client + +import ( + "context" + "testing" + + "github.com/madebywelch/anthropic-go/v3/pkg/anthropic/client/bedrock" + "github.com/madebywelch/anthropic-go/v3/pkg/anthropic/client/native" +) + +func TestMakeClientNativeSuccess(t *testing.T) { + config := native.Config{ + APIKey: "test", + } + ctx := context.Background() + _, err := MakeClient(ctx, config) + if err != nil { + t.Errorf("expected nil, got %v", err) + } +} + +func TestMakeClientBedrockSuccess(t *testing.T) { + config := bedrock.Config{ + Region: "us-west-2", + } + ctx := context.Background() + _, err := MakeClient(ctx, config) + if err != nil { + t.Errorf("expected nil, got %v", err) + } +} + +func TestMakeClientInvalidConfig(t *testing.T) { + config := "perhaps-the-archive-is-incomplete" + ctx := context.Background() + _, err := MakeClient(ctx, config) + if err == nil { + t.Errorf("expected error, got nil") + } +} diff --git a/pkg/anthropic/client/native/client.go b/pkg/anthropic/client/native/client.go new file mode 100644 index 0000000..46acd02 --- /dev/null +++ b/pkg/anthropic/client/native/client.go @@ -0,0 +1,41 @@ +package native + +import ( + "net/http" + + "github.com/madebywelch/anthropic-go/v3/pkg/anthropic" +) + +type Client struct { + httpClient *http.Client + apiKey string + baseURL string +} + +type Config struct { + APIKey string + BaseURL string + + // Optional (defaults to http.DefaultClient) + HTTPClient *http.Client +} + +func MakeClient(cfg Config) (*Client, error) { + if cfg.APIKey == "" { + return nil, anthropic.ErrAnthropicApiKeyRequired + } + + if cfg.BaseURL == "" { + cfg.BaseURL = "https://api.anthropic.com" + } + + if cfg.HTTPClient == nil { + cfg.HTTPClient = http.DefaultClient + } + + return &Client{ + httpClient: cfg.HTTPClient, + apiKey: cfg.APIKey, + baseURL: cfg.BaseURL, + }, nil +} diff --git a/pkg/anthropic/client/native/complete.go b/pkg/anthropic/client/native/complete.go new file mode 100644 index 0000000..685689b --- /dev/null +++ b/pkg/anthropic/client/native/complete.go @@ -0,0 +1,53 @@ +package native + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + + "github.com/madebywelch/anthropic-go/v3/pkg/anthropic" +) + +func (c *Client) Complete(ctx context.Context, req *anthropic.CompletionRequest) (*anthropic.CompletionResponse, error) { + err := anthropic.ValidateCompleteRequest(req) + if err != nil { + return nil, err + } + + return c.sendCompleteRequest(ctx, req) +} + +func (c *Client) sendCompleteRequest(ctx context.Context, req *anthropic.CompletionRequest) (*anthropic.CompletionResponse, error) { + // Marshal the request to JSON + data, err := json.Marshal(req) + if err != nil { + return nil, fmt.Errorf("error marshalling completion request: %w", err) + } + + // Create the HTTP request + requestURL := fmt.Sprintf("%s/v1/complete", c.baseURL) + request, err := http.NewRequestWithContext(ctx, "POST", requestURL, bytes.NewBuffer(data)) + if err != nil { + return nil, fmt.Errorf("error creating new request: %w", err) + } + request.Header.Set("Content-Type", "application/json") + request.Header.Set("X-Api-Key", c.apiKey) + + // Use the DoRequest method to send the HTTP request + response, err := c.doRequest(request) + if err != nil { + return nil, fmt.Errorf("error sending completion request: %w", err) + } + defer response.Body.Close() + + // Decode the response body to a CompletionResponse object + completionResponse := &anthropic.CompletionResponse{} + err = json.NewDecoder(response.Body).Decode(&completionResponse) + if err != nil { + return nil, fmt.Errorf("error decoding completion response: %w", err) + } + + return completionResponse, nil +} diff --git a/pkg/anthropic/client/native/complete_stream.go b/pkg/anthropic/client/native/complete_stream.go new file mode 100644 index 0000000..23ea205 --- /dev/null +++ b/pkg/anthropic/client/native/complete_stream.go @@ -0,0 +1,93 @@ +package native + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + + "github.com/madebywelch/anthropic-go/v3/pkg/anthropic" +) + +func (c *Client) CompleteStream(ctx context.Context, req *anthropic.CompletionRequest) (<-chan *anthropic.StreamResponse, <-chan error) { + cCh := make(chan *anthropic.StreamResponse) + errCh := make(chan error, 1) + + err := anthropic.ValidateCompleteStreamRequest(req) + if err != nil { + errCh <- err + close(cCh) + close(errCh) + return cCh, errCh + } + + go c.handleCompleteStreaming(ctx, req, cCh, errCh) + return cCh, errCh +} + +func (c *Client) handleCompleteStreaming( + ctx context.Context, + req *anthropic.CompletionRequest, + cCh chan<- *anthropic.StreamResponse, + errCh chan<- error, +) { + defer close(cCh) + defer close(errCh) + + data, err := json.Marshal(req) + if err != nil { + errCh <- fmt.Errorf("error marshalling completion request: %w", err) + return + } + + request, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("%s/v1/complete", c.baseURL), bytes.NewBuffer(data)) + if err != nil { + errCh <- fmt.Errorf("error creating new request: %w", err) + return + } + + request.Header.Set("Content-Type", "application/json") + request.Header.Set("X-Api-Key", c.apiKey) + request.Header.Set("Accept", "text/event-stream") + + response, err := c.doRequest(request) + if err != nil { + errCh <- fmt.Errorf("error sending completion request: %w", err) + return + } + defer response.Body.Close() + + err = c.processSseStream(response.Body, cCh) + if err != nil { + errCh <- err + } +} + +func (c *Client) processSseStream(reader io.Reader, cCh chan<- *anthropic.StreamResponse) error { + scanner := bufio.NewScanner(reader) + + for scanner.Scan() { + line := scanner.Text() + + if strings.HasPrefix(line, "data:") { + data := strings.TrimSpace(line[5:]) + event := &anthropic.StreamResponse{} + err := json.Unmarshal([]byte(data), event) + if err != nil { + return fmt.Errorf("error decoding event data: %w", err) + } + + cCh <- event + } + } + + if err := scanner.Err(); err != nil { + return fmt.Errorf("error reading from stream: %w", err) + } + + return nil +} diff --git a/pkg/anthropic/http.go b/pkg/anthropic/client/native/http.go similarity index 79% rename from pkg/anthropic/http.go rename to pkg/anthropic/client/native/http.go index 8608b52..a39d5eb 100644 --- a/pkg/anthropic/http.go +++ b/pkg/anthropic/client/native/http.go @@ -1,8 +1,10 @@ // Package client contains the HTTP client and related functionality for the anthropic package. -package anthropic +package native import ( "net/http" + + "github.com/madebywelch/anthropic-go/v3/pkg/anthropic" ) const ( @@ -10,8 +12,6 @@ const ( AnthropicAPIVersion = "2023-06-01" // AnthropicAPIMessagesBeta is the beta version of the Anthropics API that enables the messages endpoint. AnthropicAPIMessagesBeta = "messages-2023-12-15" - // AnthropicAPIToolsBeta is the beta version of the Anthropic API that enables the tools endpoint. - AnthropicAPIToolsBeta = "tools-2024-04-04" ) // doRequest sends an HTTP request and returns the response, handling any non-OK HTTP status codes. @@ -24,7 +24,7 @@ func (c *Client) doRequest(request *http.Request) (*http.Response, error) { } if response.StatusCode != http.StatusOK { - err = mapHTTPStatusCodeToError(response.StatusCode) + err = anthropic.MapHTTPStatusCodeToError(response.StatusCode) return nil, err } diff --git a/pkg/anthropic/client/native/message.go b/pkg/anthropic/client/native/message.go new file mode 100644 index 0000000..4ffe745 --- /dev/null +++ b/pkg/anthropic/client/native/message.go @@ -0,0 +1,57 @@ +package native + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + + "github.com/madebywelch/anthropic-go/v3/pkg/anthropic" +) + +func (c *Client) Message(ctx context.Context, req *anthropic.MessageRequest) (*anthropic.MessageResponse, error) { + err := anthropic.ValidateMessageRequest(req) + if err != nil { + return nil, err + } + + return c.sendMessageRequest(ctx, req) +} + +func (c *Client) sendMessageRequest( + ctx context.Context, + req *anthropic.MessageRequest, +) (*anthropic.MessageResponse, error) { + // Marshal the request to JSON + data, err := json.Marshal(req) + if err != nil { + return nil, fmt.Errorf("error marshalling message request: %w", err) + } + + // Create the HTTP request + requestURL := fmt.Sprintf("%s/v1/messages", c.baseURL) + request, err := http.NewRequestWithContext(ctx, "POST", requestURL, bytes.NewBuffer(data)) + if err != nil { + return nil, fmt.Errorf("error creating new request: %w", err) + } + request.Header.Set("Content-Type", "application/json") + request.Header.Set("X-Api-Key", c.apiKey) + request.Header.Set("anthropic-beta", AnthropicAPIMessagesBeta) + + // Use the doRequest method to send the HTTP request + response, err := c.doRequest(request) + if err != nil { + return nil, fmt.Errorf("error sending message request: %w", err) + } + defer response.Body.Close() + + // Decode the response body to a MessageResponse object + messageResponse := &anthropic.MessageResponse{} + err = json.NewDecoder(response.Body).Decode(messageResponse) + if err != nil { + return nil, fmt.Errorf("error decoding message response: %w", err) + } + + return messageResponse, nil +} diff --git a/pkg/anthropic/client/native/message_stream.go b/pkg/anthropic/client/native/message_stream.go new file mode 100644 index 0000000..3b98a39 --- /dev/null +++ b/pkg/anthropic/client/native/message_stream.go @@ -0,0 +1,105 @@ +package native + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + + "github.com/madebywelch/anthropic-go/v3/pkg/anthropic" +) + +func (c *Client) MessageStream(ctx context.Context, req *anthropic.MessageRequest) (<-chan *anthropic.MessageStreamResponse, <-chan error) { + msCh := make(chan *anthropic.MessageStreamResponse) + errCh := make(chan error, 1) + + err := anthropic.ValidateMessageStreamRequest(req) + if err != nil { + errCh <- err + close(msCh) + close(errCh) + return msCh, errCh + } + + go c.handleMessageStreaming(ctx, req, msCh, errCh) + + return msCh, errCh +} + +func (c *Client) handleMessageStreaming( + ctx context.Context, + req *anthropic.MessageRequest, + msCh chan<- *anthropic.MessageStreamResponse, + errCh chan<- error, +) { + defer close(msCh) + defer close(errCh) + + data, err := json.Marshal(req) + if err != nil { + errCh <- fmt.Errorf("error marshalling message request: %w", err) + return + } + + request, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("%s/v1/messages", c.baseURL), bytes.NewBuffer(data)) + if err != nil { + errCh <- fmt.Errorf("error creating new request: %w", err) + return + } + + request.Header.Set("Content-Type", "application/json") + request.Header.Set("X-Api-Key", c.apiKey) + request.Header.Set("Accept", "text/event-stream") + + response, err := c.doRequest(request) + if err != nil { + errCh <- fmt.Errorf("error sending message request: %w", err) + return + } + defer response.Body.Close() + + err = c.processMessageSseStream(response.Body, msCh) + if err != nil { + errCh <- err + } +} + +func (c *Client) processMessageSseStream(reader io.Reader, events chan<- *anthropic.MessageStreamResponse) error { + scanner := bufio.NewScanner(reader) + + for scanner.Scan() { + line := scanner.Text() + + if strings.HasPrefix(line, "data:") { + data := strings.TrimSpace(line[5:]) + + event := &anthropic.MessageEvent{} + err := json.Unmarshal([]byte(data), event) + if err != nil { + return fmt.Errorf("error decoding event data: %w", err) + } + + msg, err := anthropic.ParseMessageEvent(anthropic.MessageEventType(event.Type), data) + + if err != nil { + if _, ok := err.(anthropic.UnsupportedEventType); ok { + // ignore unsupported event types + } else { + return fmt.Errorf("error processing message stream: %v", err) + } + } + + events <- msg + } + } + + if err := scanner.Err(); err != nil { + return fmt.Errorf("error reading from stream: %w", err) + } + + return nil +} diff --git a/pkg/anthropic/message_test.go b/pkg/anthropic/client/native/message_stream_test.go similarity index 59% rename from pkg/anthropic/message_test.go rename to pkg/anthropic/client/native/message_stream_test.go index 0010200..25c7147 100644 --- a/pkg/anthropic/message_test.go +++ b/pkg/anthropic/client/native/message_stream_test.go @@ -1,141 +1,43 @@ -package anthropic +package native import ( - "encoding/json" + "context" "fmt" "net/http" "net/http/httptest" "strings" "testing" -) - -func TestMessage(t *testing.T) { - // Mock server for successful message response - testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - resp := MessageResponse{ - ID: "12345", - Type: "testType", - Model: "testModel", - Role: "user", - Content: []MessagePartResponse{{ - Type: "text", - Text: "Test message", - }}, - Usage: MessageUsage{ - InputTokens: 10, - OutputTokens: 5, - }, - } - json.NewEncoder(w).Encode(resp) - })) - defer testServer.Close() - - // Create a new client with the test server's URL - client, err := NewClient("fake-api-key") - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - client.baseURL = testServer.URL // Override baseURL to point to the test server - - // Prepare a message request - request := &MessageRequest{ - Model: ClaudeV2_1, - Messages: []MessagePartRequest{{Role: "user", Content: []ContentBlock{NewTextContentBlock("Hello")}}}, - } - - // Call the Message method - response, err := client.Message(request) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - - // Check the response - expectedContent := "Test message" - if len(response.Content) == 0 || response.Content[0].Text != expectedContent { - t.Errorf("Expected message %q, got %q", expectedContent, response.Content[0].Text) - } -} - -func TestMessageErrorHandling(t *testing.T) { - // Mock server for error response - testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - })) - defer testServer.Close() - - // Create a new client with the test server's URL - client, err := NewClient("fake-api-key") - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - client.baseURL = testServer.URL // Override baseURL to point to the test server - // Prepare a message request - request := &MessageRequest{ - Model: ClaudeV2_1, - Messages: []MessagePartRequest{{Role: "user", Content: []ContentBlock{NewTextContentBlock("Hello")}}}, - } - - // Call the Message method expecting an error - _, err = client.Message(request) - if err == nil { - t.Fatal("Expected an error, got none") - } -} - -func TestMessageIncompatibleModel(t *testing.T) { - // Create client - client, err := NewClient("fake-api-key") - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - - // Prepare a message request with streaming set to true - request := &MessageRequest{ - Model: ClaudeV2, - Messages: []MessagePartRequest{{Role: "user", Content: []ContentBlock{NewTextContentBlock("Hello")}}}, - } - - // Call the MessageStream method expecting an error - _, err = client.Message(request) - - if err == nil { - t.Fatal("Expected an error for streaming not supported, got none") - } - - expErr := fmt.Sprintf("model %s is not compatible with the message endpoint", request.Model) - - if err.Error() != expErr { - t.Fatalf( - "Expected error %s, got %s", - expErr, - err.Error(), - ) - } -} + "github.com/madebywelch/anthropic-go/v3/pkg/anthropic" +) func TestMessageStreamNoStreamFlag(t *testing.T) { // Create client - client, err := NewClient("fake-api-key") + client, err := MakeClient(Config{ + APIKey: "fake-api-key", + }) if err != nil { t.Fatalf("Unexpected error: %v", err) } // Prepare a message request without streaming - request := &MessageRequest{ - Model: ClaudeV2, - Messages: []MessagePartRequest{{Role: "user", Content: []ContentBlock{NewTextContentBlock("Hello")}}}, + request := &anthropic.MessageRequest{ + Model: anthropic.Claude3Opus, + Messages: []anthropic.MessagePartRequest{{ + Role: "user", + Content: []anthropic.ContentBlock{anthropic.NewTextContentBlock("Hello")}, + }}, } // Call the MessageStream method expecting an error - _, errCh := client.MessageStream(request) + _, errCh := client.MessageStream(context.Background(), request) err = <-errCh if err == nil { t.Fatal("Expected an error for streaming without a stream request") } - expErr := "cannot use MessageStream with a non-streaming request, use Message instead" + expErr := "cannot use MessageStream with streaming disabled, use Message instead" if err.Error() != expErr { t.Fatalf( @@ -148,20 +50,25 @@ func TestMessageStreamNoStreamFlag(t *testing.T) { func TestMessageStreamIncompatibleModel(t *testing.T) { // Create client - client, err := NewClient("fake-api-key") + client, err := MakeClient(Config{ + APIKey: "fake-api-key", + }) if err != nil { t.Fatalf("Unexpected error: %v", err) } - // Prepare a message request with streaming set to true - request := &MessageRequest{ - Model: ClaudeV2, - Messages: []MessagePartRequest{{Role: "user", Content: []ContentBlock{NewTextContentBlock("Hello")}}}, - Stream: true, + // Prepare a message request without streaming + request := &anthropic.MessageRequest{ + Model: anthropic.ClaudeV2, + Messages: []anthropic.MessagePartRequest{{ + Role: "user", + Content: []anthropic.ContentBlock{anthropic.NewTextContentBlock("Hello")}, + }}, + Stream: true, } // Call the MessageStream method expecting an error - _, errCh := client.MessageStream(request) + _, errCh := client.MessageStream(context.Background(), request) err = <-errCh if err == nil { @@ -208,22 +115,26 @@ func TestMessageStreamSuccess(t *testing.T) { defer testServer.Close() // Create a new client with the test server's URL - client, err := NewClient("fake-api-key") + client, err := MakeClient(Config{ + APIKey: "fake-api-key", + BaseURL: testServer.URL, + }) if err != nil { t.Fatalf("Unexpected error: %v", err) } - client.baseURL = testServer.URL // Override baseURL to point to the test server // Prepare a message request - request := &MessageRequest{ - Model: Claude3Opus, - Messages: []MessagePartRequest{{Role: "user", Content: []ContentBlock{NewTextContentBlock("Hello")}}}, - Stream: true, + request := &anthropic.MessageRequest{ + Model: anthropic.Claude3Opus, + Messages: []anthropic.MessagePartRequest{{ + Role: "user", + Content: []anthropic.ContentBlock{anthropic.NewTextContentBlock("Hello")}, + }}, + Stream: true, } - // Call the Complete method - rCh, errCh := client.MessageStream(request) - chunk := MessageStreamResponse{} + rCh, errCh := client.MessageStream(context.Background(), request) + chunk := &anthropic.MessageStreamResponse{} final := strings.Builder{} inputTokens := 0 outputTokens := 0 @@ -278,22 +189,27 @@ func TestMessageStreamErrorInStream(t *testing.T) { defer testServer.Close() // Create a new client with the test server's URL - client, err := NewClient("fake-api-key") + client, err := MakeClient(Config{ + APIKey: "fake-api-key", + BaseURL: testServer.URL, + }) if err != nil { t.Fatalf("Unexpected error: %v", err) } - client.baseURL = testServer.URL // Override baseURL to point to the test server // Prepare a message request - request := &MessageRequest{ - Model: Claude3Opus, - Messages: []MessagePartRequest{{Role: "user", Content: []ContentBlock{NewTextContentBlock("Hello")}}}, - Stream: true, + request := &anthropic.MessageRequest{ + Model: anthropic.Claude3Opus, + Messages: []anthropic.MessagePartRequest{{ + Role: "user", + Content: []anthropic.ContentBlock{anthropic.NewTextContentBlock("Hello")}, + }}, + Stream: true, } // Call the Complete method - rCh, errCh := client.MessageStream(request) - var chunk MessageStreamResponse + rCh, errCh := client.MessageStream(context.Background(), request) + var chunk *anthropic.MessageStreamResponse final := strings.Builder{} done := false diff --git a/pkg/anthropic/client/native/message_test.go b/pkg/anthropic/client/native/message_test.go new file mode 100644 index 0000000..40044f1 --- /dev/null +++ b/pkg/anthropic/client/native/message_test.go @@ -0,0 +1,133 @@ +package native + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/madebywelch/anthropic-go/v3/pkg/anthropic" +) + +func TestMessage(t *testing.T) { + // Mock server for successful message response + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := &anthropic.MessageResponse{ + ID: "12345", + Type: "testType", + Model: "testModel", + Role: "user", + Content: []anthropic.MessagePartResponse{{ + Type: "text", + Text: "Test message", + }}, + Usage: anthropic.MessageUsage{ + InputTokens: 10, + OutputTokens: 5, + }, + } + json.NewEncoder(w).Encode(resp) + })) + defer testServer.Close() + + ctx := context.Background() + + // Create a new client with the test server's URL + client, err := MakeClient(Config{ + APIKey: "fake-api-key", + BaseURL: testServer.URL, + }) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + // Prepare a message request + request := &anthropic.MessageRequest{ + Model: anthropic.Claude3Opus, + Messages: []anthropic.MessagePartRequest{{ + Role: "user", + Content: []anthropic.ContentBlock{anthropic.NewTextContentBlock("Hello")}, + }}, + } + + // Call the Message method + response, err := client.Message(ctx, request) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + // Check the response + expectedContent := "Test message" + if len(response.Content) == 0 || response.Content[0].Text != expectedContent { + t.Errorf("Expected message %q, got %q", expectedContent, response.Content[0].Text) + } +} + +func TestMessageErrorHandling(t *testing.T) { + // Mock server for error response + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + })) + defer testServer.Close() + + // Create a new client with the test server's URL + client, err := MakeClient(Config{ + APIKey: "fake-api-key", + BaseURL: testServer.URL, + }) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + client.baseURL = testServer.URL // Override baseURL to point to the test server + + // Prepare a message request + request := &anthropic.MessageRequest{ + Model: anthropic.Claude3Opus, + Messages: []anthropic.MessagePartRequest{{ + Role: "user", + Content: []anthropic.ContentBlock{anthropic.NewTextContentBlock("Hello")}, + }}, + } + + // Call the Message method expecting an error + _, err = client.Message(context.Background(), request) + if err == nil { + t.Fatal("Expected an error, got none") + } +} + +func TestMessageIncompatibleModel(t *testing.T) { + // Create client + client, err := MakeClient(Config{APIKey: "fake-api-key"}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + // Prepare a message request with streaming set to true + request := &anthropic.MessageRequest{ + Model: anthropic.ClaudeV2, + Messages: []anthropic.MessagePartRequest{{ + Role: "user", + Content: []anthropic.ContentBlock{anthropic.NewTextContentBlock("Hello")}, + }}, + } + + // Call the MessageStream method expecting an error + _, err = client.Message(context.Background(), request) + + if err == nil { + t.Fatal("Expected an error for streaming not supported, got none") + } + + expErr := fmt.Sprintf("model %s is not compatible with the message endpoint", request.Model) + + if err.Error() != expErr { + t.Fatalf( + "Expected error %s, got %s", + expErr, + err.Error(), + ) + } +} diff --git a/pkg/anthropic/complete.go b/pkg/anthropic/complete.go deleted file mode 100644 index 696a86b..0000000 --- a/pkg/anthropic/complete.go +++ /dev/null @@ -1,136 +0,0 @@ -package anthropic - -import ( - "bufio" - "bytes" - "encoding/json" - "fmt" - "io" - "net/http" - "strings" -) - -// Complete sends a completion request to the API and returns a single completion. -func (c *Client) Complete(req *CompletionRequest) (*CompletionResponse, error) { - if req.Stream { - return nil, fmt.Errorf("cannot use Complete with a streaming request, use CompleteStream instead") - } - - if !req.Model.IsCompleteCompatible() { - return nil, fmt.Errorf("model %s is not compatible with the completion endpoint", req.Model) - } - - return c.sendCompletionRequest(req) -} - -func (c *Client) CompleteStream(req *CompletionRequest) (<-chan StreamResponse, <-chan error) { - events := make(chan StreamResponse) - - // make this a buffered channel to allow for the error case below to return - errCh := make(chan error, 1) - - if !req.Stream { - errCh <- fmt.Errorf("cannot use CompleteStream with a non-streaming request, use Complete instead") - return events, errCh - } - - if !req.Model.IsCompleteCompatible() { - errCh <- fmt.Errorf("model %s is not compatible with the completion endpoint", req.Model) - return events, errCh - } - - go c.handleStreaming(events, errCh, req) - - return events, errCh -} - -// sendCompletionRequest sends a completion request to the API and returns a single completion. -func (c *Client) sendCompletionRequest(req *CompletionRequest) (*CompletionResponse, error) { - // Marshal the request to JSON - data, err := json.Marshal(req) - if err != nil { - return nil, fmt.Errorf("error marshalling completion request: %w", err) - } - - // Create the HTTP request - requestURL := fmt.Sprintf("%s/v1/complete", c.baseURL) - request, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(data)) - if err != nil { - return nil, fmt.Errorf("error creating new request: %w", err) - } - request.Header.Set("Content-Type", "application/json") - request.Header.Set("X-Api-Key", c.apiKey) - - // Use the DoRequest method to send the HTTP request - response, err := c.doRequest(request) - if err != nil { - return nil, fmt.Errorf("error sending completion request: %w", err) - } - defer response.Body.Close() - - // Decode the response body to a CompletionResponse object - var completionResponse CompletionResponse - err = json.NewDecoder(response.Body).Decode(&completionResponse) - if err != nil { - return nil, fmt.Errorf("error decoding completion response: %w", err) - } - - return &completionResponse, nil -} - -func (c *Client) handleStreaming(events chan StreamResponse, errCh chan error, req *CompletionRequest) { - defer close(events) - - data, err := json.Marshal(req) - if err != nil { - errCh <- fmt.Errorf("error marshalling completion request: %w", err) - return - } - - request, err := http.NewRequest("POST", fmt.Sprintf("%s/v1/complete", c.baseURL), bytes.NewBuffer(data)) - if err != nil { - errCh <- fmt.Errorf("error creating new request: %w", err) - return - } - - request.Header.Set("Content-Type", "application/json") - request.Header.Set("X-Api-Key", c.apiKey) - request.Header.Set("Accept", "text/event-stream") - - response, err := c.doRequest(request) - if err != nil { - errCh <- fmt.Errorf("error sending completion request: %w", err) - return - } - defer response.Body.Close() - - err = c.processSseStream(response.Body, events) - if err != nil { - errCh <- err - } -} - -func (c *Client) processSseStream(reader io.Reader, events chan StreamResponse) error { - scanner := bufio.NewScanner(reader) - - for scanner.Scan() { - line := scanner.Text() - - if strings.HasPrefix(line, "data:") { - data := strings.TrimSpace(line[5:]) - var event StreamResponse - err := json.Unmarshal([]byte(data), &event) - if err != nil { - return fmt.Errorf("error decoding event data: %w", err) - } - - events <- event - } - } - - if err := scanner.Err(); err != nil { - return fmt.Errorf("error reading from stream: %w", err) - } - - return nil -} diff --git a/pkg/anthropic/complete_test.go b/pkg/anthropic/complete_test.go deleted file mode 100644 index eb89caa..0000000 --- a/pkg/anthropic/complete_test.go +++ /dev/null @@ -1,164 +0,0 @@ -package anthropic - -import ( - "fmt" - "net/http" - "net/http/httptest" - "testing" -) - -func TestComplete(t *testing.T) { - // Create a test server to mock the Anthropics API - testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte(`{"completion": "Test completion"}`)) // Mock response - })) - defer testServer.Close() - - // Create a new client with the test server's URL - client, err := NewClient("fake-api-key") - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - client.baseURL = testServer.URL // Override baseURL to point to the test server - - // Prepare a completion request - request := NewCompletionRequest("Why is the sky blue?") - - // Call the Complete method - response, err := client.Complete(request) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - - // Check the response - expectedCompletion := "Test completion" - if response.Completion != expectedCompletion { - t.Errorf("Expected completion %q, got %q", expectedCompletion, response.Completion) - } -} - -func TestCompleteWithParameters(t *testing.T) { - // Prepare a completion request - request := NewCompletionRequest("Why is the sky blue?", - WithModel[CompletionRequest](ClaudeInstantV1_1_100k), - WithTemperature[CompletionRequest](0.5), - WithMaxTokens[CompletionRequest](10), - WithTopK[CompletionRequest](5), - WithTopP[CompletionRequest](0.9), - WithStopSequences[CompletionRequest]([]string{"\n", "Why is the sky blue?"}), - ) - - if request.Prompt != "Why is the sky blue?" { - t.Errorf("Expected prompt %q, got %q", "Why is the sky blue?", request.Prompt) - } - - if request.Model != ClaudeInstantV1_1_100k { - t.Errorf("Expected model %q, got %q", ClaudeInstantV1_1_100k, request.Model) - } - - if request.Temperature != 0.5 { - t.Errorf("Expected temperature %f, got %f", 0.5, request.Temperature) - } - - if request.MaxTokensToSample != 10 { - t.Errorf("Expected max tokens %d, got %d", 10, request.MaxTokensToSample) - } - - if request.TopK != 5 { - t.Errorf("Expected top k %d, got %d", 5, request.TopK) - } - - if request.TopP != 0.9 { - t.Errorf("Expected top p %f, got %f", 0.9, request.TopP) - } - - if len(request.StopSequences) != 2 { - t.Errorf("Expected stop sequences length %d, got %d", 2, len(request.StopSequences)) - } - - if request.StopSequences[0] != "\n" { - t.Errorf("Expected stop sequence %q, got %q", "\n", request.StopSequences[0]) - } - - if request.StopSequences[1] != "Why is the sky blue?" { - t.Errorf("Expected stop sequence %q, got %q", "Why is the sky blue?", request.StopSequences[1]) - } -} - -func TestCompleteIncompatibleModel(t *testing.T) { - // Create client - client, err := NewClient("fake-api-key") - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - - // Prepare a completion request - request := NewCompletionRequest("Why is the sky blue?", - WithModel[CompletionRequest](Claude3Opus), - ) - - // Call the Complete method expecting an error - _, err = client.Complete(request) - if err == nil { - t.Fatal("Expected an incompatibility error, got none") - } - - // Check the error message - expErr := fmt.Sprintf("model %s is not compatible with the completion endpoint", request.Model) - if err.Error() != expErr { - t.Fatalf("Expected error %s, got %s", expErr, err.Error()) - } -} - -func TestCompleteStreamNoStreamFlag(t *testing.T) { - // Create client - client, err := NewClient("fake-api-key") - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - - // Prepare a completion request - request := NewCompletionRequest("Why is the sky blue?", - WithModel[CompletionRequest](Claude3Opus), - ) - - // Call the Complete method expecting an error - _, errCh := client.CompleteStream(request) - err = <-errCh - if err == nil { - t.Fatal("Expected a missing stream flag error, got none") - } - - // Check the error message - expErr := "cannot use CompleteStream with a non-streaming request, use Complete instead" - if err.Error() != expErr { - t.Fatalf("Expected error %s, got %s", expErr, err.Error()) - } -} - -func TestCompleteStreamIncompatibleModel(t *testing.T) { - // Create client - client, err := NewClient("fake-api-key") - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - - // Prepare a completion request - request := NewCompletionRequest("Why is the sky blue?", - WithModel[CompletionRequest](Claude3Opus), - WithStream[CompletionRequest](true), - ) - - // Call the Complete method expecting an error - _, errCh := client.CompleteStream(request) - err = <-errCh - if err == nil { - t.Fatal("Expected an incompatibility error, got none") - } - - // Check the error message - expErr := fmt.Sprintf("model %s is not compatible with the completion endpoint", request.Model) - if err.Error() != expErr { - t.Fatalf("Expected error %s, got %s", expErr, err.Error()) - } -} diff --git a/pkg/anthropic/errors.go b/pkg/anthropic/errors.go index 2d81534..bac80d2 100644 --- a/pkg/anthropic/errors.go +++ b/pkg/anthropic/errors.go @@ -16,7 +16,7 @@ var ( ) // mapHTTPStatusCodeToError maps an HTTP status code to an error. -func mapHTTPStatusCodeToError(code int) error { +func MapHTTPStatusCodeToError(code int) error { switch code { case http.StatusBadRequest: return ErrAnthropicInvalidRequest diff --git a/pkg/anthropic/errors_test.go b/pkg/anthropic/errors_test.go new file mode 100644 index 0000000..352b3b5 --- /dev/null +++ b/pkg/anthropic/errors_test.go @@ -0,0 +1,28 @@ +package anthropic + +import ( + "errors" + "net/http" + "testing" +) + +func TestMapHTTPStatusCodeToError(t *testing.T) { + tests := []struct { + code int + expected error + }{ + {http.StatusBadRequest, ErrAnthropicInvalidRequest}, + {http.StatusUnauthorized, ErrAnthropicUnauthorized}, + {http.StatusForbidden, ErrAnthropicForbidden}, + {http.StatusTooManyRequests, ErrAnthropicRateLimit}, + {http.StatusInternalServerError, ErrAnthropicInternalServer}, + {http.StatusNotFound, errors.New("unknown error occurred")}, + } + + for _, test := range tests { + err := MapHTTPStatusCodeToError(test.code) + if err.Error() != test.expected.Error() { + t.Errorf("Expected error '%s', got '%s'", test.expected.Error(), err.Error()) + } + } +} diff --git a/pkg/anthropic/events.go b/pkg/anthropic/events.go index 5799063..217139b 100644 --- a/pkg/anthropic/events.go +++ b/pkg/anthropic/events.go @@ -15,6 +15,7 @@ const ( MessageEventTypeContentBlockStop MessageEventType = "content_block_stop" MessageEventTypeMessageDelta MessageEventType = "message_delta" MessageEventTypeMessageStop MessageEventType = "message_stop" + MessageEventTypeError MessageEventType = "error" // Constants for completion event types CompletionEventTypeCompletion CompletionEventType = "completion" diff --git a/pkg/anthropic/message.go b/pkg/anthropic/message.go deleted file mode 100644 index a2f3d13..0000000 --- a/pkg/anthropic/message.go +++ /dev/null @@ -1,161 +0,0 @@ -package anthropic - -import ( - "bufio" - "bytes" - "encoding/json" - "fmt" - "io" - "net/http" - "strings" -) - -func (c *Client) Message(req *MessageRequest) (*MessageResponse, error) { - if req.Stream { - return nil, fmt.Errorf("cannot use Message with streaming enabled, use MessageStream instead") - } - - if !req.Model.IsMessageCompatible() { - return nil, fmt.Errorf("model %s is not compatible with the message endpoint", req.Model) - } - - if !req.Model.IsImageCompatible() && req.ContainsImageContent() { - return nil, fmt.Errorf("model %s does not support image content", req.Model) - } - - if req.CountImageContent() > 20 { - return nil, fmt.Errorf("too many image content blocks, maximum is 20") - } - - return c.sendMessageRequest(req) -} - -func (c *Client) MessageStream(req *MessageRequest) (<-chan MessageStreamResponse, <-chan error) { - events := make(chan MessageStreamResponse) - - // make this a buffered channel to allow for the error case below to return - errCh := make(chan error, 1) - - if !req.Stream { - errCh <- fmt.Errorf("cannot use MessageStream with a non-streaming request, use Message instead") - return events, errCh - } - - if !req.Model.IsMessageCompatible() { - errCh <- fmt.Errorf("model %s is not compatible with the messagestream endpoint", req.Model) - return events, errCh - } - - if req.Stream && len(req.Tools) > 0 { - // https://docs.anthropic.com/claude/docs/tool-use - // Streaming (stream=true) is not yet supported. We plan to add streaming support in a future beta version. - errCh <- fmt.Errorf("cannot use streaming with tools") - return events, errCh - } - - go c.handleMessageStreaming(events, errCh, req) - - return events, errCh -} - -func (c *Client) sendMessageRequest(req *MessageRequest) (*MessageResponse, error) { - // Marshal the request to JSON - data, err := json.Marshal(req) - if err != nil { - return nil, fmt.Errorf("error marshalling completion request: %w", err) - } - - // Create the HTTP request - requestURL := fmt.Sprintf("%s/v1/messages", c.baseURL) - request, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(data)) - if err != nil { - return nil, fmt.Errorf("error creating new request: %w", err) - } - request.Header.Set("Content-Type", "application/json") - request.Header.Set("X-Api-Key", c.apiKey) - request.Header.Set("anthropic-beta", AnthropicAPIToolsBeta) - - // Use the DoRequest method to send the HTTP request - response, err := c.doRequest(request) - if err != nil { - return nil, fmt.Errorf("error sending completion request: %w", err) - } - defer response.Body.Close() - - // Decode the response body to a MessageResponse object - var messageResponse MessageResponse - err = json.NewDecoder(response.Body).Decode(&messageResponse) - if err != nil { - return nil, fmt.Errorf("error decoding message response: %w", err) - } - - return &messageResponse, nil -} - -func (c *Client) handleMessageStreaming(events chan MessageStreamResponse, errCh chan error, req *MessageRequest) { - defer close(events) - - data, err := json.Marshal(req) - if err != nil { - errCh <- fmt.Errorf("error marshalling message request: %w", err) - return - } - - request, err := http.NewRequest("POST", fmt.Sprintf("%s/v1/messages", c.baseURL), bytes.NewBuffer(data)) - if err != nil { - errCh <- fmt.Errorf("error creating new request: %w", err) - return - } - - request.Header.Set("Content-Type", "application/json") - request.Header.Set("X-Api-Key", c.apiKey) - request.Header.Set("Accept", "text/event-stream") - - response, err := c.doRequest(request) - if err != nil { - errCh <- fmt.Errorf("error sending message request: %w", err) - return - } - defer response.Body.Close() - - err = c.processMessageSseStream(response.Body, events) - if err != nil { - errCh <- err - } -} - -func (c *Client) processMessageSseStream(reader io.Reader, events chan MessageStreamResponse) error { - scanner := bufio.NewScanner(reader) - - for scanner.Scan() { - line := scanner.Text() - - if strings.HasPrefix(line, "data:") { - data := strings.TrimSpace(line[5:]) - - event := &MessageEvent{} - err := json.Unmarshal([]byte(data), event) - if err != nil { - return fmt.Errorf("error decoding event data: %w", err) - } - - msg, err := parseMessageEvent(event.Type, data) - - if err != nil { - if _, ok := err.(UnsupportedEventType); ok { - // ignore unsupported event types - } else { - return fmt.Errorf("error processing message stream: %v", err) - } - } - - events <- msg - } - } - - if err := scanner.Err(); err != nil { - return fmt.Errorf("error reading from stream: %w", err) - } - - return nil -} diff --git a/pkg/anthropic/message_event.go b/pkg/anthropic/message_event.go index b61a674..0f1778b 100644 --- a/pkg/anthropic/message_event.go +++ b/pkg/anthropic/message_event.go @@ -85,40 +85,40 @@ func (e UnsupportedEventType) Error() string { return e.Msg } -func parseMessageEvent(eventType, event string) (MessageStreamResponse, error) { - messageStreamResponse := MessageStreamResponse{} +func ParseMessageEvent(eventType MessageEventType, event string) (*MessageStreamResponse, error) { + messageStreamResponse := &MessageStreamResponse{} var err error switch eventType { - case "message_start": + case MessageEventTypeMessageStart: messageStartEvent := &MessageStartEvent{} err = json.Unmarshal([]byte(event), &messageStartEvent) messageStreamResponse.Type = messageStartEvent.Type messageStreamResponse.Usage = messageStartEvent.Message.Usage - case "content_block_start": + case MessageEventTypeContentBlockStart: contentBlockEvent := &ContentBlockStartEvent{} err = json.Unmarshal([]byte(event), &contentBlockEvent) messageStreamResponse.Type = contentBlockEvent.Type - case "ping": + case MessageEventTypePing: pingEvent := &PingEvent{} err = json.Unmarshal([]byte(event), &pingEvent) messageStreamResponse.Type = pingEvent.Type - case "content_block_delta": + case MessageEventTypeContentBlockDelta: contentBlockEvent := &ContentBlockDeltaEvent{} err = json.Unmarshal([]byte(event), &contentBlockEvent) messageStreamResponse.Type = contentBlockEvent.Type messageStreamResponse.Delta.Type = contentBlockEvent.Delta.Type messageStreamResponse.Delta.Text = contentBlockEvent.Delta.Text - case "content_block_stop": + case MessageEventTypeContentBlockStop: contentBlockStopEvent := &ContentBlockStopEvent{} err = json.Unmarshal([]byte(event), &contentBlockStopEvent) messageStreamResponse.Type = contentBlockStopEvent.Type - case "message_delta": + case MessageEventTypeMessageDelta: messageDeltaEvent := &MessageDeltaEvent{} err = json.Unmarshal([]byte(event), &messageDeltaEvent) @@ -126,12 +126,12 @@ func parseMessageEvent(eventType, event string) (MessageStreamResponse, error) { messageStreamResponse.Delta.StopReason = messageDeltaEvent.Delta.StopReason messageStreamResponse.Delta.StopSequence = messageDeltaEvent.Delta.StopSequence messageStreamResponse.Usage.OutputTokens = messageDeltaEvent.Usage.OutputTokens - case "message_stop": + case MessageEventTypeMessageStop: messageStopEvent := &MessageStopEvent{} err = json.Unmarshal([]byte(event), &messageStopEvent) messageStreamResponse.Type = messageStopEvent.Type - case "error": + case MessageEventTypeError: messageErrorEvent := &MessageErrorEvent{} err = json.Unmarshal([]byte(event), &messageErrorEvent) if err != nil { diff --git a/pkg/anthropic/message_event_test.go b/pkg/anthropic/message_event_test.go new file mode 100644 index 0000000..03dc413 --- /dev/null +++ b/pkg/anthropic/message_event_test.go @@ -0,0 +1,163 @@ +package anthropic + +import ( + "reflect" + "testing" +) + +func TestParseMessageEvent(t *testing.T) { + events := []struct { + eventType MessageEventType + event string + expected *MessageStreamResponse + expErrStr string + }{ + { + eventType: MessageEventTypeMessageStart, + event: `{ + "type": "message_start", + "message": { + "id": "123", + "type": "text", + "role": "user", + "content": ["Hello, world!"], + "model": "claude-v2_1", + "stop_reason": "", + "stop_sequence": "", + "usage": { + "input_tokens": 10, + "output_tokens": 20 + } + } + }`, + expected: &MessageStreamResponse{ + Type: "message_start", + Usage: MessageStreamUsage{ + InputTokens: 10, + OutputTokens: 20, + }, + }, + }, + { + eventType: MessageEventTypeContentBlockStart, + event: `{ + "type": "content_block_start", + "index": 1, + "content_block": { + "type": "text", + "text": "This is a content block" + } + }`, + expected: &MessageStreamResponse{ + Type: "content_block_start", + }, + }, + { + eventType: MessageEventTypePing, + event: `{ + "type": "ping" + }`, + expected: &MessageStreamResponse{ + Type: "ping", + }, + }, + { + eventType: MessageEventTypeContentBlockDelta, + event: `{ + "type": "content_block_delta", + "delta": { + "type": "text", + "text": "This is a content block delta" + } + }`, + expected: &MessageStreamResponse{ + Type: "content_block_delta", + Delta: MessageStreamDelta{ + Type: "text", + Text: "This is a content block delta", + StopReason: "", + StopSequence: "", + }, + }, + }, + { + eventType: MessageEventTypeContentBlockStop, + event: `{ + "type": "content_block_stop" + }`, + expected: &MessageStreamResponse{ + Type: "content_block_stop", + }, + }, + { + eventType: MessageEventTypeMessageDelta, + event: `{ + "type": "message_delta", + "delta": { + "stop_reason": "something", + "stop_sequence": "else" + }, + "usage": { + "output_tokens": 20 + } + }`, + expected: &MessageStreamResponse{ + Type: "message_delta", + Delta: MessageStreamDelta{ + StopReason: "something", + StopSequence: "else", + }, + Usage: MessageStreamUsage{ + OutputTokens: 20, + }, + }, + }, + { + eventType: MessageEventTypeMessageStop, + event: `{ + "type": "message_stop" + }`, + expected: &MessageStreamResponse{ + Type: "message_stop", + }, + }, + { + eventType: MessageEventTypeError, + event: `{ + "type": "message_error", + "error": { + "type": "error", + "message": "This is an error" + } + }`, + expected: &MessageStreamResponse{}, + expErrStr: "error type: error, message: This is an error", + }, + } + + for _, test := range events { + response, err := ParseMessageEvent(test.eventType, test.event) + if err != nil && test.expErrStr == "" { + t.Errorf("unexpected error: %v", err) + } + + if err != nil && err.Error() != test.expErrStr { + t.Errorf("unexpected error, got: %v, want: %v", err, test.expErrStr) + } + + if !reflect.DeepEqual(response, test.expected) { + t.Errorf("unexpected response, got: %v, want: %v", response, test.expected) + } + } +} + +func TestUnsupportedEventType(t *testing.T) { + res, err := ParseMessageEvent(MessageEventType("not-a-real-type"), "") + if err == nil { + t.Errorf("expected error, got: %v", res) + } + + if err.Error() != "unknown event type" { + t.Errorf("unexpected error, got: %v", err) + } +} diff --git a/pkg/anthropic/models_test.go b/pkg/anthropic/models_test.go new file mode 100644 index 0000000..09b05ad --- /dev/null +++ b/pkg/anthropic/models_test.go @@ -0,0 +1,145 @@ +package anthropic + +import "testing" + +type modelTest struct { + model Model + imageSupport bool + messageSupport bool + completeSupport bool +} + +func getTestCases() []modelTest { + return []modelTest{ + { + model: Claude3Opus, + imageSupport: true, + messageSupport: true, + completeSupport: false, + }, + { + model: Claude3Sonnet, + imageSupport: true, + messageSupport: true, + completeSupport: false, + }, + { + model: Claude3Haiku, + imageSupport: true, + messageSupport: true, + completeSupport: false, + }, + { + model: ClaudeV2_1, + imageSupport: false, + messageSupport: true, + completeSupport: true, + }, + { + model: ClaudeV2, + imageSupport: false, + messageSupport: false, + completeSupport: true, + }, + { + model: ClaudeV1, + imageSupport: false, + messageSupport: false, + completeSupport: true, + }, + { + model: ClaudeV1_100k, + imageSupport: false, + messageSupport: false, + completeSupport: true, + }, + { + model: ClaudeInstantV1, + imageSupport: false, + messageSupport: false, + completeSupport: true, + }, + { + model: ClaudeInstantV1_100k, + imageSupport: false, + messageSupport: false, + completeSupport: true, + }, + { + model: ClaudeV1_3, + imageSupport: false, + messageSupport: false, + completeSupport: true, + }, + { + model: ClaudeV1_3_100k, + imageSupport: false, + messageSupport: false, + completeSupport: true, + }, + { + model: ClaudeV1_2, + imageSupport: false, + messageSupport: false, + completeSupport: true, + }, + { + model: ClaudeV1_0, + imageSupport: false, + messageSupport: false, + completeSupport: true, + }, + { + model: ClaudeInstantV1_1, + imageSupport: false, + messageSupport: false, + completeSupport: true, + }, + { + model: ClaudeInstantV1_1_100k, + imageSupport: false, + messageSupport: false, + completeSupport: true, + }, + { + model: ClaudeInstantV1_0, + imageSupport: false, + messageSupport: false, + completeSupport: true, + }, + { + model: Model("NOT A REAL MODEL"), + imageSupport: false, + messageSupport: false, + completeSupport: false, + }, + } +} + +func TestIsImageCompatible(t *testing.T) { + testCases := getTestCases() + for _, test := range testCases { + result := test.model.IsImageCompatible() + if result != test.imageSupport { + t.Errorf("IsImageCompatible() for model %s returned %t, expected %t", test.model, result, test.imageSupport) + } + } +} +func TestIsMessageCompatible(t *testing.T) { + testCases := getTestCases() + for _, test := range testCases { + result := test.model.IsMessageCompatible() + if result != test.messageSupport { + t.Errorf("IsMessageCompatible() for model %s returned %t, expected %t", test.model, result, test.messageSupport) + } + } +} +func TestIsCompleteCompatible(t *testing.T) { + testCases := getTestCases() + for _, test := range testCases { + result := test.model.IsCompleteCompatible() + if result != test.completeSupport { + t.Errorf("IsCompleteCompatible() for model %s returned %t, expected %t", test.model, result, test.completeSupport) + } + } +} diff --git a/pkg/anthropic/options.go b/pkg/anthropic/options.go index 99b12fc..5caa1c0 100644 --- a/pkg/anthropic/options.go +++ b/pkg/anthropic/options.go @@ -1,7 +1,5 @@ package anthropic -import "net/http" - type CompletionOption func(*CompletionRequest) type MessageOption func(*MessageRequest) @@ -126,12 +124,3 @@ func WithTopP[T any](topP float64) GenericOption[T] { } } } - -// WithHTTPClient sets a custom HTTP client for the Client. -func WithHTTPClient[T any](httpClient *http.Client) GenericOption[T] { - return func(r *T) { - if v, ok := any(r).(*Client); ok { - v.httpClient = httpClient - } - } -} diff --git a/pkg/anthropic/request.go b/pkg/anthropic/request.go index 70ac2ef..0b4c13b 100644 --- a/pkg/anthropic/request.go +++ b/pkg/anthropic/request.go @@ -3,7 +3,7 @@ package anthropic // CompletionRequest is the request to the Anthropic API for a completion. type CompletionRequest struct { Prompt string `json:"prompt"` - Model Model `json:"model"` + Model Model `json:"model,omitempty"` MaxTokensToSample int `json:"max_tokens_to_sample"` StopSequences []string `json:"stop_sequences,omitempty"` // optional Stream bool `json:"stream,omitempty"` // optional diff --git a/pkg/anthropic/utils/regions_test.go b/pkg/anthropic/utils/regions_test.go new file mode 100644 index 0000000..6a845ff --- /dev/null +++ b/pkg/anthropic/utils/regions_test.go @@ -0,0 +1,21 @@ +package utils + +import ( + "testing" +) + +func TestIsRegionSupported(t *testing.T) { + regionTests := map[string]bool{ + "United States of America": true, + "Kamino": false, + "Germany": true, + "": false, + } + + for region, expected := range regionTests { + actual := IsRegionSupported(region) + if actual != expected { + t.Errorf("IsRegionSupported(%q) = %v; expected %v", region, actual, expected) + } + } +} diff --git a/pkg/anthropic/validate.go b/pkg/anthropic/validate.go new file mode 100644 index 0000000..537ac14 --- /dev/null +++ b/pkg/anthropic/validate.go @@ -0,0 +1,67 @@ +package anthropic + +import "fmt" + +func ValidateMessageRequest(req *MessageRequest) error { + if req.Stream { + return fmt.Errorf("cannot use Message with streaming enabled, use MessageStream instead") + } + + if !req.Model.IsMessageCompatible() { + return fmt.Errorf("model %s is not compatible with the message endpoint", req.Model) + } + + if !req.Model.IsImageCompatible() && req.ContainsImageContent() { + return fmt.Errorf("model %s does not support image content", req.Model) + } + + if req.CountImageContent() > 20 { + return fmt.Errorf("too many image content blocks, maximum is 20") + } + + return nil +} + +func ValidateMessageStreamRequest(req *MessageRequest) error { + if !req.Stream { + return fmt.Errorf("cannot use MessageStream with streaming disabled, use Message instead") + } + + if !req.Model.IsMessageCompatible() { + return fmt.Errorf("model %s is not compatible with the messagestream endpoint", req.Model) + } + + if !req.Model.IsImageCompatible() && req.ContainsImageContent() { + return fmt.Errorf("model %s does not support image content", req.Model) + } + + if req.CountImageContent() > 20 { + return fmt.Errorf("too many image content blocks, maximum is 20") + } + + return nil +} + +func ValidateCompleteRequest(req *CompletionRequest) error { + if req.Stream { + return fmt.Errorf("cannot use Complete with streaming enabled, use CompleteStream instead") + } + + if !req.Model.IsCompleteCompatible() { + return fmt.Errorf("model %s is not compatible with the completion endpoint", req.Model) + } + + return nil +} + +func ValidateCompleteStreamRequest(req *CompletionRequest) error { + if !req.Stream { + return fmt.Errorf("cannot use CompleteStream with streaming disabled, use Complete instead") + } + + if !req.Model.IsCompleteCompatible() { + return fmt.Errorf("model %s is not compatible with the completion endpoint", req.Model) + } + + return nil +} diff --git a/pkg/anthropic/validate_test.go b/pkg/anthropic/validate_test.go new file mode 100644 index 0000000..042b720 --- /dev/null +++ b/pkg/anthropic/validate_test.go @@ -0,0 +1,208 @@ +package anthropic + +import ( + "fmt" + "testing" +) + +type validateMessageTestCase struct { + request *MessageRequest + expErr string +} + +type validateCompleteTestCase struct { + request *CompletionRequest + expErr string +} + +func TestValidateMessageRequest(t *testing.T) { + requests := []validateMessageTestCase{ + { + request: &MessageRequest{ + Stream: true, + }, + expErr: "cannot use Message with streaming enabled, use MessageStream instead", + }, + { + request: &MessageRequest{ + Stream: false, + Model: Model("not-a-valid-model"), + }, + expErr: "model not-a-valid-model is not compatible with the message endpoint", + }, + { + request: &MessageRequest{ + Stream: false, + Model: ClaudeV2_1, + Messages: []MessagePartRequest{{ + Role: "user", + Content: []ContentBlock{ + NewImageContentBlock(MediaTypeJPEG, "a-gosh-dang-hot-dog"), + }, + }}, + }, + expErr: fmt.Sprintf("model %s does not support image content", ClaudeV2_1), + }, + { + request: &MessageRequest{ + Stream: false, + Model: Claude3Opus, + Messages: []MessagePartRequest{{ + Role: "user", + Content: getTwentyOneImgs(), + }}, + }, + expErr: fmt.Sprintf("too many image content blocks, maximum is 20"), + }, + } + + for _, test := range requests { + err := ValidateMessageRequest(test.request) + if err == nil && test.expErr != "" { + t.Errorf("Expected error %s, got nil", test.expErr) + } + + if err == nil { + continue + } + + if err.Error() != test.expErr { + t.Errorf("Expected error %s, got %s", test.expErr, err.Error()) + } + } +} + +func getTwentyOneImgs() []ContentBlock { + blocks := []ContentBlock{} + for i := 0; i < 21; i++ { + blocks = append( + blocks, + NewImageContentBlock(MediaTypeJPEG, fmt.Sprintf("a-gosh-dang-hot-dog-%d", i)), + ) + } + + return blocks +} + +func TestValidateMessageStreamRequest(t *testing.T) { + requests := []validateMessageTestCase{ + { + request: &MessageRequest{ + Stream: false, + }, + expErr: "cannot use MessageStream with streaming disabled, use Message instead", + }, + { + request: &MessageRequest{ + Stream: true, + Model: Model("not-a-valid-model"), + }, + expErr: "model not-a-valid-model is not compatible with the messagestream endpoint", + }, + { + request: &MessageRequest{ + Stream: true, + Model: ClaudeV2_1, + Messages: []MessagePartRequest{{ + Role: "user", + Content: []ContentBlock{ + NewImageContentBlock(MediaTypeJPEG, "a-gosh-dang-hot-dog"), + }, + }}, + }, + expErr: fmt.Sprintf("model %s does not support image content", ClaudeV2_1), + }, + { + request: &MessageRequest{ + Stream: true, + Model: Claude3Opus, + Messages: []MessagePartRequest{{ + Role: "user", + Content: getTwentyOneImgs(), + }}, + }, + expErr: fmt.Sprintf("too many image content blocks, maximum is 20"), + }, + } + + for _, test := range requests { + err := ValidateMessageStreamRequest(test.request) + if err == nil && test.expErr != "" { + t.Errorf("Expected error %s, got nil", test.expErr) + } + + if err == nil { + continue + } + + if err.Error() != test.expErr { + t.Errorf("Expected error %s, got %s", test.expErr, err.Error()) + } + } +} + +func TestValidateCompleteRequest(t *testing.T) { + requests := []validateCompleteTestCase{ + { + request: &CompletionRequest{ + Stream: true, + }, + expErr: "cannot use Complete with streaming enabled, use CompleteStream instead", + }, + { + request: &CompletionRequest{ + Stream: false, + Model: Model("not-a-valid-model"), + }, + expErr: "model not-a-valid-model is not compatible with the completion endpoint", + }, + } + + for _, test := range requests { + err := ValidateCompleteRequest(test.request) + if err == nil && test.expErr != "" { + t.Errorf("Expected error %s, got nil", test.expErr) + } + + if err == nil { + continue + } + + if err.Error() != test.expErr { + t.Errorf("Expected error %s, got %s", test.expErr, err.Error()) + } + } +} + +func TestValidateCompleteStreamRequest(t *testing.T) { + requests := []validateCompleteTestCase{ + { + request: &CompletionRequest{ + Stream: false, + }, + expErr: "cannot use CompleteStream with streaming disabled, use Complete instead", + }, + { + request: &CompletionRequest{ + Stream: true, + Model: Model("not-a-valid-model"), + }, + expErr: "model not-a-valid-model is not compatible with the completion endpoint", + }, + } + + for _, test := range requests { + err := ValidateCompleteStreamRequest(test.request) + if err == nil && test.expErr != "" { + t.Errorf("Expected error %s, got nil", test.expErr) + } + + if err == nil { + continue + } + + if err.Error() != test.expErr { + t.Errorf("Expected error %s, got %s", test.expErr, err.Error()) + } + } +} diff --git a/pkg/internal/examples/completion/regular/example.go b/pkg/internal/examples/completion/regular/example.go index 166e540..b59e906 100644 --- a/pkg/internal/examples/completion/regular/example.go +++ b/pkg/internal/examples/completion/regular/example.go @@ -1,34 +1,32 @@ package main import ( + "context" "fmt" - "github.com/madebywelch/anthropic-go/v2/pkg/anthropic" - "github.com/madebywelch/anthropic-go/v2/pkg/anthropic/utils" + "github.com/madebywelch/anthropic-go/v3/pkg/anthropic" + "github.com/madebywelch/anthropic-go/v3/pkg/anthropic/client/native" ) func main() { - client, err := anthropic.NewClient("your-api-key") + ctx := context.Background() + client, err := native.MakeClient(native.Config{ + APIKey: "your-api-key", + }) if err != nil { panic(err) } - prompt, err := utils.GetPrompt("Why is the sky blue?") - if err != nil { - panic(err) - } - - request := anthropic.NewCompletionRequest( - prompt, - anthropic.WithModel[anthropic.CompletionRequest](anthropic.ClaudeV2_1), - anthropic.WithMaxTokens[anthropic.CompletionRequest](100), + request := anthropic.NewMessageRequest( + []anthropic.MessagePartRequest{{Role: "user", Content: []anthropic.ContentBlock{anthropic.NewTextContentBlock("Hello, world!")}}}, + anthropic.WithModel[anthropic.MessageRequest](anthropic.Claude35Sonnet), + anthropic.WithMaxTokens[anthropic.MessageRequest](20), ) - // Note: Only use client.Complete when streaming is disabled, otherwise use client.CompleteStream! - response, err := client.Complete(request) + response, err := client.Message(ctx, request) if err != nil { panic(err) } - fmt.Printf("Completion: %s\n", response.Completion) + fmt.Println(response.Content) } diff --git a/pkg/internal/examples/completion/stream/example.go b/pkg/internal/examples/completion/stream/example.go index 93c9add..28a8f47 100644 --- a/pkg/internal/examples/completion/stream/example.go +++ b/pkg/internal/examples/completion/stream/example.go @@ -1,14 +1,19 @@ package main import ( + "context" "fmt" - "github.com/madebywelch/anthropic-go/v2/pkg/anthropic" - "github.com/madebywelch/anthropic-go/v2/pkg/anthropic/utils" + "github.com/madebywelch/anthropic-go/v3/pkg/anthropic" + "github.com/madebywelch/anthropic-go/v3/pkg/anthropic/client/native" + "github.com/madebywelch/anthropic-go/v3/pkg/anthropic/utils" ) func main() { - client, err := anthropic.NewClient("your-api-key") + ctx := context.Background() + client, err := native.MakeClient(native.Config{ + APIKey: "your-api-key", + }) if err != nil { panic(err) } @@ -26,7 +31,7 @@ func main() { ) // Note: Only use client.CompleteStream when streaming is enabled, otherwise use client.Complete! - resps, errs := client.CompleteStream(request) + resps, errs := client.CompleteStream(ctx, request) for { select { diff --git a/pkg/internal/examples/messages/regular/example.go b/pkg/internal/examples/messages/regular/example.go index e4a65a1..cf7774c 100644 --- a/pkg/internal/examples/messages/regular/example.go +++ b/pkg/internal/examples/messages/regular/example.go @@ -1,13 +1,18 @@ package main import ( + "context" "fmt" - "github.com/madebywelch/anthropic-go/v2/pkg/anthropic" + "github.com/madebywelch/anthropic-go/v3/pkg/anthropic" + "github.com/madebywelch/anthropic-go/v3/pkg/anthropic/client/native" ) func main() { - client, err := anthropic.NewClient("your-api-key") + ctx := context.Background() + client, err := native.MakeClient(native.Config{ + APIKey: "your-api-key", + }) if err != nil { panic(err) } @@ -20,7 +25,7 @@ func main() { ) // Call the Message method - response, err := client.Message(request) + response, err := client.Message(ctx, request) if err != nil { panic(err) } diff --git a/pkg/internal/examples/messages/stream/example.go b/pkg/internal/examples/messages/stream/example.go index db62c06..2caf08f 100644 --- a/pkg/internal/examples/messages/stream/example.go +++ b/pkg/internal/examples/messages/stream/example.go @@ -1,15 +1,20 @@ package main import ( + "context" "fmt" "strings" "time" - "github.com/madebywelch/anthropic-go/v2/pkg/anthropic" + "github.com/madebywelch/anthropic-go/v3/pkg/anthropic" + "github.com/madebywelch/anthropic-go/v3/pkg/anthropic/client/native" ) func main() { - client, err := anthropic.NewClient("your-api-key") + ctx := context.Background() + client, err := native.MakeClient(native.Config{ + APIKey: "your-api-key", + }) if err != nil { panic(err) } @@ -28,10 +33,10 @@ func main() { anthropic.WithStream[anthropic.MessageRequest](true), ) - rCh, errCh := client.MessageStream(request) + rCh, errCh := client.MessageStream(ctx, request) final := strings.Builder{} - chunk := anthropic.MessageStreamResponse{} + chunk := &anthropic.MessageStreamResponse{} done := false for { diff --git a/pkg/internal/examples/messages/tools/example.go b/pkg/internal/examples/messages/tools/example.go index 409351b..43426c4 100644 --- a/pkg/internal/examples/messages/tools/example.go +++ b/pkg/internal/examples/messages/tools/example.go @@ -1,11 +1,17 @@ package main import ( - "github.com/madebywelch/anthropic-go/v2/pkg/anthropic" + "context" + + "github.com/madebywelch/anthropic-go/v3/pkg/anthropic" + "github.com/madebywelch/anthropic-go/v3/pkg/anthropic/client/native" ) func main() { - client, err := anthropic.NewClient("your-api-key") + ctx := context.Background() + client, err := native.MakeClient(native.Config{ + APIKey: "your-api-key", + }) if err != nil { panic(err) } @@ -38,7 +44,7 @@ func main() { } // Call the Message method - response, err := client.Message(request) + response, err := client.Message(ctx, request) if err != nil { panic(err) } diff --git a/pkg/internal/integration_tests/complete_integration_test.go b/pkg/internal/integration_tests/complete_integration_test.go index 3db8e42..7254d47 100644 --- a/pkg/internal/integration_tests/complete_integration_test.go +++ b/pkg/internal/integration_tests/complete_integration_test.go @@ -1,12 +1,14 @@ package integration_tests import ( + "context" "os" "strings" "testing" - "github.com/madebywelch/anthropic-go/v2/pkg/anthropic" - "github.com/madebywelch/anthropic-go/v2/pkg/anthropic/utils" + "github.com/madebywelch/anthropic-go/v3/pkg/anthropic" + "github.com/madebywelch/anthropic-go/v3/pkg/anthropic/client/native" + "github.com/madebywelch/anthropic-go/v3/pkg/anthropic/utils" ) func TestCompleteIntegration(t *testing.T) { @@ -16,8 +18,12 @@ func TestCompleteIntegration(t *testing.T) { t.Skip("ANTHROPIC_API_KEY environment variable is not set, skipping integration test") } + ctx := context.Background() + // Create a new client - client, err := anthropic.NewClient(apiKey) + anthropicClient, err := native.MakeClient(native.Config{ + APIKey: apiKey, + }) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -32,7 +38,7 @@ func TestCompleteIntegration(t *testing.T) { request := anthropic.NewCompletionRequest(prompt) // Call the Complete method - response, err := client.Complete(request) + response, err := anthropicClient.Complete(ctx, request) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -52,8 +58,12 @@ func TestCompleteStreamIntegration(t *testing.T) { t.Skip("ANTHROPIC_API_KEY environment variable is not set, skipping integration test") } + ctx := context.Background() + // Create a new client - client, err := anthropic.NewClient(apiKey) + anthropicClient, err := native.MakeClient(native.Config{ + APIKey: apiKey, + }) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -65,19 +75,19 @@ func TestCompleteStreamIntegration(t *testing.T) { } // Prepare a completion request - request := anthropic.NewCompletionRequest(prompt, - anthropic.WithStreaming[anthropic.CompletionRequest](true), - anthropic.WithMaxTokens[anthropic.CompletionRequest](10), + request := anthropic.NewCompletionRequest(prompt, + anthropic.WithStreaming[anthropic.CompletionRequest](true), + anthropic.WithMaxTokens[anthropic.CompletionRequest](10), anthropic.WithModel[anthropic.CompletionRequest](anthropic.ClaudeV2_1)) // Call the Complete method (should return an error since streaming is enabled) - _, err = client.Complete(request) + _, err = anthropicClient.Complete(ctx, request) if err == nil { t.Fatalf("Expected error: %v", err) } // Call the CompleteStream method - res, errs := client.CompleteStream(request) + res, errs := anthropicClient.CompleteStream(ctx, request) MAX_ITERATIONS := 10 builder := strings.Builder{} diff --git a/pkg/internal/integration_tests/message_integration_test.go b/pkg/internal/integration_tests/message_integration_test.go index eb74f40..4760510 100644 --- a/pkg/internal/integration_tests/message_integration_test.go +++ b/pkg/internal/integration_tests/message_integration_test.go @@ -1,10 +1,12 @@ package integration_tests import ( + "context" "os" "testing" - "github.com/madebywelch/anthropic-go/v2/pkg/anthropic" + "github.com/madebywelch/anthropic-go/v3/pkg/anthropic" + "github.com/madebywelch/anthropic-go/v3/pkg/anthropic/client/native" ) func TestMessageWithToolsIntegration(t *testing.T) { @@ -13,7 +15,11 @@ func TestMessageWithToolsIntegration(t *testing.T) { t.Skip("ANTHROPIC_API_KEY environment variable is not set, skipping integration test") } - client, err := anthropic.NewClient(apiKey) + ctx := context.Background() + + anthropicClient, err := native.MakeClient(native.Config{ + APIKey: apiKey, + }) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -44,7 +50,7 @@ func TestMessageWithToolsIntegration(t *testing.T) { }, } - response, err := client.Message(request) + response, err := anthropicClient.Message(ctx, request) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -64,7 +70,9 @@ func TestMessageWithForcedToolIntegration(t *testing.T) { t.Skip("ANTHROPIC_API_KEY environment variable is not set, skipping integration test") } - client, err := anthropic.NewClient(apiKey) + anthropicClient, err := native.MakeClient(native.Config{ + APIKey: apiKey, + }) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -99,7 +107,7 @@ func TestMessageWithForcedToolIntegration(t *testing.T) { }, } - response, err := client.Message(request) + response, err := anthropicClient.Message(context.Background(), request) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -119,7 +127,9 @@ func TestMessageWithImageIntegration(t *testing.T) { t.Skip("ANTHROPIC_API_KEY environment variable is not set, skipping integration test") } - client, err := anthropic.NewClient(apiKey) + anthropicClient, err := native.MakeClient(native.Config{ + APIKey: apiKey, + }) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -138,7 +148,7 @@ func TestMessageWithImageIntegration(t *testing.T) { }, } - response, err := client.Message(request) + response, err := anthropicClient.Message(context.Background(), request) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -156,7 +166,9 @@ func TestMessageIntegration(t *testing.T) { } // Create a new client - client, err := anthropic.NewClient(apiKey) + anthropicClient, err := native.MakeClient(native.Config{ + APIKey: apiKey, + }) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -169,7 +181,7 @@ func TestMessageIntegration(t *testing.T) { } // Call the Message method - response, err := client.Message(request) + response, err := anthropicClient.Message(context.Background(), request) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -193,7 +205,9 @@ func TestMessageErrorHandlingIntegration(t *testing.T) { } // Create a new client - client, err := anthropic.NewClient(apiKey) + anthropicClient, err := native.MakeClient(native.Config{ + APIKey: apiKey, + }) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -205,12 +219,9 @@ func TestMessageErrorHandlingIntegration(t *testing.T) { } // Call the Message method expecting an error - _, err = client.Message(request) + _, err = anthropicClient.Message(context.Background(), request) // We're expecting an error here because we didn't set the required field MaxTokensToSample if err == nil { t.Fatal("Expected an error, got none") } } - -// - TODO: TestMessageWithParametersIntegration: to test sending a message with various parameters -// - TODO: TestMessageStreamIntegration: to ensure the function correctly handles streaming requests