-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: oilbeater <[email protected]>
- Loading branch information
Showing
3 changed files
with
254 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,10 @@ | ||
import { azureOpenAIProvider } from './azureOpenAI'; | ||
import { workersAIProvider } from './workersAI'; | ||
import { deepseekProvider } from './deepseek'; | ||
|
||
import { openaiProvider } from './openai'; | ||
export const providers = { | ||
'azure-openai': azureOpenAIProvider, | ||
'workers-ai': workersAIProvider, | ||
'deepseek': deepseekProvider, | ||
}; | ||
'openai': openaiProvider, | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
import { Hono, Context, Next } from 'hono'; | ||
import { AIProvider } from '../types'; | ||
import { | ||
cacheMiddleware, | ||
metricsMiddleware, | ||
bufferMiddleware, | ||
loggingMiddleware, | ||
virtualKeyMiddleware, | ||
rateLimiterMiddleware, | ||
guardMiddleware, | ||
fallbackMiddleware | ||
} from '../middlewares'; | ||
|
||
const BasePath = '/openai'; | ||
const ProviderName = 'openai'; | ||
const openaiRoute = new Hono(); | ||
|
||
const initMiddleware = async (c: Context, next: Next) => { | ||
c.set('endpoint', ProviderName); | ||
c.set('getModelName', getModelName); | ||
c.set('getTokenCount', getTokenCount); | ||
c.set('getVirtualKey', getVirtualKey); | ||
await next(); | ||
}; | ||
|
||
openaiRoute.use( | ||
initMiddleware, | ||
metricsMiddleware, | ||
loggingMiddleware, | ||
bufferMiddleware, | ||
virtualKeyMiddleware, | ||
rateLimiterMiddleware, | ||
guardMiddleware, | ||
cacheMiddleware, | ||
fallbackMiddleware | ||
); | ||
|
||
openaiRoute.post('/*', async (c: Context) => { | ||
return openaiProvider.handleRequest(c); | ||
}); | ||
|
||
openaiRoute.get('/*', async (c: Context) => { | ||
return c.text('OpenAI endpoint on Malacca.', 200, { 'Content-Type': 'text/plain' }); | ||
}); | ||
|
||
export const openaiProvider: AIProvider = { | ||
name: ProviderName, | ||
basePath: BasePath, | ||
route: openaiRoute, | ||
getModelName: getModelName, | ||
getTokenCount: getTokenCount, | ||
handleRequest: async (c: Context) => { | ||
const functionName = c.req.path.slice(`/openai/`.length); | ||
const openaiEndpoint = `https://api.openai.com/${functionName}`; | ||
console.log('openaiEndpoint', openaiEndpoint); | ||
const headers = new Headers(c.req.header()); | ||
if (c.get('middlewares')?.includes('virtualKey')) { | ||
const apiKey: string = c.get('realKey'); | ||
if (apiKey) { | ||
headers.set('Authorization', `Bearer ${apiKey}`); | ||
} | ||
} | ||
|
||
const response = await fetch(openaiEndpoint, { | ||
method: c.req.method, | ||
body: JSON.stringify(await c.req.json()), | ||
headers: headers | ||
}); | ||
|
||
return response; | ||
} | ||
}; | ||
|
||
function getModelName(c: Context): string { | ||
const body = c.get('reqBuffer') || '{}'; | ||
const model = JSON.parse(body).model; | ||
return model || "unknown"; | ||
} | ||
|
||
function getTokenCount(c: Context): { input_tokens: number, output_tokens: number } { | ||
const buf = c.get('buffer') || ""; | ||
if (c.res.status === 200) { | ||
if (c.res.headers.get('content-type') === 'application/json') { | ||
try { | ||
const jsonResponse = JSON.parse(buf); | ||
const usage = jsonResponse.usage; | ||
if (usage) { | ||
return { | ||
input_tokens: usage.prompt_tokens || 0, | ||
output_tokens: usage.completion_tokens || 0 | ||
}; | ||
} | ||
} catch (error) { | ||
console.error("Error parsing response:", error); | ||
} | ||
} | ||
else { | ||
const output = buf.trim().split('\n\n').at(-2); | ||
if (output && output.startsWith('data: ')) { | ||
const usage_message = JSON.parse(output.slice('data: '.length)); | ||
return { | ||
input_tokens: usage_message.usage.prompt_tokens || 0, | ||
output_tokens: usage_message.usage.completion_tokens || 0 | ||
}; | ||
} | ||
} | ||
} | ||
return { input_tokens: 0, output_tokens: 0 }; | ||
} | ||
|
||
function getVirtualKey(c: Context): string { | ||
const authHeader = c.req.header('Authorization') || ''; | ||
return authHeader.startsWith('Bearer ') ? authHeader.slice(7) : ''; | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
/// <reference types="vite/client" /> | ||
/// <reference path="../worker-configuration.d.ts" /> | ||
import { env, SELF } from 'cloudflare:test'; | ||
import { beforeAll, describe, it, expect } from 'vitest'; | ||
|
||
declare module "cloudflare:test" { | ||
interface ProvidedEnv extends Env { } | ||
} | ||
|
||
beforeAll(async () => { | ||
await env.MALACCA_USER.put('openai', import.meta.env.VITE_OPENAI_API_KEY); | ||
}); | ||
|
||
const url = `https://example.com/openai/v1/chat/completions`; | ||
|
||
const createRequestBody = (stream: boolean, placeholder: string) => ` | ||
{ | ||
"model": "gpt-4o-mini", | ||
"messages": [ | ||
{ | ||
"role": "system", | ||
"content": "You are an AI assistant that helps people find information." | ||
}, | ||
{ | ||
"role": "user", | ||
"content": "Tell me a very short story about ${placeholder}" | ||
} | ||
], | ||
"temperature": 0.7, | ||
"top_p": 0.95, | ||
"max_tokens": 100, | ||
"stream": ${stream} | ||
}`; | ||
|
||
describe('Test Virtual Key', () => { | ||
it('should return 401 for invalid api key', async () => { | ||
const response = await SELF.fetch(url, { | ||
method: 'POST', | ||
body: createRequestBody(true, 'Malacca'), | ||
headers: { 'Content-Type': 'application/json', 'Authorization': 'Bearer invalid-key' } | ||
}); | ||
|
||
expect(response.status).toBe(401); | ||
}); | ||
}); | ||
|
||
describe('Test Guard', () => { | ||
it('should return 403 for deny request', async () => { | ||
const response = await SELF.fetch(url, { | ||
method: 'POST', | ||
body: createRequestBody(true, 'password'), | ||
headers: { 'Content-Type': 'application/json', 'Authorization': `Bearer openai` } | ||
}); | ||
|
||
expect(response.status).toBe(403); | ||
}); | ||
}); | ||
|
||
describe('Test Cache', () => { | ||
it('with cache first response should with no header malacca-cache-status and following response with hit', async () => { | ||
const body = createRequestBody(false, 'Malacca'); | ||
let start = Date.now(); | ||
let response = await SELF.fetch(url, { method: 'POST', body: body, headers: { 'Content-Type': 'application/json', 'Authorization': `Bearer openai` } }); | ||
const value = await response.json() | ||
const duration = Date.now() - start | ||
|
||
expect(response.status).toBe(200); | ||
expect(response.headers.get('content-type')).toContain('application/json'); | ||
expect(response.headers.get('malacca-cache-status')).toBeNull(); | ||
|
||
start = Date.now(); | ||
response = await SELF.fetch(url, { method: 'POST', body: body, headers: { 'Content-Type': 'application/json', 'Authorization': `Bearer openai` } }); | ||
const cacheValue = await response.json() | ||
const cacheDuration = Date.now() - start | ||
|
||
expect(response.status).toBe(200); | ||
expect(response.headers.get('malacca-cache-status')).toBe('hit'); | ||
expect(response.headers.get('content-type')).toBe('application/json'); | ||
expect(value).toEqual(cacheValue) | ||
expect(duration / 2).toBeGreaterThan(cacheDuration) | ||
}); | ||
|
||
it('Test stream with cache', async () => { | ||
const body = createRequestBody(true, 'Malacca'); | ||
let start = Date.now(); | ||
let response = await SELF.fetch(url, { method: 'POST', body: body, headers: { 'Content-Type': 'application/json', 'Authorization': `Bearer openai` } }); | ||
const value = await response.text() | ||
const duration = Date.now() - start | ||
|
||
expect(response.status).toBe(200); | ||
expect(response.headers.get('content-type')).toContain('text/event-stream'); | ||
expect(response.headers.get('malacca-cache-status')).toBeNull(); | ||
|
||
start = Date.now(); | ||
response = await SELF.fetch(url, { method: 'POST', body: body, headers: { 'Content-Type': 'application/json', 'Authorization': `Bearer openai` } }); | ||
const cacheValue = await response.text() | ||
const cacheDuration = Date.now() - start | ||
|
||
expect(response.status).toBe(200); | ||
expect(response.headers.get('malacca-cache-status')).toBe('hit'); | ||
expect(response.headers.get('content-type')).toContain('text/event-stream'); | ||
expect(value).toEqual(cacheValue) | ||
expect(duration / 2).toBeGreaterThan(cacheDuration) | ||
}); | ||
|
||
it('should not cache non-200 responses', async () => { | ||
const invalidBody = JSON.stringify({ | ||
messages: [{ role: "user", content: "This is an invalid request" }], | ||
stream: "invalid-model", | ||
temperature: 0.7, | ||
top_p: 0.95, | ||
max_tokens: 800, | ||
}); | ||
|
||
// First request - should return a non-200 response | ||
let response = await SELF.fetch(url, { | ||
method: 'POST', | ||
body: invalidBody, | ||
headers: { 'Content-Type': 'application/json', 'Authorization': `Bearer openai` } | ||
}); | ||
|
||
expect(response.status).not.toBe(200); | ||
expect(response.headers.get('malacca-cache-status')).toBeNull(); | ||
|
||
// Second request with the same invalid body | ||
response = await SELF.fetch(url, { | ||
method: 'POST', | ||
body: invalidBody, | ||
headers: { 'Content-Type': 'application/json', 'Authorization': `Bearer openai` } | ||
}); | ||
|
||
expect(response.status).not.toBe(200); | ||
// Should still be a cache miss, as non-200 responses are not cached | ||
expect(response.headers.get('malacca-cache-status')).toBeNull(); | ||
}); | ||
}); |