Skip to content

Commit

Permalink
chore: adding uitls ufc
Browse files Browse the repository at this point in the history
  • Loading branch information
Sma1lboy committed Dec 16, 2024
1 parent 72c6dbc commit 1340532
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 44 deletions.
24 changes: 0 additions & 24 deletions backend/src/model/__tests__/app.e2e-spec.ts

This file was deleted.

9 changes: 0 additions & 9 deletions backend/src/model/__tests__/jest-e2e.json

This file was deleted.

71 changes: 60 additions & 11 deletions backend/src/model/utils.ts
Original file line number Diff line number Diff line change
@@ -1,25 +1,74 @@
import { ChatConfig, ConfigLoader } from 'src/config/config-loader';
// model-utils.ts
import { ModelDownloader } from './model-downloader';
import { Logger } from '@nestjs/common';
import { ModelStatusManager } from './model-status';
import { ConfigLoader } from 'src/config/config-loader';

const logger = new Logger('model-utils');

export async function downloadAllModels(): Promise<void> {
// TODO: verify or download embedding model to
const configLoader = ConfigLoader.getInstance();
configLoader.validateConfig();
const chats = configLoader.get<ChatConfig[]>('');
const statusManager = ModelStatusManager.getInstance();
const downloader = ModelDownloader.getInstance();
logger.log('Loaded config:', chats);
const loadPromises = chats.map(async (chatConfig: ChatConfig) => {

const chatConfigs = configLoader.getAllChatConfigs();
logger.log('Loaded chat configurations:', chatConfigs);

if (!chatConfigs.length) {
logger.warn('No chat models configured');
return;
}

const downloadTasks = chatConfigs.map(async (chatConfig) => {
const { model, task } = chatConfig;
const status = statusManager.getStatus(model);

// Skip if already downloaded
if (status?.isDownloaded) {
logger.log(`Model ${model} is already downloaded, skipping...`);
return;
}

try {
downloader.logger.log(model, task);
await downloader.downloadModel(task, model);
logger.log(`Downloading model: ${model} for task: ${task || 'chat'}`);
await downloader.downloadModel(task || 'chat', model);

statusManager.updateStatus(model, true);
logger.log(`Successfully downloaded model: ${model}`);
} catch (error) {
downloader.logger.error(`Failed to load model ${model}:`, error.message);
logger.error(`Failed to download model ${model}:`, error.message);
statusManager.updateStatus(model, false);
throw error;
}
});
await Promise.all(loadPromises);

downloader.logger.log('All models loaded.');
try {
await Promise.all(downloadTasks);
logger.log('All models downloaded successfully.');
} catch (error) {
logger.error('One or more models failed to download');
throw error;
}
}

export async function downloadModel(modelName: string): Promise<void> {
const configLoader = ConfigLoader.getInstance();
const statusManager = ModelStatusManager.getInstance();
const downloader = ModelDownloader.getInstance();

const chatConfig = configLoader.getChatConfig(modelName);
if (!chatConfig) {
throw new Error(`Model configuration not found for: ${modelName}`);
}

try {
logger.log(`Downloading model: ${modelName}`);
await downloader.downloadModel(chatConfig.task || 'chat', modelName);
statusManager.updateStatus(modelName, true);
logger.log(`Successfully downloaded model: ${modelName}`);
} catch (error) {
logger.error(`Failed to download model ${modelName}:`, error.message);
statusManager.updateStatus(modelName, false);
throw error;
}
}

0 comments on commit 1340532

Please sign in to comment.