Skip to content

Commit

Permalink
add /tags api
Browse files Browse the repository at this point in the history
  • Loading branch information
ZHallen122 committed Oct 28, 2024
1 parent 7af5bdf commit 1f5066d
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 5 deletions.
10 changes: 7 additions & 3 deletions llm-server/src/llm-provider.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import express, { Express, Request, Response } from 'express';
import { ModelProvider } from './model/model-provider.js';
import { OpenAIModelProvider } from './model/openai-model-provider.js';
import { LlamaModelProvider } from './model/llama-model-provider.js';
import { ModelProvider } from './model/model-provider';
import { OpenAIModelProvider } from './model/openai-model-provider';
import { LlamaModelProvider } from './model/llama-model-provider';
import { Logger } from '@nestjs/common';

export interface ChatMessageInput {
Expand Down Expand Up @@ -37,4 +37,8 @@ export class LLMProvider {
): Promise<void> {
await this.modelProvider.generateStreamingResponse(content, res);
}

async getModelTags(res: Response): Promise<void> {
await this.modelProvider.getModelTagsResponse(res);
}
}
23 changes: 22 additions & 1 deletion llm-server/src/main.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { Logger } from '@nestjs/common';
import { ChatMessageInput, LLMProvider } from './llm-provider.js';
import { ChatMessageInput, LLMProvider } from './llm-provider';
import express, { Express, Request, Response } from 'express';

export class App {
Expand All @@ -19,6 +19,7 @@ export class App {
setupRoutes(): void {
this.logger.log('Setting up routes...');
this.app.post('/chat/completion', this.handleChatRequest.bind(this));
this.app.get('/tags', this.handleModelTagsRequest.bind(this));
this.logger.log('Routes set up successfully.');
}

Expand All @@ -39,6 +40,26 @@ export class App {
}
}

private async handleModelTagsRequest(
req: Request,
res: Response,
): Promise<void> {
this.logger.log('Received chat request.');
try {
this.logger.debug(JSON.stringify(req.body));
const { content } = req.body as ChatMessageInput;
this.logger.debug(`Request content: "${content}"`);
res.setHeader('Content-Type', 'text/event-stream');
res.setHeader('Cache-Control', 'no-cache');
res.setHeader('Connection', 'keep-alive');
this.logger.debug('Response headers set for streaming.');
await this.llmProvider.getModelTags(res);
} catch (error) {
this.logger.error('Error in chat endpoint:', error);
res.status(500).json({ error: 'Internal server error' });
}
}

async start(): Promise<void> {
this.setupRoutes();
this.app.listen(this.PORT, () => {
Expand Down
32 changes: 32 additions & 0 deletions llm-server/src/model/llama-model-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,36 @@ export class LlamaModelProvider extends ModelProvider {
res.end();
}
}

async getModelTagsResponse(res: Response): Promise<void> {
this.logger.log('Fetching available models from OpenAI...');
// Set SSE headers
res.writeHead(200, {
'Content-Type': 'text/event-stream',
'Cache-Control': 'no-cache',
Connection: 'keep-alive',
});
try {
const startTime = Date.now();
const models = 'tresr';

const response = {
models: models, // Wrap the models in the required structure
};

const endTime = Date.now();
this.logger.log(
`Model fetching completed. Total models: ${models.length}`,
);
this.logger.log(`Fetch time: ${endTime - startTime}ms`);
res.write(JSON.stringify(response));
res.end();
this.logger.log('Response ModelTags ended.');
} catch (error) {
this.logger.error('Error during OpenAI response generation:', error);
res.write(`data: ${JSON.stringify({ error: 'Generation failed' })}\n\n`);
res.write(`data: [DONE]\n\n`);
res.end();
}
}
}
2 changes: 2 additions & 0 deletions llm-server/src/model/model-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,6 @@ export abstract class ModelProvider {
content: string,
res: Response,
): Promise<void>;

abstract getModelTagsResponse(res: Response): Promise<void>;
}
34 changes: 33 additions & 1 deletion llm-server/src/model/openai-model-provider.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { Response } from 'express';
import OpenAI from 'openai';
import { ModelProvider } from './model-provider.js';
import { ModelProvider } from './model-provider';
import { Logger } from '@nestjs/common';
export class OpenAIModelProvider extends ModelProvider {
private readonly logger = new Logger(OpenAIModelProvider.name);
Expand Down Expand Up @@ -56,4 +56,36 @@ export class OpenAIModelProvider extends ModelProvider {
res.end();
}
}

async getModelTagsResponse(res: Response): Promise<void> {
this.logger.log('Fetching available models from OpenAI...');
// Set SSE headers
res.writeHead(200, {
'Content-Type': 'text/event-stream',
'Cache-Control': 'no-cache',
Connection: 'keep-alive',
});
try {
const startTime = Date.now();
const models = await this.openai.models.list();

const response = {
models: models, // Wrap the models in the required structure
};

const endTime = Date.now();
this.logger.log(
`Model fetching completed. Total models: ${models.data.length}`,
);
this.logger.log(`Fetch time: ${endTime - startTime}ms`);
res.write(JSON.stringify(response));
res.end();
this.logger.log('Response ModelTags ended.');
} catch (error) {
this.logger.error('Error during OpenAI response generation:', error);
res.write(`data: ${JSON.stringify({ error: 'Generation failed' })}\n\n`);
res.write(`data: [DONE]\n\n`);
res.end();
}
}
}

0 comments on commit 1f5066d

Please sign in to comment.