From c197796e1de5a87170be68383b03e71c3a6b6aad Mon Sep 17 00:00:00 2001 From: KevKibe Date: Tue, 30 Apr 2024 17:39:32 +0300 Subject: [PATCH] update retrieve_and_generate method with Query result pydantic model --- src/_cohere/doc_index.py | 24 ++++++++++++++---------- src/_google/doc_index.py | 34 ++++++++++++++-------------------- src/_openai/doc_index.py | 35 +++++++++++++++-------------------- 3 files changed, 43 insertions(+), 50 deletions(-) diff --git a/src/_cohere/doc_index.py b/src/_cohere/doc_index.py index 82d8685..8304f7b 100644 --- a/src/_cohere/doc_index.py +++ b/src/_cohere/doc_index.py @@ -16,7 +16,9 @@ from langchain_core.prompts import PromptTemplate from langchain_core.output_parsers import StrOutputParser from operator import itemgetter -from _cohere.config import Config +from utils.config import Config +from utils.response_model import QueryResult +from langchain.output_parsers import PydanticOutputParser class CoherePineconeIndexer: """ @@ -218,7 +220,7 @@ def initialize_vectorstore(self, index_name): vectorstore = PineconeVectorStore(index,embed, "text") return vectorstore - def retrieve_and_generate(self,query: str, index_name: str, model_name: str = 'gpt-3.5-turbo-1106', top_k: int =5): + def retrieve_and_generate(self,query: str, vector_store: str, model_name: str = 'gpt-3.5-turbo-1106', top_k: int =5): """ Retrieve documents from the Pinecone index and generate a response. Args: @@ -228,17 +230,19 @@ def retrieve_and_generate(self,query: str, index_name: str, model_name: str = 'g top_k: The number of documents to retrieve from the index : defaults to 5 """ llm = Cohere(model="command", cohere_api_key = self.cohere_api_key) - rag_prompt = PromptTemplate(template = Config.template_str, input_variables = ["query", "context"]) - - vector_store = self.initialize_vectorstore(index_name) + parser = PydanticOutputParser(pydantic_object=QueryResult) + rag_prompt = PromptTemplate(template = Config.template_str, + input_variables = ["query", "context"], + partial_variables={"format_instructions": parser.get_format_instructions()}) retriever = vector_store.as_retriever(search_kwargs = {"k": top_k}) + rag_chain = ( {"context": itemgetter("query")| retriever, - "query": itemgetter("query"), - } - | rag_prompt - | llm - | StrOutputParser() + "query": itemgetter("query"), + } + | rag_prompt + | llm + | parser ) return rag_chain.invoke({"query": query}) diff --git a/src/_google/doc_index.py b/src/_google/doc_index.py index 71e49f8..46e2fdb 100644 --- a/src/_google/doc_index.py +++ b/src/_google/doc_index.py @@ -14,11 +14,11 @@ from langchain_community.document_loaders import UnstructuredHTMLLoader from langchain_pinecone import PineconeVectorStore from langchain_core.prompts import PromptTemplate -from langchain_core.output_parsers import StrOutputParser from operator import itemgetter from langchain_google_genai import ChatGoogleGenerativeAI -from _google.config import Config - +from utils.config import Config +from utils.response_model import QueryResult +from langchain.output_parsers import PydanticOutputParser class GooglePineconeIndexer: """ @@ -97,31 +97,22 @@ def load_document(self, file_url: str) -> List[str]: """ pages = [] file_path = Path(file_url) - - # Determine file type and use the appropriate loader file_extension = file_path.suffix - - # Load and split PDF files if file_extension == ".pdf": loader = PyPDFLoader(file_url) pages = loader.load_and_split() - # Load and split DOCX and DOC files elif file_extension in ('.docx', '.doc'): loader = UnstructuredWordDocumentLoader(file_url) pages = loader.load_and_split() - # Load and split Markdown files elif file_extension == '.md': loader = UnstructuredMarkdownLoader(file_url) pages = loader.load_and_split() - # Load and split HTML files - elif file_extension == '.html': loader = UnstructuredHTMLLoader(file_url) pages = loader.load_and_split() - # Return the list of pages return pages def tiktoken_len(self, text: str) -> int: @@ -242,7 +233,7 @@ def initialize_vectorstore(self, index_name): return vectorstore - def retrieve_and_generate(self,query: str, index_name: str, model_name: str = 'gemini-pro', top_k: int =5): + def retrieve_and_generate(self,query: str, vector_store: str, model_name: str = 'gemini-pro', top_k: int =5): """ Retrieve documents from the Pinecone index and generate a response. Args: @@ -252,16 +243,19 @@ def retrieve_and_generate(self,query: str, index_name: str, model_name: str = 'g top_k: The number of documents to retrieve from the index : defaults to 5 """ llm = ChatGoogleGenerativeAI(model = Config.default_google_model, google_api_key=self.google_api_key) - rag_prompt = PromptTemplate(template = Config.template_str, input_variables = ["query", "context"]) - vector_store = self.initialize_vectorstore(index_name) + parser = PydanticOutputParser(pydantic_object=QueryResult) + rag_prompt = PromptTemplate(template = Config.template_str, + input_variables = ["query", "context"], + partial_variables={"format_instructions": parser.get_format_instructions()}) retriever = vector_store.as_retriever(search_kwargs = {"k": top_k}) + rag_chain = ( {"context": itemgetter("query")| retriever, - "query": itemgetter("query"), - } - | rag_prompt - | llm - | StrOutputParser() + "query": itemgetter("query"), + } + | rag_prompt + | llm + | parser ) return rag_chain.invoke({"query": query}) diff --git a/src/_openai/doc_index.py b/src/_openai/doc_index.py index c0f1f4e..3decd5e 100644 --- a/src/_openai/doc_index.py +++ b/src/_openai/doc_index.py @@ -13,12 +13,11 @@ from langchain_community.document_loaders import UnstructuredHTMLLoader from langchain_pinecone import PineconeVectorStore from langchain_core.prompts import PromptTemplate -from langchain_core.output_parsers import StrOutputParser from operator import itemgetter from langchain_openai import ChatOpenAI -from _openai.config import Config - - +from utils.config import Config +from utils.response_model import QueryResult +from langchain.output_parsers import PydanticOutputParser class OpenaiPineconeIndexer: """ @@ -99,30 +98,23 @@ def load_document(self, file_url: str) -> List[str]: pages = [] file_path = Path(file_url) - # Determine file type and use the appropriate loader file_extension = file_path.suffix - # Load and split PDF files if file_extension == ".pdf": loader = PyPDFLoader(file_url) pages = loader.load_and_split() - # Load and split DOCX and DOC files elif file_extension in ('.docx', '.doc'): loader = UnstructuredWordDocumentLoader(file_url) pages = loader.load_and_split() - # Load and split Markdown files elif file_extension == '.md': loader = UnstructuredMarkdownLoader(file_url) pages = loader.load_and_split() - # Load and split HTML files elif file_extension == '.html': loader = UnstructuredHTMLLoader(file_url) pages = loader.load_and_split() - - # Return the list of pages return pages @@ -240,7 +232,7 @@ def initialize_vectorstore(self, index_name): return vectorstore - def retrieve_and_generate(self,query: str, index_name: str, model_name: str = 'gpt-3.5-turbo-1106', top_k: int =5): + def retrieve_and_generate(self,query: str, vector_store: str, model_name: str = 'gpt-3.5-turbo-1106', top_k: int =5): """ Retrieve documents from the Pinecone index and generate a response. Args: @@ -250,17 +242,19 @@ def retrieve_and_generate(self,query: str, index_name: str, model_name: str = 'g top_k: The number of documents to retrieve from the index : defaults to 5 """ llm = ChatOpenAI(model = Config.default_openai_model, openai_api_key = self.openai_api_key) - rag_prompt = PromptTemplate(template = Config.template_str, input_variables = ["query", "context"]) - - vector_store = self.initialize_vectorstore(index_name) + parser = PydanticOutputParser(pydantic_object=QueryResult) + rag_prompt = PromptTemplate(template = Config.template_str, + input_variables = ["query", "context"], + partial_variables={"format_instructions": parser.get_format_instructions()}) retriever = vector_store.as_retriever(search_kwargs = {"k": top_k}) + rag_chain = ( {"context": itemgetter("query")| retriever, - "query": itemgetter("query"), - } - | rag_prompt - | llm - | StrOutputParser() + "query": itemgetter("query"), + } + | rag_prompt + | llm + | parser ) return rag_chain.invoke({"query": query}) @@ -270,3 +264,4 @@ def retrieve_and_generate(self,query: str, index_name: str, model_name: str = 'g +