From c78e7a4ec3c986f748bd6478d62c7561dea31cf1 Mon Sep 17 00:00:00 2001 From: David Luong Date: Fri, 16 Aug 2024 23:19:34 -0400 Subject: [PATCH] Add tools for ollama --- .../Create_Ollama_Agent_With_Tool.cs | 113 ++++++++++++++++++ .../sample/AutoGen.Ollama.Sample/Program.cs | 2 +- .../src/AutoGen.Ollama/Agent/OllamaAgent.cs | 12 +- dotnet/src/AutoGen.Ollama/DTOs/ChatRequest.cs | 7 ++ dotnet/src/AutoGen.Ollama/DTOs/Message.cs | 27 ++++- dotnet/src/AutoGen.Ollama/DTOs/Tools.cs | 52 ++++++++ .../Middlewares/OllamaMessageConnector.cs | 83 +++++++++++-- .../AnthropicClientAgentTest.cs | 6 +- .../AutoGen.Ollama.Tests.csproj | 1 + .../AutoGen.Ollama.Tests/OllamaAgentTests.cs | 67 ++++++++++- .../OllamaTestFunctionCalls.cs | 40 +++++++ .../AutoGen.Ollama.Tests/OllamaTestUtils.cs | 39 ++++++ 12 files changed, 428 insertions(+), 21 deletions(-) create mode 100644 dotnet/sample/AutoGen.Ollama.Sample/Create_Ollama_Agent_With_Tool.cs create mode 100644 dotnet/src/AutoGen.Ollama/DTOs/Tools.cs create mode 100644 dotnet/test/AutoGen.Ollama.Tests/OllamaTestFunctionCalls.cs create mode 100644 dotnet/test/AutoGen.Ollama.Tests/OllamaTestUtils.cs diff --git a/dotnet/sample/AutoGen.Ollama.Sample/Create_Ollama_Agent_With_Tool.cs b/dotnet/sample/AutoGen.Ollama.Sample/Create_Ollama_Agent_With_Tool.cs new file mode 100644 index 00000000000..9b792b1ffd1 --- /dev/null +++ b/dotnet/sample/AutoGen.Ollama.Sample/Create_Ollama_Agent_With_Tool.cs @@ -0,0 +1,113 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Create_Ollama_Agent_With_Tool.cs + +using AutoGen.Core; +using AutoGen.Ollama.Extension; +using FluentAssertions; + +namespace AutoGen.Ollama.Sample; + +#region WeatherFunction +public partial class WeatherFunction +{ + /// + /// Gets the weather based on the location and the unit + /// + /// + /// + /// + [Function] + public async Task GetWeather(string location, string unit) + { + // dummy implementation + return $"The weather in {location} is currently sunny with a tempature of {unit} (s)"; + } +} +#endregion + +public class Create_Ollama_Agent_With_Tool +{ + public static async Task RunAsync() + { + #region define_tool + var tool = new Tool() + { + Function = new Function + { + Name = "get_current_weather", + Description = "Get the current weather for a location", + Parameters = new Parameters + { + Properties = new Dictionary + { + { + "location", + new Properties + { + Type = "string", Description = "The location to get the weather for, e.g. San Francisco, CA" + } + }, + { + "format", new Properties + { + Type = "string", + Description = + "The format to return the weather in, e.g. 'celsius' or 'fahrenheit'", + Enum = new List {"celsius", "fahrenheit"} + } + } + }, + Required = new List { "location", "format" } + } + } + }; + + var weatherFunction = new WeatherFunction(); + var functionMiddleware = new FunctionCallMiddleware( + functions: [ + weatherFunction.GetWeatherFunctionContract, + ], + functionMap: new Dictionary>> + { + { weatherFunction.GetWeatherFunctionContract.Name!, weatherFunction.GetWeatherWrapper }, + }); + + #endregion + + #region create_ollama_agent_llama3.1 + + var agent = new OllamaAgent( + new HttpClient { BaseAddress = new Uri("http://localhost:11434") }, + "MyAgent", + "llama3.1", + tools: [tool]); + #endregion + + // TODO cannot stream + #region register_middleware + var agentWithConnector = agent + .RegisterMessageConnector() + .RegisterPrintMessage() + .RegisterStreamingMiddleware(functionMiddleware); + #endregion register_middleware + + #region single_turn + var question = new TextMessage(Role.Assistant, + "What is the weather like in San Francisco?", + from: "user"); + var functionCallReply = await agentWithConnector.SendAsync(question); + #endregion + + #region Single_turn_verify_reply + functionCallReply.Should().BeOfType(); + #endregion Single_turn_verify_reply + + #region Multi_turn + var finalReply = await agentWithConnector.SendAsync(chatHistory: [question, functionCallReply]); + #endregion Multi_turn + + #region Multi_turn_verify_reply + finalReply.Should().BeOfType(); + #endregion Multi_turn_verify_reply + } +} diff --git a/dotnet/sample/AutoGen.Ollama.Sample/Program.cs b/dotnet/sample/AutoGen.Ollama.Sample/Program.cs index 62c92eebe7e..174b1a95f7a 100644 --- a/dotnet/sample/AutoGen.Ollama.Sample/Program.cs +++ b/dotnet/sample/AutoGen.Ollama.Sample/Program.cs @@ -3,4 +3,4 @@ using AutoGen.Ollama.Sample; -await Chat_With_LLaVA.RunAsync(); +await Create_Ollama_Agent_With_Tool.RunAsync(); diff --git a/dotnet/src/AutoGen.Ollama/Agent/OllamaAgent.cs b/dotnet/src/AutoGen.Ollama/Agent/OllamaAgent.cs index 87b176d8bcc..fb47e035273 100644 --- a/dotnet/src/AutoGen.Ollama/Agent/OllamaAgent.cs +++ b/dotnet/src/AutoGen.Ollama/Agent/OllamaAgent.cs @@ -24,16 +24,23 @@ public class OllamaAgent : IStreamingAgent private readonly string _modelName; private readonly string _systemMessage; private readonly OllamaReplyOptions? _replyOptions; + private readonly Tool[]? _tools; public OllamaAgent(HttpClient httpClient, string name, string modelName, string systemMessage = "You are a helpful AI assistant", - OllamaReplyOptions? replyOptions = null) + OllamaReplyOptions? replyOptions = null, Tool[]? tools = null) { Name = name; _httpClient = httpClient; _modelName = modelName; _systemMessage = systemMessage; _replyOptions = replyOptions; + _tools = tools; + + if (_httpClient.BaseAddress == null) + { + throw new InvalidOperationException($"Please add the base address to httpClient"); + } } public async Task GenerateReplyAsync( @@ -97,7 +104,8 @@ private async Task BuildChatRequest(IEnumerable messages, var request = new ChatRequest { Model = _modelName, - Messages = await BuildChatHistory(messages) + Messages = await BuildChatHistory(messages), + Tools = _tools }; if (options is OllamaReplyOptions replyOptions) diff --git a/dotnet/src/AutoGen.Ollama/DTOs/ChatRequest.cs b/dotnet/src/AutoGen.Ollama/DTOs/ChatRequest.cs index 3b0cf04a1a0..4e2249dbec3 100644 --- a/dotnet/src/AutoGen.Ollama/DTOs/ChatRequest.cs +++ b/dotnet/src/AutoGen.Ollama/DTOs/ChatRequest.cs @@ -50,4 +50,11 @@ public class ChatRequest [JsonPropertyName("keep_alive")] [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] public string? KeepAlive { get; set; } + + /// + /// Tools for the model to use. Not all models currently support tools. + /// Requires stream to be set to false + /// + [JsonPropertyName("tools")] + public IEnumerable? Tools { get; set; } } diff --git a/dotnet/src/AutoGen.Ollama/DTOs/Message.cs b/dotnet/src/AutoGen.Ollama/DTOs/Message.cs index 75f622ff7f0..02c77fe61dc 100644 --- a/dotnet/src/AutoGen.Ollama/DTOs/Message.cs +++ b/dotnet/src/AutoGen.Ollama/DTOs/Message.cs @@ -12,7 +12,7 @@ public Message() { } - public Message(string role, string value) + public Message(string role, string? value = null) { Role = role; Value = value; @@ -27,11 +27,34 @@ public Message(string role, string value) /// the content of the message /// [JsonPropertyName("content")] - public string Value { get; set; } = string.Empty; + public string? Value { get; set; } /// /// (optional): a list of images to include in the message (for multimodal models such as llava) /// [JsonPropertyName("images")] public IList? Images { get; set; } + + /// + /// A list of tools the model wants to use. Not all models currently support tools. + /// Tool call is not supported while streaming. + /// + [JsonPropertyName("tool_calls")] + public IEnumerable? ToolCalls { get; set; } + + public class ToolCall + { + [JsonPropertyName("function")] + public Function? Function { get; set; } + } + + public class Function + { + [JsonPropertyName("name")] + public string? Name { get; set; } + + [JsonPropertyName("arguments")] + public Dictionary? Arguments { get; set; } + } } + diff --git a/dotnet/src/AutoGen.Ollama/DTOs/Tools.cs b/dotnet/src/AutoGen.Ollama/DTOs/Tools.cs new file mode 100644 index 00000000000..62e9cd0f2f5 --- /dev/null +++ b/dotnet/src/AutoGen.Ollama/DTOs/Tools.cs @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Tools.cs + +using System.Collections.Generic; +using System.Text.Json.Serialization; + +namespace AutoGen.Ollama; + +public class Tool +{ + [JsonPropertyName("type")] + public string? Type { get; set; } = "function"; + + [JsonPropertyName("function")] + public Function? Function { get; set; } +} + +public class Function +{ + [JsonPropertyName("name")] + public string? Name { get; set; } + + [JsonPropertyName("description")] + public string? Description { get; set; } + + [JsonPropertyName("parameters")] + public Parameters? Parameters { get; set; } +} + +public class Parameters +{ + [JsonPropertyName("type")] + public string? Type { get; set; } = "object"; + + [JsonPropertyName("properties")] + public Dictionary? Properties { get; set; } + + [JsonPropertyName("required")] + public IEnumerable? Required { get; set; } +} + +public class Properties +{ + [JsonPropertyName("type")] + public string? Type { get; set; } + + [JsonPropertyName("description")] + public string? Description { get; set; } + + [JsonPropertyName("enum")] + public IEnumerable? Enum { get; set; } +} diff --git a/dotnet/src/AutoGen.Ollama/Middlewares/OllamaMessageConnector.cs b/dotnet/src/AutoGen.Ollama/Middlewares/OllamaMessageConnector.cs index 9e85ca12fd9..e6accd6e083 100644 --- a/dotnet/src/AutoGen.Ollama/Middlewares/OllamaMessageConnector.cs +++ b/dotnet/src/AutoGen.Ollama/Middlewares/OllamaMessageConnector.cs @@ -6,6 +6,7 @@ using System.Linq; using System.Net.Http; using System.Runtime.CompilerServices; +using System.Text.Json; using System.Threading; using System.Threading.Tasks; using AutoGen.Core; @@ -24,6 +25,7 @@ public async Task InvokeAsync(MiddlewareContext context, IAgent agent, return reply switch { + IMessage { Content.Message.ToolCalls: not null } messageEnvelope when messageEnvelope.Content.Message.ToolCalls.Any() => ProcessToolCalls(messageEnvelope, agent), IMessage messageEnvelope when messageEnvelope.Content.Message?.Value is string content => new TextMessage(Role.Assistant, content, messageEnvelope.From), IMessage messageEnvelope when messageEnvelope.Content.Message?.Value is null => throw new InvalidOperationException("Message content is null"), _ => reply @@ -73,20 +75,21 @@ private IEnumerable ProcessMessage(IEnumerable messages, IAg { return messages.SelectMany(m => { - if (m is IMessage messageEnvelope) + if (m is IMessage) { return [m]; } - else + + return m switch { - return m switch - { - TextMessage textMessage => ProcessTextMessage(textMessage, agent), - ImageMessage imageMessage => ProcessImageMessage(imageMessage, agent), - MultiModalMessage multiModalMessage => ProcessMultiModalMessage(multiModalMessage, agent), - _ => [m], - }; - } + TextMessage textMessage => ProcessTextMessage(textMessage, agent), + ImageMessage imageMessage => ProcessImageMessage(imageMessage, agent), + ToolCallMessage toolCallMessage => ProcessToolCallMessage(toolCallMessage, agent), + ToolCallResultMessage toolCallResultMessage => ProcessToolCallResultMessage(toolCallResultMessage), + AggregateMessage toolCallAggregateMessage => ProcessToolCallAggregateMessage(toolCallAggregateMessage, agent), + MultiModalMessage multiModalMessage => ProcessMultiModalMessage(multiModalMessage, agent), + _ => [m], + }; }); } @@ -183,4 +186,64 @@ private IEnumerable ProcessTextMessage(TextMessage textMessage, IAgent return [MessageEnvelope.Create(message, agent.Name)]; } } + + private IMessage ProcessToolCalls(IMessage messageEnvelope, IAgent agent) + { + var toolCalls = new List(); + foreach (var messageToolCall in messageEnvelope.Content.Message?.ToolCalls!) + { + toolCalls.Add(new ToolCall( + messageToolCall.Function?.Name ?? string.Empty, + JsonSerializer.Serialize(messageToolCall.Function?.Arguments))); + } + + return new ToolCallMessage(toolCalls, agent.Name) { Content = messageEnvelope.Content.Message.Value }; + } + + private IEnumerable ProcessToolCallMessage(ToolCallMessage toolCallMessage, IAgent agent) + { + var chatMessage = new Message(toolCallMessage.From ?? string.Empty, toolCallMessage.GetContent()) + { + ToolCalls = toolCallMessage.ToolCalls.Select(t => new Message.ToolCall + { + Function = new Message.Function + { + Name = t.FunctionName, + Arguments = JsonSerializer.Deserialize>(t.FunctionArguments), + }, + }), + }; + + return [MessageEnvelope.Create(chatMessage, toolCallMessage.From)]; + } + + private IEnumerable ProcessToolCallResultMessage(ToolCallResultMessage toolCallResultMessage) + { + foreach (var toolCall in toolCallResultMessage.ToolCalls) + { + if (!string.IsNullOrEmpty(toolCall.Result)) + { + return [MessageEnvelope.Create(new Message("tool", toolCall.Result), toolCallResultMessage.From)]; + } + } + + throw new InvalidOperationException("Expected to have at least one tool call result"); + } + + private IEnumerable ProcessToolCallAggregateMessage(AggregateMessage toolCallAggregateMessage, IAgent agent) + { + if (toolCallAggregateMessage.From is { } from && from != agent.Name) + { + var contents = toolCallAggregateMessage.Message2.ToolCalls.Select(t => t.Result); + var messages = + contents.Select(c => new Message("assistant", c ?? throw new ArgumentNullException(nameof(c)))); + + return messages.Select(m => new MessageEnvelope(m, from: from)); + } + + var toolCallMessage = ProcessToolCallMessage(toolCallAggregateMessage.Message1, agent); + var toolCallResult = ProcessToolCallResultMessage(toolCallAggregateMessage.Message2); + + return toolCallMessage.Concat(toolCallResult); + } } diff --git a/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientAgentTest.cs b/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientAgentTest.cs index 085917d419e..44fb7de87d4 100644 --- a/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientAgentTest.cs +++ b/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientAgentTest.cs @@ -170,7 +170,7 @@ public async Task AnthropicAgentFunctionCallMessageTest() ) .RegisterMessageConnector(); - var weatherFunctionArgumets = """ + var weatherFunctionArguments = """ { "city": "Philadelphia", "date": "6/14/2024" @@ -178,8 +178,8 @@ public async Task AnthropicAgentFunctionCallMessageTest() """; var function = new AnthropicTestFunctionCalls(); - var functionCallResult = await function.GetWeatherReportWrapper(weatherFunctionArgumets); - var toolCall = new ToolCall(function.WeatherReportFunctionContract.Name!, weatherFunctionArgumets) + var functionCallResult = await function.GetWeatherReportWrapper(weatherFunctionArguments); + var toolCall = new ToolCall(function.WeatherReportFunctionContract.Name!, weatherFunctionArguments) { ToolCallId = "get_weather", Result = functionCallResult, diff --git a/dotnet/test/AutoGen.Ollama.Tests/AutoGen.Ollama.Tests.csproj b/dotnet/test/AutoGen.Ollama.Tests/AutoGen.Ollama.Tests.csproj index c5ca1955624..86625974d9c 100644 --- a/dotnet/test/AutoGen.Ollama.Tests/AutoGen.Ollama.Tests.csproj +++ b/dotnet/test/AutoGen.Ollama.Tests/AutoGen.Ollama.Tests.csproj @@ -11,6 +11,7 @@ + diff --git a/dotnet/test/AutoGen.Ollama.Tests/OllamaAgentTests.cs b/dotnet/test/AutoGen.Ollama.Tests/OllamaAgentTests.cs index 8a416116ea9..41ad109a8b3 100644 --- a/dotnet/test/AutoGen.Ollama.Tests/OllamaAgentTests.cs +++ b/dotnet/test/AutoGen.Ollama.Tests/OllamaAgentTests.cs @@ -6,6 +6,7 @@ using AutoGen.Ollama.Extension; using AutoGen.Tests; using FluentAssertions; +using Xunit; namespace AutoGen.Ollama.Tests; @@ -49,7 +50,7 @@ public async Task GenerateReplyAsync_ReturnsValidJsonMessageContent_WhenCalled() result.Should().BeOfType>(); result.From.Should().Be(ollamaAgent.Name); - string jsonContent = ((MessageEnvelope)result).Content.Message!.Value; + string jsonContent = ((MessageEnvelope)result).Content.Message!.Value ?? string.Empty; bool isValidJson = IsValidJsonMessage(jsonContent); isValidJson.Should().BeTrue(); } @@ -195,6 +196,66 @@ public async Task ItReturnValidStreamingMessageUsingLLavaAsync() update.TotalDuration.Should().BeGreaterThan(0); } + [Fact] + public async Task GenerateReplyAsync_ReturnsValidToolMessage() + { + var host = @" http://localhost:11434"; + var modelName = "llama3.1"; + + var ollamaAgent = BuildOllamaAgent(host, modelName, [OllamaTestUtils.WeatherTool]); + var message = new Message("user", "What is the weather today?"); + var messages = new IMessage[] { MessageEnvelope.Create(message, from: modelName) }; + + var result = await ollamaAgent.GenerateReplyAsync(messages); + + result.Should().BeOfType>(); + var chatResponse = ((MessageEnvelope)result).Content; + chatResponse.Message.Should().BeOfType(); + chatResponse.Message.Should().NotBeNull(); + var toolCall = chatResponse.Message!.ToolCalls!.First(); + toolCall.Function.Should().NotBeNull(); + toolCall.Function!.Name.Should().Be("get_current_weather"); + toolCall.Function!.Arguments.Should().ContainKey("location"); + toolCall.Function!.Arguments!["location"].Should().Be("San Francisco, CA"); + toolCall.Function!.Arguments!.Should().ContainKey("format"); + toolCall.Function!.Arguments!["format"].Should().BeOneOf("celsius", "fahrenheit"); + } + + [Fact] + public async Task OllamaAgentFunctionCallMessageTest() + { + var host = @" http://localhost:11434"; + var modelName = "llama3.1"; + + var weatherFunctionArguments = """ + { + "city": "Philadelphia", + "date": "6/14/2024" + } + """; + + var function = new OllamaTestFunctionCalls(); + var functionCallResult = await function.GetWeatherReportWrapper(weatherFunctionArguments); + var toolCall = new ToolCall(function.WeatherReportFunctionContract.Name!, weatherFunctionArguments) + { + ToolCallId = "get_weather", + Result = functionCallResult, + }; + + var ollamaAgent = BuildOllamaAgent(host, modelName, [OllamaTestUtils.WeatherTool]).RegisterMessageConnector(); + IMessage[] chatHistory = [ + new TextMessage(Role.User, "what's the weather in Philadelphia?"), + new ToolCallMessage([toolCall], from: "assistant"), + new ToolCallResultMessage([toolCall], from: "user"), + ]; + + var reply = await ollamaAgent.SendAsync(chatHistory: chatHistory); + + reply.Should().BeOfType(); + reply.GetContent().Should().Contain("Philadelphia"); + reply.GetContent().Should().Contain("sunny"); + } + private static bool IsValidJsonMessage(string input) { try @@ -213,12 +274,12 @@ private static bool IsValidJsonMessage(string input) } } - private static OllamaAgent BuildOllamaAgent(string host, string modelName) + private static OllamaAgent BuildOllamaAgent(string host, string modelName, Tool[]? tools = null) { var httpClient = new HttpClient { BaseAddress = new Uri(host) }; - return new OllamaAgent(httpClient, "TestAgent", modelName); + return new OllamaAgent(httpClient, "TestAgent", modelName, tools: tools); } } diff --git a/dotnet/test/AutoGen.Ollama.Tests/OllamaTestFunctionCalls.cs b/dotnet/test/AutoGen.Ollama.Tests/OllamaTestFunctionCalls.cs new file mode 100644 index 00000000000..72eb5db051c --- /dev/null +++ b/dotnet/test/AutoGen.Ollama.Tests/OllamaTestFunctionCalls.cs @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// OllamaTestFunctionCalls.cs + +using System.Text.Json; +using System.Text.Json.Serialization; +using AutoGen.Core; + +namespace AutoGen.Ollama.Tests; + +public partial class OllamaTestFunctionCalls +{ + private class GetWeatherSchema + { + [JsonPropertyName("city")] + public string? City { get; set; } + + [JsonPropertyName("date")] + public string? Date { get; set; } + } + + /// + /// Get weather report + /// + /// city + /// date + [Function] + public async Task WeatherReport(string city, string date) + { + return $"Weather report for {city} on {date} is sunny"; + } + + public Task GetWeatherReportWrapper(string arguments) + { + var schema = JsonSerializer.Deserialize( + arguments, + new JsonSerializerOptions { PropertyNamingPolicy = JsonNamingPolicy.CamelCase }); + + return WeatherReport(schema?.City ?? string.Empty, schema?.Date ?? string.Empty); + } +} diff --git a/dotnet/test/AutoGen.Ollama.Tests/OllamaTestUtils.cs b/dotnet/test/AutoGen.Ollama.Tests/OllamaTestUtils.cs new file mode 100644 index 00000000000..1bc5205b780 --- /dev/null +++ b/dotnet/test/AutoGen.Ollama.Tests/OllamaTestUtils.cs @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// OllamaTestUtils.cs + +namespace AutoGen.Ollama.Tests; + +public static class OllamaTestUtils +{ + public static Tool WeatherTool => new() + { + Function = new Function + { + Name = "get_current_weather", + Description = "Get the current weather for a location", + Parameters = new Parameters + { + Properties = new Dictionary + { + { + "location", + new Properties + { + Type = "string", Description = "The location to get the weather for, e.g. San Francisco, CA" + } + }, + { + "format", new Properties + { + Type = "string", + Description = + "The format to return the weather in, e.g. 'celsius' or 'fahrenheit'", + Enum = new List {"celsius", "fahrenheit"} + } + } + }, + Required = new List { "location", "format" } + } + } + }; +}