diff --git a/ldai/tracker.go b/ldai/tracker.go index 3dbb9686..e4b325ca 100644 --- a/ldai/tracker.go +++ b/ldai/tracker.go @@ -200,16 +200,17 @@ func (t *Tracker) TrackUsage(usage TokenUsage) error { return nil } -func measureDurationOfTask[T any](stopwatch Stopwatch, task func() (T, error)) (T, time.Duration, error) { +func measureDurationOfTask[T any, A any](stopwatch Stopwatch, arg A, task func(A) (T, error)) (T, time.Duration, error) { stopwatch.Start() - result, err := task() + result, err := task(arg) return result, stopwatch.Stop(), err } // TrackRequest tracks metrics for a model evaluation request. The task function should return a ProviderResponse -// which can be used to specify request metrics and token usage. +// which can be used to specify request metrics and token usage. All fields of the returned ProviderResponse are optional. // -// All fields of the ProviderResponse are optional. +// The task function will be passed the current AI config, which can be used to obtain any parameters or messages +// relevant to the request. // // If the task returns an error, then the request is not considered successful and no metrics are tracked. // Otherwise, the following metrics are tracked: @@ -217,8 +218,8 @@ func measureDurationOfTask[T any](stopwatch Stopwatch, task func() (T, error)) ( // 2. Any metrics that were that set in the ProviderResponse // 2a) If Latency was not set in the ProviderResponse's Metrics field, an automatically measured duration. // 3. Any token usage that was set in the ProviderResponse. -func (t *Tracker) TrackRequest(task func() (ProviderResponse, error)) (ProviderResponse, error) { - usage, duration, err := measureDurationOfTask(t.stopwatch, task) +func (t *Tracker) TrackRequest(task func(c *Config) (ProviderResponse, error)) (ProviderResponse, error) { + usage, duration, err := measureDurationOfTask(t.stopwatch, t.config, task) if err != nil { t.logWarning("error executing request: %v", err) diff --git a/ldai/tracker_test.go b/ldai/tracker_test.go index cfa5ba95..769ae2a4 100644 --- a/ldai/tracker_test.go +++ b/ldai/tracker_test.go @@ -1,6 +1,7 @@ package ldai import ( + "github.com/launchdarkly/go-server-sdk/ldai/datamodel" "testing" "time" @@ -78,7 +79,7 @@ func TestTracker_TrackRequest(t *testing.T) { }, } - r, err := tracker.TrackRequest(func() (ProviderResponse, error) { + r, err := tracker.TrackRequest(func(c *Config) (ProviderResponse, error) { return expectedResponse, nil }) @@ -110,6 +111,29 @@ func TestTracker_TrackRequest(t *testing.T) { assert.ElementsMatch(t, expectedEvents, events.events) } +func TestTracker_TrackRequestReceivesConfig(t *testing.T) { + events := newMockEvents() + + expectedConfig := NewConfig(). + WithMessage("hello", datamodel.Assistant). + WithModelId("model"). + WithProviderId("provider"). + WithModelParam("param", ldvalue.String("value")). + WithCustomModelParam("custom", ldvalue.String("value")). + Enable(). + Build() + + tracker := newTracker("key", "versionKey", events, &expectedConfig, ldcontext.New("key"), nil) + + var gotConfig *Config + _, _ = tracker.TrackRequest(func(c *Config) (ProviderResponse, error) { + gotConfig = c + return ProviderResponse{}, nil + }) + + assert.Equal(t, expectedConfig, *gotConfig) +} + type mockStopwatch time.Duration func (m mockStopwatch) Start() {} @@ -130,7 +154,7 @@ func TestTracker_LatencyMeasuredIfNotProvided(t *testing.T) { }, } - r, err := tracker.TrackRequest(func() (ProviderResponse, error) { + r, err := tracker.TrackRequest(func(c *Config) (ProviderResponse, error) { return expectedResponse, nil })