From 1f5066da021187b62de986c369ee019479ca749d Mon Sep 17 00:00:00 2001 From: ZHallen122 Date: Mon, 28 Oct 2024 13:16:41 -0400 Subject: [PATCH] add /tags api --- llm-server/src/llm-provider.ts | 10 ++++-- llm-server/src/main.ts | 23 ++++++++++++- llm-server/src/model/llama-model-provider.ts | 32 +++++++++++++++++ llm-server/src/model/model-provider.ts | 2 ++ llm-server/src/model/openai-model-provider.ts | 34 ++++++++++++++++++- 5 files changed, 96 insertions(+), 5 deletions(-) diff --git a/llm-server/src/llm-provider.ts b/llm-server/src/llm-provider.ts index 58ff4b6..2286a9d 100644 --- a/llm-server/src/llm-provider.ts +++ b/llm-server/src/llm-provider.ts @@ -1,7 +1,7 @@ import express, { Express, Request, Response } from 'express'; -import { ModelProvider } from './model/model-provider.js'; -import { OpenAIModelProvider } from './model/openai-model-provider.js'; -import { LlamaModelProvider } from './model/llama-model-provider.js'; +import { ModelProvider } from './model/model-provider'; +import { OpenAIModelProvider } from './model/openai-model-provider'; +import { LlamaModelProvider } from './model/llama-model-provider'; import { Logger } from '@nestjs/common'; export interface ChatMessageInput { @@ -37,4 +37,8 @@ export class LLMProvider { ): Promise { await this.modelProvider.generateStreamingResponse(content, res); } + + async getModelTags(res: Response): Promise { + await this.modelProvider.getModelTagsResponse(res); + } } diff --git a/llm-server/src/main.ts b/llm-server/src/main.ts index 4ed2bb8..f062f05 100644 --- a/llm-server/src/main.ts +++ b/llm-server/src/main.ts @@ -1,5 +1,5 @@ import { Logger } from '@nestjs/common'; -import { ChatMessageInput, LLMProvider } from './llm-provider.js'; +import { ChatMessageInput, LLMProvider } from './llm-provider'; import express, { Express, Request, Response } from 'express'; export class App { @@ -19,6 +19,7 @@ export class App { setupRoutes(): void { this.logger.log('Setting up routes...'); this.app.post('/chat/completion', this.handleChatRequest.bind(this)); + this.app.get('/tags', this.handleModelTagsRequest.bind(this)); this.logger.log('Routes set up successfully.'); } @@ -39,6 +40,26 @@ export class App { } } + private async handleModelTagsRequest( + req: Request, + res: Response, + ): Promise { + this.logger.log('Received chat request.'); + try { + this.logger.debug(JSON.stringify(req.body)); + const { content } = req.body as ChatMessageInput; + this.logger.debug(`Request content: "${content}"`); + res.setHeader('Content-Type', 'text/event-stream'); + res.setHeader('Cache-Control', 'no-cache'); + res.setHeader('Connection', 'keep-alive'); + this.logger.debug('Response headers set for streaming.'); + await this.llmProvider.getModelTags(res); + } catch (error) { + this.logger.error('Error in chat endpoint:', error); + res.status(500).json({ error: 'Internal server error' }); + } + } + async start(): Promise { this.setupRoutes(); this.app.listen(this.PORT, () => { diff --git a/llm-server/src/model/llama-model-provider.ts b/llm-server/src/model/llama-model-provider.ts index 1bec324..07a24b7 100644 --- a/llm-server/src/model/llama-model-provider.ts +++ b/llm-server/src/model/llama-model-provider.ts @@ -68,4 +68,36 @@ export class LlamaModelProvider extends ModelProvider { res.end(); } } + + async getModelTagsResponse(res: Response): Promise { + this.logger.log('Fetching available models from OpenAI...'); + // Set SSE headers + res.writeHead(200, { + 'Content-Type': 'text/event-stream', + 'Cache-Control': 'no-cache', + Connection: 'keep-alive', + }); + try { + const startTime = Date.now(); + const models = 'tresr'; + + const response = { + models: models, // Wrap the models in the required structure + }; + + const endTime = Date.now(); + this.logger.log( + `Model fetching completed. Total models: ${models.length}`, + ); + this.logger.log(`Fetch time: ${endTime - startTime}ms`); + res.write(JSON.stringify(response)); + res.end(); + this.logger.log('Response ModelTags ended.'); + } catch (error) { + this.logger.error('Error during OpenAI response generation:', error); + res.write(`data: ${JSON.stringify({ error: 'Generation failed' })}\n\n`); + res.write(`data: [DONE]\n\n`); + res.end(); + } + } } diff --git a/llm-server/src/model/model-provider.ts b/llm-server/src/model/model-provider.ts index 441f477..07f6a0b 100644 --- a/llm-server/src/model/model-provider.ts +++ b/llm-server/src/model/model-provider.ts @@ -6,4 +6,6 @@ export abstract class ModelProvider { content: string, res: Response, ): Promise; + + abstract getModelTagsResponse(res: Response): Promise; } diff --git a/llm-server/src/model/openai-model-provider.ts b/llm-server/src/model/openai-model-provider.ts index 5cbc99d..e448d64 100644 --- a/llm-server/src/model/openai-model-provider.ts +++ b/llm-server/src/model/openai-model-provider.ts @@ -1,6 +1,6 @@ import { Response } from 'express'; import OpenAI from 'openai'; -import { ModelProvider } from './model-provider.js'; +import { ModelProvider } from './model-provider'; import { Logger } from '@nestjs/common'; export class OpenAIModelProvider extends ModelProvider { private readonly logger = new Logger(OpenAIModelProvider.name); @@ -56,4 +56,36 @@ export class OpenAIModelProvider extends ModelProvider { res.end(); } } + + async getModelTagsResponse(res: Response): Promise { + this.logger.log('Fetching available models from OpenAI...'); + // Set SSE headers + res.writeHead(200, { + 'Content-Type': 'text/event-stream', + 'Cache-Control': 'no-cache', + Connection: 'keep-alive', + }); + try { + const startTime = Date.now(); + const models = await this.openai.models.list(); + + const response = { + models: models, // Wrap the models in the required structure + }; + + const endTime = Date.now(); + this.logger.log( + `Model fetching completed. Total models: ${models.data.length}`, + ); + this.logger.log(`Fetch time: ${endTime - startTime}ms`); + res.write(JSON.stringify(response)); + res.end(); + this.logger.log('Response ModelTags ended.'); + } catch (error) { + this.logger.error('Error during OpenAI response generation:', error); + res.write(`data: ${JSON.stringify({ error: 'Generation failed' })}\n\n`); + res.write(`data: [DONE]\n\n`); + res.end(); + } + } }