Skip to content

Commit

Permalink
support openai
Browse files Browse the repository at this point in the history
Signed-off-by: oilbeater <[email protected]>
  • Loading branch information
oilbeater committed Oct 15, 2024
1 parent fff7b6a commit afb05fa
Show file tree
Hide file tree
Showing 3 changed files with 254 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/providers/index.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import { azureOpenAIProvider } from './azureOpenAI';
import { workersAIProvider } from './workersAI';
import { deepseekProvider } from './deepseek';

import { openaiProvider } from './openai';
export const providers = {
'azure-openai': azureOpenAIProvider,
'workers-ai': workersAIProvider,
'deepseek': deepseekProvider,
};
'openai': openaiProvider,
};
115 changes: 115 additions & 0 deletions src/providers/openai.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import { Hono, Context, Next } from 'hono';
import { AIProvider } from '../types';
import {
cacheMiddleware,
metricsMiddleware,
bufferMiddleware,
loggingMiddleware,
virtualKeyMiddleware,
rateLimiterMiddleware,
guardMiddleware,
fallbackMiddleware
} from '../middlewares';

const BasePath = '/openai';
const ProviderName = 'openai';
const openaiRoute = new Hono();

const initMiddleware = async (c: Context, next: Next) => {
c.set('endpoint', ProviderName);
c.set('getModelName', getModelName);
c.set('getTokenCount', getTokenCount);
c.set('getVirtualKey', getVirtualKey);
await next();
};

openaiRoute.use(
initMiddleware,
metricsMiddleware,
loggingMiddleware,
bufferMiddleware,
virtualKeyMiddleware,
rateLimiterMiddleware,
guardMiddleware,
cacheMiddleware,
fallbackMiddleware
);

openaiRoute.post('/*', async (c: Context) => {
return openaiProvider.handleRequest(c);
});

openaiRoute.get('/*', async (c: Context) => {
return c.text('OpenAI endpoint on Malacca.', 200, { 'Content-Type': 'text/plain' });
});

export const openaiProvider: AIProvider = {
name: ProviderName,
basePath: BasePath,
route: openaiRoute,
getModelName: getModelName,
getTokenCount: getTokenCount,
handleRequest: async (c: Context) => {
const functionName = c.req.path.slice(`/openai/`.length);
const openaiEndpoint = `https://api.openai.com/${functionName}`;
console.log('openaiEndpoint', openaiEndpoint);
const headers = new Headers(c.req.header());
if (c.get('middlewares')?.includes('virtualKey')) {
const apiKey: string = c.get('realKey');
if (apiKey) {
headers.set('Authorization', `Bearer ${apiKey}`);
}
}

const response = await fetch(openaiEndpoint, {
method: c.req.method,
body: JSON.stringify(await c.req.json()),
headers: headers
});

return response;
}
};

function getModelName(c: Context): string {
const body = c.get('reqBuffer') || '{}';
const model = JSON.parse(body).model;
return model || "unknown";
}

function getTokenCount(c: Context): { input_tokens: number, output_tokens: number } {
const buf = c.get('buffer') || "";
if (c.res.status === 200) {
if (c.res.headers.get('content-type') === 'application/json') {
try {
const jsonResponse = JSON.parse(buf);
const usage = jsonResponse.usage;
if (usage) {
return {
input_tokens: usage.prompt_tokens || 0,
output_tokens: usage.completion_tokens || 0
};
}
} catch (error) {
console.error("Error parsing response:", error);
}
}
else {
const output = buf.trim().split('\n\n').at(-2);
if (output && output.startsWith('data: ')) {
const usage_message = JSON.parse(output.slice('data: '.length));
return {
input_tokens: usage_message.usage.prompt_tokens || 0,
output_tokens: usage_message.usage.completion_tokens || 0
};
}
}
}
return { input_tokens: 0, output_tokens: 0 };
}

function getVirtualKey(c: Context): string {
const authHeader = c.req.header('Authorization') || '';
return authHeader.startsWith('Bearer ') ? authHeader.slice(7) : '';
}

136 changes: 136 additions & 0 deletions test/openai.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/// <reference types="vite/client" />
/// <reference path="../worker-configuration.d.ts" />
import { env, SELF } from 'cloudflare:test';
import { beforeAll, describe, it, expect } from 'vitest';

declare module "cloudflare:test" {
interface ProvidedEnv extends Env { }
}

beforeAll(async () => {
await env.MALACCA_USER.put('openai', import.meta.env.VITE_OPENAI_API_KEY);
});

const url = `https://example.com/openai/v1/chat/completions`;

const createRequestBody = (stream: boolean, placeholder: string) => `
{
"model": "gpt-4o-mini",
"messages": [
{
"role": "system",
"content": "You are an AI assistant that helps people find information."
},
{
"role": "user",
"content": "Tell me a very short story about ${placeholder}"
}
],
"temperature": 0.7,
"top_p": 0.95,
"max_tokens": 100,
"stream": ${stream}
}`;

describe('Test Virtual Key', () => {
it('should return 401 for invalid api key', async () => {
const response = await SELF.fetch(url, {
method: 'POST',
body: createRequestBody(true, 'Malacca'),
headers: { 'Content-Type': 'application/json', 'Authorization': 'Bearer invalid-key' }
});

expect(response.status).toBe(401);
});
});

describe('Test Guard', () => {
it('should return 403 for deny request', async () => {
const response = await SELF.fetch(url, {
method: 'POST',
body: createRequestBody(true, 'password'),
headers: { 'Content-Type': 'application/json', 'Authorization': `Bearer openai` }
});

expect(response.status).toBe(403);
});
});

describe('Test Cache', () => {
it('with cache first response should with no header malacca-cache-status and following response with hit', async () => {
const body = createRequestBody(false, 'Malacca');
let start = Date.now();
let response = await SELF.fetch(url, { method: 'POST', body: body, headers: { 'Content-Type': 'application/json', 'Authorization': `Bearer openai` } });
const value = await response.json()
const duration = Date.now() - start

expect(response.status).toBe(200);
expect(response.headers.get('content-type')).toContain('application/json');
expect(response.headers.get('malacca-cache-status')).toBeNull();

start = Date.now();
response = await SELF.fetch(url, { method: 'POST', body: body, headers: { 'Content-Type': 'application/json', 'Authorization': `Bearer openai` } });
const cacheValue = await response.json()
const cacheDuration = Date.now() - start

expect(response.status).toBe(200);
expect(response.headers.get('malacca-cache-status')).toBe('hit');
expect(response.headers.get('content-type')).toBe('application/json');
expect(value).toEqual(cacheValue)
expect(duration / 2).toBeGreaterThan(cacheDuration)
});

it('Test stream with cache', async () => {
const body = createRequestBody(true, 'Malacca');
let start = Date.now();
let response = await SELF.fetch(url, { method: 'POST', body: body, headers: { 'Content-Type': 'application/json', 'Authorization': `Bearer openai` } });
const value = await response.text()
const duration = Date.now() - start

expect(response.status).toBe(200);
expect(response.headers.get('content-type')).toContain('text/event-stream');
expect(response.headers.get('malacca-cache-status')).toBeNull();

start = Date.now();
response = await SELF.fetch(url, { method: 'POST', body: body, headers: { 'Content-Type': 'application/json', 'Authorization': `Bearer openai` } });
const cacheValue = await response.text()
const cacheDuration = Date.now() - start

expect(response.status).toBe(200);
expect(response.headers.get('malacca-cache-status')).toBe('hit');
expect(response.headers.get('content-type')).toContain('text/event-stream');
expect(value).toEqual(cacheValue)
expect(duration / 2).toBeGreaterThan(cacheDuration)
});

it('should not cache non-200 responses', async () => {
const invalidBody = JSON.stringify({
messages: [{ role: "user", content: "This is an invalid request" }],
stream: "invalid-model",
temperature: 0.7,
top_p: 0.95,
max_tokens: 800,
});

// First request - should return a non-200 response
let response = await SELF.fetch(url, {
method: 'POST',
body: invalidBody,
headers: { 'Content-Type': 'application/json', 'Authorization': `Bearer openai` }
});

expect(response.status).not.toBe(200);
expect(response.headers.get('malacca-cache-status')).toBeNull();

// Second request with the same invalid body
response = await SELF.fetch(url, {
method: 'POST',
body: invalidBody,
headers: { 'Content-Type': 'application/json', 'Authorization': `Bearer openai` }
});

expect(response.status).not.toBe(200);
// Should still be a cache miss, as non-200 responses are not cached
expect(response.headers.get('malacca-cache-status')).toBeNull();
});
});

0 comments on commit afb05fa

Please sign in to comment.