From e5791465736a7c6b8e670daacb59716cf16c24fb Mon Sep 17 00:00:00 2001 From: Jackson Chen <541898146chen@gmail.com> Date: Sat, 28 Dec 2024 03:10:28 -0600 Subject: [PATCH] feat(backend): refactor model downloading and configuration handling --- backend/src/config/common-path.ts | 19 +- backend/src/config/config-loader.spec.ts | 237 ------------------ backend/src/config/config-loader.ts | 37 ++- backend/src/main.ts | 4 +- .../__tests__/loadAllChatsModels.spec.ts | 4 +- backend/src/model/downloader/const.ts | 73 ++++++ backend/src/model/downloader/downloader.ts | 69 +++++ backend/src/model/model-downloader.ts | 31 --- backend/src/model/model-status.ts | 4 +- backend/src/model/utils.ts | 128 +++++++--- 10 files changed, 282 insertions(+), 324 deletions(-) delete mode 100644 backend/src/config/config-loader.spec.ts create mode 100644 backend/src/model/downloader/const.ts create mode 100644 backend/src/model/downloader/downloader.ts delete mode 100644 backend/src/model/model-downloader.ts diff --git a/backend/src/config/common-path.ts b/backend/src/config/common-path.ts index 6f40ab2..3ab5d60 100644 --- a/backend/src/config/common-path.ts +++ b/backend/src/config/common-path.ts @@ -1,10 +1,10 @@ import * as path from 'path'; -import { existsSync, mkdirSync, promises } from 'fs-extra'; +import { existsSync, mkdirSync, promises, writeFileSync } from 'fs-extra'; // Constants for base directories const APP_NAME = 'codefox'; // TODO: hack way to get the root directory of the workspace -const WORKSPACE_ROOT = path.resolve(__dirname, '../../../'); +const WORKSPACE_ROOT = path.resolve(__dirname, '../../../../'); const ROOT_DIR = path.join(WORKSPACE_ROOT, `.${APP_NAME}`); export const TEMPLATE_PATH = path.join(WORKSPACE_ROOT, 'backend/template'); @@ -27,10 +27,17 @@ const ensureDir = (dirPath: string): string => { export const getRootDir = (): string => ensureDir(ROOT_DIR); // Configuration Paths -export const getConfigDir = (): string => - ensureDir(path.join(getRootDir(), 'config')); -export const getConfigPath = (configName: string): string => - path.join(getConfigDir(), `${configName}.json`); +export const getConfigPath = (): string => { + const rootPath = ensureDir(path.join(getRootDir())); + return path.join(rootPath, 'config.json'); +}; + +export const getModelStatusPath = (): string => { + const rootPath = ensureDir(getRootDir()); + const modelStatusPath = path.join(rootPath, 'model-status.json'); + writeFileSync(modelStatusPath, '{}'); + return modelStatusPath; +}; // Models Directory export const getModelsDir = (): string => diff --git a/backend/src/config/config-loader.spec.ts b/backend/src/config/config-loader.spec.ts deleted file mode 100644 index 43dc3d1..0000000 --- a/backend/src/config/config-loader.spec.ts +++ /dev/null @@ -1,237 +0,0 @@ -// config-loader.ts -import * as fs from 'fs'; -import * as path from 'path'; -import * as _ from 'lodash'; -import { getConfigPath } from './common-path'; - -export interface ChatConfig { - model: string; - endpoint?: string; - token?: string; - default?: boolean; - task?: string; -} - -export interface EmbeddingConfig { - model: string; - endpoint?: string; - token?: string; -} - -export interface AppConfig { - chats?: ChatConfig[]; - embeddings?: EmbeddingConfig; -} - -export const exampleConfigContent = `{ - // Chat models configuration - // You can configure multiple chat models - "chats": [ - // Example of OpenAI GPT configuration - { - "model": "gpt-3.5-turbo", - "endpoint": "https://api.openai.com/v1", - "token": "your-openai-token", // Replace with your OpenAI token - "default": true // Set as default chat model - }, - - // Example of local model configuration - { - "model": "llama2", - "endpoint": "http://localhost:11434/v1", - "task": "chat" - } - ], - - // Embedding model configuration (optional) - "embeddings": { - "model": "text-embedding-ada-002", - "endpoint": "https://api.openai.com/v1", - "token": "your-openai-token" // Replace with your OpenAI token - } -}`; - -export class ConfigLoader { - private static instance: ConfigLoader; - private config: AppConfig; - private readonly configPath: string; - - private constructor(configPath?: string) { - this.configPath = configPath || getConfigPath('config'); - this.loadConfig(); - } - - public static getInstance(configPath?: string): ConfigLoader { - if (!ConfigLoader.instance) { - ConfigLoader.instance = new ConfigLoader(configPath); - } - return ConfigLoader.instance; - } - - public static initConfigFile(configPath: string): void { - if (fs.existsSync(configPath)) { - throw new Error('Config file already exists'); - } - - const configDir = path.dirname(configPath); - if (!fs.existsSync(configDir)) { - fs.mkdirSync(configDir, { recursive: true }); - } - - fs.writeFileSync(configPath, exampleConfigContent, 'utf-8'); - } - - public reload(): void { - this.loadConfig(); - } - - private loadConfig() { - try { - const file = fs.readFileSync(this.configPath, 'utf-8'); - const jsonContent = file.replace( - /\\"|"(?:\\"|[^"])*"|(\/\/.*|\/\*[\s\S]*?\*\/)/g, - (m, g) => (g ? '' : m), - ); - this.config = JSON.parse(jsonContent); - this.validateConfig(); - } catch (error) { - if ( - error.code === 'ENOENT' || - error.message.includes('Unexpected end of JSON input') - ) { - this.config = {}; - this.saveConfig(); - } else { - throw error; - } - } - } - - get(path?: string): T { - if (!path) { - return this.config as unknown as T; - } - return _.get(this.config, path) as T; - } - - set(path: string, value: any) { - _.set(this.config, path, value); - this.saveConfig(); - } - - private saveConfig() { - const configDir = path.dirname(this.configPath); - if (!fs.existsSync(configDir)) { - fs.mkdirSync(configDir, { recursive: true }); - } - fs.writeFileSync( - this.configPath, - JSON.stringify(this.config, null, 2), - 'utf-8', - ); - } - - getAllChatConfigs(): ChatConfig[] { - return this.config.chats || []; - } - - getChatConfig(modelName?: string): ChatConfig | null { - if (!this.config.chats || !Array.isArray(this.config.chats)) { - return null; - } - - const chats = this.config.chats; - - if (modelName) { - const foundChat = chats.find((chat) => chat.model === modelName); - if (foundChat) { - return foundChat; - } - } - - return ( - chats.find((chat) => chat.default) || (chats.length > 0 ? chats[0] : null) - ); - } - - addChatConfig(config: ChatConfig) { - if (!this.config.chats) { - this.config.chats = []; - } - - const index = this.config.chats.findIndex( - (chat) => chat.model === config.model, - ); - if (index !== -1) { - this.config.chats.splice(index, 1); - } - - if (config.default) { - this.config.chats.forEach((chat) => { - chat.default = false; - }); - } - - this.config.chats.push(config); - this.saveConfig(); - } - - removeChatConfig(modelName: string): boolean { - if (!this.config.chats) { - return false; - } - - const initialLength = this.config.chats.length; - this.config.chats = this.config.chats.filter( - (chat) => chat.model !== modelName, - ); - - if (this.config.chats.length !== initialLength) { - this.saveConfig(); - return true; - } - - return false; - } - - getEmbeddingConfig(): EmbeddingConfig | null { - return this.config.embeddings || null; - } - - validateConfig() { - if (!this.config) { - this.config = {}; - } - - if (typeof this.config !== 'object') { - throw new Error('Invalid configuration: Must be an object'); - } - - if (this.config.chats) { - if (!Array.isArray(this.config.chats)) { - throw new Error("Invalid configuration: 'chats' must be an array"); - } - - this.config.chats.forEach((chat, index) => { - if (!chat.model) { - throw new Error( - `Invalid chat configuration at index ${index}: 'model' is required`, - ); - } - }); - - const defaultChats = this.config.chats.filter((chat) => chat.default); - if (defaultChats.length > 1) { - throw new Error( - 'Invalid configuration: Multiple default chat configurations found', - ); - } - } - - if (this.config.embeddings) { - if (!this.config.embeddings.model) { - throw new Error("Invalid embedding configuration: 'model' is required"); - } - } - } -} diff --git a/backend/src/config/config-loader.ts b/backend/src/config/config-loader.ts index 43dc3d1..378921a 100644 --- a/backend/src/config/config-loader.ts +++ b/backend/src/config/config-loader.ts @@ -1,8 +1,8 @@ -// config-loader.ts import * as fs from 'fs'; import * as path from 'path'; import * as _ from 'lodash'; import { getConfigPath } from './common-path'; +import { Logger } from '@nestjs/common'; export interface ChatConfig { model: string; @@ -56,29 +56,32 @@ export class ConfigLoader { private config: AppConfig; private readonly configPath: string; - private constructor(configPath?: string) { - this.configPath = configPath || getConfigPath('config'); + private constructor() { + this.configPath = getConfigPath(); + this.initConfigFile(); this.loadConfig(); } - public static getInstance(configPath?: string): ConfigLoader { + public static getInstance(): ConfigLoader { if (!ConfigLoader.instance) { - ConfigLoader.instance = new ConfigLoader(configPath); + ConfigLoader.instance = new ConfigLoader(); } return ConfigLoader.instance; } - public static initConfigFile(configPath: string): void { - if (fs.existsSync(configPath)) { - throw new Error('Config file already exists'); - } + public initConfigFile(): void { + Logger.log('Creating example config file', 'ConfigLoader'); - const configDir = path.dirname(configPath); - if (!fs.existsSync(configDir)) { - fs.mkdirSync(configDir, { recursive: true }); + const config = getConfigPath(); + if (fs.existsSync(config)) { + return; } - fs.writeFileSync(configPath, exampleConfigContent, 'utf-8'); + if (!fs.existsSync(config)) { + //make file + fs.writeFileSync(config, exampleConfigContent, 'utf-8'); + } + Logger.log('Creating example config file', 'ConfigLoader'); } public reload(): void { @@ -87,6 +90,10 @@ export class ConfigLoader { private loadConfig() { try { + Logger.log( + `Loading configuration from ${this.configPath}`, + 'ConfigLoader', + ); const file = fs.readFileSync(this.configPath, 'utf-8'); const jsonContent = file.replace( /\\"|"(?:\\"|[^"])*"|(\/\/.*|\/\*[\s\S]*?\*\/)/g, @@ -234,4 +241,8 @@ export class ConfigLoader { } } } + + getConfig(): AppConfig { + return this.config; + } } diff --git a/backend/src/main.ts b/backend/src/main.ts index ea049c9..598613c 100644 --- a/backend/src/main.ts +++ b/backend/src/main.ts @@ -1,7 +1,7 @@ import { NestFactory } from '@nestjs/core'; import { AppModule } from './app.module'; import 'reflect-metadata'; -import { downloadAllModels } from './model/utils'; +import { checkAndDownloadAllModels } from './model/utils'; async function bootstrap() { const app = await NestFactory.create(AppModule); @@ -17,7 +17,7 @@ async function bootstrap() { 'Access-Control-Allow-Credentials', ], }); - await downloadAllModels(); + await checkAndDownloadAllModels(); await app.listen(process.env.PORT ?? 3000); } diff --git a/backend/src/model/__tests__/loadAllChatsModels.spec.ts b/backend/src/model/__tests__/loadAllChatsModels.spec.ts index c5f702e..f6f51ce 100644 --- a/backend/src/model/__tests__/loadAllChatsModels.spec.ts +++ b/backend/src/model/__tests__/loadAllChatsModels.spec.ts @@ -1,9 +1,9 @@ import path from 'path'; import * as fs from 'fs'; import { ConfigLoader } from '../../config/config-loader'; -import { ModelDownloader } from '../model-downloader'; +import { ModelDownloader } from '../downloader/downloader'; import { downloadAllModels } from '../utils'; -import { getConfigDir, getConfigPath } from 'src/config/common-path'; +import { getConfigPath, getConfigPath } from 'src/config/common-path'; const originalIsArray = Array.isArray; diff --git a/backend/src/model/downloader/const.ts b/backend/src/model/downloader/const.ts new file mode 100644 index 0000000..3f89569 --- /dev/null +++ b/backend/src/model/downloader/const.ts @@ -0,0 +1,73 @@ +export const REMOTE_MODEL_LISTS = [ + 'gpt-4-0125-preview', + 'gpt-4-1106-preview', + 'gpt-4-vision-preview', + 'gpt-4', + 'gpt-4-32k', + 'gpt-3.5-turbo-0125', + 'gpt-3.5-turbo-1106', + 'gpt-3.5-turbo', + 'gpt-3.5-turbo-16k', + 'claude-3-opus-20240229', + 'claude-3-sonnet-20240229', + 'claude-3-haiku-20240229', + 'claude-2.1', + 'claude-2.0', + 'claude-instant-1.2', + 'gemini-1.0-pro', + 'gemini-1.0-pro-vision', + 'gemini-1.0-ultra', + 'mistral-large-latest', + 'mistral-medium-latest', + 'mistral-small-latest', + 'command', + 'command-light', + 'command-nightly', + 'azure-gpt-4', + 'azure-gpt-35-turbo', + 'azure-gpt-4-32k', + 'yi-34b-chat', + 'yi-34b-200k', + 'yi-34b', + 'yi-6b-chat', + 'yi-6b-200k', + 'yi-6b', + 'text-bison-001', + 'chat-bison-001', + 'claude-2.0', + 'claude-1.2', + 'claude-1.0', + 'claude-instant-1.2', + 'llama-2-70b-chat', + 'llama-2-13b-chat', + 'llama-2-7b-chat', + 'qwen-72b-chat', + 'qwen-14b-chat', + 'qwen-7b-chat', + 'deepseek-67b-chat', + 'deepseek-33b-chat', + 'deepseek-7b-chat', + 'mixtral-8x7b-32k', + 'mixtral-8x7b', + 'baichuan-2-53b', + 'baichuan-2-13b', + 'baichuan-2-7b', + 'xverse-65b-chat', + 'xverse-13b-chat', + 'xverse-7b-chat', + 'command-r', + 'command-light-r', + 'claude-instant', + 'vicuna-13b', + 'vicuna-7b', + 'falcon-40b', + 'falcon-7b', + 'stablelm-7b', + 'mpt-7b', + 'dolly-12b', + 'alpaca-13b', + 'pythia-12b', +]; + +export const isRemoteModel = (model: string): boolean => + REMOTE_MODEL_LISTS.includes(model); diff --git a/backend/src/model/downloader/downloader.ts b/backend/src/model/downloader/downloader.ts new file mode 100644 index 0000000..1d584d8 --- /dev/null +++ b/backend/src/model/downloader/downloader.ts @@ -0,0 +1,69 @@ +import { Logger } from '@nestjs/common'; +import { PipelineType, pipeline, env } from '@huggingface/transformers'; +import { getModelPath, getModelsDir } from 'src/config/common-path'; +import { isRemoteModel } from './const'; +import { ModelStatusManager } from '../model-status'; + +env.allowLocalModels = true; +env.localModelPath = getModelsDir(); + +export class ModelDownloader { + readonly logger = new Logger(ModelDownloader.name); + private static instance: ModelDownloader; + private readonly statusManager: ModelStatusManager; + + private constructor() { + this.statusManager = ModelStatusManager.getInstance(); + } + + public static getInstance(): ModelDownloader { + if (!ModelDownloader.instance) { + ModelDownloader.instance = new ModelDownloader(); + } + return ModelDownloader.instance; + } + + async downloadModel(task: string, model: string): Promise { + if (isRemoteModel(model)) { + this.logger.log(`Remote model detected: ${model}, marking as downloaded`); + this.statusManager.updateStatus(model, true); + return null; + } + + this.logger.log(`Starting download for local model: ${model}`); + try { + const pipelineInstance = await pipeline(task as PipelineType, model, { + cache_dir: getModelsDir(), + }); + this.logger.log(`Successfully downloaded local model: ${model}`); + this.statusManager.updateStatus(model, true); + return pipelineInstance; + } catch (error) { + this.logger.error(`Failed to download model ${model}: ${error.message}`); + this.statusManager.updateStatus(model, false); + throw error; + } + } + + public async getLocalModel(task: string, model: string): Promise { + if (isRemoteModel(model)) { + this.logger.log(`Remote model detected: ${model}, marking as downloaded`); + this.statusManager.updateStatus(model, true); + return null; + } + + this.logger.log(`Checking local model: ${model}`); + try { + const pipelineInstance = await pipeline(task as PipelineType, model, { + local_files_only: true, + revision: 'local', + }); + this.statusManager.updateStatus(model, true); + return pipelineInstance; + } catch (error) { + this.logger.error(`Failed to get local model ${model}: ${error.message}`); + this.statusManager.updateStatus(model, false); + throw error; + } + } +} diff --git a/backend/src/model/model-downloader.ts b/backend/src/model/model-downloader.ts deleted file mode 100644 index bc754bf..0000000 --- a/backend/src/model/model-downloader.ts +++ /dev/null @@ -1,31 +0,0 @@ -import { Logger } from '@nestjs/common'; -import { PipelineType, pipeline, env } from '@huggingface/transformers'; -import { getModelPath, getModelsDir } from 'src/config/common-path'; -env.allowLocalModels = true; -env.localModelPath = getModelsDir(); -export class ModelDownloader { - readonly logger = new Logger(ModelDownloader.name); - private static instance: ModelDownloader; - public static getInstance(): ModelDownloader { - if (!ModelDownloader.instance) { - ModelDownloader.instance = new ModelDownloader(); - } - return ModelDownloader.instance; - } - - async downloadModel(task: string, model: string): Promise { - const pipelineInstance = await pipeline(task as PipelineType, model, { - cache_dir: getModelsDir(), - }); - return pipelineInstance; - } - - public async getLocalModel(task: string, model: string): Promise { - const pipelineInstance = await pipeline(task as PipelineType, model, { - local_files_only: true, - revision: 'local', - }); - - return pipelineInstance; - } -} diff --git a/backend/src/model/model-status.ts b/backend/src/model/model-status.ts index e1dc043..efae453 100644 --- a/backend/src/model/model-status.ts +++ b/backend/src/model/model-status.ts @@ -1,6 +1,6 @@ import * as fs from 'fs'; import * as path from 'path'; -import { getConfigPath } from 'src/config/common-path'; +import { getConfigPath, getModelStatusPath } from 'src/config/common-path'; export interface ModelStatus { isDownloaded: boolean; @@ -13,7 +13,7 @@ export class ModelStatusManager { private readonly statusPath: string; private constructor() { - this.statusPath = getConfigPath('model-status'); + this.statusPath = getModelStatusPath(); this.loadStatus(); } diff --git a/backend/src/model/utils.ts b/backend/src/model/utils.ts index 8943a0f..00bc2ca 100644 --- a/backend/src/model/utils.ts +++ b/backend/src/model/utils.ts @@ -1,74 +1,140 @@ -// model-utils.ts -import { ModelDownloader } from './model-downloader'; +import { ModelDownloader } from './downloader/downloader'; import { Logger } from '@nestjs/common'; import { ModelStatusManager } from './model-status'; import { ConfigLoader } from 'src/config/config-loader'; const logger = new Logger('model-utils'); -export async function downloadAllModels(): Promise { +export async function checkAndDownloadAllModels(): Promise { const configLoader = ConfigLoader.getInstance(); const statusManager = ModelStatusManager.getInstance(); const downloader = ModelDownloader.getInstance(); const chatConfigs = configLoader.getAllChatConfigs(); - logger.log('Loaded chat configurations:', chatConfigs); + const embeddingConfig = configLoader.getEmbeddingConfig(); + + logger.log('Checking and downloading configured models...'); if (!chatConfigs.length) { logger.warn('No chat models configured'); - return; + } else { + const chatDownloadTasks = chatConfigs.map(async (chatConfig) => { + const { model, task } = chatConfig; + const status = statusManager.getStatus(model); + + if (status?.isDownloaded) { + try { + await downloader.getLocalModel(task || 'chat', model); + logger.log( + `Model ${model} is already downloaded and verified, skipping...`, + ); + return; + } catch (error) { + logger.warn( + `Model ${model} was marked as downloaded but not found locally, re-downloading...`, + ); + } + } + + try { + logger.log( + `Downloading chat model: ${model} for task: ${task || 'chat'}`, + ); + await downloader.downloadModel(task || 'chat', model); + logger.log(`Successfully downloaded model: ${model}`); + } catch (error) { + logger.error(`Failed to download model ${model}:`, error.message); + throw error; + } + }); + + try { + await Promise.all(chatDownloadTasks); + logger.log('All chat models downloaded successfully.'); + } catch (error) { + logger.error('One or more chat models failed to download'); + throw error; + } } - const downloadTasks = chatConfigs.map(async (chatConfig) => { - const { model, task } = chatConfig; + if (embeddingConfig) { + const { model } = embeddingConfig; const status = statusManager.getStatus(model); - // Skip if already downloaded if (status?.isDownloaded) { - logger.log(`Model ${model} is already downloaded, skipping...`); - return; + try { + await downloader.getLocalModel('feature-extraction', model); + logger.log( + `Embedding model ${model} is already downloaded and verified, skipping...`, + ); + return; + } catch (error) { + logger.warn( + `Embedding model ${model} was marked as downloaded but not found locally, re-downloading...`, + ); + } } try { - logger.log(`Downloading model: ${model} for task: ${task || 'chat'}`); - await downloader.downloadModel(task || 'chat', model); - - statusManager.updateStatus(model, true); - logger.log(`Successfully downloaded model: ${model}`); + logger.log(`Downloading embedding model: ${model}`); + await downloader.downloadModel('feature-extraction', model); + logger.log(`Successfully downloaded embedding model: ${model}`); } catch (error) { - logger.error(`Failed to download model ${model}:`, error.message); - statusManager.updateStatus(model, false); + logger.error( + `Failed to download embedding model ${model}:`, + error.message, + ); throw error; } - }); - - try { - await Promise.all(downloadTasks); - logger.log('All models downloaded successfully.'); - } catch (error) { - logger.error('One or more models failed to download'); - throw error; + } else { + logger.warn('No embedding model configured'); } } -export async function downloadModel(modelName: string): Promise { +export async function downloadModel( + modelName: string, + isEmbedding = false, +): Promise { const configLoader = ConfigLoader.getInstance(); const statusManager = ModelStatusManager.getInstance(); const downloader = ModelDownloader.getInstance(); - const chatConfig = configLoader.getChatConfig(modelName); - if (!chatConfig) { + let modelConfig; + let task: string; + + if (isEmbedding) { + modelConfig = configLoader.getEmbeddingConfig(); + task = 'feature-extraction'; + } else { + modelConfig = configLoader.getChatConfig(modelName); + task = modelConfig?.task || 'chat'; + } + + if (!modelConfig) { throw new Error(`Model configuration not found for: ${modelName}`); } + const status = statusManager.getStatus(modelName); + if (status?.isDownloaded) { + try { + await downloader.getLocalModel(task, modelName); + logger.log( + `Model ${modelName} is already downloaded and verified, skipping...`, + ); + return; + } catch (error) { + logger.warn( + `Model ${modelName} was marked as downloaded but not found locally, re-downloading...`, + ); + } + } + try { logger.log(`Downloading model: ${modelName}`); - await downloader.downloadModel(chatConfig.task || 'chat', modelName); - statusManager.updateStatus(modelName, true); + await downloader.downloadModel(task, modelName); logger.log(`Successfully downloaded model: ${modelName}`); } catch (error) { logger.error(`Failed to download model ${modelName}:`, error.message); - statusManager.updateStatus(modelName, false); throw error; } }