From afb05fa8218732a38053da208b94713c1e296c58 Mon Sep 17 00:00:00 2001 From: oilbeater Date: Tue, 15 Oct 2024 14:19:57 +0800 Subject: [PATCH] support openai Signed-off-by: oilbeater --- src/providers/index.ts | 5 +- src/providers/openai.ts | 115 +++++++++++++++++++++++++++++++++ test/openai.test.ts | 136 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 254 insertions(+), 2 deletions(-) create mode 100644 src/providers/openai.ts create mode 100644 test/openai.test.ts diff --git a/src/providers/index.ts b/src/providers/index.ts index e724dd7..a9d4638 100644 --- a/src/providers/index.ts +++ b/src/providers/index.ts @@ -1,9 +1,10 @@ import { azureOpenAIProvider } from './azureOpenAI'; import { workersAIProvider } from './workersAI'; import { deepseekProvider } from './deepseek'; - +import { openaiProvider } from './openai'; export const providers = { 'azure-openai': azureOpenAIProvider, 'workers-ai': workersAIProvider, 'deepseek': deepseekProvider, -}; \ No newline at end of file + 'openai': openaiProvider, +}; diff --git a/src/providers/openai.ts b/src/providers/openai.ts new file mode 100644 index 0000000..5d2c2d5 --- /dev/null +++ b/src/providers/openai.ts @@ -0,0 +1,115 @@ +import { Hono, Context, Next } from 'hono'; +import { AIProvider } from '../types'; +import { + cacheMiddleware, + metricsMiddleware, + bufferMiddleware, + loggingMiddleware, + virtualKeyMiddleware, + rateLimiterMiddleware, + guardMiddleware, + fallbackMiddleware +} from '../middlewares'; + +const BasePath = '/openai'; +const ProviderName = 'openai'; +const openaiRoute = new Hono(); + +const initMiddleware = async (c: Context, next: Next) => { + c.set('endpoint', ProviderName); + c.set('getModelName', getModelName); + c.set('getTokenCount', getTokenCount); + c.set('getVirtualKey', getVirtualKey); + await next(); +}; + +openaiRoute.use( + initMiddleware, + metricsMiddleware, + loggingMiddleware, + bufferMiddleware, + virtualKeyMiddleware, + rateLimiterMiddleware, + guardMiddleware, + cacheMiddleware, + fallbackMiddleware +); + +openaiRoute.post('/*', async (c: Context) => { + return openaiProvider.handleRequest(c); +}); + +openaiRoute.get('/*', async (c: Context) => { + return c.text('OpenAI endpoint on Malacca.', 200, { 'Content-Type': 'text/plain' }); +}); + +export const openaiProvider: AIProvider = { + name: ProviderName, + basePath: BasePath, + route: openaiRoute, + getModelName: getModelName, + getTokenCount: getTokenCount, + handleRequest: async (c: Context) => { + const functionName = c.req.path.slice(`/openai/`.length); + const openaiEndpoint = `https://api.openai.com/${functionName}`; + console.log('openaiEndpoint', openaiEndpoint); + const headers = new Headers(c.req.header()); + if (c.get('middlewares')?.includes('virtualKey')) { + const apiKey: string = c.get('realKey'); + if (apiKey) { + headers.set('Authorization', `Bearer ${apiKey}`); + } + } + + const response = await fetch(openaiEndpoint, { + method: c.req.method, + body: JSON.stringify(await c.req.json()), + headers: headers + }); + + return response; + } +}; + +function getModelName(c: Context): string { + const body = c.get('reqBuffer') || '{}'; + const model = JSON.parse(body).model; + return model || "unknown"; +} + +function getTokenCount(c: Context): { input_tokens: number, output_tokens: number } { + const buf = c.get('buffer') || ""; + if (c.res.status === 200) { + if (c.res.headers.get('content-type') === 'application/json') { + try { + const jsonResponse = JSON.parse(buf); + const usage = jsonResponse.usage; + if (usage) { + return { + input_tokens: usage.prompt_tokens || 0, + output_tokens: usage.completion_tokens || 0 + }; + } + } catch (error) { + console.error("Error parsing response:", error); + } + } + else { + const output = buf.trim().split('\n\n').at(-2); + if (output && output.startsWith('data: ')) { + const usage_message = JSON.parse(output.slice('data: '.length)); + return { + input_tokens: usage_message.usage.prompt_tokens || 0, + output_tokens: usage_message.usage.completion_tokens || 0 + }; + } + } + } + return { input_tokens: 0, output_tokens: 0 }; +} + +function getVirtualKey(c: Context): string { + const authHeader = c.req.header('Authorization') || ''; + return authHeader.startsWith('Bearer ') ? authHeader.slice(7) : ''; +} + diff --git a/test/openai.test.ts b/test/openai.test.ts new file mode 100644 index 0000000..a807a00 --- /dev/null +++ b/test/openai.test.ts @@ -0,0 +1,136 @@ +/// +/// +import { env, SELF } from 'cloudflare:test'; +import { beforeAll, describe, it, expect } from 'vitest'; + +declare module "cloudflare:test" { + interface ProvidedEnv extends Env { } +} + +beforeAll(async () => { + await env.MALACCA_USER.put('openai', import.meta.env.VITE_OPENAI_API_KEY); +}); + +const url = `https://example.com/openai/v1/chat/completions`; + +const createRequestBody = (stream: boolean, placeholder: string) => ` +{ + "model": "gpt-4o-mini", + "messages": [ + { + "role": "system", + "content": "You are an AI assistant that helps people find information." + }, + { + "role": "user", + "content": "Tell me a very short story about ${placeholder}" + } + ], + "temperature": 0.7, + "top_p": 0.95, + "max_tokens": 100, + "stream": ${stream} +}`; + +describe('Test Virtual Key', () => { + it('should return 401 for invalid api key', async () => { + const response = await SELF.fetch(url, { + method: 'POST', + body: createRequestBody(true, 'Malacca'), + headers: { 'Content-Type': 'application/json', 'Authorization': 'Bearer invalid-key' } + }); + + expect(response.status).toBe(401); + }); +}); + +describe('Test Guard', () => { + it('should return 403 for deny request', async () => { + const response = await SELF.fetch(url, { + method: 'POST', + body: createRequestBody(true, 'password'), + headers: { 'Content-Type': 'application/json', 'Authorization': `Bearer openai` } + }); + + expect(response.status).toBe(403); + }); +}); + +describe('Test Cache', () => { + it('with cache first response should with no header malacca-cache-status and following response with hit', async () => { + const body = createRequestBody(false, 'Malacca'); + let start = Date.now(); + let response = await SELF.fetch(url, { method: 'POST', body: body, headers: { 'Content-Type': 'application/json', 'Authorization': `Bearer openai` } }); + const value = await response.json() + const duration = Date.now() - start + + expect(response.status).toBe(200); + expect(response.headers.get('content-type')).toContain('application/json'); + expect(response.headers.get('malacca-cache-status')).toBeNull(); + + start = Date.now(); + response = await SELF.fetch(url, { method: 'POST', body: body, headers: { 'Content-Type': 'application/json', 'Authorization': `Bearer openai` } }); + const cacheValue = await response.json() + const cacheDuration = Date.now() - start + + expect(response.status).toBe(200); + expect(response.headers.get('malacca-cache-status')).toBe('hit'); + expect(response.headers.get('content-type')).toBe('application/json'); + expect(value).toEqual(cacheValue) + expect(duration / 2).toBeGreaterThan(cacheDuration) + }); + + it('Test stream with cache', async () => { + const body = createRequestBody(true, 'Malacca'); + let start = Date.now(); + let response = await SELF.fetch(url, { method: 'POST', body: body, headers: { 'Content-Type': 'application/json', 'Authorization': `Bearer openai` } }); + const value = await response.text() + const duration = Date.now() - start + + expect(response.status).toBe(200); + expect(response.headers.get('content-type')).toContain('text/event-stream'); + expect(response.headers.get('malacca-cache-status')).toBeNull(); + + start = Date.now(); + response = await SELF.fetch(url, { method: 'POST', body: body, headers: { 'Content-Type': 'application/json', 'Authorization': `Bearer openai` } }); + const cacheValue = await response.text() + const cacheDuration = Date.now() - start + + expect(response.status).toBe(200); + expect(response.headers.get('malacca-cache-status')).toBe('hit'); + expect(response.headers.get('content-type')).toContain('text/event-stream'); + expect(value).toEqual(cacheValue) + expect(duration / 2).toBeGreaterThan(cacheDuration) + }); + + it('should not cache non-200 responses', async () => { + const invalidBody = JSON.stringify({ + messages: [{ role: "user", content: "This is an invalid request" }], + stream: "invalid-model", + temperature: 0.7, + top_p: 0.95, + max_tokens: 800, + }); + + // First request - should return a non-200 response + let response = await SELF.fetch(url, { + method: 'POST', + body: invalidBody, + headers: { 'Content-Type': 'application/json', 'Authorization': `Bearer openai` } + }); + + expect(response.status).not.toBe(200); + expect(response.headers.get('malacca-cache-status')).toBeNull(); + + // Second request with the same invalid body + response = await SELF.fetch(url, { + method: 'POST', + body: invalidBody, + headers: { 'Content-Type': 'application/json', 'Authorization': `Bearer openai` } + }); + + expect(response.status).not.toBe(200); + // Should still be a cache miss, as non-200 responses are not cached + expect(response.headers.get('malacca-cache-status')).toBeNull(); + }); + }); \ No newline at end of file