Skip to content

Commit

Permalink
move get token to provider and add estimate for input_tokens
Browse files Browse the repository at this point in the history
Signed-off-by: oilbeater <[email protected]>
  • Loading branch information
oilbeater committed Oct 2, 2024
1 parent e765d74 commit a8dc6f1
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 21 deletions.
31 changes: 11 additions & 20 deletions src/middlewares/analytics.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,20 @@ export function recordAnalytics(
c: Context,
endpoint: string,
duration: number,
prompt_tokens: number,
completion_tokens: number
) {

const getModelName = c.get('getModelName');
const modelName = typeof getModelName === 'function' ? getModelName(c) : 'unknown';

const getTokenCount = c.get('getTokenCount');
const { input_tokens, output_tokens } = typeof getTokenCount === 'function' ? getTokenCount(c) : { input_tokens: 0, output_tokens: 0 };

// console.log(endpoint, c.req.path, modelName, input_tokens, output_tokens, c.get('malacca-cache-status') || 'miss', c.res.status);

if (c.env.MALACCA) {
const getModelName = c.get('getModelName');
const modelName = typeof getModelName === 'function' ? getModelName(c) : 'unknown';
c.env.MALACCA.writeDataPoint({
'blobs': [endpoint, c.req.path, c.res.status, c.get('malacca-cache-status') || 'miss', modelName],
'doubles': [duration, prompt_tokens, completion_tokens],
'doubles': [duration, input_tokens, output_tokens],
'indexes': [endpoint],
});
}
Expand All @@ -25,24 +30,10 @@ export const metricsMiddleware: MiddlewareHandler = async (c, next) => {

c.executionCtx.waitUntil((async () => {
await c.get('bufferPromise')
const buf = c.get('buffer')
const endTime = Date.now();
const duration = endTime - startTime;
const endpoint = c.get('endpoint') || 'unknown';
let prompt_tokens = 0;
let completion_tokens = 0;
if (c.res.status === 200) {
if (c.res.headers.get('content-type') === 'application/json') {
const usage = JSON.parse(buf)['usage'];
if (usage) {
prompt_tokens = usage['prompt_tokens'] | 0;
completion_tokens = usage['completion_tokens'] | 0;
}
} else {
completion_tokens = buf.split('\n\n').length - 1;
}
}
recordAnalytics(c, endpoint, duration, prompt_tokens, completion_tokens);
recordAnalytics(c, endpoint, duration);
})());
};

Expand Down
3 changes: 3 additions & 0 deletions src/middlewares/buffer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ export const bufferMiddleware: MiddlewareHandler = async (c: Context, next: Next
})
c.set('bufferPromise', bufferPromise)

const reqBuffer: string = await c.req.text() || ''
c.set('reqBuffer', reqBuffer)

await next()

const originalResponse = c.res
Expand Down
2 changes: 1 addition & 1 deletion src/middlewares/logging.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ export const loggingMiddleware = async (c: Context, next: Next) => {

// Log request and response
c.executionCtx.waitUntil((async () => {
const requestBody = await c.req.text().catch(() => ({}));
const requestBody = c.get('reqBuffer') || '';
console.log('Request:', {
body: requestBody,
});
Expand Down
26 changes: 26 additions & 0 deletions src/providers/azureOpenAI.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ const azureOpenAIRoute = new Hono();
const initMiddleware = async (c: Context, next: Next) => {
c.set('endpoint', ProviderName);
c.set('getModelName', getModelName);
c.set('getTokenCount', getTokenCount);
await next();
};

Expand All @@ -30,6 +31,7 @@ export const azureOpenAIProvider: AIProvider = {
basePath: BasePath,
route: azureOpenAIRoute,
getModelName: getModelName,
getTokenCount: getTokenCount,
handleRequest: async (c: Context) => {
const resourceName = c.req.param('resource_name') || '';
const deploymentName = c.req.param('deployment_name') || '';
Expand Down Expand Up @@ -78,4 +80,28 @@ function getModelName(c: Context): string {
}
}
return "unknown"
}

function getTokenCount(c: Context): { input_tokens: number, output_tokens: number } {
const buf = c.get('buffer') || ""
if (c.res.status === 200) {
if (c.res.headers.get('content-type') === 'application/json') {
const usage = JSON.parse(buf)['usage'];
if (usage) {
const input_tokens = usage['prompt_tokens'] || 0;
const output_tokens = usage['completion_tokens'] || 0;
return { input_tokens, output_tokens }
}
} else {
// For streaming response, azure openai does not return usage in the response body, so we count the words and multiply by 4/3 to get the number of input tokens
const requestBody = c.get('reqBuffer') || '{}'
const messages = JSON.stringify(JSON.parse(requestBody).messages);
const input_tokens = Math.ceil(messages.split(/\s+/).length * 4 / 3);

// For streaming responses, we count the number of '\n\n' as the number of output tokens
const output_tokens = buf.split('\n\n').length - 1;
return { input_tokens: input_tokens, output_tokens: output_tokens }
}
}
return { input_tokens: 0, output_tokens: 0 }
}
1 change: 1 addition & 0 deletions src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ export interface AIProvider {
name: string;
handleRequest: (c: Context) => Promise<Response>;
getModelName: (c: Context) => string;
getTokenCount: (c: Context) => {input_tokens: number, output_tokens: number};
basePath: string;
route: Hono;
}

0 comments on commit a8dc6f1

Please sign in to comment.