Skip to content

Commit

Permalink
Implement compatiblity checks & message endpoint streaming (#15)
Browse files Browse the repository at this point in the history
* Implement compatiblity checks & message endpoint streaming

* Handle unmarshal errors on stream error parsing

* Extend message compatibility check to reg endpoint and expose usage

---------

Co-authored-by: Zachery Stuart <[email protected]>
  • Loading branch information
pigeonlaser and zachery-stuart authored Mar 5, 2024
1 parent 9099c07 commit f991339
Show file tree
Hide file tree
Showing 8 changed files with 665 additions and 9 deletions.
14 changes: 13 additions & 1 deletion pkg/anthropic/complete.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,29 @@ func (c *Client) Complete(req *CompletionRequest) (*CompletionResponse, error) {
return nil, fmt.Errorf("cannot use Complete with a streaming request, use CompleteStream instead")
}

if !req.Model.IsCompleteCompatible() {
return nil, fmt.Errorf("model %s is not compatible with the completion endpoint", req.Model)
}

return c.sendCompletionRequest(req)
}

func (c *Client) CompleteStream(req *CompletionRequest) (<-chan StreamResponse, <-chan error) {
events := make(chan StreamResponse)
errCh := make(chan error)

// make this a buffered channel to allow for the error case below to return
errCh := make(chan error, 1)

if !req.Stream {
errCh <- fmt.Errorf("cannot use CompleteStream with a non-streaming request, use Complete instead")
return events, errCh
}

if !req.Model.IsCompleteCompatible() {
errCh <- fmt.Errorf("model %s is not compatible with the completion endpoint", req.Model)
return events, errCh
}

go c.handleStreaming(events, errCh, req)

return events, errCh
Expand Down Expand Up @@ -104,6 +115,7 @@ func (c *Client) processSseStream(reader io.Reader, events chan StreamResponse)

for scanner.Scan() {
line := scanner.Text()

if strings.HasPrefix(line, "data:") {
data := strings.TrimSpace(line[5:])
var event StreamResponse
Expand Down
79 changes: 79 additions & 0 deletions pkg/anthropic/complete_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package anthropic

import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
Expand Down Expand Up @@ -83,3 +84,81 @@ func TestCompleteWithParameters(t *testing.T) {
t.Errorf("Expected stop sequence %q, got %q", "Why is the sky blue?", request.StopSequences[1])
}
}

func TestCompleteIncompatibleModel(t *testing.T) {
// Create client
client, err := NewClient("fake-api-key")
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

// Prepare a completion request
request := NewCompletionRequest("Why is the sky blue?",
WithModel[CompletionRequest](Claude3Opus),
)

// Call the Complete method expecting an error
_, err = client.Complete(request)
if err == nil {
t.Fatal("Expected an incompatibility error, got none")
}

// Check the error message
expErr := fmt.Sprintf("model %s is not compatible with the completion endpoint", request.Model)
if err.Error() != expErr {
t.Fatalf("Expected error %s, got %s", expErr, err.Error())
}
}

func TestCompleteStreamNoStreamFlag(t *testing.T) {
// Create client
client, err := NewClient("fake-api-key")
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

// Prepare a completion request
request := NewCompletionRequest("Why is the sky blue?",
WithModel[CompletionRequest](Claude3Opus),
)

// Call the Complete method expecting an error
_, errCh := client.CompleteStream(request)
err = <-errCh
if err == nil {
t.Fatal("Expected a missing stream flag error, got none")
}

// Check the error message
expErr := "cannot use CompleteStream with a non-streaming request, use Complete instead"
if err.Error() != expErr {
t.Fatalf("Expected error %s, got %s", expErr, err.Error())
}
}

func TestCompleteStreamIncompatibleModel(t *testing.T) {
// Create client
client, err := NewClient("fake-api-key")
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

// Prepare a completion request
request := NewCompletionRequest("Why is the sky blue?",
WithModel[CompletionRequest](Claude3Opus),
WithStream[CompletionRequest](true),
)

// Call the Complete method expecting an error
_, errCh := client.CompleteStream(request)
err = <-errCh
if err == nil {
t.Fatal("Expected an incompatibility error, got none")
}

// Check the error message
expErr := fmt.Sprintf("model %s is not compatible with the completion endpoint", request.Model)
if err.Error() != expErr {
t.Fatalf("Expected error %s, got %s", expErr, err.Error())
}
}
99 changes: 95 additions & 4 deletions pkg/anthropic/message.go
Original file line number Diff line number Diff line change
@@ -1,23 +1,46 @@
package anthropic

import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
)

func (c *Client) Message(req *MessageRequest) (*MessageResponse, error) {
if req.Stream {
return nil, fmt.Errorf("cannot use Message with streaming enabled, use MessageStream instead (not yet supported)")
return nil, fmt.Errorf("cannot use Message with streaming enabled, use MessageStream instead.")
}

if !req.Model.IsMessageCompatible() {
return nil, fmt.Errorf("model %s is not compatible with the message endpoint", req.Model)
}

return c.sendMessageRequest(req)
}

// MessageStream (NOT YET SUPPORTED) returns a channel of StreamResponse objects and a channel of errors.
func (c *Client) MessageStream(req *MessageRequest) (<-chan StreamResponse, <-chan error) {
return nil, nil
func (c *Client) MessageStream(req *MessageRequest) (<-chan MessageStreamResponse, <-chan error) {
events := make(chan MessageStreamResponse)

// make this a buffered channel to allow for the error case below to return
errCh := make(chan error, 1)

if !req.Stream {
errCh <- fmt.Errorf("cannot use MessageStream with a non-streaming request, use Message instead")
return events, errCh
}

if !req.Model.IsMessageCompatible() {
errCh <- fmt.Errorf("model %s is not compatible with the messagestream endpoint", req.Model)
return events, errCh
}

go c.handleMessageStreaming(events, errCh, req)

return events, errCh
}

func (c *Client) sendMessageRequest(req *MessageRequest) (*MessageResponse, error) {
Expand Down Expand Up @@ -53,3 +76,71 @@ func (c *Client) sendMessageRequest(req *MessageRequest) (*MessageResponse, erro

return &messageResponse, nil
}

func (c *Client) handleMessageStreaming(events chan MessageStreamResponse, errCh chan error, req *MessageRequest) {
defer close(events)

data, err := json.Marshal(req)
if err != nil {
errCh <- fmt.Errorf("error marshalling message request: %w", err)
return
}

request, err := http.NewRequest("POST", fmt.Sprintf("%s/v1/messages", c.baseURL), bytes.NewBuffer(data))
if err != nil {
errCh <- fmt.Errorf("error creating new request: %w", err)
return
}

request.Header.Set("Content-Type", "application/json")
request.Header.Set("X-Api-Key", c.apiKey)
request.Header.Set("Accept", "text/event-stream")

response, err := c.doRequest(request)
if err != nil {
errCh <- fmt.Errorf("error sending message request: %w", err)
return
}
defer response.Body.Close()

err = c.processMessageSseStream(response.Body, events)
if err != nil {
errCh <- err
}
}

func (c *Client) processMessageSseStream(reader io.Reader, events chan MessageStreamResponse) error {
scanner := bufio.NewScanner(reader)

for scanner.Scan() {
line := scanner.Text()

if strings.HasPrefix(line, "data:") {
data := strings.TrimSpace(line[5:])

event := &MessageEvent{}
err := json.Unmarshal([]byte(data), event)
if err != nil {
return fmt.Errorf("error decoding event data: %w", err)
}

msg, err := parseMessageEvent(event.Type, data)

if err != nil {
if _, ok := err.(UnsupportedEventType); ok {
// ignore unsupported event types
} else {
return fmt.Errorf("error processing message stream: %v", err)
}
}

events <- msg
}
}

if err := scanner.Err(); err != nil {
return fmt.Errorf("error reading from stream: %w", err)
}

return nil
}
Loading

0 comments on commit f991339

Please sign in to comment.