From cc44b019fe5ed25b8ed463de7363249d46637dff Mon Sep 17 00:00:00 2001 From: Casey Waldren Date: Wed, 4 Dec 2024 13:15:47 -0800 Subject: [PATCH] more comments, and allow time to be mocked --- ldai/client.go | 8 ++- ldai/client_test.go | 3 +- ldai/tracker.go | 120 ++++++++++++++++++++++++++++------- ldai/tracker_test.go | 145 ++++++++++++++++++++++++++++++++++++++++--- 4 files changed, 241 insertions(+), 35 deletions(-) diff --git a/ldai/client.go b/ldai/client.go index b0f434c9..74b17897 100644 --- a/ldai/client.go +++ b/ldai/client.go @@ -3,6 +3,7 @@ package ldai import ( "encoding/json" "fmt" + "github.com/alexkappa/mustache" "github.com/launchdarkly/go-sdk-common/v3/ldcontext" "github.com/launchdarkly/go-sdk-common/v3/ldvalue" @@ -31,6 +32,7 @@ type ServerSDK interface { } // Client is the main entrypoint for the AI SDK. A client can be used to obtain an AI config from LaunchDarkly. +// Unless otherwise noted, the Client's method are not safe for concurrent use. type Client struct { sdk ServerSDK logger interfaces.LDLoggers @@ -73,13 +75,13 @@ func (c *Client) Config( // empty object.) if result.Type() != ldvalue.ObjectType { c.logConfigWarning(key, "unmarshalling failed, expected JSON object but got %s", result.Type().String()) - return defaultValue, NewTracker(key, c.sdk, &defaultValue, context, c.logger) + return defaultValue, newTracker(key, c.sdk, &defaultValue, context, c.logger) } var parsed datamodel.Config if err := json.Unmarshal([]byte(result.JSONString()), &parsed); err != nil { c.logConfigWarning(key, "unmarshalling failed: %v", err) - return defaultValue, NewTracker(key, c.sdk, &defaultValue, context, c.logger) + return defaultValue, newTracker(key, c.sdk, &defaultValue, context, c.logger) } mergedVariables := map[string]interface{}{ @@ -111,7 +113,7 @@ func (c *Client) Config( } cfg := builder.Build() - return cfg, NewTracker(key, c.sdk, &cfg, context, c.logger) + return cfg, newTracker(key, c.sdk, &cfg, context, c.logger) } func getAllAttributes(context ldcontext.Context) map[string]interface{} { diff --git a/ldai/client_test.go b/ldai/client_test.go index a254c521..99260e6e 100644 --- a/ldai/client_test.go +++ b/ldai/client_test.go @@ -2,6 +2,8 @@ package ldai import ( "errors" + "testing" + "github.com/launchdarkly/go-sdk-common/v3/ldcontext" "github.com/launchdarkly/go-sdk-common/v3/ldlog" "github.com/launchdarkly/go-sdk-common/v3/ldlogtest" @@ -10,7 +12,6 @@ import ( "github.com/launchdarkly/go-server-sdk/v7/interfaces" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "testing" ) type mockServerSDK struct { diff --git a/ldai/tracker.go b/ldai/tracker.go index a34f944a..8cec8152 100644 --- a/ldai/tracker.go +++ b/ldai/tracker.go @@ -2,10 +2,11 @@ package ldai import ( "fmt" + "time" + "github.com/launchdarkly/go-sdk-common/v3/ldcontext" "github.com/launchdarkly/go-sdk-common/v3/ldvalue" "github.com/launchdarkly/go-server-sdk/v7/interfaces" - "time" ) const ( @@ -18,55 +19,108 @@ const ( tokenOutput = "$ld:ai:tokens:output" ) +// TokenUsage represents the token usage returned by a model provider for a specific request. type TokenUsage struct { - Total int - Input int + // Total is the total number of tokens used. + Total int + // Input is the number of input tokens used. + Input int + // Output is the number of output tokens used. Output int } +// Set returns true if any of the fields are non-zero. func (t TokenUsage) Set() bool { return t.Total > 0 || t.Input > 0 || t.Output > 0 } +// Metrics represents the metrics returned by a model provider for a specific request. type Metrics struct { - LatencyMs float64 + // Latency is the latency of the request. + Latency time.Duration } +// Set returns true if the latency is non-zero. func (m Metrics) Set() bool { - return m.LatencyMs != 0 + return m.Latency != 0 } +// ProviderResponse represents the response from a model provider for a specific request. type ProviderResponse struct { - Usage TokenUsage + // Usage is the token usage. + Usage TokenUsage + // Metrics is the request metrics. Metrics Metrics } +// Feedback represents the feedback provided by a user for a model evaluation. type Feedback string const ( + // Positive feedback (result was good). Positive Feedback = "positive" + // Negative feedback (result was bad). Negative Feedback = "negative" ) +// EventSink represents the Tracker's requirements for delivering analytic events. This is generally satisfied +// by the LaunchDarkly SDK's TrackMetric method. +type EventSink interface { + // TrackMetric sends a named analytic event to LaunchDarkly relevant to a particular context, and containing a + // metric value and additional data. + TrackMetric( + eventName string, + context ldcontext.Context, + metricValue float64, + data ldvalue.Value, + ) error +} + +// Stopwatch is used to measure the duration of a task. Start will always be called before Stop. +// If an implementation is not provided, the Tracker uses a default implementation that delegates to +// time.Now and time.Since. +type Stopwatch interface { + // Start starts the stopwatch. + Start() + // Stop stops the stopwatch and returns the duration since Start was called. + Stop() time.Duration +} + +// Tracker is used to track metrics for AI config evaluation. +// Unless otherwise noted, the Tracker's method are not safe for concurrent use. type Tracker struct { key string config *Config context ldcontext.Context - events EventTracker + events EventSink trackData ldvalue.Value logger interfaces.LDLoggers + stopwatch Stopwatch } -type EventTracker interface { - TrackMetric( - eventName string, - context ldcontext.Context, - metricValue float64, - data ldvalue.Value, - ) error +// Used if a custom Stopwatch is not provided. +type defaultStopwatch struct { + start time.Time +} + +// Start saves the current time using time.Now. +func (d *defaultStopwatch) Start() { + d.start = time.Now() +} + +// Stop returns the duration since Start was called using time.Since. +func (d *defaultStopwatch) Stop() time.Duration { + return time.Since(d.start) +} + +// newTracker creates a new Tracker with the specified key, event sink, config, context, and loggers. +func newTracker(key string, events EventSink, config *Config, ctx ldcontext.Context, loggers interfaces.LDLoggers) *Tracker { + return newTrackerWithStopwatch(key, events, config, ctx, loggers, &defaultStopwatch{}) } -func NewTracker(key string, events EventTracker, config *Config, ctx ldcontext.Context, loggers interfaces.LDLoggers) *Tracker { +// newTrackerWithStopwatch creates a new Tracker with the specified key, event sink, config, context, loggers, and +// stopwatch. This method is used for testing purposes. +func newTrackerWithStopwatch(key string, events EventSink, config *Config, ctx ldcontext.Context, loggers interfaces.LDLoggers, stopwatch Stopwatch) *Tracker { if config == nil { panic("LaunchDarkly SDK programmer error: config must never be nil") } @@ -82,6 +136,7 @@ func NewTracker(key string, events EventTracker, config *Config, ctx ldcontext.C events: events, context: ctx, logger: loggers, + stopwatch: stopwatch, } } @@ -90,10 +145,15 @@ func (t *Tracker) logWarning(format string, args ...interface{}) { t.logger.Warnf(prefix+format, args...) } -func (t *Tracker) TrackDuration(durationMs float64) error { - return t.events.TrackMetric(duration, t.context, durationMs, t.trackData) +// TrackDuration tracks the duration of a task. For example, the duration of a model evaluation request may be +// tracked here. See also TrackRequest. +// The duration in milliseconds must fit within a float64. +func (t *Tracker) TrackDuration(dur time.Duration) error { + return t.events.TrackMetric(duration, t.context, float64(dur.Milliseconds()), t.trackData) } +// TrackFeedback tracks the feedback provided by a user for a model evaluation. If the feedback is not +// Positive or Negative, returns an error and does not track anything. func (t *Tracker) TrackFeedback(feedback Feedback) error { switch feedback { case Positive: @@ -105,10 +165,12 @@ func (t *Tracker) TrackFeedback(feedback Feedback) error { } } +// TrackSuccess tracks a successful model evaluation. func (t *Tracker) TrackSuccess() error { return t.events.TrackMetric(generation, t.context, 1, t.trackData) } +// TrackUsage tracks the token usage for a model evaluation. func (t *Tracker) TrackUsage(usage TokenUsage) error { var failed bool @@ -138,15 +200,25 @@ func (t *Tracker) TrackUsage(usage TokenUsage) error { return nil } -func measureDurationOfTask[T any](task func() (T, error)) (T, int64, error) { - start := time.Now() +func measureDurationOfTask[T any](stopwatch Stopwatch, task func() (T, error)) (T, time.Duration, error) { + stopwatch.Start() result, err := task() - duration := time.Since(start).Milliseconds() - return result, duration, err + return result, stopwatch.Stop(), err } +// TrackRequest tracks metrics for a model evaluation request. The task function should return a ProviderResponse +// which can be used to specify request metrics and token usage. +// +// All fields of the ProviderResponse are optional. +// +// If the task returns an error, then the request is not considered successful and no metrics are tracked. +// Otherwise, the following metrics are tracked: +// 1. Successful model evaluation. +// 2. Any metrics that were that set in the ProviderResponse +// 2a) If Latency was not set in the ProviderResponse's Metrics field, an automatically measured duration. +// 3. Any token usage that was set in the ProviderResponse. func (t *Tracker) TrackRequest(task func() (ProviderResponse, error)) (ProviderResponse, error) { - usage, duration, err := measureDurationOfTask(task) + usage, duration, err := measureDurationOfTask(t.stopwatch, task) if err != nil { t.logWarning("error executing request: %v", err) @@ -157,11 +229,11 @@ func (t *Tracker) TrackRequest(task func() (ProviderResponse, error)) (ProviderR } if usage.Metrics.Set() { - if err := t.TrackDuration(usage.Metrics.LatencyMs); err != nil { + if err := t.TrackDuration(usage.Metrics.Latency); err != nil { t.logWarning("error tracking duration metric (user provided) for request: %v", err) } } else { - if err := t.TrackDuration(float64(duration)); err != nil { + if err := t.TrackDuration(duration); err != nil { t.logWarning("error tracking duration metric (automatically measured) for request: %v", err) } } diff --git a/ldai/tracker_test.go b/ldai/tracker_test.go index 77c3dc85..5a74ec7b 100644 --- a/ldai/tracker_test.go +++ b/ldai/tracker_test.go @@ -1,11 +1,14 @@ package ldai import ( + "testing" + "time" + "github.com/launchdarkly/go-sdk-common/v3/ldcontext" "github.com/launchdarkly/go-sdk-common/v3/ldlogtest" "github.com/launchdarkly/go-sdk-common/v3/ldvalue" "github.com/stretchr/testify/assert" - "testing" + "github.com/stretchr/testify/require" ) type mockEvents struct { @@ -31,13 +34,13 @@ func (m *mockEvents) TrackMetric(eventName string, context ldcontext.Context, me func TestTracker_NewPanicsWithNilConfig(t *testing.T) { assert.Panics(t, func() { - NewTracker("key", newMockEvents(), nil, ldcontext.New("key"), nil) + newTracker("key", newMockEvents(), nil, ldcontext.New("key"), nil) }) } func TestTracker_NewDoesNotPanicWithConfig(t *testing.T) { assert.NotPanics(t, func() { - NewTracker("key", newMockEvents(), &Config{}, ldcontext.New("key"), nil) + newTracker("key", newMockEvents(), &Config{}, ldcontext.New("key"), nil) }) } @@ -49,7 +52,7 @@ func makeTrackData(configKey, versionKey string) ldvalue.Value { func TestTracker_TrackSuccess(t *testing.T) { events := newMockEvents() - tracker := NewTracker("key", events, &Config{}, ldcontext.New("key"), nil) + tracker := newTracker("key", events, &Config{}, ldcontext.New("key"), nil) assert.NoError(t, tracker.TrackSuccess()) expectedEvent := trackEvent{ @@ -64,14 +67,14 @@ func TestTracker_TrackSuccess(t *testing.T) { func TestTracker_TrackRequest(t *testing.T) { events := newMockEvents() - tracker := NewTracker("key", events, &Config{}, ldcontext.New("key"), nil) + tracker := newTracker("key", events, &Config{}, ldcontext.New("key"), nil) expectedResponse := ProviderResponse{ Usage: TokenUsage{ Total: 1, }, Metrics: Metrics{ - LatencyMs: 1.0, + Latency: 10 * time.Millisecond, }, } @@ -92,7 +95,7 @@ func TestTracker_TrackRequest(t *testing.T) { expectedDurationEvent := trackEvent{ name: "$ld:ai:duration:total", context: ldcontext.New("key"), - metricValue: 1.0, + metricValue: 10.0, data: makeTrackData("key", ""), } @@ -106,3 +109,131 @@ func TestTracker_TrackRequest(t *testing.T) { expectedEvents := []trackEvent{expectedSuccessEvent, expectedDurationEvent, expectedTokenUsageEvent} assert.ElementsMatch(t, expectedEvents, events.events) } + +type mockStopwatch time.Duration + +func (m mockStopwatch) Start() {} + +func (m mockStopwatch) Stop() time.Duration { + return time.Duration(m) +} + +func TestTracker_LatencyMeasuredIfNotProvided(t *testing.T) { + events := newMockEvents() + + tracker := newTrackerWithStopwatch( + "key", events, &Config{}, ldcontext.New("key"), nil, mockStopwatch(42*time.Millisecond)) + + expectedResponse := ProviderResponse{ + Usage: TokenUsage{ + Total: 1, + }, + } + + r, err := tracker.TrackRequest(func() (ProviderResponse, error) { + return expectedResponse, nil + }) + + assert.NoError(t, err) + assert.Equal(t, expectedResponse, r) + + require.Equal(t, 3, len(events.events)) + gotEvent := events.events[1] + assert.Equal(t, "$ld:ai:duration:total", gotEvent.name) + assert.Equal(t, 42.0, gotEvent.metricValue) +} + +func TestTracker_TrackDuration(t *testing.T) { + events := newMockEvents() + tracker := newTracker("key", events, &Config{}, ldcontext.New("key"), nil) + + assert.NoError(t, tracker.TrackDuration(time.Millisecond*10)) + + expectedEvent := trackEvent{ + name: "$ld:ai:duration:total", + context: ldcontext.New("key"), + metricValue: 10.0, + data: makeTrackData("key", ""), + } + + assert.ElementsMatch(t, []trackEvent{expectedEvent}, events.events) +} + +func TestTracker_TrackFeedback(t *testing.T) { + events := newMockEvents() + tracker := newTracker("key", events, &Config{}, ldcontext.New("key"), nil) + + assert.NoError(t, tracker.TrackFeedback(Positive)) + assert.NoError(t, tracker.TrackFeedback(Negative)) + assert.Error(t, tracker.TrackFeedback("not a valid feedback value")) + + expectedPositiveEvent := trackEvent{ + name: "$ld:ai:feedback:user:positive", + context: ldcontext.New("key"), + metricValue: 1.0, + data: makeTrackData("key", ""), + } + + expectedNegativeEvent := trackEvent{ + name: "$ld:ai:feedback:user:negative", + context: ldcontext.New("key"), + metricValue: 1.0, + data: makeTrackData("key", ""), + } + + assert.ElementsMatch(t, []trackEvent{expectedPositiveEvent, expectedNegativeEvent}, events.events) +} + +func TestTracker_TrackUsage(t *testing.T) { + t.Run("only one field set, only one event", func(t *testing.T) { + events := newMockEvents() + tracker := newTracker("key", events, &Config{}, ldcontext.New("key"), nil) + + assert.NoError(t, tracker.TrackUsage(TokenUsage{ + Total: 42, + })) + + expectedEvent := trackEvent{ + name: "$ld:ai:tokens:total", + context: ldcontext.New("key"), + metricValue: 42.0, + data: makeTrackData("key", ""), + } + + assert.ElementsMatch(t, []trackEvent{expectedEvent}, events.events) + }) + + t.Run("all fields set, all events", func(t *testing.T) { + events := newMockEvents() + tracker := newTracker("key", events, &Config{}, ldcontext.New("key"), nil) + + assert.NoError(t, tracker.TrackUsage(TokenUsage{ + Total: 42, + Input: 20, + Output: 22, + })) + + expectedTotal := trackEvent{ + name: "$ld:ai:tokens:total", + context: ldcontext.New("key"), + metricValue: 42.0, + data: makeTrackData("key", ""), + } + + expectedInput := trackEvent{ + name: "$ld:ai:tokens:input", + context: ldcontext.New("key"), + metricValue: 20.0, + data: makeTrackData("key", ""), + } + + expectedOutput := trackEvent{ + name: "$ld:ai:tokens:output", + context: ldcontext.New("key"), + metricValue: 22.0, + data: makeTrackData("key", ""), + } + + assert.ElementsMatch(t, []trackEvent{expectedTotal, expectedInput, expectedOutput}, events.events) + }) +}