diff --git a/ldai/README.md b/ldai/README.md index 428f2b4f..ac2e27ca 100644 --- a/ldai/README.md +++ b/ldai/README.md @@ -1,6 +1,6 @@ LaunchDarkly Server-side AI SDK for Go ============================================== -[![Actions Status](https://github.com/launchdarkly/go-server-sdk/actions/workflows/ldoai-ci.yml/badge.svg?branch=v7)](https://github.com/launchdarkly/go-server-sdk/actions/workflows/ldai-ci.yml) +[![Actions Status](https://github.com/launchdarkly/go-server-sdk/actions/workflows/ldai-ci.yml/badge.svg?branch=v7)](https://github.com/launchdarkly/go-server-sdk/actions/workflows/ldai-ci.yml) LaunchDarkly overview ------------------------- @@ -20,22 +20,32 @@ import ( ) ``` -Configure the base LaunchDarkly client: +Configure the base LaunchDarkly Server SDK: ```go -client, _ = ld.MakeClient("your-sdk-key", 5*time.Second) +sdkClient, _ = ld.MakeClient("your-sdk-key", 5*time.Second) ``` -Instantiate the AI client: +Instantiate the AI client, passing in the base Server SDK: ```go - -aiClient := ldai.New(client) +aiClient, err := ldai.NewClient(sdkClient) ``` +Fetch a model configuration for a specific LaunchDarkly context: +```go +// The default value 'ldai.Disabled()' be returned if LaunchDarkly is unavailable or the config +// cannot be fetched. To customize the default value, use ldai.NewConfig(). +config, tracker := aiClient.Config("your-model-key", ldcontext.New("user-key"), ldai.Disabled(), nil) + +// Access the methods on config, and optionally use the returned tracker to generate analytic events +// related to usage of the model config. +``` Learn more ----------- Read our [documentation](http://docs.launchdarkly.com) for in-depth instructions on configuring and using LaunchDarkly. +You can also head straight to the [complete reference guide for this SDK](https://docs.launchdarkly.com/sdk/ai/go). + Contributing ------------ diff --git a/ldai/client.go b/ldai/client.go index 6d699185..2f66f0d7 100644 --- a/ldai/client.go +++ b/ldai/client.go @@ -1,12 +1,160 @@ package ldai +import ( + "encoding/json" + "fmt" + + "github.com/launchdarkly/go-server-sdk/ldai/datamodel" + + "github.com/alexkappa/mustache" + "github.com/launchdarkly/go-sdk-common/v3/ldcontext" + "github.com/launchdarkly/go-sdk-common/v3/ldvalue" + "github.com/launchdarkly/go-server-sdk/v7/interfaces" +) + +// Defines the Mustache variable name used to access the provided context. +const ldContextVariable = "ldctx" + +// ServerSDK defines the required methods for the AI SDK to interact with LaunchDarkly. These methods are +// satisfied by the LaunchDarkly Go Server SDK. type ServerSDK interface { + JSONVariation( + key string, + context ldcontext.Context, + defaultVal ldvalue.Value, + ) (ldvalue.Value, error) + Loggers() interfaces.LDLoggers + TrackMetric( + eventName string, + context ldcontext.Context, + metricValue float64, + data ldvalue.Value, + ) error } +// Client is the main entrypoint for the AI SDK. A client can be used to obtain an AI config from LaunchDarkly. +// Unless otherwise noted, the Client's method are not safe for concurrent use. type Client struct { - sdk ServerSDK + sdk ServerSDK + logger interfaces.LDLoggers +} + +// NewClient creates a new AI Client. The provided SDK interface must not be nil. The client will use the provided SDK's +// loggers to log warnings and errors. +func NewClient(sdk ServerSDK) (*Client, error) { + if sdk == nil { + return nil, fmt.Errorf("sdk must not be nil") + } + return &Client{ + sdk: sdk, + logger: sdk.Loggers(), + }, nil +} + +func (c *Client) logConfigWarning(key string, format string, args ...interface{}) { + prefix := "AI config '" + key + "': " + c.logger.Warnf(prefix+format, args...) +} + +// Config evaluates an AI config named by a given key for the given context. +// +// The config's messages will undergo Mustache template interpolation using the provided variables, which may be +// nil. If the config cannot be evaluated or LaunchDarkly is unreachable, the default value is returned. Note that +// the messages in the default will not undergo template interpolation. +// +// To send analytic events to LaunchDarkly related to the AI config, call methods on the returned Tracker. +func (c *Client) Config( + key string, + context ldcontext.Context, + defaultValue Config, + variables map[string]interface{}, +) (Config, *Tracker) { + + result, _ := c.sdk.JSONVariation(key, context, defaultValue.AsLdValue()) + + // The spec requires the config to at least be an object (although all properties are optional, so it may be an + // empty object.) + if result.Type() != ldvalue.ObjectType { + c.logConfigWarning(key, "unmarshalling failed, expected JSON object but got %s", result.Type().String()) + return defaultValue, newTracker(key, c.sdk, &defaultValue, context, c.logger) + } + + var parsed datamodel.Config + if err := json.Unmarshal([]byte(result.JSONString()), &parsed); err != nil { + c.logConfigWarning(key, "unmarshalling failed: %v", err) + return defaultValue, newTracker(key, c.sdk, &defaultValue, context, c.logger) + } + + mergedVariables := map[string]interface{}{ + ldContextVariable: getAllAttributes(context), + } + + for k, v := range variables { + if k == ldContextVariable { + c.logConfigWarning(key, "config variables contains 'ldctx', which is reserved and cannot be overwritten") + continue + } + mergedVariables[k] = v + } + + builder := NewConfig(). + WithModelId(parsed.Model.Id). + WithProviderId(parsed.Provider.Id). + WithEnabled(parsed.Meta.Enabled) + + for i, msg := range parsed.Messages { + content, err := interpolateTemplate(msg.Content, mergedVariables) + if err != nil { + c.logConfigWarning(key, + "malformed message at index %d: %v", i, err, + ) + return defaultValue, &Tracker{} + } + builder.WithMessage(content, msg.Role) + } + + cfg := builder.Build() + return cfg, newTracker(key, c.sdk, &cfg, context, c.logger) +} + +func getAllAttributes(context ldcontext.Context) map[string]interface{} { + if !context.Multiple() { + return addContextAttributes(context, false) + } + + attributes := map[string]interface{}{ + "kind": context.Kind(), + "key": context.FullyQualifiedKey(), + } + + for _, ctx := range context.GetAllIndividualContexts(nil) { + attributes[string(ctx.Kind())] = addContextAttributes(ctx, true) + } + + return attributes +} + +func addContextAttributes(context ldcontext.Context, omitKind bool) map[string]interface{} { + attributes := map[string]interface{}{ + "key": context.Key(), + "anonymous": context.Anonymous(), + } + + if !omitKind { + attributes["kind"] = context.Kind() + } + + for _, attr := range context.GetOptionalAttributeNames(nil) { + attributes[attr] = context.GetValue(attr).AsArbitraryValue() + } + + return attributes } -func New(sdk ServerSDK) *Client { - return &Client{} +func interpolateTemplate(template string, variables map[string]interface{}) (string, error) { + m := mustache.New() + if err := m.ParseString(template); err != nil { + return "", err + } + return m.RenderString(variables) } diff --git a/ldai/client_test.go b/ldai/client_test.go index e29a8fe8..391e68d8 100644 --- a/ldai/client_test.go +++ b/ldai/client_test.go @@ -1,7 +1,403 @@ package ldai -import "testing" +import ( + "errors" + "testing" + + "github.com/launchdarkly/go-server-sdk/ldai/datamodel" + + "github.com/launchdarkly/go-sdk-common/v3/ldcontext" + "github.com/launchdarkly/go-sdk-common/v3/ldlog" + "github.com/launchdarkly/go-sdk-common/v3/ldlogtest" + "github.com/launchdarkly/go-sdk-common/v3/ldvalue" + "github.com/launchdarkly/go-server-sdk/v7/interfaces" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type mockServerSDK struct { + log *ldlogtest.MockLog + json []byte + err error +} + +func newMockSDK(json []byte, err error) *mockServerSDK { + return &mockServerSDK{json: json, err: err, log: ldlogtest.NewMockLog()} +} + +func (m *mockServerSDK) JSONVariation( + key string, + context ldcontext.Context, + defaultVal ldvalue.Value, +) (ldvalue.Value, error) { + + if m.err != nil { + return defaultVal, m.err + } + + return ldvalue.Parse(m.json), nil +} + +func (m *mockServerSDK) Loggers() interfaces.LDLoggers { + return m.log.Loggers +} + +func (m *mockServerSDK) TrackMetric(eventName string, context ldcontext.Context, metricValue float64, data ldvalue.Value) error { + return nil +} + +func TestNewClientReturnsErrorWhenSDKIsNil(t *testing.T) { + _, err := NewClient(nil) + require.Error(t, err) +} func TestNewClient(t *testing.T) { - _ = New(nil) + client, err := NewClient(newMockSDK(nil, nil)) + require.NoError(t, err) + require.NotNil(t, client) +} + +func TestEvalErrorReturnsDefault(t *testing.T) { + client, err := NewClient(newMockSDK(nil, errors.New("client is offline"))) + require.NoError(t, err) + require.NotNil(t, client) + + defaultVal := NewConfig().Enable().WithMessage("hello", datamodel.User).Build() + + cfg, tracker := client.Config("key", ldcontext.New("user"), defaultVal, nil) + assert.NotNil(t, tracker) + assert.Equal(t, defaultVal, cfg) +} + +func TestInvalidConfigReturnsDefault(t *testing.T) { + tests := []struct { + name string + json []byte + }{ + {"null value", []byte("null")}, + {"invalid json", []byte("invalid")}, + {"is a number", []byte("42")}, + {"is a string", []byte(`"hello"`)}, + {"is an array", []byte(`["hello"]`)}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + sdk := newMockSDK(test.json, nil) + client, err := NewClient(sdk) + require.NoError(t, err) + require.NotNil(t, client) + + defaultVal := NewConfig().Enable().WithMessage("hello", datamodel.User).Build() + + cfg, _ := client.Config("key", ldcontext.New("user"), defaultVal, nil) + assert.Equal(t, defaultVal, cfg) + + sdk.log.AssertMessageMatch(t, true, ldlog.Warn, "AI config 'key':") + }) + } +} + +func TestDisabledConfigs(t *testing.T) { + tests := []struct { + name string + json []byte + }{ + {"empty object", []byte("{}")}, + {"missing meta field", []byte(`{"model": {}, "messages": []}`)}, + {"meta disabled explicitly", []byte(`{"meta": {"enabled": false, "versionKey": "1"}, "model": {}, "messages": []}`)}, + {"meta disable implicitly", []byte(`{"meta": { "versionKey": "1"}, "model": {}, "messages": []}`)}, + } + + defaultVal := NewConfig().Enable().WithMessage("hello", datamodel.User).Build() + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + client, err := NewClient(newMockSDK(test.json, nil)) + require.NoError(t, err) + require.NotNil(t, client) + + cfg, _ := client.Config("key", ldcontext.New("user"), defaultVal, nil) + + // We *shouldn't* be getting the default value, because these are all valid configs that should + // be parsed as disabled. + assert.False(t, cfg.Enabled()) + }) + } +} + +func TestCanSetDefaultConfigFields(t *testing.T) { + client, err := NewClient(newMockSDK(nil, nil)) + require.NoError(t, err) + require.NotNil(t, client) + + defaultVal := NewConfig().Enable(). + WithMessage("hello", datamodel.User). + WithMessage("world", datamodel.System). + WithProviderId("provider"). + WithModelId("model").Build() + + cfg, _ := client.Config("key", ldcontext.New("user"), defaultVal, nil) + + assert.True(t, cfg.Enabled()) + assert.Equal(t, "provider", cfg.ProviderId()) + assert.Equal(t, "model", cfg.ModelId()) + assert.Equal(t, 2, len(cfg.Messages())) + + msg := cfg.Messages() + assert.Equal(t, "hello", msg[0].Content) + assert.Equal(t, datamodel.User, msg[0].Role) + assert.Equal(t, "world", msg[1].Content) + assert.Equal(t, datamodel.System, msg[1].Role) +} + +func TestCanSetModelParameters(t *testing.T) { + client, err := NewClient(newMockSDK(nil, nil)) + require.NoError(t, err) + require.NotNil(t, client) + + defaultVal := NewConfig().WithModelParam("foo", ldvalue.String("bar")).Build() + cfg, _ := client.Config("key", ldcontext.New("user"), defaultVal, nil) + + t.Run("param is present", func(t *testing.T) { + p, ok := cfg.ModelParam("foo") + assert.True(t, ok) + assert.Equal(t, "bar", p.StringValue()) + }) + + t.Run("param is missing", func(t *testing.T) { + p, ok := cfg.ModelParam("missing") + assert.False(t, ok) + assert.Equal(t, ldvalue.Null(), p) + }) +} + +func TestCanSetCustomModelParameters(t *testing.T) { + client, err := NewClient(newMockSDK(nil, nil)) + require.NoError(t, err) + require.NotNil(t, client) + + defaultVal := NewConfig().WithCustomModelParam("foo", ldvalue.String("bar")).Build() + cfg, _ := client.Config("key", ldcontext.New("user"), defaultVal, nil) + + t.Run("param is present", func(t *testing.T) { + p, ok := cfg.CustomModelParam("foo") + assert.True(t, ok) + assert.Equal(t, "bar", p.StringValue()) + }) + + t.Run("param is missing", func(t *testing.T) { + p, ok := cfg.CustomModelParam("missing") + assert.False(t, ok) + assert.Equal(t, ldvalue.Null(), p) + }) +} + +func TestNormalAndCustomParamsDoNotInterfere(t *testing.T) { + client, err := NewClient(newMockSDK(nil, nil)) + require.NoError(t, err) + require.NotNil(t, client) + + defaultVal := NewConfig(). + WithModelParam("foo", ldvalue.String("bar")). + WithCustomModelParam("foo", ldvalue.String("baz")).Build() + + cfg, _ := client.Config("key", ldcontext.New("user"), defaultVal, nil) + + foo1, ok := cfg.ModelParam("foo") + require.True(t, ok) + assert.Equal(t, "bar", foo1.StringValue()) + + foo2, ok := cfg.CustomModelParam("foo") + require.True(t, ok) + assert.Equal(t, "baz", foo2.StringValue()) +} + +func TestCannotOverwriteMessages(t *testing.T) { + client, err := NewClient(newMockSDK(nil, nil)) + require.NoError(t, err) + require.NotNil(t, client) + + defaultVal := NewConfig(). + WithMessage("hello", datamodel.Assistant).Build() + + cfg, _ := client.Config("key", ldcontext.New("user"), defaultVal, nil) + + cfg.Messages()[0].Content = "changed" + cfg.Messages()[0].Role = datamodel.User + + assert.ElementsMatch(t, []datamodel.Message{{Content: "hello", Role: datamodel.Assistant}}, cfg.Messages()) +} + +func eval(t *testing.T, prompt string, ctx ldcontext.Context, variables map[string]interface{}) (string, error) { + t.Helper() + json := []byte(`{ + "_ldMeta": {"versionKey": "1", "enabled": true}, + "messages": [ + {"content": "` + prompt + `", "role": "user"} + ] + }`) + + client, err := NewClient(newMockSDK(json, nil)) + require.NoError(t, err) + cfg, _ := client.Config("key", ctx, Disabled(), variables) + if len(cfg.Messages()) == 0 { + return "", errors.New("no messages interpolated") + } + return cfg.Messages()[0].Content, nil +} + +func TestInterpolation(t *testing.T) { + t.Run("missing variables", func(t *testing.T) { + cases := []string{ + "{{ adjective }}", + "{{ adjective.nested.deep }}", + "{{ ldctx.this_is_not_a_variable }}", + } + + for _, c := range cases { + t.Run(c, func(t *testing.T) { + result, err := eval(t, "I am an ("+c+") LLM", ldcontext.New("user"), nil) + require.NoError(t, err) + assert.Equal(t, "I am an () LLM", result) + }) + } + }) + + t.Run("simple variables", func(t *testing.T) { + cases := []string{ + "awesome", + "slow", + "all powerful", + } + + for _, c := range cases { + t.Run(c, func(t *testing.T) { + result, err := eval(t, "I am an {{ adjective }} LLM", ldcontext.New("user"), map[string]interface{}{"adjective": c}) + require.NoError(t, err) + assert.Equal(t, "I am an "+c+" LLM", result) + }) + } + }) + + t.Run("multiple variables", func(t *testing.T) { + vars := map[string]interface{}{ + "adjective": "awesome", + "noun": "robot", + "stats": map[string]interface{}{ + "power": "9000", + }, + } + result, err := eval(t, "I am an {{ adjective }} {{ noun }} with power over {{ stats.power }}", ldcontext.New("user"), vars) + require.NoError(t, err) + assert.Equal(t, "I am an awesome robot with power over 9000", result) + }) + + t.Run("interpolation with array indices does not work", func(t *testing.T) { + vars := map[string]interface{}{ + "adjectives": []string{"awesome", "slow", "all powerful"}, + } + + t.Run("dot syntax interpolates as empty string", func(t *testing.T) { + result, err := eval(t, "I am an ({{ adjectives.0 }}) LLM", ldcontext.New("user"), vars) + require.NoError(t, err) + assert.Equal(t, "I am an () LLM", result) + }) + + t.Run("bracket syntax returns error", func(t *testing.T) { + _, err := eval(t, "I am an ({{ adjectives[0] }}) LLM", ldcontext.New("user"), vars) + assert.Error(t, err) + }) + }) + + t.Run("array sections", func(t *testing.T) { + vars := map[string]interface{}{ + "adjectives": []string{"hello", "world", "!"}, + } + + result, err := eval(t, "{{#adjectives }}{{ . }} {{/adjectives }}", ldcontext.New("user"), vars) + require.NoError(t, err) + assert.Equal(t, "hello world ! ", result) + }) + + t.Run("malformed syntax", func(t *testing.T) { + _, err := eval(t, "This is a {{ malformed }]} prompt", ldcontext.New("user"), nil) + require.Error(t, err) + }) + + t.Run("interpolate single kind context", func(t *testing.T) { + context := ldcontext.NewBuilder("123").Name("Sandy").Build() + result, err := eval(t, "I'm a {{ ldctx.kind}} with key {{ ldctx.key }}, named {{ ldctx.name }}", context, nil) + require.NoError(t, err) + assert.Equal(t, "I'm a user with key 123, named Sandy", result) + }) + + t.Run("interpolation with nested context attributes", func(t *testing.T) { + context := ldcontext.NewBuilder("123"). + SetValue("stats", ldvalue.ObjectBuild().Set("power", ldvalue.Int(9000)).Build()).Build() + result, err := eval(t, "I can ingest over {{ ldctx.stats.power }} tokens per second!", context, nil) + require.NoError(t, err) + assert.Equal(t, "I can ingest over 9000 tokens per second!", result) + }) + + t.Run("interpolation with multi kind context", func(t *testing.T) { + user := ldcontext.NewBuilder("123"). + SetValue("cat_ownership", ldvalue.ObjectBuild().Set("count", ldvalue.Int(12)).Build()).Build() + + cat := ldcontext.NewBuilder("456").Kind("cat"). + SetValue("health", ldvalue.ObjectBuild().Set("hunger", ldvalue.String("off the charts")).Build()).Build() + + context := ldcontext.NewMulti(user, cat) + + result, err := eval(t, "As an owner of {{ ldctx.user.cat_ownership.count }} cats, I must report that my cat's hunger level is {{ ldctx.cat.health.hunger }}!", context, nil) + require.NoError(t, err) + assert.Equal(t, "As an owner of 12 cats, I must report that my cat's hunger level is off the charts!", result) + }) + + t.Run("interpolation with multi kind context does not have anonymous attribute", func(t *testing.T) { + user := ldcontext.NewBuilder("123"). + SetValue("cat_ownership", ldvalue.ObjectBuild().Set("count", ldvalue.Int(12)).Build()).Build() + + cat := ldcontext.NewBuilder("456").Kind("cat"). + SetValue("health", ldvalue.ObjectBuild().Set("hunger", ldvalue.String("off the charts")).Build()).Build() + + context := ldcontext.NewMulti(user, cat) + + result, err := eval(t, "anonymous=<{{ ldctx.anonymous }}>", context, nil) + require.NoError(t, err) + assert.Equal(t, "anonymous=<>", result) + }) + + t.Run("interpolation with multi kind context has kind multi", func(t *testing.T) { + user := ldcontext.NewBuilder("123"). + SetValue("cat_ownership", ldvalue.ObjectBuild().Set("count", ldvalue.Int(12)).Build()).Build() + + cat := ldcontext.NewBuilder("456").Kind("cat"). + SetValue("health", ldvalue.ObjectBuild().Set("hunger", ldvalue.String("off the charts")).Build()).Build() + + context := ldcontext.NewMulti(user, cat) + + result, err := eval(t, "kind=<{{ ldctx.kind }}>", context, nil) + require.NoError(t, err) + assert.Equal(t, "kind=", result) + }) + + t.Run("interpolation with multi kind context does not have child kinds", func(t *testing.T) { + + // The idea here is that in a multi-kind context, we can access ldctx.kind (== "multi"), but you can't + // access the kind field of the individual nested contexts since this doesn't match the actual data model. + // That is, you can't access ldctx.user.kind or ldctx.cat.kind, only ldctx.kind. + + user := ldcontext.NewBuilder("123"). + SetValue("cat_ownership", ldvalue.ObjectBuild().Set("count", ldvalue.Int(12)).Build()).Build() + + cat := ldcontext.NewBuilder("456").Kind("cat"). + SetValue("health", ldvalue.ObjectBuild().Set("hunger", ldvalue.String("off the charts")).Build()).Build() + + context := ldcontext.NewMulti(user, cat) + + result, err := eval(t, "user_kind=<{{ ldctx.user.kind}}>,cat_kind=<{{ ldctx.cat.kind }}>", context, nil) + require.NoError(t, err) + assert.Equal(t, "user_kind=<>,cat_kind=<>", result) + }) } diff --git a/ldai/config.go b/ldai/config.go new file mode 100644 index 00000000..97a9c83b --- /dev/null +++ b/ldai/config.go @@ -0,0 +1,152 @@ +package ldai + +import ( + "github.com/launchdarkly/go-sdk-common/v3/ldvalue" + "github.com/launchdarkly/go-server-sdk/ldai/datamodel" + "golang.org/x/exp/maps" + "golang.org/x/exp/slices" +) + +// Config represents an AI config. +type Config struct { + c datamodel.Config +} + +// VersionKey is used internally by LaunchDarkly. +func (c *Config) VersionKey() string { + return c.c.Meta.VersionKey +} + +// Messages returns the messages defined by the config. The series of messages may be +// passed to an AI model provider. +func (c *Config) Messages() []datamodel.Message { + return slices.Clone(c.c.Messages) +} + +// Enabled returns whether the config is enabled. +func (c *Config) Enabled() bool { + return c.c.Meta.Enabled +} + +// ProviderId returns the provider ID associated with the config. +func (c *Config) ProviderId() string { + return c.c.Provider.Id +} + +// ModelId returns the model ID associated with the config. +func (c *Config) ModelId() string { + return c.c.Model.Id +} + +// ModelParam returns the model parameter named by key. The second parameter is true if the key exists. +func (c *Config) ModelParam(key string) (ldvalue.Value, bool) { + val, ok := c.c.Model.Parameters[key] + return val, ok +} + +// CustomModelParam returns the custom model parameter named by key. The second parameter is true if the key exists. +func (c *Config) CustomModelParam(key string) (ldvalue.Value, bool) { + val, ok := c.c.Model.Custom[key] + return val, ok +} + +// AsLdValue is used internally. +func (c *Config) AsLdValue() ldvalue.Value { + return ldvalue.FromJSONMarshal(c.c) +} + +// ConfigBuilder is used to define a default AI config, returned when LaunchDarkly is unreachable or there +// is an error evaluating the config. +type ConfigBuilder struct { + messages []datamodel.Message + enabled bool + providerId string + modelId string + modelParams map[string]ldvalue.Value + modelCustomParams map[string]ldvalue.Value +} + +// NewConfig returns a new ConfigBuilder. By default, the config is disabled. +func NewConfig() *ConfigBuilder { + return &ConfigBuilder{ + modelParams: make(map[string]ldvalue.Value), + modelCustomParams: make(map[string]ldvalue.Value), + } +} + +// Disabled is a helper that returns a built Config that is disabled and contains no messages. +func Disabled() Config { + return NewConfig().Disable().Build() +} + +// WithMessage appends a message to the config with the given role. +func (cb *ConfigBuilder) WithMessage(content string, role datamodel.Role) *ConfigBuilder { + cb.messages = append(cb.messages, datamodel.Message{ + Content: content, + Role: role, + }) + return cb +} + +// WithEnabled sets whether the config is enabled. See also Enable and Disable. +func (cb *ConfigBuilder) WithEnabled(enabled bool) *ConfigBuilder { + cb.enabled = enabled + return cb +} + +// Enable enables the config. +func (cb *ConfigBuilder) Enable() *ConfigBuilder { + return cb.WithEnabled(true) +} + +// Disable disables the config. +func (cb *ConfigBuilder) Disable() *ConfigBuilder { + return cb.WithEnabled(false) +} + +// WithModelId sets the model ID associated with the config. +func (cb *ConfigBuilder) WithModelId(modelId string) *ConfigBuilder { + cb.modelId = modelId + return cb +} + +// WithProviderId sets the provider ID associated with the config. +func (cb *ConfigBuilder) WithProviderId(providerId string) *ConfigBuilder { + cb.providerId = providerId + return cb +} + +// WithModelParam sets a model parameter named by key to the given value. If the key already exists, it will be +// overwritten. Model parameters are generally set by LaunchDarkly; for custom parameters not recognized by +// LaunchDarkly, use WithModelCustomParam. +func (cb *ConfigBuilder) WithModelParam(key string, value ldvalue.Value) *ConfigBuilder { + cb.modelParams[key] = value + return cb +} + +// WithCustomModelParam sets a custom model parameter named by key to the given value. If the key already exists, it +// will be overwritten. +func (cb *ConfigBuilder) WithCustomModelParam(key string, value ldvalue.Value) *ConfigBuilder { + cb.modelCustomParams[key] = value + return cb +} + +// Build creates a Config from the current builder state. +func (cb *ConfigBuilder) Build() Config { + return Config{ + c: datamodel.Config{ + Messages: slices.Clone(cb.messages), + Meta: datamodel.Meta{ + Enabled: cb.enabled, + }, + Model: datamodel.Model{ + Id: cb.modelId, + Parameters: maps.Clone(cb.modelParams), + Custom: maps.Clone(cb.modelCustomParams), + }, + Provider: datamodel.Provider{ + Id: cb.providerId, + }, + }, + } +} diff --git a/ldai/datamodel/datamodel.go b/ldai/datamodel/datamodel.go new file mode 100644 index 00000000..d4dcd063 --- /dev/null +++ b/ldai/datamodel/datamodel.go @@ -0,0 +1,68 @@ +package datamodel + +import "github.com/launchdarkly/go-sdk-common/v3/ldvalue" + +// Meta defines the serialization format for config metadata. +type Meta struct { + // VersionKey is the version key. + VersionKey string `json:"versionKey,omitempty"` + + // Enabled is true if the config is enabled. + Enabled bool `json:"enabled,omitempty"` +} + +// Model defines the serialization format for a model. +type Model struct { + // Id identifies the model. + Id string `json:"id"` + + // Parameters are the model parameters, generally provided by LaunchDarkly. + Parameters map[string]ldvalue.Value `json:"parameters,omitempty"` + + // Custom are custom model parameters, generally provided by the user. + Custom map[string]ldvalue.Value `json:"custom,omitempty"` +} + +// Provider defines the serialization format for a model provider. +type Provider struct { + // Id identifies the provider. + Id string `json:"id"` +} + +// Role defines the role of a message. +type Role string + +const ( + // User represents the user. + User Role = "user" + + // System represents the system. + System Role = "system" + + // Assistant represents an assistant. + Assistant Role = "assistant" +) + +// Message defines the serialization format for a message which may be passed to an AI model provider. +type Message struct { + // Content is the message content. + Content string `json:"content"` + + // Role is the role of the message. + Role Role `json:"role"` +} + +// Config defines the serialization format for an AI config. +type Config struct { + // Messages is a list of messages. The messages received from LaunchDarkly are uninterpolated. + Messages []Message `json:"messages,omitempty"` + + // Meta is the config metadata. + Meta Meta `json:"_ldMeta,omitempty"` + + // Model is the model. + Model Model `json:"model,omitempty"` + + // Provider is the provider. + Provider Provider `json:"provider,omitempty"` +} diff --git a/ldai/go.mod b/ldai/go.mod index 08aa0707..b2392185 100644 --- a/ldai/go.mod +++ b/ldai/go.mod @@ -1,3 +1,22 @@ module github.com/launchdarkly/go-server-sdk/ldai go 1.18 + +require ( + github.com/alexkappa/mustache v1.0.0 + github.com/launchdarkly/go-sdk-common/v3 v3.2.0 + github.com/launchdarkly/go-server-sdk/v7 v7.7.0 + github.com/stretchr/testify v1.9.0 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/google/uuid v1.1.1 // indirect + github.com/josharian/intern v1.0.0 // indirect + github.com/launchdarkly/go-jsonstream/v3 v3.1.0 // indirect + github.com/launchdarkly/go-sdk-events/v3 v3.4.0 // indirect + github.com/mailru/easyjson v0.7.7 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/ldai/go.sum b/ldai/go.sum new file mode 100644 index 00000000..f80bba5b --- /dev/null +++ b/ldai/go.sum @@ -0,0 +1,32 @@ +github.com/alexkappa/mustache v1.0.0 h1:GeF7AKKpKKVq8emIwYRQPsKPDWSGYQGWWSsddge62M4= +github.com/alexkappa/mustache v1.0.0/go.mod h1:6v0WNoCZEQ8K5OZAv82ScIARg2bDqFD+Jl0LWxnApas= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY= +github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= +github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/launchdarkly/go-jsonstream/v3 v3.1.0 h1:U/7/LplZO72XefBQ+FzHf6o4FwLHVqBE+4V58Ornu/E= +github.com/launchdarkly/go-jsonstream/v3 v3.1.0/go.mod h1:2Pt4BR5AwWgsuVTCcIpB6Os04JFIKWfoA+7faKkZB5E= +github.com/launchdarkly/go-sdk-common/v3 v3.2.0 h1:LzwlrXRBPC7NjdbnDxio8YGHMvDrNb4i6lbjpLgwsyk= +github.com/launchdarkly/go-sdk-common/v3 v3.2.0/go.mod h1:mXFmDGEh4ydK3QilRhrAyKuf9v44VZQWnINyhqbbOd0= +github.com/launchdarkly/go-sdk-events/v3 v3.4.0 h1:22sVSEDEXpdOEK3UBtmThwsUHqc+cbbe/pJfsliBAA4= +github.com/launchdarkly/go-sdk-events/v3 v3.4.0/go.mod h1:oepYWQ2RvvjfL2WxkE1uJJIuRsIMOP4WIVgUpXRPcNI= +github.com/launchdarkly/go-server-sdk/v7 v7.7.0 h1:UZ1Fn28UiIsINLcxnKEmOvBUgrqR2f4zY4WNPaScL0A= +github.com/launchdarkly/go-server-sdk/v7 v7.7.0/go.mod h1:rf/K2E4s5OjkB8Nn3ATDOR6W6S3U7D8FJ3WAKLxSTIQ= +github.com/launchdarkly/go-test-helpers/v3 v3.0.2 h1:rh0085g1rVJM5qIukdaQ8z1XTWZztbJ49vRZuveqiuU= +github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= +github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa h1:ELnwvuAXPNtPk1TJRuGkI9fDTwym6AYBu0qzT8AcHdI= +golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa/go.mod h1:akd2r19cwCdwSwWeIdzYQGa/EZZyqcOdwWiwj5L5eKQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/ldai/tracker.go b/ldai/tracker.go new file mode 100644 index 00000000..8cec8152 --- /dev/null +++ b/ldai/tracker.go @@ -0,0 +1,247 @@ +package ldai + +import ( + "fmt" + "time" + + "github.com/launchdarkly/go-sdk-common/v3/ldcontext" + "github.com/launchdarkly/go-sdk-common/v3/ldvalue" + "github.com/launchdarkly/go-server-sdk/v7/interfaces" +) + +const ( + duration = "$ld:ai:duration:total" + feedbackPositive = "$ld:ai:feedback:user:positive" + feedbackNegative = "$ld:ai:feedback:user:negative" + generation = "$ld:ai:generation" + tokenTotal = "$ld:ai:tokens:total" + tokenInput = "$ld:ai:tokens:input" + tokenOutput = "$ld:ai:tokens:output" +) + +// TokenUsage represents the token usage returned by a model provider for a specific request. +type TokenUsage struct { + // Total is the total number of tokens used. + Total int + // Input is the number of input tokens used. + Input int + // Output is the number of output tokens used. + Output int +} + +// Set returns true if any of the fields are non-zero. +func (t TokenUsage) Set() bool { + return t.Total > 0 || t.Input > 0 || t.Output > 0 +} + +// Metrics represents the metrics returned by a model provider for a specific request. +type Metrics struct { + // Latency is the latency of the request. + Latency time.Duration +} + +// Set returns true if the latency is non-zero. +func (m Metrics) Set() bool { + return m.Latency != 0 +} + +// ProviderResponse represents the response from a model provider for a specific request. +type ProviderResponse struct { + // Usage is the token usage. + Usage TokenUsage + // Metrics is the request metrics. + Metrics Metrics +} + +// Feedback represents the feedback provided by a user for a model evaluation. +type Feedback string + +const ( + // Positive feedback (result was good). + Positive Feedback = "positive" + // Negative feedback (result was bad). + Negative Feedback = "negative" +) + +// EventSink represents the Tracker's requirements for delivering analytic events. This is generally satisfied +// by the LaunchDarkly SDK's TrackMetric method. +type EventSink interface { + // TrackMetric sends a named analytic event to LaunchDarkly relevant to a particular context, and containing a + // metric value and additional data. + TrackMetric( + eventName string, + context ldcontext.Context, + metricValue float64, + data ldvalue.Value, + ) error +} + +// Stopwatch is used to measure the duration of a task. Start will always be called before Stop. +// If an implementation is not provided, the Tracker uses a default implementation that delegates to +// time.Now and time.Since. +type Stopwatch interface { + // Start starts the stopwatch. + Start() + // Stop stops the stopwatch and returns the duration since Start was called. + Stop() time.Duration +} + +// Tracker is used to track metrics for AI config evaluation. +// Unless otherwise noted, the Tracker's method are not safe for concurrent use. +type Tracker struct { + key string + config *Config + context ldcontext.Context + events EventSink + trackData ldvalue.Value + logger interfaces.LDLoggers + stopwatch Stopwatch +} + +// Used if a custom Stopwatch is not provided. +type defaultStopwatch struct { + start time.Time +} + +// Start saves the current time using time.Now. +func (d *defaultStopwatch) Start() { + d.start = time.Now() +} + +// Stop returns the duration since Start was called using time.Since. +func (d *defaultStopwatch) Stop() time.Duration { + return time.Since(d.start) +} + +// newTracker creates a new Tracker with the specified key, event sink, config, context, and loggers. +func newTracker(key string, events EventSink, config *Config, ctx ldcontext.Context, loggers interfaces.LDLoggers) *Tracker { + return newTrackerWithStopwatch(key, events, config, ctx, loggers, &defaultStopwatch{}) +} + +// newTrackerWithStopwatch creates a new Tracker with the specified key, event sink, config, context, loggers, and +// stopwatch. This method is used for testing purposes. +func newTrackerWithStopwatch(key string, events EventSink, config *Config, ctx ldcontext.Context, loggers interfaces.LDLoggers, stopwatch Stopwatch) *Tracker { + if config == nil { + panic("LaunchDarkly SDK programmer error: config must never be nil") + } + + trackData := ldvalue.ObjectBuild(). + Set("versionKey", ldvalue.String(config.VersionKey())). + Set("configKey", ldvalue.String(key)).Build() + + return &Tracker{ + key: key, + config: config, + trackData: trackData, + events: events, + context: ctx, + logger: loggers, + stopwatch: stopwatch, + } +} + +func (t *Tracker) logWarning(format string, args ...interface{}) { + prefix := "AI config tracker for '" + t.key + "': " + t.logger.Warnf(prefix+format, args...) +} + +// TrackDuration tracks the duration of a task. For example, the duration of a model evaluation request may be +// tracked here. See also TrackRequest. +// The duration in milliseconds must fit within a float64. +func (t *Tracker) TrackDuration(dur time.Duration) error { + return t.events.TrackMetric(duration, t.context, float64(dur.Milliseconds()), t.trackData) +} + +// TrackFeedback tracks the feedback provided by a user for a model evaluation. If the feedback is not +// Positive or Negative, returns an error and does not track anything. +func (t *Tracker) TrackFeedback(feedback Feedback) error { + switch feedback { + case Positive: + return t.events.TrackMetric(feedbackPositive, t.context, 1, t.trackData) + case Negative: + return t.events.TrackMetric(feedbackNegative, t.context, 1, t.trackData) + default: + return fmt.Errorf("tracker: unexpected feedback value: %v", feedback) + } +} + +// TrackSuccess tracks a successful model evaluation. +func (t *Tracker) TrackSuccess() error { + return t.events.TrackMetric(generation, t.context, 1, t.trackData) +} + +// TrackUsage tracks the token usage for a model evaluation. +func (t *Tracker) TrackUsage(usage TokenUsage) error { + var failed bool + + if usage.Total > 0 { + if err1 := t.events.TrackMetric(tokenTotal, t.context, float64(usage.Total), t.trackData); err1 != nil { + t.logWarning("error tracking total token usage: %v", err1) + failed = true + } + } + if usage.Input > 0 { + if err2 := t.events.TrackMetric(tokenInput, t.context, float64(usage.Input), t.trackData); err2 != nil { + t.logWarning("error tracking input token usage: %v", err2) + failed = true + } + } + if usage.Output > 0 { + if err3 := t.events.TrackMetric(tokenOutput, t.context, float64(usage.Output), t.trackData); err3 != nil { + t.logWarning("error tracking output token usage: %v", err3) + failed = true + } + } + + if failed { + return fmt.Errorf("tracker: error tracking token usage, logs contain more information") + } + + return nil +} + +func measureDurationOfTask[T any](stopwatch Stopwatch, task func() (T, error)) (T, time.Duration, error) { + stopwatch.Start() + result, err := task() + return result, stopwatch.Stop(), err +} + +// TrackRequest tracks metrics for a model evaluation request. The task function should return a ProviderResponse +// which can be used to specify request metrics and token usage. +// +// All fields of the ProviderResponse are optional. +// +// If the task returns an error, then the request is not considered successful and no metrics are tracked. +// Otherwise, the following metrics are tracked: +// 1. Successful model evaluation. +// 2. Any metrics that were that set in the ProviderResponse +// 2a) If Latency was not set in the ProviderResponse's Metrics field, an automatically measured duration. +// 3. Any token usage that was set in the ProviderResponse. +func (t *Tracker) TrackRequest(task func() (ProviderResponse, error)) (ProviderResponse, error) { + usage, duration, err := measureDurationOfTask(t.stopwatch, task) + + if err != nil { + t.logWarning("error executing request: %v", err) + return ProviderResponse{}, err + } + if err := t.TrackSuccess(); err != nil { + t.logWarning("error tracking success metric for request: %v", err) + } + + if usage.Metrics.Set() { + if err := t.TrackDuration(usage.Metrics.Latency); err != nil { + t.logWarning("error tracking duration metric (user provided) for request: %v", err) + } + } else { + if err := t.TrackDuration(duration); err != nil { + t.logWarning("error tracking duration metric (automatically measured) for request: %v", err) + } + } + + if usage.Usage.Set() { + // TrackUsage logs errors. + _ = t.TrackUsage(usage.Usage) + } + + return usage, nil +} diff --git a/ldai/tracker_test.go b/ldai/tracker_test.go new file mode 100644 index 00000000..5a74ec7b --- /dev/null +++ b/ldai/tracker_test.go @@ -0,0 +1,239 @@ +package ldai + +import ( + "testing" + "time" + + "github.com/launchdarkly/go-sdk-common/v3/ldcontext" + "github.com/launchdarkly/go-sdk-common/v3/ldlogtest" + "github.com/launchdarkly/go-sdk-common/v3/ldvalue" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type mockEvents struct { + log *ldlogtest.MockLog + events []trackEvent +} + +type trackEvent struct { + name string + context ldcontext.Context + metricValue float64 + data ldvalue.Value +} + +func newMockEvents() *mockEvents { + return &mockEvents{log: ldlogtest.NewMockLog()} +} + +func (m *mockEvents) TrackMetric(eventName string, context ldcontext.Context, metricValue float64, data ldvalue.Value) error { + m.events = append(m.events, trackEvent{name: eventName, context: context, metricValue: metricValue, data: data}) + return nil +} + +func TestTracker_NewPanicsWithNilConfig(t *testing.T) { + assert.Panics(t, func() { + newTracker("key", newMockEvents(), nil, ldcontext.New("key"), nil) + }) +} + +func TestTracker_NewDoesNotPanicWithConfig(t *testing.T) { + assert.NotPanics(t, func() { + newTracker("key", newMockEvents(), &Config{}, ldcontext.New("key"), nil) + }) +} + +func makeTrackData(configKey, versionKey string) ldvalue.Value { + return ldvalue.ObjectBuild(). + Set("versionKey", ldvalue.String(versionKey)). + Set("configKey", ldvalue.String(configKey)).Build() +} + +func TestTracker_TrackSuccess(t *testing.T) { + events := newMockEvents() + tracker := newTracker("key", events, &Config{}, ldcontext.New("key"), nil) + assert.NoError(t, tracker.TrackSuccess()) + + expectedEvent := trackEvent{ + name: "$ld:ai:generation", + context: ldcontext.New("key"), + metricValue: 1.0, + data: makeTrackData("key", ""), + } + + assert.ElementsMatch(t, []trackEvent{expectedEvent}, events.events) +} + +func TestTracker_TrackRequest(t *testing.T) { + events := newMockEvents() + tracker := newTracker("key", events, &Config{}, ldcontext.New("key"), nil) + + expectedResponse := ProviderResponse{ + Usage: TokenUsage{ + Total: 1, + }, + Metrics: Metrics{ + Latency: 10 * time.Millisecond, + }, + } + + r, err := tracker.TrackRequest(func() (ProviderResponse, error) { + return expectedResponse, nil + }) + + assert.NoError(t, err) + assert.Equal(t, expectedResponse, r) + + expectedSuccessEvent := trackEvent{ + name: "$ld:ai:generation", + context: ldcontext.New("key"), + metricValue: 1, + data: makeTrackData("key", ""), + } + + expectedDurationEvent := trackEvent{ + name: "$ld:ai:duration:total", + context: ldcontext.New("key"), + metricValue: 10.0, + data: makeTrackData("key", ""), + } + + expectedTokenUsageEvent := trackEvent{ + name: "$ld:ai:tokens:total", + context: ldcontext.New("key"), + metricValue: 1, + data: makeTrackData("key", ""), + } + + expectedEvents := []trackEvent{expectedSuccessEvent, expectedDurationEvent, expectedTokenUsageEvent} + assert.ElementsMatch(t, expectedEvents, events.events) +} + +type mockStopwatch time.Duration + +func (m mockStopwatch) Start() {} + +func (m mockStopwatch) Stop() time.Duration { + return time.Duration(m) +} + +func TestTracker_LatencyMeasuredIfNotProvided(t *testing.T) { + events := newMockEvents() + + tracker := newTrackerWithStopwatch( + "key", events, &Config{}, ldcontext.New("key"), nil, mockStopwatch(42*time.Millisecond)) + + expectedResponse := ProviderResponse{ + Usage: TokenUsage{ + Total: 1, + }, + } + + r, err := tracker.TrackRequest(func() (ProviderResponse, error) { + return expectedResponse, nil + }) + + assert.NoError(t, err) + assert.Equal(t, expectedResponse, r) + + require.Equal(t, 3, len(events.events)) + gotEvent := events.events[1] + assert.Equal(t, "$ld:ai:duration:total", gotEvent.name) + assert.Equal(t, 42.0, gotEvent.metricValue) +} + +func TestTracker_TrackDuration(t *testing.T) { + events := newMockEvents() + tracker := newTracker("key", events, &Config{}, ldcontext.New("key"), nil) + + assert.NoError(t, tracker.TrackDuration(time.Millisecond*10)) + + expectedEvent := trackEvent{ + name: "$ld:ai:duration:total", + context: ldcontext.New("key"), + metricValue: 10.0, + data: makeTrackData("key", ""), + } + + assert.ElementsMatch(t, []trackEvent{expectedEvent}, events.events) +} + +func TestTracker_TrackFeedback(t *testing.T) { + events := newMockEvents() + tracker := newTracker("key", events, &Config{}, ldcontext.New("key"), nil) + + assert.NoError(t, tracker.TrackFeedback(Positive)) + assert.NoError(t, tracker.TrackFeedback(Negative)) + assert.Error(t, tracker.TrackFeedback("not a valid feedback value")) + + expectedPositiveEvent := trackEvent{ + name: "$ld:ai:feedback:user:positive", + context: ldcontext.New("key"), + metricValue: 1.0, + data: makeTrackData("key", ""), + } + + expectedNegativeEvent := trackEvent{ + name: "$ld:ai:feedback:user:negative", + context: ldcontext.New("key"), + metricValue: 1.0, + data: makeTrackData("key", ""), + } + + assert.ElementsMatch(t, []trackEvent{expectedPositiveEvent, expectedNegativeEvent}, events.events) +} + +func TestTracker_TrackUsage(t *testing.T) { + t.Run("only one field set, only one event", func(t *testing.T) { + events := newMockEvents() + tracker := newTracker("key", events, &Config{}, ldcontext.New("key"), nil) + + assert.NoError(t, tracker.TrackUsage(TokenUsage{ + Total: 42, + })) + + expectedEvent := trackEvent{ + name: "$ld:ai:tokens:total", + context: ldcontext.New("key"), + metricValue: 42.0, + data: makeTrackData("key", ""), + } + + assert.ElementsMatch(t, []trackEvent{expectedEvent}, events.events) + }) + + t.Run("all fields set, all events", func(t *testing.T) { + events := newMockEvents() + tracker := newTracker("key", events, &Config{}, ldcontext.New("key"), nil) + + assert.NoError(t, tracker.TrackUsage(TokenUsage{ + Total: 42, + Input: 20, + Output: 22, + })) + + expectedTotal := trackEvent{ + name: "$ld:ai:tokens:total", + context: ldcontext.New("key"), + metricValue: 42.0, + data: makeTrackData("key", ""), + } + + expectedInput := trackEvent{ + name: "$ld:ai:tokens:input", + context: ldcontext.New("key"), + metricValue: 20.0, + data: makeTrackData("key", ""), + } + + expectedOutput := trackEvent{ + name: "$ld:ai:tokens:output", + context: ldcontext.New("key"), + metricValue: 22.0, + data: makeTrackData("key", ""), + } + + assert.ElementsMatch(t, []trackEvent{expectedTotal, expectedInput, expectedOutput}, events.events) + }) +}