From fe1bc437c30d2379162ada87d2a437adc366b015 Mon Sep 17 00:00:00 2001 From: suluyana <110878454+suluyana@users.noreply.github.com> Date: Thu, 13 Jun 2024 09:41:28 +0800 Subject: [PATCH] Feat/rag use llm (#483) Co-authored-by: skyline2006 Co-authored-by: Zhikaiiii <1658973216@qq.com> --- apps/agentfabric/app.py | 10 ++++------ apps/agentfabric/appBot.py | 7 ++----- apps/agentfabric/server.py | 12 ++++++------ modelscope_agent/memory/memory_with_rag.py | 16 ++++++++++++---- modelscope_agent/rag/knowledge.py | 22 +++++++++++++++++----- tests/test_rag.py | 12 ++++++++++++ 6 files changed, 53 insertions(+), 26 deletions(-) diff --git a/apps/agentfabric/app.py b/apps/agentfabric/app.py index 090705a6f..378511010 100644 --- a/apps/agentfabric/app.py +++ b/apps/agentfabric/app.py @@ -603,16 +603,14 @@ def preview_send_message(chatbot, input, _state, uuid_str): # get chat history from memory history = user_memory.get_history() - # get knowledge from memory, currently get one file - uploaded_file = None - if len(append_files) > 0: - uploaded_file = append_files[0] + use_llm = True if len(user_agent.function_list) else False ref_doc = user_memory.run( query=input.text, - url=uploaded_file, + url=append_files, max_token=4000, top_k=2, - checked=True) + checked=True, + use_llm=use_llm) response = '' try: diff --git a/apps/agentfabric/appBot.py b/apps/agentfabric/appBot.py index d3bc83534..9b3bf8346 100644 --- a/apps/agentfabric/appBot.py +++ b/apps/agentfabric/appBot.py @@ -124,12 +124,9 @@ def send_message(chatbot, input, _state): # get short term memory history history = user_memory.get_history() - # get long term memory knowledge, currently get one file - uploaded_file = None - if len(append_files) > 0: - uploaded_file = append_files[0] + use_llm = True if len(user_agent.function_list) else False ref_doc = user_memory.run( - query=input.text, url=uploaded_file, checked=True) + query=input.text, url=append_files, checked=True, use_llm=use_llm) response = '' try: diff --git a/apps/agentfabric/server.py b/apps/agentfabric/server.py index 4518e036b..7c6034a31 100644 --- a/apps/agentfabric/server.py +++ b/apps/agentfabric/server.py @@ -402,15 +402,15 @@ def generate(): f'load history method: time consumed {time.time() - start_time}' ) - # get knowledge from memory, currently get one file - uploaded_file = None - if len(file_paths) > 0: - uploaded_file = file_paths[0] + use_llm = True if len(user_agent.function_list) else False ref_doc = user_memory.run( - query=input_content, url=uploaded_file, checked=True) + query=input_content, + url=file_paths, + checked=True, + use_llm=use_llm) logger.info( f'load knowledge method: time consumed {time.time() - start_time}, ' - f'the uploaded_file name is {uploaded_file}') # noqa + f'the uploaded_file name is {file_paths}') # noqa response = '' diff --git a/modelscope_agent/memory/memory_with_rag.py b/modelscope_agent/memory/memory_with_rag.py index de9e773fb..7e540a6b5 100644 --- a/modelscope_agent/memory/memory_with_rag.py +++ b/modelscope_agent/memory/memory_with_rag.py @@ -42,13 +42,21 @@ def _run(self, query: str = None, url: str = None, max_token: int = 4000, - top_k: int = 3, **kwargs) -> Union[str, Iterator[str]]: if isinstance(url, str): url = [url] if url and len(url): self.store_knowledge.add(files=url) if query: - summary_result = self.store_knowledge.run(query, files=url) - # limit length - return summary_result[0:max_token - 1] + summary_result = self.store_knowledge.run( + query, files=url, **kwargs) + # limit length + if isinstance(summary_result, list): + single_max_token = int(max_token / len(summary_result)) + concatenated_records = '\n'.join([ + record[0:single_max_token - 1] for record in summary_result + ]) + + return concatenated_records + else: + return summary_result[0:max_token - 1] diff --git a/modelscope_agent/rag/knowledge.py b/modelscope_agent/rag/knowledge.py index 793ccd14a..bbc15fe34 100644 --- a/modelscope_agent/rag/knowledge.py +++ b/modelscope_agent/rag/knowledge.py @@ -10,7 +10,8 @@ from llama_index.core.postprocessor.types import BaseNodePostprocessor from llama_index.core.query_engine import BaseQueryEngine, RetrieverQueryEngine from llama_index.core.readers.base import BaseReader -from llama_index.core.schema import Document, QueryBundle, TransformComponent +from llama_index.core.schema import (Document, MetadataMode, QueryBundle, + TransformComponent) from llama_index.core.settings import Settings from llama_index.core.vector_stores.types import (MetadataFilter, MetadataFilters) @@ -299,15 +300,26 @@ def set_filter(self, files: List[str]): ] retriever._filters = MetadataFilters(filters=filters) - def run(self, query: str, files: List[str] = [], **kwargs) -> str: + def run(self, + query: str, + files: List[str] = [], + use_llm: bool = True, + **kwargs) -> Union[str, List[str]]: query_bundle = FileQueryBundle(query) if isinstance(files, str): files = [files] if files and len(files) > 0: self.set_filter(files) - - return str(self.query_engine.query(query_bundle, **kwargs)) + if use_llm: + return str(self.query_engine.query(query_bundle)) + else: + nodes = self.query_engine.retrieve(query_bundle) + msg = [ + n.node.get_content(metadata_mode=MetadataMode.LLM) + for n in nodes + ] + return msg def add(self, files: List[str]): if isinstance(files, str): @@ -329,4 +341,4 @@ def add(self, files: List[str]): knowledge = BaseKnowledge('./data2', use_cache=False, llm=llm) knowledge.add(['./data/常见QA.pdf']) - print(knowledge.run('高德天气API申请', files=['常见QA.pdf'])) + print(knowledge.run('高德天气API申请', files=['常见QA.pdf'], use_llm=False)) diff --git a/tests/test_rag.py b/tests/test_rag.py index d7901b53c..72b9e0000 100644 --- a/tests/test_rag.py +++ b/tests/test_rag.py @@ -78,3 +78,15 @@ def test_memory_with_rag_multi_modal(): summary_str = memory.run('我想看rag的流程图') print(summary_str) assert 'rag.png' in summary_str + + +def test_memory_with_rag_no_use_llm(): + memory = MemoryWithRag(use_knowledge_cache=False) + + summary_str = memory.run( + query='模型大文件上传失败怎么办', + url=['tests/samples/modelscope_qa_2.txt'], + use_llm=False) + print(summary_str) + assert 'file_path' in summary_str + assert 'git-lfs' in summary_str