diff --git a/src/providers/workersAI.ts b/src/providers/workersAI.ts index 6a80968..5e6a090 100644 --- a/src/providers/workersAI.ts +++ b/src/providers/workersAI.ts @@ -4,7 +4,7 @@ import { AIProvider } from '../types'; const ProviderName = 'workers-ai'; const BasePath = '/workers-ai'; const workersAIRoute = new Hono(); -workersAIRoute.all('/*', async (c) => { +workersAIRoute.post('/:provider/:repo/:model', async (c) => { return workersAIProvider.handleRequest(c); }); @@ -15,12 +15,17 @@ export const workersAIProvider: AIProvider = { getModelName: getModelName, getTokenCount: getTokenCount, handleRequest: async (c: Context<{ Bindings: Env }>) => { - const response = await c.env.AI.run("@cf/meta/llama-3.1-8b-instruct", + const provider = c.req.param('provider'); + const repo = c.req.param('repo'); + const model = c.req.param('model') + const response = await c.env.AI.run(`${provider}/${repo}/${model}`, await c.req.json()); + if (response instanceof ReadableStream) { return new Response(response); } - return new Response(response.response); + + return new Response(JSON.stringify(response)); } }; @@ -31,4 +36,3 @@ function getModelName(c: Context) { function getTokenCount(c: Context) { return { input_tokens: 0, output_tokens: 0 }; } - \ No newline at end of file