diff --git a/backend/src/build-system/__tests__/test.model-provider.spec.ts b/backend/src/build-system/__tests__/test.model-provider.spec.ts new file mode 100644 index 0000000..f66df10 --- /dev/null +++ b/backend/src/build-system/__tests__/test.model-provider.spec.ts @@ -0,0 +1,9 @@ +import { EmbeddingProvider } from "src/common/embedding-provider"; + +describe('Model Provider Test', () => { + let embProvider = EmbeddingProvider.getInstance(); + it('should generate a response from the model provider', async () => { + let res = await embProvider.generateEmbResponse("Your text string goes here", "text-embedding-3-small"); + console.log(res); + }); +}); \ No newline at end of file diff --git a/backend/src/config/common-path.ts b/backend/src/config/common-path.ts index bc88ea6..ccf30c9 100644 --- a/backend/src/config/common-path.ts +++ b/backend/src/config/common-path.ts @@ -1,6 +1,7 @@ import * as path from 'path'; -import { existsSync, mkdirSync, promises } from 'fs-extra'; -import { cwd } from 'process'; + +import { existsSync, mkdirSync, promises, writeFileSync } from 'fs-extra'; +import { ConfigType } from '@nestjs/config'; // Constants for base directories const APP_NAME = 'codefox'; @@ -29,10 +30,17 @@ const ensureDir = (dirPath: string): string => { export const getRootDir = (): string => ensureDir(ROOT_DIR); // Configuration Paths -export const getConfigDir = (): string => - ensureDir(path.join(getRootDir(), 'config')); -export const getConfigPath = (configName: string): string => - path.join(getConfigDir(), `${configName}.json`); +export const getConfigPath = (): string => { + const rootPath = ensureDir(path.join(getRootDir())); + return path.join(rootPath, 'config.json'); +}; + +export const getModelStatusPath = (): string => { + const rootPath = ensureDir(getRootDir()); + const modelStatusPath = path.join(rootPath, `model-status.json`); + writeFileSync(modelStatusPath, '{}'); + return modelStatusPath; +}; // Models Directory export const getModelsDir = (): string => @@ -40,6 +48,12 @@ export const getModelsDir = (): string => export const getModelPath = (modelName: string): string => path.join(getModelsDir(), modelName); +// Embs Directory +export const getEmbDir = (): string => + ensureDir(path.join(getRootDir(), 'embeddings')); +export const getEmbPath = (modelName: string): string => + path.join(getModelsDir(), modelName); + // Project-Specific Paths export const getProjectsDir = (): string => ensureDir(path.join(getRootDir(), 'projects')); diff --git a/backend/src/config/config-loader.ts b/backend/src/config/config-loader.ts index b8c390b..650e207 100644 --- a/backend/src/config/config-loader.ts +++ b/backend/src/config/config-loader.ts @@ -1,7 +1,10 @@ import * as fs from 'fs'; import * as _ from 'lodash'; import { getConfigPath } from './common-path'; -export interface ChatConfig { +import { ConfigType } from 'src/downloader/universal-utils'; +import { Logger } from '@nestjs/common'; + +export interface ModelConfig { model: string; endpoint?: string; token?: string; @@ -9,45 +12,242 @@ export interface ChatConfig { task?: string; } -export class ConfigLoader { - private chatsConfig: ChatConfig[]; +export interface EmbeddingConfig { + model: string; + endpoint?: string; + default?: boolean; + token?: string; +} +export interface AppConfig { + models?: ModelConfig[]; + embeddings?: EmbeddingConfig[]; +} + +export const exampleConfigContent = `{ + // Chat models configuration + // You can configure multiple chat models + "models": [ + // Example of OpenAI GPT configuration + { + "model": "gpt-3.5-turbo", + "endpoint": "https://api.openai.com/v1", + "token": "your-openai-token", // Replace with your OpenAI token + "default": true // Set as default chat model + }, + + // Example of local model configuration + { + "model": "llama2", + "endpoint": "http://localhost:11434/v1" + } + ], + + // Embedding model configuration (optional) + "embeddings": [{ + "model": "text-embedding-ada-002", + "endpoint": "https://api.openai.com/v1", + "token": "your-openai-token", // Replace with your OpenAI token + "default": true // Set as default embedding + }] +}`; + +export class ConfigLoader { + readonly logger = new Logger(ConfigLoader.name); + private type: string; + private static instances: Map = new Map(); + private static config: AppConfig; private readonly configPath: string; - constructor() { - this.configPath = getConfigPath('config'); + private constructor(type: ConfigType) { + this.type = type; + this.configPath = getConfigPath(); + this.initConfigFile(); + this.loadConfig(); + } + + public static getInstance(type: ConfigType): ConfigLoader { + if (!ConfigLoader.instances.has(type)) { + ConfigLoader.instances.set(type, new ConfigLoader(type)); + } + return ConfigLoader.instances.get(type)!; + } + + public initConfigFile(): void { + Logger.log('Creating example config file', 'ConfigLoader'); + + const config = getConfigPath(); + if (fs.existsSync(config)) { + return; + } + + if (!fs.existsSync(config)) { + //make file + fs.writeFileSync(config, exampleConfigContent, 'utf-8'); + } + Logger.log('Creating example config file', 'ConfigLoader'); + } + + public reload(): void { this.loadConfig(); } private loadConfig() { - const file = fs.readFileSync(this.configPath, 'utf-8'); + try { + Logger.log( + `Loading configuration from ${this.configPath}`, + 'ConfigLoader', + ); + const file = fs.readFileSync(this.configPath, 'utf-8'); + const jsonContent = file.replace( + /\\"|"(?:\\"|[^"])*"|(\/\/.*|\/\*[\s\S]*?\*\/)/g, + (m, g) => (g ? '' : m), + ); + ConfigLoader.config = JSON.parse(jsonContent); + this.validateConfig(); + } catch (error) { + if ( + error.code === 'ENOENT' || + error.message.includes('Unexpected end of JSON input') + ) { + ConfigLoader.config = {}; + this.saveConfig(); + } else { + throw error; + } + } + + this.logger.log(ConfigLoader.config); - this.chatsConfig = JSON.parse(file); } - get(path: string) { + get(path?: string): T { if (!path) { - return this.chatsConfig as unknown as T; + return ConfigLoader.config as unknown as T; } - return _.get(this.chatsConfig, path) as T; + return _.get(ConfigLoader.config, path) as T; } set(path: string, value: any) { - _.set(this.chatsConfig, path, value); + _.set(ConfigLoader.config, path, value); this.saveConfig(); } private saveConfig() { + const configDir = path.dirname(this.configPath); + if (!fs.existsSync(configDir)) { + fs.mkdirSync(configDir, { recursive: true }); + } fs.writeFileSync( this.configPath, - JSON.stringify(this.chatsConfig, null, 4), + JSON.stringify(ConfigLoader.config, null, 2), 'utf-8', ); } + addConfig(config: ModelConfig | EmbeddingConfig) { + if (!ConfigLoader.config[this.type]) { + ConfigLoader.config[this.type] = []; + } + this.logger.log(ConfigLoader.config); + const index = ConfigLoader.config[this.type].findIndex( + (chat) => chat.model === config.model, + ); + if (index !== -1) { + ConfigLoader.config[this.type].splice(index, 1); + } + + if (config.default) { + ConfigLoader.config.models.forEach((chat) => { + chat.default = false; + }); + } + + ConfigLoader.config[this.type].push(config); + this.saveConfig(); + } + + removeConfig(modelName: string): boolean { + if (!ConfigLoader.config[this.type]) { + return false; + } + + const initialLength = ConfigLoader.config[this.type].length; + ConfigLoader.config.models = ConfigLoader.config[this.type].filter( + (chat) => chat.model !== modelName, + ); + + if (ConfigLoader.config[this.type].length !== initialLength) { + this.saveConfig(); + return true; + } + + return false; + } + + getAllConfigs(): EmbeddingConfig[] | ModelConfig[] | null { + const res = ConfigLoader.config[this.type]; + return Array.isArray(res) ? res : null; + } + validateConfig() { - if (!this.chatsConfig) { - throw new Error("Invalid configuration: 'chats' section is missing."); + if (!ConfigLoader.config) { + ConfigLoader.config = {}; + } + + if (typeof ConfigLoader.config !== 'object') { + throw new Error('Invalid configuration: Must be an object'); + } + + if (ConfigLoader.config.models) { + if (!Array.isArray(ConfigLoader.config.models)) { + throw new Error("Invalid configuration: 'chats' must be an array"); + } + + ConfigLoader.config.models.forEach((chat, index) => { + if (!chat.model) { + throw new Error( + `Invalid chat configuration at index ${index}: 'model' is required`, + ); + } + }); + + const defaultChats = ConfigLoader.config.models.filter( + (chat) => chat.default, + ); + if (defaultChats.length > 1) { + throw new Error( + 'Invalid configuration: Multiple default chat configurations found', + ); + } } + + if (ConfigLoader.config[ConfigType.EMBEDDINGS]) { + this.logger.log(ConfigLoader.config[ConfigType.EMBEDDINGS]); + if (!Array.isArray(ConfigLoader.config[ConfigType.EMBEDDINGS])) { + throw new Error("Invalid configuration: 'embeddings' must be an array"); + } + + ConfigLoader.config.models.forEach((emb, index) => { + if (!emb.model) { + throw new Error( + `Invalid chat configuration at index ${index}: 'model' is required`, + ); + } + }); + + const defaultChats = ConfigLoader.config[ConfigType.EMBEDDINGS].filter( + (chat) => chat.default, + ); + if (defaultChats.length > 1) { + throw new Error( + 'Invalid configuration: Multiple default emb configurations found', + ); + } + } + } + + getConfig(): AppConfig { + return ConfigLoader.config; } } diff --git a/backend/src/model/__tests__/loadAllChatsModels.spec.ts b/backend/src/downloader/__tests__/loadAllChatsModels.spec.ts similarity index 61% rename from backend/src/model/__tests__/loadAllChatsModels.spec.ts rename to backend/src/downloader/__tests__/loadAllChatsModels.spec.ts index c5f702e..8a437a8 100644 --- a/backend/src/model/__tests__/loadAllChatsModels.spec.ts +++ b/backend/src/downloader/__tests__/loadAllChatsModels.spec.ts @@ -1,9 +1,12 @@ import path from 'path'; import * as fs from 'fs'; -import { ConfigLoader } from '../../config/config-loader'; -import { ModelDownloader } from '../model-downloader'; -import { downloadAllModels } from '../utils'; -import { getConfigDir, getConfigPath } from 'src/config/common-path'; +import { + ConfigLoader, + ModelConfig, + EmbeddingConfig, +} from '../../config/config-loader'; +import { UniversalDownloader } from '../model-downloader'; +import { ConfigType, downloadAll, TaskType } from '../universal-utils'; const originalIsArray = Array.isArray; @@ -36,27 +39,35 @@ Array.isArray = jest.fn((type: any): type is any[] => { // }); describe('loadAllChatsModels with real model loading', () => { - let configLoader: ConfigLoader; + let modelConfigLoader: ConfigLoader; + let embConfigLoader: ConfigLoader; beforeAll(async () => { - const testConfig = [ - { - model: 'Felladrin/onnx-flan-alpaca-base', - task: 'text2text-generation', - }, - ]; - const configPath = getConfigPath('config'); - fs.writeFileSync(configPath, JSON.stringify(testConfig, null, 2), 'utf8'); + modelConfigLoader = ConfigLoader.getInstance(ConfigType.CHATS); + embConfigLoader = ConfigLoader.getInstance(ConfigType.EMBEDDINGS); + const modelConfig: ModelConfig = { + model: 'Xenova/flan-t5-small', + endpoint: 'http://localhost:11434/v1', + token: 'your-token-here', + task: 'text2text-generation', + }; + modelConfigLoader.addConfig(modelConfig); - configLoader = new ConfigLoader(); - await downloadAllModels(); - }, 600000); + const embConfig: EmbeddingConfig = { + model: 'fast-bge-base-en-v1.5', + endpoint: 'http://localhost:11434/v1', + token: 'your-token-here', + }; + embConfigLoader.addConfig(embConfig); + console.log('preload starts'); + await downloadAll(); + console.log('preload successfully'); + }, 60000000); it('should load real models specified in config', async () => { - const downloader = ModelDownloader.getInstance(); - + const downloader = UniversalDownloader.getInstance(); const chat1Model = await downloader.getLocalModel( - 'text2text-generation', - 'Felladrin/onnx-flan-alpaca-base', + TaskType.CHAT, + 'Xenova/flan-t5-small', ); expect(chat1Model).toBeDefined(); console.log('Loaded Model:', chat1Model); @@ -82,7 +93,7 @@ describe('loadAllChatsModels with real model loading', () => { } catch (error) { console.error('Error during model inference:', error); } - }, 600000); + }, 6000000); }); afterAll(() => { diff --git a/backend/src/downloader/const.ts b/backend/src/downloader/const.ts new file mode 100644 index 0000000..3f89569 --- /dev/null +++ b/backend/src/downloader/const.ts @@ -0,0 +1,73 @@ +export const REMOTE_MODEL_LISTS = [ + 'gpt-4-0125-preview', + 'gpt-4-1106-preview', + 'gpt-4-vision-preview', + 'gpt-4', + 'gpt-4-32k', + 'gpt-3.5-turbo-0125', + 'gpt-3.5-turbo-1106', + 'gpt-3.5-turbo', + 'gpt-3.5-turbo-16k', + 'claude-3-opus-20240229', + 'claude-3-sonnet-20240229', + 'claude-3-haiku-20240229', + 'claude-2.1', + 'claude-2.0', + 'claude-instant-1.2', + 'gemini-1.0-pro', + 'gemini-1.0-pro-vision', + 'gemini-1.0-ultra', + 'mistral-large-latest', + 'mistral-medium-latest', + 'mistral-small-latest', + 'command', + 'command-light', + 'command-nightly', + 'azure-gpt-4', + 'azure-gpt-35-turbo', + 'azure-gpt-4-32k', + 'yi-34b-chat', + 'yi-34b-200k', + 'yi-34b', + 'yi-6b-chat', + 'yi-6b-200k', + 'yi-6b', + 'text-bison-001', + 'chat-bison-001', + 'claude-2.0', + 'claude-1.2', + 'claude-1.0', + 'claude-instant-1.2', + 'llama-2-70b-chat', + 'llama-2-13b-chat', + 'llama-2-7b-chat', + 'qwen-72b-chat', + 'qwen-14b-chat', + 'qwen-7b-chat', + 'deepseek-67b-chat', + 'deepseek-33b-chat', + 'deepseek-7b-chat', + 'mixtral-8x7b-32k', + 'mixtral-8x7b', + 'baichuan-2-53b', + 'baichuan-2-13b', + 'baichuan-2-7b', + 'xverse-65b-chat', + 'xverse-13b-chat', + 'xverse-7b-chat', + 'command-r', + 'command-light-r', + 'claude-instant', + 'vicuna-13b', + 'vicuna-7b', + 'falcon-40b', + 'falcon-7b', + 'stablelm-7b', + 'mpt-7b', + 'dolly-12b', + 'alpaca-13b', + 'pythia-12b', +]; + +export const isRemoteModel = (model: string): boolean => + REMOTE_MODEL_LISTS.includes(model); diff --git a/backend/src/downloader/embedding-downloader.ts b/backend/src/downloader/embedding-downloader.ts new file mode 100644 index 0000000..fe00624 --- /dev/null +++ b/backend/src/downloader/embedding-downloader.ts @@ -0,0 +1,39 @@ +import { Logger } from '@nestjs/common'; +import { UniversalStatusManager } from './universal-status'; +import { EmbeddingModel, FlagEmbedding } from 'fastembed'; +import { getEmbDir } from 'src/config/common-path'; +export class EmbeddingDownloader { + readonly logger = new Logger(EmbeddingDownloader.name); + private static instance: EmbeddingDownloader; + private readonly statusManager = UniversalStatusManager.getInstance(); + + public static getInstance(): EmbeddingDownloader { + if (!EmbeddingDownloader.instance) { + EmbeddingDownloader.instance = new EmbeddingDownloader(); + } + + return EmbeddingDownloader.instance; + } + + async getPipeline(model: string): Promise { + if (!Object.values(EmbeddingModel).includes(model as EmbeddingModel)) { + this.logger.error( + `Invalid model: ${model} is not a valid EmbeddingModel.`, + ); + return null; + } + try { + const embeddingModel = await FlagEmbedding.init({ + model: model as EmbeddingModel, + cacheDir: getEmbDir(), + }); + this.statusManager.updateStatus(model, true); + return embeddingModel; + } catch (error) { + this.logger.error( + `Failed to load local model: ${model} with error: ${error.message || error}`, + ); + return null; + } + } +} diff --git a/backend/src/downloader/model-downloader.ts b/backend/src/downloader/model-downloader.ts new file mode 100644 index 0000000..6564663 --- /dev/null +++ b/backend/src/downloader/model-downloader.ts @@ -0,0 +1,60 @@ +import { Logger } from '@nestjs/common'; +import { PipelineType, pipeline, env, cat } from '@huggingface/transformers'; +import { getEmbDir, getModelPath, getModelsDir } from 'src/config/common-path'; +import { isRemoteModel } from './const'; +import { UniversalStatusManager } from './universal-status'; + +env.allowLocalModels = true; +env.localModelPath = getModelsDir(); +export class UniversalDownloader { + readonly logger = new Logger(UniversalDownloader.name); + private static instance: UniversalDownloader; + private readonly statusManager = UniversalStatusManager.getInstance(); + + public static getInstance(): UniversalDownloader { + if (!UniversalDownloader.instance) { + UniversalDownloader.instance = new UniversalDownloader(); + } + return UniversalDownloader.instance; + } + + async getPipeline(task: string, model: string, path: string): Promise { + if (isRemoteModel(model)) { + this.logger.log(`Remote model detected: ${model}, marking as downloaded`); + console.log(this.statusManager); + this.statusManager.updateStatus(model, true); + return null; + } + + this.logger.log(`Starting download for local model: ${model}`); + try { + console.log(path); + const pipelineInstance = await pipeline(task as PipelineType, model, { + cache_dir: path, + }); + this.logger.log(`Successfully downloaded local model: ${model}`); + this.statusManager.updateStatus(model, true); + return pipelineInstance; + } catch (error) { + this.logger.error(`Failed to download model ${model}: ${error.message}`); + this.statusManager.updateStatus(model, false); + return null; + } + } + + public async getLocalModel(task: string, model: string): Promise { + try { + const pipelineInstance = await pipeline(task as PipelineType, model, { + local_files_only: true, + revision: 'local', + cache_dir: getModelsDir(), + }); + return pipelineInstance; + } catch (error) { + this.logger.error( + `Failed to load local model: ${model} with error: ${error.message || error}`, + ); + return null; + } + } +} diff --git a/backend/src/downloader/universal-status.ts b/backend/src/downloader/universal-status.ts new file mode 100644 index 0000000..188e5ac --- /dev/null +++ b/backend/src/downloader/universal-status.ts @@ -0,0 +1,75 @@ +import * as fs from 'fs'; +import * as path from 'path'; +import { getModelStatusPath } from 'src/config/common-path'; + +export interface UniversalStatus { + isDownloaded: boolean; + lastChecked: Date; +} + +export class UniversalStatusManager { + private static instance: UniversalStatusManager; + private status: Record; + private readonly statusPath: string; + + private constructor() { + this.statusPath = getModelStatusPath(); + this.loadStatus(); + } + + public static getInstance(): UniversalStatusManager { + if (!UniversalStatusManager.instance) { + UniversalStatusManager.instance = new UniversalStatusManager(); + } + return UniversalStatusManager.instance; + } + + private loadStatus() { + try { + const file = fs.readFileSync(this.statusPath, 'utf-8'); + const data = JSON.parse(file); + this.status = Object.entries(data).reduce( + (acc, [key, value]: [string, any]) => { + acc[key] = { + ...value, + lastChecked: value.lastChecked + ? new Date(value.lastChecked) + : new Date(), + }; + return acc; + }, + {} as Record, + ); + } catch (error) { + this.status = {}; + } + } + + private saveStatus() { + const statusDir = path.dirname(this.statusPath); + if (!fs.existsSync(statusDir)) { + fs.mkdirSync(statusDir, { recursive: true }); + } + fs.writeFileSync( + this.statusPath, + JSON.stringify(this.status, null, 2), + 'utf-8', + ); + } + + updateStatus(UniversalName: string, isDownloaded: boolean) { + this.status[UniversalName] = { + isDownloaded, + lastChecked: new Date(), + }; + this.saveStatus(); + } + + getStatus(UniversalName: string): UniversalStatus | undefined { + return this.status[UniversalName]; + } + + getAllStatus(): Record { + return { ...this.status }; + } +} diff --git a/backend/src/downloader/universal-utils.ts b/backend/src/downloader/universal-utils.ts new file mode 100644 index 0000000..ece20c7 --- /dev/null +++ b/backend/src/downloader/universal-utils.ts @@ -0,0 +1,115 @@ +// model-utils.ts +import { UniversalDownloader } from './model-downloader'; +import { Logger } from '@nestjs/common'; +import { UniversalStatusManager } from './universal-status'; +import { ModelConfig } from '../config/config-loader'; +import { ConfigLoader } from 'src/config/config-loader'; +import { getEmbDir, getModelsDir } from 'src/config/common-path'; +import { EmbeddingDownloader } from './embedding-downloader'; + +const logger = new Logger('model-utils'); + +export enum ConfigType { + EMBEDDINGS = 'embeddings', + CHATS = 'models', +} +export enum TaskType { + CHAT = 'text2text-generation', + EMBEDDING = 'feature-extraction', +} + +export async function downloadAll() { + await checkAndDownloadAllModels(ConfigType.CHATS); + + console.log('embedding load starts'); + await checkAndDownloadAllModels(ConfigType.EMBEDDINGS); + console.log('embedding load ends'); +} +async function downloadModelForType( + type: ConfigType, + modelName: string, + task: string, +) { + const statusManager = UniversalStatusManager.getInstance(); + const storePath = + type === ConfigType.EMBEDDINGS ? getEmbDir() : getModelsDir(); + + if (type === ConfigType.EMBEDDINGS) { + const embeddingDownloader = EmbeddingDownloader.getInstance(); + try { + logger.log(`Downloading embedding model: ${modelName}`); + await embeddingDownloader.getPipeline(modelName); + statusManager.updateStatus(modelName, true); + logger.log(`Successfully downloaded embedding model: ${modelName}`); + + console.log('embedding load finished'); + } catch (error) { + logger.error( + `Failed to download embedding model ${modelName}:`, + error.message, + ); + statusManager.updateStatus(modelName, false); + throw error; + } + } else { + const downloader = UniversalDownloader.getInstance(); + const status = statusManager.getStatus(modelName); + + if (status?.isDownloaded) { + try { + await downloader.getLocalModel(task, modelName); + logger.log( + `Model ${modelName} is already downloaded and verified, skipping...`, + ); + return; + } catch (error) { + logger.warn( + `Model ${modelName} was marked as downloaded but not found locally, re-downloading...`, + ); + } + } + + try { + logger.log( + `Downloading chat model: ${modelName} for task: ${task} in path: ${storePath}`, + ); + await downloader.getPipeline(task, modelName, storePath); + statusManager.updateStatus(modelName, true); + logger.log(`Successfully downloaded chat model: ${modelName}`); + } catch (error) { + logger.error(`Failed to download model ${modelName}:`, error.message); + statusManager.updateStatus(modelName, false); + throw error; + } + } +} + +export async function checkAndDownloadAllModels( + type: ConfigType, +): Promise { + const configLoader = ConfigLoader.getInstance(type); + const modelsConfig = configLoader.getAllConfigs(); + + logger.log('Checking and downloading configured models...'); + + if (!modelsConfig.length) { + logger.warn(`No ${type} models configured`); + return; + } + + const downloadTasks = modelsConfig.map(async (config) => { + const { model, task } = config; + const taskType = + task || + (type === ConfigType.EMBEDDINGS ? TaskType.EMBEDDING : TaskType.CHAT); + await downloadModelForType(type, model, taskType); + }); + + try { + await Promise.all(downloadTasks); + logger.log(`All ${type} models downloaded successfully.`); + } catch (error) { + logger.error(`One or more ${type} models failed to download.`); + throw error; + } +} diff --git a/backend/src/embedding/__tests__/loadAllEmbModels.spec.ts b/backend/src/embedding/__tests__/loadAllEmbModels.spec.ts new file mode 100644 index 0000000..b2d1ea8 --- /dev/null +++ b/backend/src/embedding/__tests__/loadAllEmbModels.spec.ts @@ -0,0 +1,39 @@ +import { localEmbProvider } from '../local-embedding-provider'; +import { EmbeddingModel } from 'fastembed'; +import { openAIEmbProvider } from '../../../../llm-server/src/embedding/openai-embedding-provider'; +const originalIsArray = Array.isArray; + +Array.isArray = jest.fn((type: any): type is any[] => { + if ( + type && + type.constructor && + (type.constructor.name === 'Float32Array' || + type.constructor.name === 'BigInt64Array') + ) { + return true; + } + return originalIsArray(type); +}) as unknown as (arg: any) => arg is any[]; + +describe('testing embedding provider', () => { + it('should load real models specified in config', async () => { + const documents = [ + 'passage: Hello, World!', + 'query: Hello, World!', + 'passage: This is an example passage.', + // You can leave out the prefix but it's recommended + 'fastembed-js is licensed under MIT', + ]; + + await localEmbProvider.generateEmbResponse( + EmbeddingModel.BGEBaseENV15, + documents, + ); + }, 6000000); + + +}); + +afterAll(() => { + Array.isArray = originalIsArray; +}); diff --git a/backend/src/embedding/local-embedding-provider.ts b/backend/src/embedding/local-embedding-provider.ts new file mode 100644 index 0000000..901b055 --- /dev/null +++ b/backend/src/embedding/local-embedding-provider.ts @@ -0,0 +1,27 @@ +import { Logger } from '@nestjs/common'; +import { EmbeddingModel } from 'fastembed'; +import { EmbeddingDownloader } from 'src/downloader/embedding-downloader'; + +export class localEmbProvider { + private static logger = new Logger(localEmbProvider.name); + + static async generateEmbResponse(model: string, message: string[]) { + const embLoader = EmbeddingDownloader.getInstance(); + try { + const embeddingModel = await embLoader.getPipeline(model); + const embeddings = embeddingModel.embed(message); + + for await (const batch of embeddings) { + console.log(batch); + } + } catch (error) { + this.logger.log(`error when using ${model} api`); + } + } + + static async getEmbList() { + Object.values(EmbeddingModel).forEach((model) => { + this.logger.log(model); + }); + } +} diff --git a/backend/src/embedding/openai-embbeding-provider.ts b/backend/src/embedding/openai-embbeding-provider.ts new file mode 100644 index 0000000..d8ab576 --- /dev/null +++ b/backend/src/embedding/openai-embbeding-provider.ts @@ -0,0 +1,33 @@ +import { Logger } from '@nestjs/common'; +import { EmbeddingModel } from 'fastembed'; +import openai, { OpenAI } from 'openai'; +import { EmbeddingDownloader } from 'src/downloader/embedding-downloader'; + +export class openAIEmbProvider { + private static logger = new Logger(openAIEmbProvider.name); + + private static openai = () => { + return new OpenAI({ + apiKey: process.env.OPEN_API_KEY, + }); + }; + + static async generateEmbResponse(model: string, message: string) { + const embedding = await this.openai().embeddings.create({ + model: model, + input: message, + encoding_format: 'float', + }); + console.log(embedding.data[0].embedding); + } + + static async getEmbList() { + try { + const models = await this.openai().models.list(); + Object.values(models).filter((model) => model.object === 'embedding'); + this.logger.log(`Models fetched: ${models.data.length}`); + } catch (error) { + this.logger.error('Error fetching models:', error); + } + } +} diff --git a/backend/src/main.ts b/backend/src/main.ts index ea049c9..8d523fe 100644 --- a/backend/src/main.ts +++ b/backend/src/main.ts @@ -1,9 +1,11 @@ import { NestFactory } from '@nestjs/core'; import { AppModule } from './app.module'; import 'reflect-metadata'; -import { downloadAllModels } from './model/utils'; +import { downloadAll } from './downloader/universal-utils'; +import * as dotenv from 'dotenv'; async function bootstrap() { + dotenv.config(); // 加载 .env 文件中的环境变量 const app = await NestFactory.create(AppModule); app.enableCors({ origin: '*', @@ -17,7 +19,7 @@ async function bootstrap() { 'Access-Control-Allow-Credentials', ], }); - await downloadAllModels(); + await downloadAll(); await app.listen(process.env.PORT ?? 3000); } diff --git a/backend/src/model/__tests__/app.e2e-spec.ts b/backend/src/model/__tests__/app.e2e-spec.ts deleted file mode 100644 index f7e956e..0000000 --- a/backend/src/model/__tests__/app.e2e-spec.ts +++ /dev/null @@ -1,24 +0,0 @@ -import { Test, TestingModule } from '@nestjs/testing'; -import { INestApplication } from '@nestjs/common'; -import * as request from 'supertest'; -import { AppModule } from '../../app.module'; - -describe('AppController (e2e)', () => { - let app: INestApplication; - - beforeEach(async () => { - const moduleFixture: TestingModule = await Test.createTestingModule({ - imports: [AppModule], - }).compile(); - - app = moduleFixture.createNestApplication(); - await app.init(); - }); - - it('/ (GET)', () => { - return request(app.getHttpServer()) - .get('/') - .expect(200) - .expect('Hello World!'); - }); -}); diff --git a/backend/src/model/__tests__/jest-e2e.json b/backend/src/model/__tests__/jest-e2e.json deleted file mode 100644 index e9d912f..0000000 --- a/backend/src/model/__tests__/jest-e2e.json +++ /dev/null @@ -1,9 +0,0 @@ -{ - "moduleFileExtensions": ["js", "json", "ts"], - "rootDir": ".", - "testEnvironment": "node", - "testRegex": ".e2e-spec.ts$", - "transform": { - "^.+\\.(t|j)s$": "ts-jest" - } -} diff --git a/backend/src/model/model-downloader.ts b/backend/src/model/model-downloader.ts deleted file mode 100644 index bc754bf..0000000 --- a/backend/src/model/model-downloader.ts +++ /dev/null @@ -1,31 +0,0 @@ -import { Logger } from '@nestjs/common'; -import { PipelineType, pipeline, env } from '@huggingface/transformers'; -import { getModelPath, getModelsDir } from 'src/config/common-path'; -env.allowLocalModels = true; -env.localModelPath = getModelsDir(); -export class ModelDownloader { - readonly logger = new Logger(ModelDownloader.name); - private static instance: ModelDownloader; - public static getInstance(): ModelDownloader { - if (!ModelDownloader.instance) { - ModelDownloader.instance = new ModelDownloader(); - } - return ModelDownloader.instance; - } - - async downloadModel(task: string, model: string): Promise { - const pipelineInstance = await pipeline(task as PipelineType, model, { - cache_dir: getModelsDir(), - }); - return pipelineInstance; - } - - public async getLocalModel(task: string, model: string): Promise { - const pipelineInstance = await pipeline(task as PipelineType, model, { - local_files_only: true, - revision: 'local', - }); - - return pipelineInstance; - } -} diff --git a/backend/src/model/utils.ts b/backend/src/model/utils.ts deleted file mode 100644 index 3c2f7f6..0000000 --- a/backend/src/model/utils.ts +++ /dev/null @@ -1,22 +0,0 @@ -import { ChatConfig, ConfigLoader } from 'src/config/config-loader'; -import { ModelDownloader } from './model-downloader'; - -export async function downloadAllModels(): Promise { - const configLoader = new ConfigLoader(); - configLoader.validateConfig(); - const chats = configLoader.get(''); - const downloader = ModelDownloader.getInstance(); - console.log('Loaded config:', chats); - const loadPromises = chats.map(async (chatConfig: ChatConfig) => { - const { model, task } = chatConfig; - try { - downloader.logger.log(model, task); - const pipelineInstance = await downloader.downloadModel(task, model); - } catch (error) { - downloader.logger.error(`Failed to load model ${model}:`, error.message); - } - }); - await Promise.all(loadPromises); - - downloader.logger.log('All models loaded.'); -} diff --git a/llm-server/package.json b/llm-server/package.json index dd5f7f2..0bf91d8 100644 --- a/llm-server/package.json +++ b/llm-server/package.json @@ -18,6 +18,7 @@ "dependencies": { "@nestjs/common": "^10.4.5", "express": "^4.21.1", + "fastembed": "^1.14.1", "node-fetch": "^3.3.2", "node-llama-cpp": "^3.1.1", "nodemon": "^3.1.7", @@ -31,7 +32,7 @@ "eslint": "^8.57.1", "eslint-config-prettier": "^9.0.0", "eslint-plugin-prettier": "^5.0.0", - "openai": "^4.68.1", + "openai": "^4.77.0", "prettier": "^3.0.0", "ts-loader": "^9.5.1", "ts-node": "^10.4.0", diff --git a/llm-server/src/emb-provider.ts b/llm-server/src/emb-provider.ts new file mode 100644 index 0000000..c27fabd --- /dev/null +++ b/llm-server/src/emb-provider.ts @@ -0,0 +1,153 @@ +import { Response } from 'express'; +import { openAIEmbProvider } from './embedding/openai-embedding-provider'; +import { LlamaModelProvider } from './model/llama-model-provider'; // 如果支持Llama模型 +import { Logger } from '@nestjs/common'; +import { + ModelProviderType, + ModelProviderOptions, + ModelError, + GenerateMessageParams, +} from './types'; +import { ModelProvider } from './model/model-provider'; +import { EmbeddingProvider } from './embedding/emb-provider'; + +export interface EmbeddingInput { + content: string; +} + +export interface Embedding { + model: string; + embedding: number[]; +} + +export class EmbeddingModelProvider { + private readonly logger = new Logger(EmbeddingModelProvider.name); + private modelProvider: EmbeddingProvider; + private readonly options: ModelProviderOptions; + private initialized: boolean = false; + + constructor( + modelProviderType: ModelProviderType = 'openai', + options: ModelProviderOptions = {}, + ) { + this.options = { + maxConcurrentRequests: 5, + maxRetries: 3, + retryDelay: 1000, + ...options, + }; + + this.modelProvider = this.createModelProvider(modelProviderType); + } + + private createModelProvider(type: ModelProviderType): EmbeddingProvider { + switch (type) { + case 'openai': + return new openAIEmbProvider({apiKey: process.env.OPEN_API_KEY}); + // case 'llama': + // + // // return new LlamaModelProvider(); + default: + throw new Error(`Unsupported embedding model provider type: ${type}`); + } + } + + async initialize(): Promise { + try { + this.logger.log('Initializing embedding provider...'); + await this.modelProvider.initialize(); + this.initialized = true; + this.logger.log('Embedding provider fully initialized and ready.'); + } catch (error) { + const modelError = this.normalizeError(error); + this.logger.error('Failed to initialize embedding provider:', modelError); + throw modelError; + } + } + + async generateEmbeddingResponse( + params: GenerateMessageParams, + res: Response, + ): Promise { + this.ensureInitialized(); + + try { + await this.modelProvider.generateEmbResponse(params, res); + } catch (error) { + const modelError = this.normalizeError(error); + this.logger.error('Error in generating embedding response:', modelError); + + if (!res.writableEnded) { + this.sendErrorResponse(res, modelError); + } + } + } + + async getEmbeddingModels(res: Response): Promise { + this.ensureInitialized(); + + try { + await this.modelProvider.getEmbList(res); + } catch (error) { + const modelError = this.normalizeError(error); + this.logger.error('Error getting embedding models:', modelError); + + if (!res.writableEnded) { + this.sendErrorResponse(res, modelError); + } + } + } + + private ensureInitialized(): void { + if (!this.initialized) { + throw new Error('Embedding provider not initialized. Call initialize() first.'); + } + } + + private normalizeError(error: any): ModelError { + if (error instanceof Error) { + return { + ...error, + code: (error as any).code || 'UNKNOWN_ERROR', + retryable: (error as any).retryable || false, + }; + } + + return { + name: 'Error', + message: String(error), + code: 'UNKNOWN_ERROR', + retryable: false, + }; + } + + private sendErrorResponse(res: Response, error: ModelError): void { + const errorResponse = { + error: { + message: error.message, + code: error.code, + details: error.details, + }, + }; + + if (res.headersSent) { + res.write(`data: ${JSON.stringify(errorResponse)}\n\n`); + res.write('data: [DONE]\n\n'); + res.end(); + } else { + res.status(500).json(errorResponse); + } + } + + isInitialized(): boolean { + return this.initialized; + } + + getCurrentProvider(): string { + return this.modelProvider.constructor.name; + } + + getProviderOptions(): ModelProviderOptions { + return { ...this.options }; + } +} diff --git a/llm-server/src/embedding/emb-provider.ts b/llm-server/src/embedding/emb-provider.ts new file mode 100644 index 0000000..84d7e75 --- /dev/null +++ b/llm-server/src/embedding/emb-provider.ts @@ -0,0 +1,9 @@ +import { Response } from 'express'; +import { GenerateMessageParams } from '../types'; + +export interface EmbeddingProvider { + initialize(): Promise; + generateEmbResponse(params: GenerateMessageParams, + res: Response,): Promise; + getEmbList(res: Response): Promise; +} diff --git a/llm-server/src/embedding/openai-embedding-provider.ts b/llm-server/src/embedding/openai-embedding-provider.ts new file mode 100644 index 0000000..084a139 --- /dev/null +++ b/llm-server/src/embedding/openai-embedding-provider.ts @@ -0,0 +1,93 @@ + +import { Logger } from '@nestjs/common'; +import OpenAI from 'openai'; +import PQueue from 'p-queue'; +import { Response } from 'express' +import { EmbeddingProvider } from './emb-provider'; +import { GenerateMessageParams } from '../types'; + +export class openAIEmbProvider implements EmbeddingProvider { + + private logger = new Logger(openAIEmbProvider.name); + + private openai: OpenAI; + private requestQueue: PQueue; + private readonly options: { + maxConcurrentRequests: number; + maxRetries: number; + retryDelay: number; + queueInterval: number; + intervalCap: number; + apiKey: string; + }; + + constructor(options: { apiKey?: string } = {}) { + this.options = { + maxConcurrentRequests: 5, + maxRetries: 3, + retryDelay: 1000, + queueInterval: 1000, + intervalCap: 10, + apiKey: process.env.OPEN_API_KEY || options.apiKey, + ...options, + }; + + this.requestQueue = new PQueue({ + concurrency: this.options.maxConcurrentRequests, + interval: this.options.queueInterval, + intervalCap: this.options.intervalCap, + }); + + this.requestQueue.on('active', () => { + this.logger.debug(`Queue size: ${this.requestQueue.size}, Pending: ${this.requestQueue.pending}`); + }); + } + + async initialize(): Promise { + this.logger.log('Initializing OpenAI model...'); + if (!this.options.apiKey) { + throw new Error('OpenAI API key is required'); + } + + this.openai = new OpenAI({ + apiKey: this.options.apiKey, + }); + + this.logger.log(`OpenAI model initialized with options: ${JSON.stringify(this.options)}`); + } + + async generateEmbResponse( + params: GenerateMessageParams, + res: Response, + ): Promise { + try { + const embedding = await this.openai.embeddings.create({ + model: params.model, + input: params.message, + encoding_format: "float", + }); + console.log(embedding.data[0].embedding); + res.json({ + embedding: embedding.data[0].embedding + }) + } catch (error) { + this.logger.error(`Error generating embedding for model ${params.model}:`, error); + res.status(500).json({ + error: error, + }) + } + } + + async getEmbList(res: Response): Promise { + try { + const models = await this.openai.models.list(); + const embeddingModels = models.data.filter(model => (model.object as string) === "embedding"); + + res.json({ models: embeddingModels }); + this.logger.log(`Fetched ${embeddingModels.length} embedding models.`); + } catch (error) { + this.logger.error('Error fetching models:', error); + res.status(500).json({ error: 'Error fetching models', message: error.message }); + } + } +} diff --git a/llm-server/src/llm-provider.ts b/llm-server/src/llm-provider.ts index dc53a22..42bd497 100644 --- a/llm-server/src/llm-provider.ts +++ b/llm-server/src/llm-provider.ts @@ -22,7 +22,7 @@ export class LLMProvider { private initialized: boolean = false; constructor( - modelProviderType: ModelProviderType = 'llama', + modelProviderType: ModelProviderType = 'openai', options: ModelProviderOptions = {}, ) { this.options = { diff --git a/llm-server/src/main.ts b/llm-server/src/main.ts index bb2c8e7..d8398b9 100644 --- a/llm-server/src/main.ts +++ b/llm-server/src/main.ts @@ -2,28 +2,60 @@ import { Logger, Module } from '@nestjs/common'; import { ChatMessageInput, LLMProvider } from './llm-provider'; import express, { Express, Request, Response } from 'express'; import { GenerateMessageParams } from './types'; +import { EmbeddingModelProvider } from './emb-provider'; export class App { private readonly logger = new Logger(App.name); private app: Express; private readonly PORT: number; private llmProvider: LLMProvider; + private embProvider: EmbeddingModelProvider; - constructor(llmProvider: LLMProvider) { + constructor(llmProvider: LLMProvider, embProvider: EmbeddingModelProvider) { this.app = express(); this.app.use(express.json()); this.PORT = parseInt(process.env.PORT || '3001', 10); this.llmProvider = llmProvider; + this.embProvider = embProvider; this.logger.log(`App initialized with PORT: ${this.PORT}`); } setupRoutes(): void { this.logger.log('Setting up routes...'); this.app.post('/chat/completion', this.handleChatRequest.bind(this)); + this.app.post('/embedding', this.handleEmbRequest.bind(this)); this.app.get('/tags', this.handleModelTagsRequest.bind(this)); this.logger.log('Routes set up successfully.'); } + private async handleEmbRequest(req: Request, res: Response): Promise { + this.logger.log('Received embedding request.'); + try { + this.logger.debug(JSON.stringify(req.body)); + const { content, model } = req.body as ChatMessageInput & { + model: string; + }; + if (!content || !model) { + res.status(400).json({ error: 'Content and model are required' }); + } + + this.logger.log(`Received chat request for model: ${model}`); + const params: GenerateMessageParams = { + model: model || 'text-embedding-ada-002', + message: content + }; + + this.logger.debug(`Request content: "${content}"`); + res.setHeader('Content-Type', 'application/json'); + res.setHeader('Cache-Control', 'no-cache'); + this.logger.debug('Response headers set for streaming.'); + await this.embProvider.generateEmbeddingResponse(params, res); + } catch (error) { + this.logger.error('Error in chat endpoint:', error); + res.status(500).json({ error: 'Internal server error' }); + } + } + private async handleChatRequest(req: Request, res: Response): Promise { this.logger.log('Received chat request.'); try { @@ -76,7 +108,10 @@ async function main() { try { const llmProvider = new LLMProvider('openai'); await llmProvider.initialize(); - const app = new App(llmProvider); + + const embProvider = new EmbeddingModelProvider('openai'); + await embProvider.initialize(); + const app = new App(llmProvider, embProvider); await app.start(); } catch (error) { logger.error('Failed to start the application:', error);