diff --git a/ldai/client.go b/ldai/client.go index 5650fd67..1d665c7d 100644 --- a/ldai/client.go +++ b/ldai/client.go @@ -102,6 +102,14 @@ func (c *Client) Config( WithProviderName(parsed.Provider.Name). WithEnabled(parsed.Meta.Enabled) + for k, v := range parsed.Model.Parameters { + builder.WithModelParam(k, v) + } + + for k, v := range parsed.Model.Custom { + builder.WithCustomModelParam(k, v) + } + for i, msg := range parsed.Messages { content, err := interpolateTemplate(msg.Content, mergedVariables) if err != nil { diff --git a/ldai/client_test.go b/ldai/client_test.go index 9745aa27..873bc2c6 100644 --- a/ldai/client_test.go +++ b/ldai/client_test.go @@ -69,7 +69,77 @@ func TestEvalErrorReturnsDefault(t *testing.T) { assert.Equal(t, defaultVal, cfg) } -func TestInvalidConfigReturnsDefault(t *testing.T) { +func TestParseMultipleMessages(t *testing.T) { + json := []byte(`{ + "_ldMeta": {"versionKey": "1", "enabled": true}, + "messages": [ + {"content": "hello", "role": "user"}, + {"content": "world", "role": "system"} + ] + }`) + + client, err := NewClient(newMockSDK(json, nil)) + require.NoError(t, err) + require.NotNil(t, client) + + cfg, _ := client.Config("key", ldcontext.New("user"), Disabled(), nil) + + assert.ElementsMatch(t, cfg.Messages(), []datamodel.Message{ + {Content: "hello", Role: datamodel.User}, + {Content: "world", Role: datamodel.System}, + }) +} + +func TestParseModelName(t *testing.T) { + tests := []struct { + name string + json []byte + expected string + }{ + {"missing", []byte(`{"model": {}}`), ""}, + {"empty string", []byte(`{"model": {"name": ""}}`), ""}, + {"non-empty string", []byte(`{"model": {"name": "my-model"}}`), "my-model"}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + client, err := NewClient(newMockSDK(test.json, nil)) + require.NoError(t, err) + require.NotNil(t, client) + + defaultVal := NewConfig().Enable().WithMessage("hello", datamodel.User).Build() + cfg, _ := client.Config("key", ldcontext.New("user"), defaultVal, nil) + + assert.Equal(t, test.expected, cfg.ModelName()) + }) + } +} + +func TestParseProviderName(t *testing.T) { + tests := []struct { + name string + json []byte + expected string + }{ + {"missing", []byte(`{"provider": {}}`), ""}, + {"empty string", []byte(`{"provider": {"name": ""}}`), ""}, + {"non-empty string", []byte(`{"provider": {"name": "my-provider"}}`), "my-provider"}} + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + client, err := NewClient(newMockSDK(test.json, nil)) + require.NoError(t, err) + require.NotNil(t, client) + + defaultVal := NewConfig().Enable().WithMessage("hello", datamodel.User).Build() + cfg, _ := client.Config("key", ldcontext.New("user"), defaultVal, nil) + + assert.Equal(t, test.expected, cfg.ProviderName()) + }) + } +} + +func TestParseInvalidConfigReturnsDefault(t *testing.T) { tests := []struct { name string json []byte @@ -98,7 +168,7 @@ func TestInvalidConfigReturnsDefault(t *testing.T) { } } -func TestDisabledConfigs(t *testing.T) { +func TestParseDisabledConfigs(t *testing.T) { tests := []struct { name string json []byte @@ -126,6 +196,72 @@ func TestDisabledConfigs(t *testing.T) { } } +func TestParseModelParams(t *testing.T) { + tests := []struct { + name string + json []byte + expected map[string]ldvalue.Value + }{ + {"omitted", []byte(`{"model": {"name": "model"}}`), nil}, + {"empty", []byte(`{"model": {"name": "model", "parameters": {}}}`), map[string]ldvalue.Value{}}, + {"single", []byte(`{"model": {"name": "model", "parameters": {"foo": "bar"}}}`), + map[string]ldvalue.Value{"foo": ldvalue.String("bar")}}, + {"multiple", []byte(`{"model": {"name": "model", "parameters": {"foo": "bar", "baz": 42}}}`), + map[string]ldvalue.Value{"foo": ldvalue.String("bar"), "baz": ldvalue.Int(42)}}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + client, err := NewClient(newMockSDK(test.json, nil)) + require.NoError(t, err) + require.NotNil(t, client) + + defaultVal := NewConfig().Enable().WithMessage("hello", datamodel.User).Build() + cfg, _ := client.Config("key", ldcontext.New("user"), defaultVal, nil) + + for k, v := range test.expected { + p, ok := cfg.ModelParam(k) + if assert.True(t, ok) { + assert.Equal(t, v, p) + } + } + }) + } +} + +func TestParseCustomModelParams(t *testing.T) { + tests := []struct { + name string + json []byte + expected map[string]ldvalue.Value + }{ + {"omitted", []byte(`{"model": {"name": "model"}}`), nil}, + {"empty", []byte(`{"model": {"name": "model", "custom": {}}}`), map[string]ldvalue.Value{}}, + {"single", []byte(`{"model": {"name": "model", "custom": {"foo": "bar"}}}`), + map[string]ldvalue.Value{"foo": ldvalue.String("bar")}}, + {"multiple", []byte(`{"model": {"name": "model", "custom": {"foo": "bar", "baz": 42}}}`), + map[string]ldvalue.Value{"foo": ldvalue.String("bar"), "baz": ldvalue.Int(42)}}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + client, err := NewClient(newMockSDK(test.json, nil)) + require.NoError(t, err) + require.NotNil(t, client) + + defaultVal := NewConfig().Enable().WithMessage("hello", datamodel.User).Build() + cfg, _ := client.Config("key", ldcontext.New("user"), defaultVal, nil) + + for k, v := range test.expected { + p, ok := cfg.CustomModelParam(k) + if assert.True(t, ok) { + assert.Equal(t, v, p) + } + } + }) + } +} + func TestCanSetDefaultConfigFields(t *testing.T) { client, err := NewClient(newMockSDK(nil, nil)) require.NoError(t, err)