From 5dad122c6f7a0b4e3cc5183be0159b1106964418 Mon Sep 17 00:00:00 2001 From: Casey Waldren Date: Mon, 25 Nov 2024 18:47:51 -0800 Subject: [PATCH] add TrackRequest --- ldai/tracker.go | 55 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/ldai/tracker.go b/ldai/tracker.go index 1d12788c..2d6e8b17 100644 --- a/ldai/tracker.go +++ b/ldai/tracker.go @@ -5,6 +5,7 @@ import ( "github.com/launchdarkly/go-sdk-common/v3/ldcontext" "github.com/launchdarkly/go-sdk-common/v3/ldvalue" "github.com/launchdarkly/go-server-sdk/v7/interfaces" + "time" ) const ( @@ -23,6 +24,23 @@ type TokenUsage struct { Output int } +func (t TokenUsage) Set() bool { + return t.Total > 0 || t.Input > 0 || t.Output > 0 +} + +type Metrics struct { + LatencyMs float64 +} + +func (m Metrics) Set() bool { + return m.LatencyMs != 0 +} + +type ProviderResponse struct { + Usage TokenUsage + Metrics Metrics +} + type Feedback string const ( @@ -112,3 +130,40 @@ func (t *Tracker) TrackUsage(usage TokenUsage) error { return nil } + +func measureDurationOfTask[T any](task func() (T, error)) (T, int64, error) { + start := time.Now() + result, err := task() + duration := time.Since(start).Milliseconds() + return result, duration, err +} + +func (t *Tracker) TrackRequest(task func() (ProviderResponse, error)) (ProviderResponse, error) { + usage, duration, err := measureDurationOfTask(task) + + if err != nil { + t.logger.Warn("Error executing request: %s", err.Error()) + return ProviderResponse{}, err + } + if err := t.TrackSuccess(); err != nil { + t.logger.Warn("Error tracking success metric for request: %s", err.Error()) + } + + if usage.Metrics.Set() { + if err := t.TrackDuration(usage.Metrics.LatencyMs); err != nil { + t.logger.Warn("Error tracking duration metric (user provided) for request: %s", err.Error()) + } + } else { + if err := t.TrackDuration(float64(duration)); err != nil { + t.logger.Warn("Error tracking duration metric (automatically measured) for request: %s", err.Error()) + } + } + + if usage.Usage.Set() { + if err := t.TrackUsage(usage.Usage); err != nil { + t.logger.Warn("Error tracking token usage for request: %s", err.Error()) + } + } + + return usage, nil +}