Skip to content

Commit

Permalink
OpenAI-DotNet 8.1.0 (#334)
Browse files Browse the repository at this point in the history
- Fixed streaming event race conditions where the subscriber to the stream would finish before steam events were executed
- Refactored streaming events callbacks from `Action<IServerSentEvent>` to `Func<IServerSentEvent, Task>`
- Added `Exception` data to `OpenAI.Error` response
- Added `ChatEndpoint.StreamCompletionAsync` with `Func<ChatResponse, Task>` overload
  • Loading branch information
StephenHodgson authored Jun 21, 2024
1 parent 3b0f89a commit 3fdd926
Show file tree
Hide file tree
Showing 20 changed files with 295 additions and 87 deletions.
84 changes: 76 additions & 8 deletions OpenAI-DotNet-Tests/TestFixture_03_Threads.cs
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ public async Task Test_03_03_01_CreateRun_Streaming()
var message = await thread.CreateMessageAsync("I need to solve the equation `3x + 11 = 14`. Can you help me?");
Assert.NotNull(message);

var run = await thread.CreateRunAsync(assistant, streamEvent =>
var run = await thread.CreateRunAsync(assistant, async streamEvent =>
{
Console.WriteLine(streamEvent.ToJsonString());

Expand Down Expand Up @@ -300,11 +300,9 @@ public async Task Test_03_03_01_CreateRun_Streaming()
case Error errorEvent:
Assert.NotNull(errorEvent);
break;
//default:
// handle event not already processed by library
// var @event = JsonSerializer.Deserialize<T>(streamEvent.ToJsonString());
//break;
}

await Task.CompletedTask;
});

Assert.IsNotNull(run);
Expand Down Expand Up @@ -343,7 +341,7 @@ public async Task Test_03_03_02_CreateRun_Streaming_ToolCalls()

try
{
async void StreamEventHandler(IServerSentEvent streamEvent)
async Task StreamEventHandler(IServerSentEvent streamEvent)
{
try
{
Expand Down Expand Up @@ -469,9 +467,10 @@ public async Task Test_04_02_CreateThreadAndRun_Streaming()
try
{
var run = await assistant.CreateThreadAndRunAsync("I need to solve the equation `3x + 11 = 14`. Can you help me?",
@event =>
async @event =>
{
Console.WriteLine(@event.ToJsonString());
await Task.CompletedTask;
});
Assert.IsNotNull(run);
thread = await run.GetThreadAsync();
Expand Down Expand Up @@ -500,7 +499,76 @@ public async Task Test_04_02_CreateThreadAndRun_Streaming()
}

[Test]
public async Task Test_04_03_CreateThreadAndRun_SubmitToolOutput()
public async Task Test_04_03_CreateThreadAndRun_Streaming_ToolCalls()
{
Assert.NotNull(OpenAIClient.ThreadsEndpoint);

var tools = new List<Tool>
{
Tool.GetOrCreateTool(typeof(DateTimeUtility), nameof(DateTimeUtility.GetDateTime))
};
var assistantRequest = new CreateAssistantRequest(
instructions: "You are a helpful assistant.",
tools: tools);
var assistant = await OpenAIClient.AssistantsEndpoint.CreateAssistantAsync(assistantRequest);
Assert.IsNotNull(assistant);
ThreadResponse thread = null;
// check if any exceptions thrown in stream event handler
var exceptionThrown = false;

try
{
async Task StreamEventHandler(IServerSentEvent streamEvent)
{
Console.WriteLine($"{streamEvent.ToJsonString()}");

try
{
switch (streamEvent)
{
case ThreadResponse threadResponse:
thread = threadResponse;
break;
case RunResponse runResponse:
if (runResponse.Status == RunStatus.RequiresAction)
{
var toolOutputs = await assistant.GetToolOutputsAsync(runResponse);
var toolRun = await runResponse.SubmitToolOutputsAsync(toolOutputs, StreamEventHandler);
Assert.NotNull(toolRun);
Assert.IsTrue(toolRun.Status == RunStatus.Completed);
}
break;
case Error errorResponse:
throw errorResponse.Exception ?? new Exception(errorResponse.Message);
}
}
catch (Exception e)
{
Console.WriteLine(e);
exceptionThrown = true;
}
}

var run = await assistant.CreateThreadAndRunAsync("What date is it?", StreamEventHandler);
Assert.NotNull(thread);
Assert.IsNotNull(run);
Assert.IsFalse(exceptionThrown);
Assert.IsTrue(run.Status == RunStatus.Completed);
}
finally
{
await assistant.DeleteAsync(deleteToolResources: thread == null);

if (thread != null)
{
var isDeleted = await thread.DeleteAsync(deleteToolResources: true);
Assert.IsTrue(isDeleted);
}
}
}

[Test]
public async Task Test_04_04_CreateThreadAndRun_SubmitToolOutput()
{
var tools = new List<Tool>
{
Expand Down
10 changes: 8 additions & 2 deletions OpenAI-DotNet-Tests/TestFixture_04_Chat.cs
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,10 @@ public async Task Test_02_01_GetChatToolCompletion()
Console.WriteLine($"{message.Role}: {message.Content}");
}

var tools = Tool.GetAllAvailableTools(false, forceUpdate: true, clearCache: true);
var tools = new List<Tool>
{
Tool.GetOrCreateTool(typeof(WeatherService), nameof(WeatherService.GetCurrentWeatherAsync))
};
var chatRequest = new ChatRequest(messages, tools: tools, toolChoice: "none");
var response = await OpenAIClient.ChatEndpoint.GetCompletionAsync(chatRequest);
Assert.IsNotNull(response);
Expand Down Expand Up @@ -211,7 +214,10 @@ public async Task Test_02_02_GetChatToolCompletion_Streaming()
Console.WriteLine($"{message.Role}: {message.Content}");
}

var tools = Tool.GetAllAvailableTools(false);
var tools = new List<Tool>
{
Tool.GetOrCreateTool(typeof(WeatherService), nameof(WeatherService.GetCurrentWeatherAsync))
};
var chatRequest = new ChatRequest(messages, tools: tools, toolChoice: "none");
var response = await OpenAIClient.ChatEndpoint.StreamCompletionAsync(chatRequest, partialResponse =>
{
Expand Down
7 changes: 7 additions & 0 deletions OpenAI-DotNet-Tests/TestServices/WeatherService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,11 @@ public static async Task<string> GetCurrentWeatherAsync(

public static int CelsiusToFahrenheit(int celsius) => (celsius * 9 / 5) + 32;
}

internal static class DateTimeUtility
{
[Function("Get the current date and time.")]
public static async Task<string> GetDateTime()
=> await Task.FromResult(DateTimeOffset.Now.ToString());
}
}
12 changes: 10 additions & 2 deletions OpenAI-DotNet/Assistants/AssistantExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,23 @@ from vectorStoreId in assistant.ToolResources?.FileSearch?.VectorStoreIds
return deleteTasks.TrueForAll(task => task.Result);
}

[Obsolete("use new overload with Func<IServerSentEvent, Task> instead.")]
public static async Task<RunResponse> CreateThreadAndRunAsync(this AssistantResponse assistant, CreateThreadRequest request, Action<IServerSentEvent> streamEventHandler, CancellationToken cancellationToken = default)
=> await CreateThreadAndRunAsync(assistant, request, streamEventHandler == null ? null : async serverSentEvent =>
{
streamEventHandler.Invoke(serverSentEvent);
await Task.CompletedTask;
}, cancellationToken);

/// <summary>
/// Create a thread and run it.
/// </summary>
/// <param name="assistant"><see cref="AssistantResponse"/>.</param>
/// <param name="request">Optional, <see cref="CreateThreadRequest"/>.</param>
/// <param name="streamEventHandler">Optional, <see cref="Action{IStreamEvent}"/> stream callback handler.</param>
/// <param name="streamEventHandler">Optional, <see cref="Func{IServerSentEvent, Task}"/> stream callback handler.</param>
/// <param name="cancellationToken">Optional, <see cref="CancellationToken"/>.</param>
/// <returns><see cref="RunResponse"/>.</returns>
public static async Task<RunResponse> CreateThreadAndRunAsync(this AssistantResponse assistant, CreateThreadRequest request = null, Action<IServerSentEvent> streamEventHandler = null, CancellationToken cancellationToken = default)
public static async Task<RunResponse> CreateThreadAndRunAsync(this AssistantResponse assistant, CreateThreadRequest request = null, Func<IServerSentEvent, Task> streamEventHandler = null, CancellationToken cancellationToken = default)
=> await assistant.Client.ThreadsEndpoint.CreateThreadAndRunAsync(new CreateThreadAndRunRequest(assistant.Id, createThreadRequest: request), streamEventHandler, cancellationToken).ConfigureAwait(false);

#region Tools
Expand Down
31 changes: 26 additions & 5 deletions OpenAI-DotNet/Chat/ChatEndpoint.cs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public async Task<ChatResponse> GetCompletionAsync(ChatRequest chatRequest, Canc
/// Created a completion for the chat message and stream the results to the <paramref name="resultHandler"/> as they come in.
/// </summary>
/// <param name="chatRequest">The chat request which contains the message content.</param>
/// <param name="resultHandler">An action to be called as each new result arrives.</param>
/// <param name="resultHandler">An <see cref="Action{ChatResponse}"/> to be invoked as each new result arrives.</param>
/// <param name="streamUsage">
/// Optional, If set, an additional chunk will be streamed before the 'data: [DONE]' message.
/// The 'usage' field on this chunk shows the token usage statistics for the entire request,
Expand All @@ -54,12 +54,34 @@ public async Task<ChatResponse> GetCompletionAsync(ChatRequest chatRequest, Canc
/// <param name="cancellationToken">Optional, <see cref="CancellationToken"/>.</param>
/// <returns><see cref="ChatResponse"/>.</returns>
public async Task<ChatResponse> StreamCompletionAsync(ChatRequest chatRequest, Action<ChatResponse> resultHandler, bool streamUsage = false, CancellationToken cancellationToken = default)
=> await StreamCompletionAsync(chatRequest, async response =>
{
resultHandler.Invoke(response);
await Task.CompletedTask;
}, streamUsage, cancellationToken);

/// <summary>
/// Created a completion for the chat message and stream the results to the <paramref name="resultHandler"/> as they come in.
/// </summary>
/// <param name="chatRequest">The chat request which contains the message content.</param>
/// <param name="resultHandler">A <see cref="Func{ChatResponse, Task}"/> to to be invoked as each new result arrives.</param>
/// <param name="streamUsage">
/// Optional, If set, an additional chunk will be streamed before the 'data: [DONE]' message.
/// The 'usage' field on this chunk shows the token usage statistics for the entire request,
/// and the 'choices' field will always be an empty array. All other chunks will also include a 'usage' field,
/// but with a null value.
/// </param>
/// <param name="cancellationToken">Optional, <see cref="CancellationToken"/>.</param>
/// <returns><see cref="ChatResponse"/>.</returns>
public async Task<ChatResponse> StreamCompletionAsync(ChatRequest chatRequest, Func<ChatResponse, Task> resultHandler, bool streamUsage = false, CancellationToken cancellationToken = default)
{
if (chatRequest == null) { throw new ArgumentNullException(nameof(chatRequest)); }
if (resultHandler == null) { throw new ArgumentNullException(nameof(resultHandler)); }
chatRequest.Stream = true;
chatRequest.StreamOptions = streamUsage ? new StreamOptions() : null;
ChatResponse chatResponse = null;
using var payload = JsonSerializer.Serialize(chatRequest, OpenAIClient.JsonSerializationOptions).ToJsonStringContent();
using var response = await this.StreamEventsAsync(GetUrl("/completions"), payload, (sseResponse, ssEvent) =>
using var response = await this.StreamEventsAsync(GetUrl("/completions"), payload, async (sseResponse, ssEvent) =>
{
var partialResponse = sseResponse.Deserialize<ChatResponse>(ssEvent, client);

Expand All @@ -72,13 +94,12 @@ public async Task<ChatResponse> StreamCompletionAsync(ChatRequest chatRequest, A
chatResponse.AppendFrom(partialResponse);
}

resultHandler?.Invoke(partialResponse);

await resultHandler.Invoke(partialResponse);
}, cancellationToken);

if (chatResponse == null) { return null; }
chatResponse.SetResponseData(response.Headers, client);
resultHandler?.Invoke(chatResponse);
await resultHandler.Invoke(chatResponse);
return chatResponse;
}

Expand Down
9 changes: 8 additions & 1 deletion OpenAI-DotNet/Chat/ChatResponse.cs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,14 @@ internal void AppendFrom(ChatResponse other)
{
if (other is null) { return; }

if (!string.IsNullOrWhiteSpace(other.Id))
if (!string.IsNullOrWhiteSpace(Id))
{
if (Id != other.Id)
{
throw new InvalidOperationException($"Attempting to append a different object than the original! {Id} != {other.Id}");
}
}
else
{
Id = other.Id;
}
Expand Down
13 changes: 13 additions & 0 deletions OpenAI-DotNet/Common/Error.cs
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
// Licensed under the MIT License. See LICENSE in the project root for license information.

using System;
using System.Text;
using System.Text.Json.Serialization;

namespace OpenAI
{
public sealed class Error : BaseResponse, IServerSentEvent
{
public Error() { }

internal Error(Exception e)
{
Type = e.GetType().Name;
Message = e.Message;
Exception = e;
}

/// <summary>
/// An error code identifying the error type.
/// </summary>
Expand Down Expand Up @@ -50,6 +60,9 @@ public sealed class Error : BaseResponse, IServerSentEvent
[JsonIgnore]
public string Object => "error";

[JsonIgnore]
public Exception Exception { get; }

public override string ToString()
{
var builder = new StringBuilder();
Expand Down
3 changes: 3 additions & 0 deletions OpenAI-DotNet/Common/Function.cs
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ public static Function FromFunc<T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, TR
/// </summary>
[JsonInclude]
[JsonPropertyName("description")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public string Description { get; private set; }

private string parametersString;
Expand All @@ -166,6 +167,7 @@ public static Function FromFunc<T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, TR
/// </summary>
[JsonInclude]
[JsonPropertyName("parameters")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public JsonNode Parameters
{
get
Expand All @@ -190,6 +192,7 @@ public JsonNode Parameters
/// </summary>
[JsonInclude]
[JsonPropertyName("arguments")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public JsonNode Arguments
{
get
Expand Down
15 changes: 7 additions & 8 deletions OpenAI-DotNet/Extensions/BaseEndpointExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ internal static class BaseEndpointExtensions
/// <summary>
/// https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events
/// </summary>
public static async Task<HttpResponseMessage> StreamEventsAsync(this OpenAIBaseEndpoint baseEndpoint, string endpoint, StringContent payload, Action<HttpResponseMessage, ServerSentEvent> eventCallback, CancellationToken cancellationToken)
public static async Task<HttpResponseMessage> StreamEventsAsync(this OpenAIBaseEndpoint baseEndpoint, string endpoint, StringContent payload, Func<HttpResponseMessage, ServerSentEvent, Task> eventCallback, CancellationToken cancellationToken)
{
using var request = new HttpRequestMessage(HttpMethod.Post, endpoint);
request.Content = payload;
Expand All @@ -34,7 +34,7 @@ public static async Task<HttpResponseMessage> StreamEventsAsync(this OpenAIBaseE

try
{
while (await reader.ReadLineAsync() is { } streamData)
while (await reader.ReadLineAsync().ConfigureAwait(false) is { } streamData)
{
if (isEndOfStream)
{
Expand All @@ -56,11 +56,10 @@ public static async Task<HttpResponseMessage> StreamEventsAsync(this OpenAIBaseE
string value;
string data;

Match match = matches[i];
var match = matches[i];

// If the field type is not provided, treat it as a comment
type = ServerSentEvent.EventMap.GetValueOrDefault(match.Groups[nameof(type)].Value.Trim(), ServerSentEventKind.Comment);

// The UTF-8 decode algorithm strips one leading UTF-8 Byte Order Mark (BOM), if any.
value = match.Groups[nameof(value)].Value.TrimStart(' ');
data = match.Groups[nameof(data)].Value;
Expand Down Expand Up @@ -104,17 +103,17 @@ public static async Task<HttpResponseMessage> StreamEventsAsync(this OpenAIBaseE
{
var previousEvent = events.Pop();
previousEvent.Data = @event.Value;
eventCallback?.Invoke(response, previousEvent);
events.Push(previousEvent);
await eventCallback.Invoke(response, previousEvent).ConfigureAwait(false);
}
else
{
events.Push(@event);

if (type != ServerSentEventKind.Event)
{
eventCallback?.Invoke(response, @event);
await eventCallback.Invoke(response, @event).ConfigureAwait(false);
}

events.Push(@event);
}
}
}
Expand Down
7 changes: 6 additions & 1 deletion OpenAI-DotNet/OpenAI-DotNet.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,13 @@ More context [on Roger Pincombe's blog](https://rogerpincombe.com/openai-dotnet-
<AssemblyOriginatorKeyFile>OpenAI-DotNet.pfx</AssemblyOriginatorKeyFile>
<IncludeSymbols>True</IncludeSymbols>
<TreatWarningsAsErrors>True</TreatWarningsAsErrors>
<Version>8.0.3</Version>
<Version>8.1.0</Version>
<PackageReleaseNotes>
Version 8.1.0
- Fixed streaming event race conditions where the subscriber to the stream would finish before steam events were executed
- Refactored streaming events callbacks from Action&lt;IServerSentEvent&gt; to Func&lt;IServerSentEvent, Task&gt;
- Added Exception data to OpenAI.Error response
- Added ChatEndpoint.StreamCompletionAsync with Func&lt;ChatResponse, Task&gt; overload
Version 8.0.3
- Fixed Thread.MessageResponse and Thread.RunStepResponse Delta stream event objects not being properly populated
- Added Thread.MessageDelta.PrintContent()
Expand Down
Loading

0 comments on commit 3fdd926

Please sign in to comment.