Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: mesh streaming response support (#894) #895

Merged
merged 1 commit into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package com.xiaomi.data.push.uds.processor;

/**
* @author [email protected]
* @date 2024/11/7 11:56
*/
public interface StreamCallback {

void onContent(String content);

void onComplete();

void onError(Throwable error);

}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,16 @@ public interface UdsProcessor<Request, Response> {

Response processRequest(Request request);

// 新增:判断是否为流式处理器
default boolean isStreamProcessor() {
return false;
}

// 新增:流式处理方法
default void processStream(Request request, StreamCallback callback) {
throw new UnsupportedOperationException("Stream processing not supported");
}


default String cmd() {
return "";
Expand Down
2 changes: 1 addition & 1 deletion jcommon/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-core</artifactId>
<version>1.1.2</version>
<version>1.2.3</version>
<scope>provided</scope>
</dependency>
<dependency>
Expand Down
4 changes: 3 additions & 1 deletion jcommon/rcurve/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,6 @@
+ A well-performing mesh underlying communication framework.
+ Support UDS communication and TCP communication.
+ Support hessian gson protostuff encoding.
+ The performance is pretty good.
+ The performance is pretty good.
+ jvm
+ --add-opens=java.base/java.lang=ALL-UNNAMED --add-opens=java.base/java.util=ALL-UNNAMED --add-opens=java.base/java.math=ALL-UNNAMED --add-opens=java.base/sun.reflect=ALL-UNNAMED --add-exports=java.base/sun.reflect.annotation=ALL-UNNAMED --add-exports=java.base/sun.reflect.generics.reflectiveObjects=ALL-UNNAMED --enable-preview
17 changes: 14 additions & 3 deletions jcommon/rcurve/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
<dependency>
<groupId>run.mone</groupId>
<artifactId>api</artifactId>
<version>1.4.1-jdk20-SNAPSHOT</version>
<version>${submodule-release.version}</version>
</dependency>
<dependency>
<groupId>run.mone</groupId>
Expand All @@ -41,17 +41,28 @@
<artifactId>easy</artifactId>
<version>1.6.0-jdk21-SNAPSHOT</version>
</dependency>


<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-all</artifactId>
<version>4.1.48.Final</version>
<version>4.1.114.Final</version>
</dependency>

<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-transport-native-kqueue</artifactId>
<version>4.1.114.Final</version>
<classifier>osx-aarch_64</classifier>
</dependency>


<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
<version>1.2.3</version>
<scope>provided</scope>
</dependency>

</dependencies>

<build>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
public class NetUtils {

public static EventLoopGroup getEventLoopGroup() {
if (CommonUtils.isMac() && CommonUtils.isArch64()) {
return new NioEventLoopGroup();
}
// if (CommonUtils.isMac() && CommonUtils.isArch64()) {
// return new NioEventLoopGroup();
// }
if (CommonUtils.isWindows()) {
return new NioEventLoopGroup();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import com.xiaomi.data.push.uds.context.TraceContext;
import com.xiaomi.data.push.uds.context.TraceEvent;
import com.xiaomi.data.push.uds.context.UdsClientContext;
import com.xiaomi.data.push.uds.handler.MessageTypes;
import com.xiaomi.data.push.uds.handler.ClientStreamCallback;
import com.xiaomi.data.push.uds.handler.UdsClientConnetManageHandler;
import com.xiaomi.data.push.uds.handler.UdsClientHandler;
import com.xiaomi.data.push.uds.po.UdsCommand;
Expand All @@ -41,7 +43,6 @@
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;

/**
* @author [email protected]
Expand Down Expand Up @@ -134,6 +135,20 @@ public void call(Object msg) {
Send.send(this.channel, command);
}


/**
* 发送OpenAI流式请求
*/
public void stream(UdsCommand command, ClientStreamCallback callback) {
Map<String, String> attachments = command.getAttachments();
// 注册回调
((UdsClientHandler) channel.pipeline().last()).getStreamCallbacks()
.put(attachments.get(MessageTypes.STREAM_ID_KEY), callback);
// 发送请求
Send.send(this.channel, command);
}


@Override
public UdsCommand call(UdsCommand req) {
Stopwatch sw = Stopwatch.createStarted();
Expand All @@ -142,11 +157,11 @@ public UdsCommand call(UdsCommand req) {
long id = req.getId();
try {
CompletableFuture<Object> future = new CompletableFuture<>();
HashMap<String,Object> hashMap = new HashMap<>();
HashMap<String, Object> hashMap = new HashMap<>();
hashMap.put("future", future);
hashMap.put("async", req.isAsync());
hashMap.put("returnType", req.getReturnClass());
reqMap.put(req.getId(),hashMap);
reqMap.put(req.getId(), hashMap);
Channel channel = this.channel;
if (null == channel || !channel.isOpen()) {
log.warn("client channel is close");
Expand All @@ -160,7 +175,7 @@ public UdsCommand call(UdsCommand req) {
wheelTimer.newTimeout(() -> {
log.warn("check async udsClient time out auto close:{},{}", req.getId(), req.getTimeout());
reqMap.remove(req.getId());
}, req.getTimeout()+350);
}, req.getTimeout() + 350);
return req;
}
return (UdsCommand) future.get(req.getTimeout(), TimeUnit.MILLISECONDS);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package com.xiaomi.data.push.uds.handler;

/**
* @author [email protected]
* @date 2024/11/7 10:35
*/
public interface ClientStreamCallback {

void onContent(String content);

void onComplete();

void onError(Throwable error);

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package com.xiaomi.data.push.uds.handler;

/**
* @author [email protected]
* @date 2024/11/6 17:41
*/
public class MessageTypes {

public static final String TYPE_KEY = "messageType";
public static final String TYPE_NORMAL = "normal";
public static final String TYPE_OPENAI = "openai";
public static final String STREAM_ID_KEY = "streamId";
public static final String PROMPT_KEY = "prompt";
public static final String CONTENT_KEY = "content";
public static final String STATUS_KEY = "status";


}
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;

import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
Expand All @@ -42,6 +44,9 @@ public class UdsClientHandler extends SimpleChannelInboundHandler<ByteBuf> {

private ConcurrentHashMap<String, Pair<UdsProcessor<UdsCommand, UdsCommand>,ExecutorService>> processorMap;

@Getter
private final Map<String, ClientStreamCallback> streamCallbacks = new ConcurrentHashMap<>();


public UdsClientHandler(ConcurrentHashMap<String, Pair<UdsProcessor<UdsCommand, UdsCommand>,ExecutorService>> processorMap) {
this.processorMap = processorMap;
Expand Down Expand Up @@ -70,28 +75,67 @@ protected void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) throws Excep
log.warn("processor is null cmd:{}", command.getCmd());
}
} else {
Optional.ofNullable(UdsClient.reqMap.get(command.getId())).ifPresent(f -> {
if (Boolean.TRUE.toString().equals(String.valueOf(f.get("async")))) {
Object res = null;
try {
res = processResult(command, (Class<?>) f.get("returnType"));
if (command.getCode() == 0) {
((CompletableFuture)f.get("future")).complete(res);
} else {
((CompletableFuture)f.get("future")).completeExceptionally(new RuntimeException(res.toString()));
}
} catch (Exception e) {
log.error("async response error,", e);
((CompletableFuture)f.get("future")).completeExceptionally(e);
String messageType = command.getAttachments()
.getOrDefault(MessageTypes.TYPE_KEY, MessageTypes.TYPE_NORMAL);

//流式的操作
if (MessageTypes.TYPE_OPENAI.equals(messageType)) {
handleOpenAIResponse(command);
} else {
handleNormalResponse(command);
}

}
}

private void handleOpenAIResponse(UdsCommand command) {
Map<String, String> attachments = command.getAttachments();
String streamId = attachments.get(MessageTypes.STREAM_ID_KEY);
String content = attachments.get(MessageTypes.CONTENT_KEY);
String status = attachments.get(MessageTypes.STATUS_KEY);

ClientStreamCallback callback = streamCallbacks.get(streamId);
if (callback != null) {
if ("complete".equals(status)) {
callback.onComplete();
streamCallbacks.remove(streamId);
} else if ("error".equals(status)) {
callback.onError(new RuntimeException(content));
streamCallbacks.remove(streamId);
} else {
callback.onContent(content);
}
}
}


private void handleNormalResponse(UdsCommand command) {
// 保持原有的处理逻辑不变
Optional.ofNullable(UdsClient.reqMap.get(command.getId())).ifPresent(f -> {
if (Boolean.TRUE.toString().equals(String.valueOf(f.get("async")))) {
Object res = null;
try {
res = processResult(command, (Class<?>) f.get("returnType"));
if (command.getCode() == 0) {
((CompletableFuture)f.get("future")).complete(res);
} else {
((CompletableFuture)f.get("future")).completeExceptionally(
new RuntimeException(res.toString())
);
}
UdsClient.reqMap.remove(command.getId());
} else {
((CompletableFuture)f.get("future")).complete(command);
} catch (Exception e) {
log.error("async response error,", e);
((CompletableFuture)f.get("future")).completeExceptionally(e);
}
});
}
UdsClient.reqMap.remove(command.getId());
} else {
((CompletableFuture)f.get("future")).complete(command);
}
});
}



@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
log.error("client channelInactive:{}",ctx.channel().id());
Expand Down
Loading
Loading