Skip to content

Commit

Permalink
Merge branch 'feat-adding-embedding-config' of https://github.com/Sma…
Browse files Browse the repository at this point in the history
…1lboy/codefox into feat-adding-embedding-config
  • Loading branch information
NarwhalChen committed Dec 30, 2024
2 parents fb9be7d + 2c2027d commit 0b85081
Show file tree
Hide file tree
Showing 11 changed files with 12,049 additions and 14,832 deletions.
18 changes: 10 additions & 8 deletions backend/src/config/config-loader.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import { getConfigPath } from './common-path';
import { ConfigType } from 'src/downloader/universal-utils';
import { Logger } from '@nestjs/common';


export interface ModelConfig {
model: string;
endpoint?: string;
Expand Down Expand Up @@ -54,14 +53,13 @@ export const exampleConfigContent = `{
}]
}`;


export class ConfigLoader {
readonly logger = new Logger(ConfigLoader.name);
private type: string;
private static instances: Map<ConfigType, ConfigLoader> = new Map();
private static config: AppConfig;
private readonly configPath: string;

private constructor(type: ConfigType) {
this.type = type;
this.configPath = getConfigPath();
Expand Down Expand Up @@ -119,7 +117,7 @@ export class ConfigLoader {
throw error;
}
}

this.logger.log(ConfigLoader.config);
}

Expand Down Expand Up @@ -188,7 +186,7 @@ export class ConfigLoader {
}

getAllConfigs(): EmbeddingConfig[] | ModelConfig[] | null {
let res = ConfigLoader.config[this.type];
const res = ConfigLoader.config[this.type];
return Array.isArray(res) ? res : null;
}

Expand All @@ -214,7 +212,9 @@ export class ConfigLoader {
}
});

const defaultChats = ConfigLoader.config.models.filter((chat) => chat.default);
const defaultChats = ConfigLoader.config.models.filter(
(chat) => chat.default,
);
if (defaultChats.length > 1) {
throw new Error(
'Invalid configuration: Multiple default chat configurations found',
Expand All @@ -223,7 +223,7 @@ export class ConfigLoader {
}

if (ConfigLoader.config[ConfigType.EMBEDDINGS]) {
this.logger.log(ConfigLoader.config[ConfigType.EMBEDDINGS])
this.logger.log(ConfigLoader.config[ConfigType.EMBEDDINGS]);
if (!Array.isArray(ConfigLoader.config[ConfigType.EMBEDDINGS])) {
throw new Error("Invalid configuration: 'embeddings' must be an array");
}
Expand All @@ -236,7 +236,9 @@ export class ConfigLoader {
}
});

const defaultChats = ConfigLoader.config[ConfigType.EMBEDDINGS].filter((chat) => chat.default);
const defaultChats = ConfigLoader.config[ConfigType.EMBEDDINGS].filter(
(chat) => chat.default,
);
if (defaultChats.length > 1) {
throw new Error(
'Invalid configuration: Multiple default emb configurations found',
Expand Down
12 changes: 8 additions & 4 deletions backend/src/downloader/__tests__/loadAllChatsModels.spec.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import path from 'path';
import * as fs from 'fs';
import { ConfigLoader, ModelConfig, EmbeddingConfig } from '../../config/config-loader';
import {
ConfigLoader,
ModelConfig,
EmbeddingConfig,
} from '../../config/config-loader';
import { UniversalDownloader } from '../model-downloader';
import { ConfigType, downloadAll, TaskType } from '../universal-utils';

Expand Down Expand Up @@ -41,7 +45,7 @@ describe('loadAllChatsModels with real model loading', () => {
modelConfigLoader = ConfigLoader.getInstance(ConfigType.CHATS);
embConfigLoader = ConfigLoader.getInstance(ConfigType.EMBEDDINGS);
const modelConfig: ModelConfig = {
model: "Xenova/flan-t5-small",
model: 'Xenova/flan-t5-small',
endpoint: 'http://localhost:11434/v1',
token: 'your-token-here',
task: 'text2text-generation',
Expand All @@ -54,9 +58,9 @@ describe('loadAllChatsModels with real model loading', () => {
token: 'your-token-here',
};
embConfigLoader.addConfig(embConfig);
console.log("preload starts");
console.log('preload starts');
await downloadAll();
console.log("preload successfully");
console.log('preload successfully');
}, 60000000);

it('should load real models specified in config', async () => {
Expand Down
25 changes: 14 additions & 11 deletions backend/src/downloader/embedding-downloader.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { Logger } from '@nestjs/common';
import { UniversalStatusManager } from './universal-status';
import { EmbeddingModel, FlagEmbedding } from "fastembed";
import { EmbeddingModel, FlagEmbedding } from 'fastembed';
import { getEmbDir } from 'src/config/common-path';
export class EmbeddingDownloader {
readonly logger = new Logger(EmbeddingDownloader.name);
Expand All @@ -9,28 +9,31 @@ export class EmbeddingDownloader {

public static getInstance(): EmbeddingDownloader {
if (!EmbeddingDownloader.instance) {
EmbeddingDownloader.instance = new EmbeddingDownloader();
EmbeddingDownloader.instance = new EmbeddingDownloader();
}

return EmbeddingDownloader.instance;
}

async getPipeline(model: string): Promise<any> {
if(!Object.values(EmbeddingModel).includes(model as EmbeddingModel)){
this.logger.error(`Invalid model: ${model} is not a valid EmbeddingModel.`);
if (!Object.values(EmbeddingModel).includes(model as EmbeddingModel)) {
this.logger.error(
`Invalid model: ${model} is not a valid EmbeddingModel.`,
);
return null;
}
try{
try {
const embeddingModel = await FlagEmbedding.init({
model: model as EmbeddingModel,
cacheDir: getEmbDir(),
model: model as EmbeddingModel,
cacheDir: getEmbDir(),
});
this.statusManager.updateStatus(model, true);
return embeddingModel;
}catch(error){
this.logger.error(`Failed to load local model: ${model} with error: ${error.message || error}`);
} catch (error) {
this.logger.error(
`Failed to load local model: ${model} with error: ${error.message || error}`,
);
return null;
}

}
}
8 changes: 5 additions & 3 deletions backend/src/downloader/model-downloader.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ export class UniversalDownloader {

public static getInstance(): UniversalDownloader {
if (!UniversalDownloader.instance) {
UniversalDownloader.instance = new UniversalDownloader();
UniversalDownloader.instance = new UniversalDownloader();
}
return UniversalDownloader.instance;
}
Expand Down Expand Up @@ -47,11 +47,13 @@ export class UniversalDownloader {
const pipelineInstance = await pipeline(task as PipelineType, model, {
local_files_only: true,
revision: 'local',
cache_dir: getModelsDir()
cache_dir: getModelsDir(),
});
return pipelineInstance;
} catch (error) {
this.logger.error(`Failed to load local model: ${model} with error: ${error.message || error}`);
this.logger.error(
`Failed to load local model: ${model} with error: ${error.message || error}`,
);
return null;
}
}
Expand Down
2 changes: 1 addition & 1 deletion backend/src/downloader/universal-status.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import * as fs from 'fs';
import * as path from 'path';
import { getModelStatusPath } from 'src/config/common-path';

export interface UniversalStatus {
isDownloaded: boolean;
lastChecked: Date;
Expand Down
58 changes: 38 additions & 20 deletions backend/src/downloader/universal-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import { UniversalDownloader } from './model-downloader';
import { Logger } from '@nestjs/common';
import { UniversalStatusManager } from './universal-status';
import {ModelConfig} from '../config/config-loader';
import { ModelConfig } from '../config/config-loader';
import { ConfigLoader } from 'src/config/config-loader';
import { getEmbDir, getModelsDir } from 'src/config/common-path';
import { EmbeddingDownloader } from './embedding-downloader';
Expand All @@ -13,33 +13,41 @@ export enum ConfigType {
EMBEDDINGS = 'embeddings',
CHATS = 'models',
}
export enum TaskType{
CHAT = "text2text-generation",
EMBEDDING = "feature-extraction",
export enum TaskType {
CHAT = 'text2text-generation',
EMBEDDING = 'feature-extraction',
}

export async function downloadAll(){
export async function downloadAll() {
await checkAndDownloadAllModels(ConfigType.CHATS);
console.log("embedding load starts");

console.log('embedding load starts');
await checkAndDownloadAllModels(ConfigType.EMBEDDINGS);
console.log("embedding load ends");
console.log('embedding load ends');
}
async function downloadModelForType(type: ConfigType, modelName: string, task: string) {
async function downloadModelForType(
type: ConfigType,
modelName: string,
task: string,
) {
const statusManager = UniversalStatusManager.getInstance();
const storePath = type === ConfigType.EMBEDDINGS ? getEmbDir() : getModelsDir();

const storePath =
type === ConfigType.EMBEDDINGS ? getEmbDir() : getModelsDir();

if (type === ConfigType.EMBEDDINGS) {
const embeddingDownloader = EmbeddingDownloader.getInstance();
try {
logger.log(`Downloading embedding model: ${modelName}`);
await embeddingDownloader.getPipeline(modelName);
statusManager.updateStatus(modelName, true);
logger.log(`Successfully downloaded embedding model: ${modelName}`);
console.log("embedding load finished");

console.log('embedding load finished');
} catch (error) {
logger.error(`Failed to download embedding model ${modelName}:`, error.message);
logger.error(
`Failed to download embedding model ${modelName}:`,
error.message,
);
statusManager.updateStatus(modelName, false);
throw error;
}
Expand All @@ -50,15 +58,21 @@ async function downloadModelForType(type: ConfigType, modelName: string, task: s
if (status?.isDownloaded) {
try {
await downloader.getLocalModel(task, modelName);
logger.log(`Model ${modelName} is already downloaded and verified, skipping...`);
logger.log(
`Model ${modelName} is already downloaded and verified, skipping...`,
);
return;
} catch (error) {
logger.warn(`Model ${modelName} was marked as downloaded but not found locally, re-downloading...`);
logger.warn(
`Model ${modelName} was marked as downloaded but not found locally, re-downloading...`,
);
}
}

try {
logger.log(`Downloading chat model: ${modelName} for task: ${task} in path: ${storePath}`);
logger.log(
`Downloading chat model: ${modelName} for task: ${task} in path: ${storePath}`,
);
await downloader.getPipeline(task, modelName, storePath);
statusManager.updateStatus(modelName, true);
logger.log(`Successfully downloaded chat model: ${modelName}`);
Expand All @@ -70,7 +84,9 @@ async function downloadModelForType(type: ConfigType, modelName: string, task: s
}
}

export async function checkAndDownloadAllModels(type: ConfigType): Promise<void> {
export async function checkAndDownloadAllModels(
type: ConfigType,
): Promise<void> {
const configLoader = ConfigLoader.getInstance(type);
const modelsConfig = configLoader.getAllConfigs();

Expand All @@ -83,7 +99,9 @@ export async function checkAndDownloadAllModels(type: ConfigType): Promise<void>

const downloadTasks = modelsConfig.map(async (config) => {
const { model, task } = config;
const taskType = task || (type === ConfigType.EMBEDDINGS ? TaskType.EMBEDDING : TaskType.CHAT);
const taskType =
task ||
(type === ConfigType.EMBEDDINGS ? TaskType.EMBEDDING : TaskType.CHAT);
await downloadModelForType(type, model, taskType);
});

Expand All @@ -94,4 +112,4 @@ export async function checkAndDownloadAllModels(type: ConfigType): Promise<void>
logger.error(`One or more ${type} models failed to download.`);
throw error;
}
}
}
19 changes: 11 additions & 8 deletions backend/src/embedding/__tests__/loadAllEmbModels.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,18 @@ Array.isArray = jest.fn((type: any): type is any[] => {

describe('testing embedding provider', () => {
it('should load real models specified in config', async () => {
let documents = [
"passage: Hello, World!",
"query: Hello, World!",
"passage: This is an example passage.",
const documents = [
'passage: Hello, World!',
'query: Hello, World!',
'passage: This is an example passage.',
// You can leave out the prefix but it's recommended
"fastembed-js is licensed under MIT"
];

await localEmbProvider.generateEmbResponse(EmbeddingModel.BGEBaseENV15, documents);
'fastembed-js is licensed under MIT',
];

await localEmbProvider.generateEmbResponse(
EmbeddingModel.BGEBaseENV15,
documents,
);
}, 6000000);


Expand Down
14 changes: 6 additions & 8 deletions backend/src/embedding/local-embedding-provider.ts
Original file line number Diff line number Diff line change
@@ -1,29 +1,27 @@
import { Logger } from '@nestjs/common';
import { EmbeddingModel} from "fastembed";
import { EmbeddingModel } from 'fastembed';
import { EmbeddingDownloader } from 'src/downloader/embedding-downloader';

export class localEmbProvider {
private static logger = new Logger(localEmbProvider.name);

static async generateEmbResponse(
model: string, message: string[]
){
let embLoader = EmbeddingDownloader.getInstance();
try{
static async generateEmbResponse(model: string, message: string[]) {
const embLoader = EmbeddingDownloader.getInstance();
try {
const embeddingModel = await embLoader.getPipeline(model);
const embeddings = embeddingModel.embed(message);

for await (const batch of embeddings) {
console.log(batch);
}
}catch(error){
} catch (error) {
this.logger.log(`error when using ${model} api`);
}
}

static async getEmbList() {
Object.values(EmbeddingModel).forEach((model) => {
this.logger.log(model);
})
});
}
}
Loading

0 comments on commit 0b85081

Please sign in to comment.