From b1a3a8c8e8be0c0cc092ad5329b33a07019e8119 Mon Sep 17 00:00:00 2001 From: Casey Waldren Date: Fri, 22 Nov 2024 14:19:49 -0800 Subject: [PATCH] feat: update AI SDK with latest spec changes (#50) * Updates `Model` config to contain `Id`, `Parameters` and `Custom`, which are both dictionaries from string -> `LdValue` * Adds `Provider` config which contains an `Id` * Renames `Prompt` -> `Messages`, update the getters There are no changes to the tracking methods. --- pkgs/sdk/server-ai/src/Config/LdAiConfig.cs | 146 +++++++++++++----- pkgs/sdk/server-ai/src/DataModel/DataModel.cs | 50 +++++- .../src/Interfaces/ILdAiConfigTracker.cs | 2 +- pkgs/sdk/server-ai/src/LdAiClient.cs | 11 +- pkgs/sdk/server-ai/src/LdAiConfigTracker.cs | 2 +- .../src/{Provider => Tracking}/Feedback.cs | 2 +- .../src/{Provider => Tracking}/Usage.cs | 2 +- pkgs/sdk/server-ai/test/InterpolationTests.cs | 8 +- pkgs/sdk/server-ai/test/LdAiClientTest.cs | 135 ++++++++++++++-- pkgs/sdk/server-ai/test/LdAiConfigTest.cs | 32 ++-- .../server-ai/test/LdAiConfigTrackerTest.cs | 2 +- 11 files changed, 313 insertions(+), 79 deletions(-) rename pkgs/sdk/server-ai/src/{Provider => Tracking}/Feedback.cs (85%) rename pkgs/sdk/server-ai/src/{Provider => Tracking}/Usage.cs (94%) diff --git a/pkgs/sdk/server-ai/src/Config/LdAiConfig.cs b/pkgs/sdk/server-ai/src/Config/LdAiConfig.cs index 321dae7b..eb97273d 100644 --- a/pkgs/sdk/server-ai/src/Config/LdAiConfig.cs +++ b/pkgs/sdk/server-ai/src/Config/LdAiConfig.cs @@ -32,6 +32,51 @@ internal Message(string content, Role role) } } + + /// + /// Information about the model provider. + /// + public record ModelProvider + { + /// + /// The ID of the model provider. + /// + public readonly string Id; + + internal ModelProvider(string id) + { + Id = id; + } + } + + /// + /// Information about the model. + /// + public record ModelConfiguration + { + /// + /// The ID of the model. + /// + public readonly string Id; + + /// + /// The model's built-in parameters provided by LaunchDarkly. + /// + public readonly IReadOnlyDictionary Parameters; + + /// + /// The model's custom parameters provided by the user. + /// + public readonly IReadOnlyDictionary Custom; + + internal ModelConfiguration(string id, IReadOnlyDictionary parameters, IReadOnlyDictionary custom) + { + Id = id; + Parameters = parameters; + Custom = custom; + } + } + /// /// Builder for constructing an LdAiConfig instance, which can be passed as the default /// value to the AI Client's method. @@ -39,25 +84,29 @@ internal Message(string content, Role role) public class Builder { private bool _enabled; - private readonly List _prompt; - private readonly Dictionary _modelParams; + private readonly List _messages; + private readonly Dictionary _modelParams; + private readonly Dictionary _customModelParams; + private string _providerId; internal Builder() { _enabled = false; - _prompt = new List(); - _modelParams = new Dictionary(); + _messages = new List(); + _modelParams = new Dictionary(); + _customModelParams = new Dictionary(); + _providerId = ""; } /// - /// Adds a prompt message with the given content and role. The default role is . + /// Adds a message with the given content and role. The default role is . /// /// the content, which may contain Mustache templates /// the role /// a new builder - public Builder AddPromptMessage(string content, Role role = Role.User) + public Builder AddMessage(string content, Role role = Role.User) { - _prompt.Add(new Message(content, role)); + _messages.Add(new Message(content, role)); return this; } @@ -85,66 +134,74 @@ public Builder SetEnabled(bool enabled) } /// - /// Sets a parameter for the model. The value may be any object. + /// Sets a parameter for the model. /// /// the parameter name /// the parameter value /// the builder - public Builder SetModelParam(string name, object value) + public Builder SetModelParam(string name, LdValue value) { _modelParams[name] = value; return this; } + /// + /// Sets a custom parameter for the model. + /// + /// the custom parameter name + /// the custom parameter value + /// the builder + public Builder SetCustomModelParam(string name, LdValue value) + { + _customModelParams[name] = value; + return this; + } + + /// + /// Sets the model provider's ID. By default, this will be the empty string. + /// + /// the ID + /// + public Builder SetModelProviderId(string id) + { + _providerId = id; + return this; + } + /// /// Builds the LdAiConfig instance. /// /// a new LdAiConfig public LdAiConfig Build() { - return new LdAiConfig(_enabled, _prompt, new Meta(), _modelParams); + return new LdAiConfig(_enabled, _messages, new Meta(), new Model {Parameters = _modelParams, Custom = _customModelParams}, new Provider{ Id = _providerId }); } } /// /// The prompts associated with the config. /// - public readonly IReadOnlyList Prompt; + public readonly IReadOnlyList Messages; /// /// The model parameters associated with the config. /// - public readonly IReadOnlyDictionary Model; - + public readonly ModelConfiguration Model; + /// + /// Information about the model provider. + /// + public readonly ModelProvider Provider; - internal LdAiConfig(bool enabled, IEnumerable prompt, Meta meta, IReadOnlyDictionary model) + internal LdAiConfig(bool enabled, IEnumerable messages, Meta meta, Model model, Provider provider) { - Model = model ?? new Dictionary(); - Prompt = prompt?.ToList() ?? new List(); + Model = new ModelConfiguration(model?.Id ?? "", model?.Parameters ?? new Dictionary(), + model?.Custom ?? new Dictionary()); + Messages = messages?.ToList() ?? new List(); VersionKey = meta?.VersionKey ?? ""; Enabled = enabled; + Provider = new ModelProvider(provider?.Id ?? ""); } - - private static LdValue ObjectToValue(object obj) - { - if (obj == null) - { - return LdValue.Null; - } - - return obj switch - { - bool b => LdValue.Of(b), - double d => LdValue.Of(d), - string s => LdValue.Of(s), - IEnumerable list => LdValue.ArrayFrom(list.Select(ObjectToValue)), - IDictionary dict => LdValue.ObjectFrom(dict.ToDictionary(kv => kv.Key, - kv => ObjectToValue(kv.Value))), - _ => LdValue.Null - }; - } - internal LdValue ToLdValue() { return LdValue.ObjectFrom(new Dictionary @@ -155,12 +212,20 @@ internal LdValue ToLdValue() { "versionKey", LdValue.Of(VersionKey) }, { "enabled", LdValue.Of(Enabled) } }) }, - { "prompt", LdValue.ArrayFrom(Prompt.Select(m => LdValue.ObjectFrom(new Dictionary + { "messages", LdValue.ArrayFrom(Messages.Select(m => LdValue.ObjectFrom(new Dictionary { { "content", LdValue.Of(m.Content) }, { "role", LdValue.Of(m.Role.ToString()) } }))) }, - { "model", ObjectToValue(Model) } + { "model", LdValue.ObjectFrom(new Dictionary + { + { "parameters", LdValue.ObjectFrom(Model.Parameters) }, + { "custom", LdValue.ObjectFrom(Model.Custom) } + }) }, + {"provider", LdValue.ObjectFrom(new Dictionary + { + {"id", LdValue.Of(Provider.Id)} + })} }); } @@ -176,7 +241,6 @@ internal LdValue ToLdValue() /// true if enabled public bool Enabled { get; } - /// /// This field meant for internal LaunchDarkly usage. /// @@ -185,7 +249,5 @@ internal LdValue ToLdValue() /// /// Convenient helper that returns a disabled LdAiConfig. /// - public static LdAiConfig Disabled = New().Disable().Build(); - - + public static LdAiConfig Disabled => New().Disable().Build(); } diff --git a/pkgs/sdk/server-ai/src/DataModel/DataModel.cs b/pkgs/sdk/server-ai/src/DataModel/DataModel.cs index a5f4a9e6..9ed72239 100644 --- a/pkgs/sdk/server-ai/src/DataModel/DataModel.cs +++ b/pkgs/sdk/server-ai/src/DataModel/DataModel.cs @@ -69,8 +69,8 @@ public class AiConfig /// /// The prompt. /// - [JsonPropertyName("prompt")] - public List Prompt { get; set; } + [JsonPropertyName("messages")] + public List Messages { get; set; } /// /// LaunchDarkly metadata. @@ -79,8 +79,50 @@ public class AiConfig public Meta Meta { get; set; } /// - /// The model params; + /// The model configuration. /// [JsonPropertyName("model")] - public Dictionary Model { get; set; } + public Model Model { get; set; } + + /// + /// The model provider. + /// + [JsonPropertyName("provider")] + public Provider Provider { get; set; } +} + +/// +/// Represents the JSON serialization of a model. +/// +public class Model +{ + /// + /// The model's ID. + /// + [JsonPropertyName("id")] + public string Id { get; set; } + + /// + /// The model's parameters. These are provided by LaunchDarkly. + /// + [JsonPropertyName("parameters")] + public Dictionary Parameters { get; set; } + + /// + /// The model's custom parameters. These are arbitrary and provided by the user. + /// + [JsonPropertyName("custom")] + public Dictionary Custom { get; set; } +} + +/// +/// Represents the JSON serialization of a model provider. +/// +public class Provider +{ + /// + /// The provider's ID. + /// + [JsonPropertyName("id")] + public string Id { get; set; } } diff --git a/pkgs/sdk/server-ai/src/Interfaces/ILdAiConfigTracker.cs b/pkgs/sdk/server-ai/src/Interfaces/ILdAiConfigTracker.cs index a8d2c4ed..aa652d14 100644 --- a/pkgs/sdk/server-ai/src/Interfaces/ILdAiConfigTracker.cs +++ b/pkgs/sdk/server-ai/src/Interfaces/ILdAiConfigTracker.cs @@ -1,7 +1,7 @@ using System; using System.Threading.Tasks; using LaunchDarkly.Sdk.Server.Ai.Config; -using LaunchDarkly.Sdk.Server.Ai.Provider; +using LaunchDarkly.Sdk.Server.Ai.Tracking; namespace LaunchDarkly.Sdk.Server.Ai.Interfaces; diff --git a/pkgs/sdk/server-ai/src/LdAiClient.cs b/pkgs/sdk/server-ai/src/LdAiClient.cs index 522bcb68..404f6a2f 100644 --- a/pkgs/sdk/server-ai/src/LdAiClient.cs +++ b/pkgs/sdk/server-ai/src/LdAiClient.cs @@ -57,7 +57,6 @@ public ILdAiConfigTracker ModelConfig(string key, Context context, LdAiConfig de return new LdAiConfigTracker(_client, key, defaultValue, context); } - var mergedVariables = new Dictionary { { LdContextVariable, GetAllAttributes(context) } }; if (variables != null) { @@ -75,14 +74,14 @@ public ILdAiConfigTracker ModelConfig(string key, Context context, LdAiConfig de var prompt = new List(); - if (parsed.Prompt != null) + if (parsed.Messages != null) { - for (var i = 0; i < parsed.Prompt.Count; i++) + for (var i = 0; i < parsed.Messages.Count; i++) { try { - var content = InterpolateTemplate(parsed.Prompt[i].Content, mergedVariables); - prompt.Add(new LdAiConfig.Message(content, parsed.Prompt[i].Role)); + var content = InterpolateTemplate(parsed.Messages[i].Content, mergedVariables); + prompt.Add(new LdAiConfig.Message(content, parsed.Messages[i].Role)); } catch (Exception ex) { @@ -93,7 +92,7 @@ public ILdAiConfigTracker ModelConfig(string key, Context context, LdAiConfig de } } - return new LdAiConfigTracker(_client, key, new LdAiConfig(parsed.Meta?.Enabled ?? false, prompt, parsed.Meta, parsed.Model), context); + return new LdAiConfigTracker(_client, key, new LdAiConfig(parsed.Meta?.Enabled ?? false, prompt, parsed.Meta, parsed.Model, parsed.Provider), context); } diff --git a/pkgs/sdk/server-ai/src/LdAiConfigTracker.cs b/pkgs/sdk/server-ai/src/LdAiConfigTracker.cs index 2b548df1..b8dc33b4 100644 --- a/pkgs/sdk/server-ai/src/LdAiConfigTracker.cs +++ b/pkgs/sdk/server-ai/src/LdAiConfigTracker.cs @@ -4,7 +4,7 @@ using System.Threading.Tasks; using LaunchDarkly.Sdk.Server.Ai.Config; using LaunchDarkly.Sdk.Server.Ai.Interfaces; -using LaunchDarkly.Sdk.Server.Ai.Provider; +using LaunchDarkly.Sdk.Server.Ai.Tracking; namespace LaunchDarkly.Sdk.Server.Ai; diff --git a/pkgs/sdk/server-ai/src/Provider/Feedback.cs b/pkgs/sdk/server-ai/src/Tracking/Feedback.cs similarity index 85% rename from pkgs/sdk/server-ai/src/Provider/Feedback.cs rename to pkgs/sdk/server-ai/src/Tracking/Feedback.cs index 8383bb6c..67e1b1a3 100644 --- a/pkgs/sdk/server-ai/src/Provider/Feedback.cs +++ b/pkgs/sdk/server-ai/src/Tracking/Feedback.cs @@ -1,4 +1,4 @@ -namespace LaunchDarkly.Sdk.Server.Ai.Provider; +namespace LaunchDarkly.Sdk.Server.Ai.Tracking; /// /// Feedback about the generated content. diff --git a/pkgs/sdk/server-ai/src/Provider/Usage.cs b/pkgs/sdk/server-ai/src/Tracking/Usage.cs similarity index 94% rename from pkgs/sdk/server-ai/src/Provider/Usage.cs rename to pkgs/sdk/server-ai/src/Tracking/Usage.cs index 633632d5..85a8e6d5 100644 --- a/pkgs/sdk/server-ai/src/Provider/Usage.cs +++ b/pkgs/sdk/server-ai/src/Tracking/Usage.cs @@ -1,4 +1,4 @@ -namespace LaunchDarkly.Sdk.Server.Ai.Provider; +namespace LaunchDarkly.Sdk.Server.Ai.Tracking; /// /// Represents metrics returned by a model provider. diff --git a/pkgs/sdk/server-ai/test/InterpolationTests.cs b/pkgs/sdk/server-ai/test/InterpolationTests.cs index eb5c6b46..b2dbe5d3 100644 --- a/pkgs/sdk/server-ai/test/InterpolationTests.cs +++ b/pkgs/sdk/server-ai/test/InterpolationTests.cs @@ -23,10 +23,10 @@ private string Eval(string prompt, Context context, IReadOnlyDictionary", - "role": "System" + "role": "system" } ] } @@ -41,7 +41,7 @@ private string Eval(string prompt, Context context, IReadOnlyDictionary(); + + var mockLogger = new Mock(); + + mockClient.Setup(x => + x.JsonVariation("foo", It.IsAny(), It.IsAny())).Returns(LdValue.Null); + + mockClient.Setup(x => x.GetLogger()).Returns(mockLogger.Object); + + var client = new LdAiClient(mockClient.Object); + + var tracker = client.ModelConfig("foo", Context.New(ContextKind.Default, "key"), + LdAiConfig.New(). + AddMessage("foo"). + SetModelParam("foo", LdValue.Of("bar")). + SetCustomModelParam("foo", LdValue.Of("baz")). + SetModelProviderId("amazing-provider"). + SetEnabled(true).Build()); + + Assert.True(tracker.Config.Enabled); + Assert.Collection(tracker.Config.Messages, + message => + { + Assert.Equal("foo", message.Content); + Assert.Equal(Role.User, message.Role); + }); + Assert.Equal("amazing-provider", tracker.Config.Provider.Id); + Assert.Equal("bar", tracker.Config.Model.Parameters["foo"].AsString); + Assert.Equal("baz", tracker.Config.Model.Custom["foo"].AsString); + } + [Fact] public void ConfigEnabledReturnsInstance() { @@ -111,8 +145,7 @@ public void ConfigEnabledReturnsInstance() const string json = """ { "_ldMeta": {"versionKey": "1", "enabled": true}, - "model": {}, - "prompt": [{"content": "Hello!", "role": "system"}] + "messages": [{"content": "Hello!", "role": "system"}] } """; @@ -126,13 +159,97 @@ public void ConfigEnabledReturnsInstance() // We shouldn't get this default. var tracker = client.ModelConfig("foo", context, - LdAiConfig.New().AddPromptMessage("Goodbye!").Build()); + LdAiConfig.New().AddMessage("Goodbye!").Build()); - Assert.Collection(tracker.Config.Prompt, + Assert.Collection(tracker.Config.Messages, message => { Assert.Equal("Hello!", message.Content); Assert.Equal(Role.System, message.Role); }); + + Assert.Equal("", tracker.Config.Provider.Id); + Assert.Equal("", tracker.Config.Model.Id); + Assert.Empty(tracker.Config.Model.Custom); + Assert.Empty(tracker.Config.Model.Parameters); + } + + + [Fact] + public void ModelParametersAreParsed() + { + + var mockClient = new Mock(); + + var mockLogger = new Mock(); + + const string json = """ + { + "_ldMeta": {"versionKey": "1", "enabled": true}, + "model" : { + "id": "model-foo", + "parameters": { + "foo": "bar", + "baz": 42 + }, + "custom": { + "foo": "baz", + "baz": 43 + } + } + } + """; + + + mockClient.Setup(x => + x.JsonVariation("foo", It.IsAny(), It.IsAny())).Returns(LdValue.Parse(json)); + + mockClient.Setup(x => x.GetLogger()).Returns(mockLogger.Object); + + var context = Context.New(ContextKind.Default, "key"); + var client = new LdAiClient(mockClient.Object); + + // We shouldn't get this default. + var tracker = client.ModelConfig("foo", context, + LdAiConfig.New().AddMessage("Goodbye!").Build()); + + Assert.Equal("model-foo", tracker.Config.Model.Id); + Assert.Equal("bar", tracker.Config.Model.Parameters["foo"].AsString); + Assert.Equal(42, tracker.Config.Model.Parameters["baz"].AsInt); + Assert.Equal("baz", tracker.Config.Model.Custom["foo"].AsString); + Assert.Equal(43, tracker.Config.Model.Custom["baz"].AsInt); + } + + [Fact] + public void ProviderConfigIsParsed() + { + + var mockClient = new Mock(); + + var mockLogger = new Mock(); + + const string json = """ + { + "_ldMeta": {"versionKey": "1", "enabled": true}, + "provider": { + "id": "amazing-provider" + } + } + """; + + + mockClient.Setup(x => + x.JsonVariation("foo", It.IsAny(), It.IsAny())).Returns(LdValue.Parse(json)); + + mockClient.Setup(x => x.GetLogger()).Returns(mockLogger.Object); + + var context = Context.New(ContextKind.Default, "key"); + var client = new LdAiClient(mockClient.Object); + + // We shouldn't get this default. + var tracker = client.ModelConfig("foo", context, + LdAiConfig.New().AddMessage("Goodbye!").Build()); + + Assert.Equal("amazing-provider", tracker.Config.Provider.Id); } } diff --git a/pkgs/sdk/server-ai/test/LdAiConfigTest.cs b/pkgs/sdk/server-ai/test/LdAiConfigTest.cs index 1bc9eb97..abb53ddd 100644 --- a/pkgs/sdk/server-ai/test/LdAiConfigTest.cs +++ b/pkgs/sdk/server-ai/test/LdAiConfigTest.cs @@ -32,12 +32,12 @@ public void CanDisableAndEnableConfig() public void CanAddPromptMessages() { var config = LdAiConfig.New() - .AddPromptMessage("Hello") - .AddPromptMessage("World", Role.System) - .AddPromptMessage("!", Role.Assistant) + .AddMessage("Hello") + .AddMessage("World", Role.System) + .AddMessage("!", Role.Assistant) .Build(); - Assert.Collection(config.Prompt, + Assert.Collection(config.Messages, message => { Assert.Equal("Hello", message.Content); @@ -55,16 +55,30 @@ public void CanAddPromptMessages() }); } - [Fact] public void CanSetModelParams() { var config = LdAiConfig.New() - .SetModelParam("foo", "bar") - .SetModelParam("baz", 42) + .SetModelParam("foo", LdValue.Of("bar")) + .SetModelParam("baz", LdValue.Of(42)) + .SetCustomModelParam("foo", LdValue.Of("baz")) + .SetCustomModelParam("baz", LdValue.Of(43)) + .Build(); + + Assert.Equal(LdValue.Of("bar"), config.Model.Parameters["foo"]); + Assert.Equal(LdValue.Of(42), config.Model.Parameters["baz"]); + + Assert.Equal(LdValue.Of("baz"), config.Model.Custom["foo"]); + Assert.Equal(LdValue.Of(43), config.Model.Custom["baz"]); + } + + [Fact] + public void CanSetModelProviderId() + { + var config = LdAiConfig.New() + .SetModelProviderId("amazing-provider") .Build(); - Assert.Equal("bar", config.Model["foo"]); - Assert.Equal(42, config.Model["baz"]); + Assert.Equal("amazing-provider", config.Provider.Id); } } diff --git a/pkgs/sdk/server-ai/test/LdAiConfigTrackerTest.cs b/pkgs/sdk/server-ai/test/LdAiConfigTrackerTest.cs index a3f629b4..7c672213 100644 --- a/pkgs/sdk/server-ai/test/LdAiConfigTrackerTest.cs +++ b/pkgs/sdk/server-ai/test/LdAiConfigTrackerTest.cs @@ -2,7 +2,7 @@ using System.Threading.Tasks; using LaunchDarkly.Sdk.Server.Ai.Config; using LaunchDarkly.Sdk.Server.Ai.Interfaces; -using LaunchDarkly.Sdk.Server.Ai.Provider; +using LaunchDarkly.Sdk.Server.Ai.Tracking; using Moq; using Xunit;