Skip to content

Commit

Permalink
Remove local storage and enable Elasticsearch hybrid query mode (#60)
Browse files Browse the repository at this point in the history
* Add gpu dockerfile

* Fix bug

* Fix gb2312

* Update embedding batch size

* Set default embedding and llm model

* Update docker tag

* Fix hologres check

* Update registry

* Fix bug

* Fix tests

* Add queue

* Update batch size

* Add async interface

* Fix index conflict

* Add change index parameter for FAISS

* Fix batch size

* Update
  • Loading branch information
moria97 authored Jun 13, 2024
1 parent ba1132a commit daba1f5
Show file tree
Hide file tree
Showing 26 changed files with 1,217 additions and 180 deletions.
10 changes: 6 additions & 4 deletions src/pai_rag/app/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
from typing import List, Dict


class VectorDbConfig(BaseModel):
faiss_path: str | None = None


class RagQuery(BaseModel):
question: str
temperature: float | None = 0.1
vector_topk: int | None = 3
score_threshold: float | None = 0.5
chat_history: List[Dict[str, str]] | None = None
session_id: str | None = None
vector_db: VectorDbConfig | None = None


class LlmQuery(BaseModel):
Expand All @@ -20,8 +23,7 @@ class LlmQuery(BaseModel):

class RetrievalQuery(BaseModel):
question: str
topk: int | None = 3
score_threshold: float | None = 0.5
vector_db: VectorDbConfig | None = None


class RagResponse(BaseModel):
Expand Down
76 changes: 43 additions & 33 deletions src/pai_rag/core/rag_application.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from pai_rag.data.rag_dataloader import RagDataLoader
from pai_rag.utils.oss_cache import OssCache
from pai_rag.modules.module_registry import module_registry
from pai_rag.evaluations.batch_evaluator import BatchEvaluator
from pai_rag.app.api.models import (
Expand All @@ -24,49 +22,34 @@ def uuid_generator() -> str:
class RagApplication:
def __init__(self):
self.name = "RagApplication"
logging.basicConfig(level=logging.INFO) # 将日志级别设置为INFO
self.logger = logging.getLogger(__name__)

def initialize(self, config):
self.config = config

module_registry.init_modules(self.config)
self.index = module_registry.get_module("IndexModule")
self.llm = module_registry.get_module("LlmModule")
self.retriever = module_registry.get_module("RetrieverModule")
self.chat_store = module_registry.get_module("ChatStoreModule")
self.query_engine = module_registry.get_module("QueryEngineModule")
self.chat_engine_factory = module_registry.get_module("ChatEngineFactoryModule")
self.llm_chat_engine_factory = module_registry.get_module(
"LlmChatEngineFactoryModule"
)
self.data_reader_factory = module_registry.get_module("DataReaderFactoryModule")
self.agent = module_registry.get_module("AgentModule")

oss_cache = None
if config.get("oss_cache", None):
oss_cache = OssCache(config.oss_cache)
node_parser = module_registry.get_module("NodeParserModule")

self.data_loader = RagDataLoader(
self.data_reader_factory, node_parser, self.index, oss_cache
)
self.logger.info("RagApplication initialized successfully.")

def reload(self, config):
self.initialize(config)
self.logger.info("RagApplication reloaded successfully.")

# TODO: 大量文件上传实现异步添加
def load_knowledge(self, file_dir, enable_qa_extraction=False):
self.data_loader.load(file_dir, enable_qa_extraction)
async def load_knowledge(self, file_dir, enable_qa_extraction=False):
data_loader = module_registry.get_module_with_config(
"DataLoaderModule", self.config
)
await data_loader.aload(file_dir, enable_qa_extraction)

async def aquery_retrieval(self, query: RetrievalQuery) -> RetrievalResponse:
if not query.question:
return RetrievalResponse(docs=[])

query_bundle = QueryBundle(query.question)
node_results = await self.query_engine.aretrieve(query_bundle)

query_engine = module_registry.get_module_with_config(
"QueryEngineModule", self.config
)
node_results = await query_engine.aretrieve(query_bundle)

docs = [
ContextDoc(
Expand Down Expand Up @@ -96,11 +79,24 @@ async def aquery(self, query: RagQuery) -> RagResponse:
answer="Empty query. Please input your question.", session_id=session_id
)

query_chat_engine = self.chat_engine_factory.get_chat_engine(
sessioned_config = self.config
if query.vector_db and query.vector_db.faiss_path:
sessioned_config = self.config.copy()
sessioned_config.index.update({"persist_path": query.vector_db.faiss_path})
print(sessioned_config)

chat_engine_factory = module_registry.get_module_with_config(
"ChatEngineFactoryModule", sessioned_config
)
query_chat_engine = chat_engine_factory.get_chat_engine(
session_id, query.chat_history
)
response = await query_chat_engine.achat(query.question)
self.chat_store.persist()

chat_store = module_registry.get_module_with_config(
"ChatStoreModule", sessioned_config
)
chat_store.persist()
return RagResponse(answer=response.response, session_id=session_id)

async def aquery_llm(self, query: LlmQuery) -> LlmResponse:
Expand All @@ -122,11 +118,18 @@ async def aquery_llm(self, query: LlmQuery) -> LlmResponse:
answer="Empty query. Please input your question.", session_id=session_id
)

llm_chat_engine = self.llm_chat_engine_factory.get_chat_engine(
llm_chat_engine_factory = module_registry.get_module_with_config(
"LlmChatEngineFactoryModule", self.config
)
llm_chat_engine = llm_chat_engine_factory.get_chat_engine(
session_id, query.chat_history
)
response = await llm_chat_engine.achat(query.question)
self.chat_store.persist()

chat_store = module_registry.get_module_with_config(
"ChatStoreModule", self.config
)
chat_store.persist()
return LlmResponse(answer=response.response, session_id=session_id)

async def aquery_agent(self, query: LlmQuery) -> LlmResponse:
Expand All @@ -143,11 +146,18 @@ async def aquery_agent(self, query: LlmQuery) -> LlmResponse:
if not query.question:
return LlmResponse(answer="Empty query. Please input your question.")

response = await self.agent.achat(query.question)
agent = module_registry.get_module_with_config("AgentModule", self.config)
response = await agent.achat(query.question)
return LlmResponse(answer=response.response)

async def batch_evaluate_retrieval_and_response(self, type):
batch_eval = BatchEvaluator(self.config, self.retriever, self.query_engine)
retriever = module_registry.get_module_with_config(
"RetrieverModule", self.config
)
query_engine = module_registry.get_module_with_config(
"QueryEngineModule", self.config
)
batch_eval = BatchEvaluator(self.config, retriever, query_engine)
df, eval_res_avg = await batch_eval.batch_retrieval_response_aevaluation(
type=type, workers=2, save_to_file=True
)
Expand Down
10 changes: 7 additions & 3 deletions src/pai_rag/core/rag_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from pai_rag.app.web.view_model import view_model
from openinference.instrumentation import using_attributes
from typing import Any, Dict
import logging

logger = logging.getLogger(__name__)


def trace_correlation_id(function):
Expand Down Expand Up @@ -48,14 +51,15 @@ def reload(self, new_config: Any):
self.rag.reload(self.rag_configuration.get_value())
self.rag_configuration.persist()

def add_knowledge_async(
async def add_knowledge_async(
self, task_id: str, file_dir: str, enable_qa_extraction: bool = False
):
self.tasks_status[task_id] = "processing"
try:
self.rag.load_knowledge(file_dir, enable_qa_extraction)
await self.rag.load_knowledge(file_dir, enable_qa_extraction)
self.tasks_status[task_id] = "completed"
except Exception:
except Exception as ex:
logger.error(f"Upload failed: {ex}")
self.tasks_status[task_id] = "failed"

def get_task_status(self, task_id: str) -> str:
Expand Down
4 changes: 2 additions & 2 deletions src/pai_rag/data/rag_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def __init__(
):
self.datareader_factory = datareader_factory
self.node_parser = node_parser
self.index = index
self.oss_cache = oss_cache
self.index = index

if use_local_qa_model:
# API暂不支持此选项
Expand Down Expand Up @@ -111,7 +111,7 @@ async def aload(self, file_directory: str, enable_qa_extraction: bool):

logger.info("[DataReader] Start inserting to index.")

self.index.insert_nodes(nodes)
await self.index.insert_nodes_async(nodes)
self.index.storage_context.persist(persist_dir=store_path.persist_path)
logger.info(f"Inserted {len(nodes)} nodes successfully.")
return
Expand Down
15 changes: 2 additions & 13 deletions src/pai_rag/data/rag_datapipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@
import click
import os
from pathlib import Path
from pai_rag.data.rag_dataloader import RagDataLoader
from pai_rag.core.rag_configuration import RagConfiguration
from pai_rag.utils.oss_cache import OssCache
from pai_rag.modules.module_registry import module_registry


class RagDataPipeline:
def __init__(self, data_loader: RagDataLoader):
def __init__(self, data_loader):
self.data_loader = data_loader

async def ingest_from_folder(self, folder_path: str, enable_qa_extraction: bool):
Expand All @@ -23,16 +21,7 @@ def __init_data_pipeline(use_local_qa_model):
config = RagConfiguration.from_file(config_file).get_value()
module_registry.init_modules(config)

oss_cache = None
if config.get("oss_cache", None):
oss_cache = OssCache(config.oss_cache)
node_parser = module_registry.get_module("NodeParserModule")
index = module_registry.get_module("IndexModule")
data_reader_factory = module_registry.get_module("DataReaderFactoryModule")

data_loader = RagDataLoader(
data_reader_factory, node_parser, index, oss_cache, use_local_qa_model
)
data_loader = module_registry.get_module_with_config("DataLoaderModule", config)
return RagDataPipeline(data_loader)


Expand Down
4 changes: 2 additions & 2 deletions src/pai_rag/evaluations/batch_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,8 @@ def __init_evaluator_pipeline():
config = RagConfiguration.from_file(config_file).get_value()
module_registry.init_modules(config)

retriever = module_registry.get_module("RetrieverModule")
query_engine = module_registry.get_module("QueryEngineModule")
retriever = module_registry.get_module_with_config("RetrieverModule", config)
query_engine = module_registry.get_module_with_config("QueryEngineModule", config)

return BatchEvaluator(config, retriever, query_engine)

Expand Down
23 changes: 20 additions & 3 deletions src/pai_rag/evaluations/dataset_generation/generate_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
from pathlib import Path
from pai_rag.core.rag_configuration import RagConfiguration
from pai_rag.modules.module_registry import module_registry
from llama_index.core.prompts.prompt_type import PromptType
Expand All @@ -16,8 +17,13 @@
DEFAULT_TEXT_QA_PROMPT_TMPL,
DEFAULT_QUESTION_GENERATION_QUERY,
)

import json

_BASE_DIR = Path(__file__).parent.parent.parent
DEFAULT_EVAL_CONFIG_FILE = os.path.join(_BASE_DIR, "config/settings.toml")
DEFAULT_EVAL_DATA_FOLDER = "tests/testdata/paul_graham"


class GenerateDatasetPipeline(ModifiedRagDatasetGenerator):
def __init__(
Expand All @@ -29,11 +35,22 @@ def __init__(
show_progress: Optional[bool] = True,
) -> None:
self.name = "GenerateDatasetPipeline"
self.nodes = list(
module_registry.get_module("IndexModule").docstore.docs.values()
self.config = RagConfiguration.from_file(DEFAULT_EVAL_CONFIG_FILE).get_value()

# load nodes
module_registry.init_modules(self.config)
datareader_factory = module_registry.get_module_with_config(
"DataReaderFactoryModule", self.config
)
self.node_parser = module_registry.get_module_with_config(
"NodeParserModule", self.config
)
reader = datareader_factory.get_reader(DEFAULT_EVAL_DATA_FOLDER)
docs = reader.load_data()
self.nodes = self.node_parser.get_nodes_from_documents(docs)

self.num_questions_per_chunk = num_questions_per_chunk
self.llm = module_registry.get_module("LlmModule")
self.llm = module_registry.get_module_with_config("LlmModule", self.config)
self.text_question_template = PromptTemplate(text_question_template_str)
self.text_qa_template = PromptTemplate(
text_qa_template_str, prompt_type=PromptType.QUESTION_ANSWER
Expand Down
5 changes: 5 additions & 0 deletions src/pai_rag/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pai_rag.modules.embedding.embedding import EmbeddingModule
from pai_rag.modules.llm.llm_module import LlmModule
from pai_rag.modules.datareader.data_loader import DataLoaderModule
from pai_rag.modules.datareader.datareader_factory import DataReaderFactoryModule
from pai_rag.modules.index.index import IndexModule
from pai_rag.modules.nodeparser.node_parser import NodeParserModule
Expand All @@ -12,10 +13,13 @@
from pai_rag.modules.chat.chat_store import ChatStoreModule
from pai_rag.modules.agent.agent import AgentModule
from pai_rag.modules.tool.tool import ToolModule
from pai_rag.modules.cache.oss_cache import OssCacheModule


ALL_MODULES = [
"EmbeddingModule",
"LlmModule",
"DataLoaderModule",
"DataReaderFactoryModule",
"IndexModule",
"NodeParserModule",
Expand All @@ -28,6 +32,7 @@
"LlmChatEngineFactoryModule",
"AgentModule",
"ToolModule",
"OssCacheModule",
]

__all__ = ALL_MODULES + ["ALL_MODULES"]
26 changes: 5 additions & 21 deletions src/pai_rag/modules/base/configurable_module.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
from abc import ABC, abstractmethod
from typing import Dict, List, Any
import logging

DEFAULT_INSTANCE_KEY = "__DEFAULT_INSTANCE__"


logger = logging.getLogger(__name__)


class ConfigurableModule(ABC):
"""Configurable Module
Helps to create instances according to configuration.
"""

def __init__(self):
self.__params_map = {}
self.__instance_map = {}

@abstractmethod
def _create_new_instance(self, new_params: Dict[str, Any]):
raise NotImplementedError
Expand All @@ -24,20 +24,4 @@ def get_dependencies() -> List[str]:
raise NotImplementedError

def get_or_create(self, new_params: Dict[str, Any]):
return self.get_or_create_by_name(new_params=new_params)

def get_or_create_by_name(
self, new_params: Dict[str, Any], name: str = DEFAULT_INSTANCE_KEY
):
# Create new instance when initializing or config changed.
if (
self.__params_map.get(name, None) is None
or self.__params_map[name] != new_params
):
print(f"{self.__class__.__name__} param changed, updating")
self.__instance_map[name] = self._create_new_instance(new_params)
self.__params_map[name] = new_params
else:
print(f"{self.__class__.__name__} param unchanged, skipping")

return self.__instance_map[name]
return self._create_new_instance(new_params)
20 changes: 20 additions & 0 deletions src/pai_rag/modules/cache/oss_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from typing import Any, Dict, List
from pai_rag.utils.oss_cache import OssCache
from pai_rag.modules.base.configurable_module import ConfigurableModule
from pai_rag.modules.base.module_constants import MODULE_PARAM_CONFIG
import logging

logger = logging.getLogger(__name__)


class OssCacheModule(ConfigurableModule):
@staticmethod
def get_dependencies() -> List[str]:
return []

def _create_new_instance(self, new_params: Dict[str, Any]):
cache_config = new_params[MODULE_PARAM_CONFIG]
if cache_config:
return OssCache(cache_config)
else:
return None
Loading

0 comments on commit daba1f5

Please sign in to comment.