Skip to content

Commit

Permalink
feat: pass tracker's config into TrackRequest (#210)
Browse files Browse the repository at this point in the history
This PR improves the ergonomics of `TrackRequest`. The user callback now
receives a `Config` directly, which can be accessed to provide
parameters for the request.

An alternative would be directly exposing the `Config` as a public
variable of `Tracker`, or forcing the user to instead reference the
config returned from a call to `cfg, tracker := client.Config(..)`.

Those alternatives are not great, because they make the user carry
around the config and manually inject it into the callback. This way,
the callback can be defined wherever and not need to worry about how to
receive the config.
  • Loading branch information
cwaldren-ld authored Dec 5, 2024
1 parent 215116a commit 8321db6
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 8 deletions.
13 changes: 7 additions & 6 deletions ldai/tracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,25 +200,26 @@ func (t *Tracker) TrackUsage(usage TokenUsage) error {
return nil
}

func measureDurationOfTask[T any](stopwatch Stopwatch, task func() (T, error)) (T, time.Duration, error) {
func measureDurationOfTask[T any, A any](stopwatch Stopwatch, arg A, task func(A) (T, error)) (T, time.Duration, error) {
stopwatch.Start()
result, err := task()
result, err := task(arg)
return result, stopwatch.Stop(), err
}

// TrackRequest tracks metrics for a model evaluation request. The task function should return a ProviderResponse
// which can be used to specify request metrics and token usage.
// which can be used to specify request metrics and token usage. All fields of the returned ProviderResponse are optional.
//
// All fields of the ProviderResponse are optional.
// The task function will be passed the current AI config, which can be used to obtain any parameters or messages
// relevant to the request.
//
// If the task returns an error, then the request is not considered successful and no metrics are tracked.
// Otherwise, the following metrics are tracked:
// 1. Successful model evaluation.
// 2. Any metrics that were that set in the ProviderResponse
// 2a) If Latency was not set in the ProviderResponse's Metrics field, an automatically measured duration.
// 3. Any token usage that was set in the ProviderResponse.
func (t *Tracker) TrackRequest(task func() (ProviderResponse, error)) (ProviderResponse, error) {
usage, duration, err := measureDurationOfTask(t.stopwatch, task)
func (t *Tracker) TrackRequest(task func(c *Config) (ProviderResponse, error)) (ProviderResponse, error) {
usage, duration, err := measureDurationOfTask(t.stopwatch, t.config, task)

if err != nil {
t.logWarning("error executing request: %v", err)
Expand Down
28 changes: 26 additions & 2 deletions ldai/tracker_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ldai

import (
"github.com/launchdarkly/go-server-sdk/ldai/datamodel"
"testing"
"time"

Expand Down Expand Up @@ -78,7 +79,7 @@ func TestTracker_TrackRequest(t *testing.T) {
},
}

r, err := tracker.TrackRequest(func() (ProviderResponse, error) {
r, err := tracker.TrackRequest(func(c *Config) (ProviderResponse, error) {
return expectedResponse, nil
})

Expand Down Expand Up @@ -110,6 +111,29 @@ func TestTracker_TrackRequest(t *testing.T) {
assert.ElementsMatch(t, expectedEvents, events.events)
}

func TestTracker_TrackRequestReceivesConfig(t *testing.T) {
events := newMockEvents()

expectedConfig := NewConfig().
WithMessage("hello", datamodel.Assistant).
WithModelId("model").
WithProviderId("provider").
WithModelParam("param", ldvalue.String("value")).
WithCustomModelParam("custom", ldvalue.String("value")).
Enable().
Build()

tracker := newTracker("key", "versionKey", events, &expectedConfig, ldcontext.New("key"), nil)

var gotConfig *Config
_, _ = tracker.TrackRequest(func(c *Config) (ProviderResponse, error) {
gotConfig = c
return ProviderResponse{}, nil
})

assert.Equal(t, expectedConfig, *gotConfig)
}

type mockStopwatch time.Duration

func (m mockStopwatch) Start() {}
Expand All @@ -130,7 +154,7 @@ func TestTracker_LatencyMeasuredIfNotProvided(t *testing.T) {
},
}

r, err := tracker.TrackRequest(func() (ProviderResponse, error) {
r, err := tracker.TrackRequest(func(c *Config) (ProviderResponse, error) {
return expectedResponse, nil
})

Expand Down

0 comments on commit 8321db6

Please sign in to comment.