Skip to content

Commit

Permalink
refactor: update Anthropic provider to accumulate usage and send in f…
Browse files Browse the repository at this point in the history
…inal chunk
  • Loading branch information
kevin-on committed Nov 21, 2024
1 parent 2f293a4 commit 3f0afb5
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 46 deletions.
6 changes: 1 addition & 5 deletions src/components/chat-view/Chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ import {
import { readTFileContent } from '../../utils/obsidian'
import { openSettingsModalWithError } from '../../utils/openSettingsModal'
import { PromptGenerator } from '../../utils/promptGenerator'
import { sumTokenUsages } from '../../utils/token'

import AssistantMessageActions from './AssistantMessageActions'
import ChatUserInput, { ChatUserInputRef } from './chat-input/ChatUserInput'
Expand Down Expand Up @@ -298,10 +297,7 @@ const Chat = forwardRef<ChatRef, ChatProps>((props, ref) => {
content: message.content + content,
metadata: {
...message.metadata,
usage: sumTokenUsages([
message.metadata?.usage,
chunk.usage,
]),
usage: chunk.usage ?? message.metadata?.usage, // Keep existing usage if chunk has no usage data
model: chatModel,
},
}
Expand Down
50 changes: 27 additions & 23 deletions src/core/llm/anthropic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import {
import {
LLMResponseNonStreaming,
LLMResponseStreaming,
ResponseUsage,
} from '../../types/llm/response'

import { BaseLLMProvider } from './base'
Expand Down Expand Up @@ -114,23 +115,22 @@ export class AnthropicProvider implements BaseLLMProvider {
async function* streamResponse(): AsyncIterable<LLMResponseStreaming> {
let messageId = ''
let model = ''
let usage: ResponseUsage = {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
}

for await (const chunk of stream) {
if (chunk.type === 'message_start') {
messageId = chunk.message.id
model = chunk.message.model

yield {
id: messageId,
choices: [],
object: 'chat.completion.chunk',
model: model,
usage: {
prompt_tokens: chunk.message.usage.input_tokens,
completion_tokens: chunk.message.usage.output_tokens,
total_tokens:
chunk.message.usage.input_tokens +
chunk.message.usage.output_tokens,
},
usage = {
prompt_tokens: chunk.message.usage.input_tokens,
completion_tokens: chunk.message.usage.output_tokens,
total_tokens:
chunk.message.usage.input_tokens +
chunk.message.usage.output_tokens,
}
} else if (chunk.type === 'content_block_delta') {
yield AnthropicProvider.parseStreamingResponseChunk(
Expand All @@ -139,19 +139,23 @@ export class AnthropicProvider implements BaseLLMProvider {
model,
)
} else if (chunk.type === 'message_delta') {
yield {
id: messageId,
choices: [],
object: 'chat.completion.chunk',
model: model,
usage: {
prompt_tokens: 0,
completion_tokens: chunk.usage.output_tokens,
total_tokens: chunk.usage.output_tokens,
},
usage = {
prompt_tokens: usage.prompt_tokens,
completion_tokens:
usage.completion_tokens + chunk.usage.output_tokens,
total_tokens: usage.total_tokens + chunk.usage.output_tokens,
}
}
}

// After the stream is complete, yield the final usage
yield {
id: messageId,
choices: [],
object: 'chat.completion.chunk',
model: model,
usage: usage,
}
}

return streamResponse()
Expand Down
5 changes: 5 additions & 0 deletions src/core/llm/manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ class LLMManager implements LLMManagerInterface {
request: LLMRequestStreaming,
options?: LLMOptions,
): Promise<AsyncIterable<LLMResponseStreaming>> {
/*
* OpenAI, OpenAI-compatible, and Anthropic providers include token usage statistics
* in the final chunk of the stream (following OpenAI's behavior).
* Groq and Ollama currently do not support usage statistics for streaming responses.
*/
switch (model.provider) {
case 'openai':
return await this.openaiProvider.streamResponse(model, request, options)
Expand Down
18 changes: 0 additions & 18 deletions src/utils/token.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import { getEncoding } from 'js-tiktoken'

import { ResponseUsage } from '../types/llm/response'

// TODO: Replace js-tiktoken with tiktoken library for better performance
// Note: tiktoken uses WebAssembly, requiring esbuild configuration

Expand All @@ -11,19 +9,3 @@ export async function tokenCount(text: string): Promise<number> {
const encoder = getEncoding('cl100k_base')
return encoder.encode(text).length
}

export function sumTokenUsages(
usages: (ResponseUsage | undefined)[],
): ResponseUsage | undefined {
return usages.reduce((total, current) => {
if (!total && !current) return undefined // If both are undefined, return undefined
if (!total) return current
if (!current) return total

return {
prompt_tokens: total.prompt_tokens + current.prompt_tokens,
completion_tokens: total.completion_tokens + current.completion_tokens,
total_tokens: total.total_tokens + current.total_tokens,
}
})
}

0 comments on commit 3f0afb5

Please sign in to comment.