Skip to content

Commit

Permalink
fix: optimize code.
Browse files Browse the repository at this point in the history
  • Loading branch information
arthuridea committed Dec 8, 2023
1 parent fa0f18f commit ebaebc5
Show file tree
Hide file tree
Showing 13 changed files with 63 additions and 57 deletions.
12 changes: 6 additions & 6 deletions src/LLMService.Baidu.ErnieVilg/BaiduErnieVilgApiService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public class BaiduErnieVilgApiService : IAIPaintApiService<PaintApplyRequest, Pa
/// </summary>
private readonly ILogger _logger;
/// <summary>
/// Initializes a new instance of the <see cref="BaiduWenxinApiService"/> class.
/// Initializes a new instance of the <see cref="BaiduErnieVilgApiService"/> class.
/// </summary>
/// <param name="factory">The factory.</param>
/// <param name="imgProvider"></param>
Expand All @@ -49,7 +49,7 @@ public BaiduErnieVilgApiService(
private HttpClient GetClient()
{
var client = _httpClientFactory.CreateClient(_api_client_key);
_logger.LogDebug($"[API CLIENT]{_api_client_key} -> {client.BaseAddress}");
_logger.LogDebug("[API CLIENT]{0} -> {1}", _api_client_key, client.BaseAddress);
return client;
}

Expand All @@ -65,10 +65,10 @@ public async Task<PaintResultResponse> Text2Image(PaintApplyRequest request)
request.ConversationId = Guid.NewGuid().ToString();
}

_logger.LogDebug(@$"【CALL ErnieVilg V2】{JsonSerializer.Serialize(request, new JsonSerializerOptions
_logger.LogDebug(@"【CALL ErnieVilg V2】{0}", JsonSerializer.Serialize(request, new JsonSerializerOptions
{
Encoder = System.Text.Encodings.Web.JavaScriptEncoder.UnsafeRelaxedJsonEscaping
})}");
}));

var _client = GetClient();

Expand Down Expand Up @@ -115,8 +115,8 @@ private async Task<PaintResultResponse> challengePaintResult(HttpClient client,

foreach (var image in images)
{
_logger.LogDebug($"{image}");
await _imageProvider.Save(image, $"aigc\\images\\{DateTime.Now.ToString("yyyyMM")}\\{DateTime.Now.ToString("yyyyMMddHHmmssffff")}.jpg");
_logger.LogDebug("{0}", image);
await _imageProvider.Save(image, $"aigc\\images\\{DateTime.Now:yyyyMM}\\{DateTime.Now:yyyyMMddHHmmssffff}.jpg");
}
break;
}
Expand Down
24 changes: 13 additions & 11 deletions src/LLMService.Baidu.Wenxinworkshop/BaiduWenxinApiService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,12 @@ public async Task Chat(ChatRequest request, CancellationToken cancellationToken

await _chatDataProvider.AddChatMessage(conversation, request.Message, "user");

#if DEBUG
_logger.LogDebug(@$"【CALL {request.ModelSchema}{JsonSerializer.Serialize(conversation, new JsonSerializerOptions
{
Encoder = System.Text.Encodings.Web.JavaScriptEncoder.UnsafeRelaxedJsonEscaping
})}");
#endif
#endregion

#region 准备HttpClient和请求对象实体
Expand All @@ -126,15 +128,15 @@ public async Task Chat(ChatRequest request, CancellationToken cancellationToken
{
#region 非流式请求
ChatApiResponse result = new();
var apiResponse = await _client.PostAsJsonAsync(chatApiEndpoint, postdata);
var apiResponse = await _client.PostAsJsonAsync(chatApiEndpoint, postdata, cancellationToken: cancellationToken);
apiResponse.EnsureSuccessStatusCode();

result = await apiResponse.DeserializeAsync<ChatApiResponse>(logger: _logger);
result.ConversationId = request.ConversationId;
result.ModelSchema = request.ModelSchema;

//只有流模式会返回是否结束标识,在非流式请求中直接设置为true.
result.IsEnd = request.Stream ? result.IsEnd : true;
result.IsEnd = !request.Stream || result.IsEnd;

if (!result.NeedClearHistory)
{
Expand All @@ -146,7 +148,7 @@ public async Task Chat(ChatRequest request, CancellationToken cancellationToken
_chatDataProvider.ResetSession(request.ConversationId);
}
response.BuildAIGeneratedResponseFeature();
await response.WriteAsJsonAsync(result);
await response.WriteAsJsonAsync(result, cancellationToken: cancellationToken);

#endregion

Expand All @@ -163,19 +165,19 @@ public async Task Chat(ChatRequest request, CancellationToken cancellationToken
var requestHttpMessage = new HttpRequestMessage
{
Method = HttpMethod.Post,
RequestUri = new Uri($"{_client.BaseAddress}{chatApiEndpoint.Substring(1)}"),
RequestUri = new Uri($"{_client.BaseAddress}{chatApiEndpoint[1..]}"),
Content = content
};
var apiResponse = await _client.SendAsync(requestHttpMessage, HttpCompletionOption.ResponseHeadersRead);
var apiResponse = await _client.SendAsync(requestHttpMessage, HttpCompletionOption.ResponseHeadersRead, cancellationToken);
apiResponse.EnsureSuccessStatusCode();

// 设置响应头
// 设置响应头:返回SSE格式
response.BuildAIGeneratedResponseFeature(true);
// 在读取SSE推流前就开启输出!
await response.Body.FlushAsync();
await response.Body.FlushAsync(cancellationToken);

using var stream = await apiResponse.Content.ReadAsStreamAsync();
using var stream = await apiResponse.Content.ReadAsStreamAsync(cancellationToken);
using var reader = new StreamReader(stream);
string sseSection = string.Empty;
bool isEnd = false;
Expand All @@ -186,15 +188,15 @@ public async Task Chat(ChatRequest request, CancellationToken cancellationToken
sseSection = await reader.ReadLineAsync();
if (!string.IsNullOrEmpty(sseSection))
{
await response.WriteAsync($"{sseSection} \n");
await response.WriteAsync("\n");
await response.Body.FlushAsync();
await response.WriteAsync($"{sseSection} \n", cancellationToken: cancellationToken);
await response.WriteAsync("\n", cancellationToken: cancellationToken);
await response.Body.FlushAsync(cancellationToken);
if (sseSection.Contains("\"is_end\":true"))
{
isEnd = true;
break;
}
await Task.Delay(100);
await Task.Delay(100, cancellationToken);
}
}
#endregion
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ namespace LLMService.Shared.Authentication.Handlers
/// <summary>
///
/// </summary>
/// <seealso cref="LLMServiceHub.Authentication.IAccessTokensCacheManager" />
/// <seealso cref="IAccessTokensCacheManager" />
public class AccessTokensCacheManager : IAccessTokensCacheManager
{
/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public BaiduApiAuthenticationHandler(
/// <param name="accessTokensCacheManager">The access tokens cache manager.</param>
/// <param name="clientCredentials">The client credentials.</param>
/// <param name="accessControlHttpClient">The access control HTTP client.</param>
/// <exception cref="LLMServiceHub.Authentication.AuthenticationHandlerException"></exception>
/// <exception cref="AuthenticationHandlerException"></exception>
public BaiduApiAuthenticationHandler(
IAccessTokensCacheManager accessTokensCacheManager,
ClientCredentials clientCredentials,
Expand All @@ -62,7 +62,7 @@ public BaiduApiAuthenticationHandler(
throw new AuthenticationHandlerException($"{nameof(HttpClient.BaseAddress)} should be set to Identity Server url");
}

if (!(bool)_accessControlHttpClient.BaseAddress?.AbsoluteUri.EndsWith("/"))
if (!(bool)_accessControlHttpClient.BaseAddress?.AbsoluteUri.EndsWith('/'))
{
_accessControlHttpClient.BaseAddress = new Uri(_accessControlHttpClient.BaseAddress.AbsoluteUri + "/");
}
Expand All @@ -78,7 +78,7 @@ public BaiduApiAuthenticationHandler(
/// </returns>
protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
{
HttpResponseMessage response = new HttpResponseMessage(HttpStatusCode.Unauthorized);
HttpResponseMessage response = new(HttpStatusCode.Unauthorized);
try
{
var token = await GetToken();
Expand Down Expand Up @@ -132,30 +132,28 @@ private async Task<TokenResponse> GetToken()
/// </summary>
/// <param name="credentials">The credentials.</param>
/// <returns></returns>
/// <exception cref="LLMServiceHub.Authentication.AuthenticationHandlerException"></exception>
/// <exception cref="AuthenticationHandlerException"></exception>
private async Task<TokenResponse> GetNewToken(ClientCredentials credentials)
{
using (var request = new HttpRequestMessage(HttpMethod.Post, _clientCredentials.TokenEndpoint))
using var request = new HttpRequestMessage(HttpMethod.Post, _clientCredentials.TokenEndpoint);
request.Content = new FormUrlEncodedContent(new[]
{
request.Content = new FormUrlEncodedContent(new[]
{
new KeyValuePair<string, string>("grant_type", "client_credentials"),
new KeyValuePair<string, string>("client_id", credentials.ClientId),
new KeyValuePair<string, string>("client_secret", credentials.ClientSecret),
new KeyValuePair<string, string>("scope", credentials.Scopes)
});

var response = await _accessControlHttpClient.SendAsync(request);

if (response.StatusCode == HttpStatusCode.OK)
{
var tokenResponse = await response.DeserializeAsync<TokenResponse>();
return tokenResponse;
}
var response = await _accessControlHttpClient.SendAsync(request);

var errorMessage = await GetErrorMessageAsync(response);
throw new AuthenticationHandlerException(errorMessage);
if (response.StatusCode == HttpStatusCode.OK)
{
var tokenResponse = await response.DeserializeAsync<TokenResponse>();
return tokenResponse;
}

var errorMessage = await GetErrorMessageAsync(response);
throw new AuthenticationHandlerException(errorMessage);
}

/// <summary>
Expand Down
8 changes: 8 additions & 0 deletions src/LLMService.Shared/LLMServiceConsts.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,14 @@

namespace LLMService.Shared
{
/// <summary>
/// consts of LLM service
/// </summary>
public static class LLMServiceConsts
{
/***** 文心大模型 ****/

#region
//public const string BaiduApiAuthority = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop";
/// <summary>
/// The baidu API authority
Expand All @@ -18,10 +23,12 @@ public static class LLMServiceConsts
/// The baidu wenxin API client name
/// </summary>
public const string BaiduWenxinApiClientName = "_Baidu_Wenxin_Workshop_Client";
#endregion


/***** 智能绘画 ****/

#region
/// <summary>
/// The baidu ernie vilg API authority
/// </summary>
Expand All @@ -30,5 +37,6 @@ public static class LLMServiceConsts
/// The baidu ernie vilg API client name
/// </summary>
public const string BaiduErnieVilgApiClientName = "_Baidu_ErnieVilg_Client";
#endregion
}
}
2 changes: 1 addition & 1 deletion src/LLMService.Shared/Models/BaiduWenxinChatResponse.cs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ public class BaiduWenxinChatResponse /*: IServerSentEventData*/
/// <summary>
///
/// </summary>
/// <seealso cref="LLMServiceHub.Models.BaiduWenxinChatResponse" />
/// <seealso cref="BaiduWenxinChatResponse" />
public class ChatApiResponse : BaiduWenxinChatResponse
{
/// <summary>
Expand Down
2 changes: 1 addition & 1 deletion src/LLMService.Shared/Models/BaiduWenxinMessage.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ namespace LLMService.Shared.Models
/// <summary>
///
/// </summary>
/// <seealso cref="LLMServiceHub.Models.IChatMessage" />
/// <seealso cref="IChatMessage" />
public class BaiduWenxinMessage : IChatMessage
{
/// <summary>
Expand Down
2 changes: 1 addition & 1 deletion src/LLMService.Shared/Models/ChatRequest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ namespace LLMService.Shared.Models
/// <summary>
/// 对话请求实体
/// </summary>
/// <seealso cref="LLMServiceHub.Models.AIFeatureModel" />
/// <seealso cref="AIFeatureModel" />
public class ChatRequest: AIFeatureModel
{
/// <summary>
Expand Down
16 changes: 7 additions & 9 deletions src/LLMService.Shared/ServiceInterfaces/IAIChatApiService.cs
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Mvc;

namespace LLMService.Shared.ServiceInterfaces
namespace LLMService.Shared.ServiceInterfaces
{
/// <summary>
/// AI Chat interface
/// <para>NOTE: Should ALWAYS initialize <seealso cref="IChatDataProvider{TChatMessage}"/> first before DI.</para>
/// </summary>
/// <typeparam name="TChatRequest">The type of the chat request.</typeparam>
/// <typeparam name="TChatResponse">The type of the chat response.</typeparam>
public interface IAIChatApiService<TChatRequest, TChatResponse>
{
/// <summary>
Expand Down
14 changes: 7 additions & 7 deletions src/LLMService.Shared/ServiceInterfaces/IAIPaintApiService.cs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;

namespace LLMService.Shared.ServiceInterfaces
namespace LLMService.Shared.ServiceInterfaces
{
/// <summary>
/// AI Painting interface.
/// <para>NOTE: Should ALWAYS initialize <seealso cref="IImageStorageProvider"/> first before DI.</para>
/// </summary>
/// <typeparam name="TPaintApplyRequest">The type of the paint apply request.</typeparam>
/// <typeparam name="TPaintResultResponse">The type of the paint result response.</typeparam>
public interface IAIPaintApiService<TPaintApplyRequest, TPaintResultResponse>
where TPaintApplyRequest : new()
where TPaintResultResponse : new()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
namespace LLMService.Shared.ServiceInterfaces
{
/// <summary>
///
/// Since AIGC APIs usually provide a temperally resource url, we may want to save aigc media locally.
/// This provider shows a proper way to download resources.
/// </summary>
public interface IImageStorageProvider
{
Expand Down
2 changes: 1 addition & 1 deletion src/LLMServiceHub/LLMServiceHub.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
</ItemGroup>

<ItemGroup>
<PackageReference Include="Asp.Versioning.Mvc" Version="7.1.0" />
<PackageReference Include="Asp.Versioning.Mvc" Version="7.1.1" />
<PackageReference Include="Asp.Versioning.Mvc.ApiExplorer" Version="7.1.0" />
<PackageReference Include="IdentityModel" Version="6.2.0" />
<PackageReference Include="Microsoft.AspNetCore.Authentication.JwtBearer" Version="8.0.0" />
Expand Down
3 changes: 1 addition & 2 deletions src/LLMServiceHub/Startup.cs
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,7 @@ public void ConfigureServices(IServiceCollection services)
return new[] { api.GroupName };
}

var controllerActionDescriptor = api.ActionDescriptor as ControllerActionDescriptor;
if (controllerActionDescriptor != null)
if (api.ActionDescriptor is ControllerActionDescriptor controllerActionDescriptor)
{
return new[] { controllerActionDescriptor.ControllerName };
}
Expand Down

0 comments on commit ebaebc5

Please sign in to comment.