Skip to content

Commit

Permalink
feat: Add deployment name and assembled streaming response to chat logs
Browse files Browse the repository at this point in the history
#346 (#347)

Co-authored-by: Aliaksandr Stsiapanay <[email protected]>
  • Loading branch information
astsiapanay and astsiapanay authored May 31, 2024
1 parent 376fff2 commit f05fdad
Show file tree
Hide file tree
Showing 2 changed files with 325 additions and 1 deletion.
154 changes: 153 additions & 1 deletion src/main/java/com/epam/aidial/core/log/GfLogStore.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
import com.epam.deltix.gflog.api.LogFactory;
import com.epam.deltix.gflog.api.LogLevel;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.node.ArrayNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import io.netty.buffer.ByteBufInputStream;
import io.vertx.core.Vertx;
import io.vertx.core.buffer.Buffer;
import io.vertx.core.http.HttpMethod;
Expand All @@ -22,6 +26,7 @@
import java.time.ZoneId;
import java.time.format.DateTimeFormatter;
import java.util.List;
import java.util.Scanner;

@Slf4j
public class GfLogStore implements LogStore {
Expand Down Expand Up @@ -92,6 +97,10 @@ private void append(ProxyContext context, LogEntry entry) throws JsonProcessingE
append(entry, "}", false);
}

append(entry, ",\"deployment\":\"", false);
append(entry, context.getDeployment().getName(), true);
append(entry, "\"", false);

String sourceDeployment = context.getSourceDeployment();
if (sourceDeployment != null) {
append(entry, ",\"parent_deployment\":\"", false);
Expand All @@ -105,6 +114,15 @@ private void append(ProxyContext context, LogEntry entry) throws JsonProcessingE
append(entry, ProxyUtil.MAPPER.writeValueAsString(executionPath), false);
}

append(entry, ",\"assembled_response\":\"", false);
Buffer responseBody = context.getResponseBody();
if (isStreamingResponse(responseBody)) {
append(entry, assembleStreamingResponse(responseBody), true);
} else {
append(entry, responseBody);
}
append(entry, "\"", false);

append(entry, ",\"trace\":{\"trace_id\":\"", false);
append(entry, context.getTraceId(), true);

Expand Down Expand Up @@ -146,7 +164,6 @@ private void append(ProxyContext context, LogEntry entry) throws JsonProcessingE
append(entry, "\"}}", false);
}


private static void append(LogEntry entry, Buffer buffer) {
if (buffer != null) {
byte[] bytes = buffer.getBytes();
Expand Down Expand Up @@ -199,4 +216,139 @@ private static String formatTimestamp(long timestamp) {
return LocalDateTime.ofInstant(Instant.ofEpochMilli(timestamp), ZoneId.of("UTC"))
.format(DateTimeFormatter.ISO_DATE_TIME);
}

/**
* Assembles streaming response into a single one.
* The assembling process merges chunks of the streaming response one by one using separator: <code>\n*data: *</code>
*
* @param response byte array response to be assembled.
* @return assembled streaming response
*/
static String assembleStreamingResponse(Buffer response) {
try (Scanner scanner = new Scanner(new ByteBufInputStream(response.getByteBuf()))) {
StringBuilder content = new StringBuilder();
ObjectNode last = null;
ObjectNode choice = ProxyUtil.MAPPER.createObjectNode();
ObjectNode message = ProxyUtil.MAPPER.createObjectNode();
choice.set("message", message);
JsonNode usage = null;
JsonNode statistics = null;
JsonNode systemFingerprint = null;
JsonNode model = null;
// each chunk is separated by one or multiple new lines with the prefix: 'data:'
scanner.useDelimiter("\n*data: *");
while (scanner.hasNext()) {
String chunk = scanner.next();
if (chunk.startsWith("[DONE]")) {
break;
}
ObjectNode tree = (ObjectNode) ProxyUtil.MAPPER.readTree(chunk);
if (tree.get("usage") != null) {
usage = tree.get("usage");
}
if (tree.get("statistics") != null) {
statistics = tree.get("statistics");
}
if (tree.get("system_fingerprint") != null) {
systemFingerprint = tree.get("system_fingerprint");
}
if (model == null && tree.get("model") != null) {
model = tree.get("model");
}
last = tree;
ArrayNode choices = (ArrayNode) tree.get("choices");
if (choices == null) {
// skip error message
continue;
}
JsonNode curChoice = choices.get(0);
choice.set("finish_reason", curChoice.get("finish_reason"));
JsonNode delta = curChoice.get("delta");
if (delta.get("custom_content") != null) {
message.set("custom_content", delta.get("custom_content"));
}
if (delta.get("tool_calls") != null) {
message.set("tool_calls", delta.get("tool_calls"));
}
if (delta.get("function_call") != null) {
message.set("function_call", delta.get("function_call"));
}
JsonNode contentNode = delta.get("content");
if (contentNode != null) {
content.append(contentNode.textValue());
}
}

if (last == null) {
log.warn("no chunk is found in streaming response");
return "{}";
}

ObjectNode result = ProxyUtil.MAPPER.createObjectNode();
result.set("id", last.get("id"));
result.put("object", "chat.completion");
result.set("created", last.get("created"));
result.set("model", model);

if (usage != null) {
result.set("usage", usage);
}
if (statistics != null) {
result.set("statistics", statistics);
}
if (systemFingerprint != null) {
result.set("system_fingerprint", systemFingerprint);
}

if (content.isEmpty()) {
// error
return ProxyUtil.convertToString(result);
}

ArrayNode choices = ProxyUtil.MAPPER.createArrayNode();
result.set("choices", choices);
choices.add(choice);
choice.put("index", 0);
message.put("role", "assistant");
message.put("content", content.toString());

return ProxyUtil.convertToString(result);
} catch (Throwable e) {
log.warn("Can't assemble streaming response", e);
return "{}";
}
}

/**
* Determines if the given response is streaming.
* <p>
* Streaming response is spitted into chunks. Each chunk starts with a new line and has a prefix: 'data:'.
* For example<br/>
* <code>
* data: {content: "some text"}
* \n\ndata: {content: "some text"}
* \ndata: [DONE]
* </code>
* </p>
*
* @param response byte array response.
* @return <code>true</code> is the response is streaming.
*/
static boolean isStreamingResponse(Buffer response) {
int i = 0;
for (; i < response.length(); i++) {
byte b = response.getByte(i);
if (!Character.isWhitespace(b)) {
break;
}
}
String dataToken = "data:";
int j = 0;
for (; i < response.length() && j < dataToken.length(); i++, j++) {
if (dataToken.charAt(j) != response.getByte(i)) {
break;
}
}
return j == dataToken.length();
}
}
Loading

0 comments on commit f05fdad

Please sign in to comment.