Skip to content

Commit

Permalink
fix: propagate parsed model params into returned AI config
Browse files Browse the repository at this point in the history
  • Loading branch information
cwaldren-ld committed Dec 9, 2024
1 parent ebdc281 commit 45e7b66
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 2 deletions.
8 changes: 8 additions & 0 deletions ldai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,14 @@ func (c *Client) Config(
WithProviderName(parsed.Provider.Name).
WithEnabled(parsed.Meta.Enabled)

for k, v := range parsed.Model.Parameters {
builder.WithModelParam(k, v)
}

for k, v := range parsed.Model.Custom {
builder.WithCustomModelParam(k, v)
}

for i, msg := range parsed.Messages {
content, err := interpolateTemplate(msg.Content, mergedVariables)
if err != nil {
Expand Down
140 changes: 138 additions & 2 deletions ldai/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,77 @@ func TestEvalErrorReturnsDefault(t *testing.T) {
assert.Equal(t, defaultVal, cfg)
}

func TestInvalidConfigReturnsDefault(t *testing.T) {
func TestParseMultipleMessages(t *testing.T) {
json := []byte(`{
"_ldMeta": {"versionKey": "1", "enabled": true},
"messages": [
{"content": "hello", "role": "user"},
{"content": "world", "role": "system"}
]
}`)

client, err := NewClient(newMockSDK(json, nil))
require.NoError(t, err)
require.NotNil(t, client)

cfg, _ := client.Config("key", ldcontext.New("user"), Disabled(), nil)

assert.ElementsMatch(t, cfg.Messages(), []datamodel.Message{
{Content: "hello", Role: datamodel.User},
{Content: "world", Role: datamodel.System},
})
}

func TestParseModelName(t *testing.T) {
tests := []struct {
name string
json []byte
expected string
}{
{"missing", []byte(`{"model": {}}`), ""},
{"empty string", []byte(`{"model": {"name": ""}}`), ""},
{"non-empty string", []byte(`{"model": {"name": "my-model"}}`), "my-model"},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
client, err := NewClient(newMockSDK(test.json, nil))
require.NoError(t, err)
require.NotNil(t, client)

defaultVal := NewConfig().Enable().WithMessage("hello", datamodel.User).Build()
cfg, _ := client.Config("key", ldcontext.New("user"), defaultVal, nil)

assert.Equal(t, test.expected, cfg.ModelName())
})
}
}

func TestParseProviderName(t *testing.T) {
tests := []struct {
name string
json []byte
expected string
}{
{"missing", []byte(`{"provider": {}}`), ""},
{"empty string", []byte(`{"provider": {"name": ""}}`), ""},
{"non-empty string", []byte(`{"provider": {"name": "my-provider"}}`), "my-provider"}}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
client, err := NewClient(newMockSDK(test.json, nil))
require.NoError(t, err)
require.NotNil(t, client)

defaultVal := NewConfig().Enable().WithMessage("hello", datamodel.User).Build()
cfg, _ := client.Config("key", ldcontext.New("user"), defaultVal, nil)

assert.Equal(t, test.expected, cfg.ProviderName())
})
}
}

func TestParseInvalidConfigReturnsDefault(t *testing.T) {
tests := []struct {
name string
json []byte
Expand Down Expand Up @@ -98,7 +168,7 @@ func TestInvalidConfigReturnsDefault(t *testing.T) {
}
}

func TestDisabledConfigs(t *testing.T) {
func TestParseDisabledConfigs(t *testing.T) {
tests := []struct {
name string
json []byte
Expand Down Expand Up @@ -126,6 +196,72 @@ func TestDisabledConfigs(t *testing.T) {
}
}

func TestParseModelParams(t *testing.T) {
tests := []struct {
name string
json []byte
expected map[string]ldvalue.Value
}{
{"omitted", []byte(`{"model": {"name": "model"}}`), nil},
{"empty", []byte(`{"model": {"name": "model", "parameters": {}}}`), map[string]ldvalue.Value{}},
{"single", []byte(`{"model": {"name": "model", "parameters": {"foo": "bar"}}}`),
map[string]ldvalue.Value{"foo": ldvalue.String("bar")}},
{"multiple", []byte(`{"model": {"name": "model", "parameters": {"foo": "bar", "baz": 42}}}`),
map[string]ldvalue.Value{"foo": ldvalue.String("bar"), "baz": ldvalue.Int(42)}},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
client, err := NewClient(newMockSDK(test.json, nil))
require.NoError(t, err)
require.NotNil(t, client)

defaultVal := NewConfig().Enable().WithMessage("hello", datamodel.User).Build()
cfg, _ := client.Config("key", ldcontext.New("user"), defaultVal, nil)

for k, v := range test.expected {
p, ok := cfg.ModelParam(k)
if assert.True(t, ok) {
assert.Equal(t, v, p)
}
}
})
}
}

func TestParseCustomModelParams(t *testing.T) {
tests := []struct {
name string
json []byte
expected map[string]ldvalue.Value
}{
{"omitted", []byte(`{"model": {"name": "model"}}`), nil},
{"empty", []byte(`{"model": {"name": "model", "custom": {}}}`), map[string]ldvalue.Value{}},
{"single", []byte(`{"model": {"name": "model", "custom": {"foo": "bar"}}}`),
map[string]ldvalue.Value{"foo": ldvalue.String("bar")}},
{"multiple", []byte(`{"model": {"name": "model", "custom": {"foo": "bar", "baz": 42}}}`),
map[string]ldvalue.Value{"foo": ldvalue.String("bar"), "baz": ldvalue.Int(42)}},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
client, err := NewClient(newMockSDK(test.json, nil))
require.NoError(t, err)
require.NotNil(t, client)

defaultVal := NewConfig().Enable().WithMessage("hello", datamodel.User).Build()
cfg, _ := client.Config("key", ldcontext.New("user"), defaultVal, nil)

for k, v := range test.expected {
p, ok := cfg.CustomModelParam(k)
if assert.True(t, ok) {
assert.Equal(t, v, p)
}
}
})
}
}

func TestCanSetDefaultConfigFields(t *testing.T) {
client, err := NewClient(newMockSDK(nil, nil))
require.NoError(t, err)
Expand Down

0 comments on commit 45e7b66

Please sign in to comment.