diff --git a/README.md b/README.md index 9a0fcd3..de33a53 100644 --- a/README.md +++ b/README.md @@ -87,6 +87,7 @@ response = pinecone_indexer.retrieve_and_generate( vector_store = vectorstore, query = query, top_k = "number of sources to retrieve", # Default is 3 + pydantic_parser=True # Whether to use Pydantic parsing for the generated response (default is True) rerank_model = "reranking model" # Default is 'flashrank' Other models available Docs:https://github.com/AnswerDotAI/rerankers ) response diff --git a/src/_google/doc_index.py b/src/_google/doc_index.py index eb32129..55fffc9 100644 --- a/src/_google/doc_index.py +++ b/src/_google/doc_index.py @@ -243,7 +243,7 @@ def initialize_vectorstore(self, index_name: str) -> PineconeVectorStore: model="models/embedding-001", google_api_key=self.google_api_key ) - vectorstore = PineconeVectorStore(index, embed, "text") + vectorstore = PineconeVectorStore(index=index, index_name=index_name, embedding=embed) return vectorstore @@ -252,6 +252,7 @@ def retrieve_and_generate( query: str, vector_store: str, top_k: int =3, + pydantic_parser: bool = True, rerank_model: str = 'flashrank', model_type: Optional[str] = None, lang: Optional[str] = None, @@ -264,6 +265,7 @@ def retrieve_and_generate( query (str): The query from the user. vector_store (str): The name of the Pinecone index. top_k (int, optional): The number of documents to retrieve from the index (default is 3). + pydantic_parser (bool, optional): Whether to use Pydantic parsing for the generated response (default is True). rerank_model (str, optional): The name or path of the model to use for ranking (default is 'flashrank'). model_type (str, optional): The type of the model (e.g., 'cross-encoder', 'flashrank', 't5', etc.). lang (str, optional): The language for multilingual models. @@ -275,7 +277,7 @@ def retrieve_and_generate( Raises: ValueError: If an unsupported model_type is provided. """ - llm = ChatGoogleGenerativeAI(model = Config.default_google_model, google_api_key=self.google_api_key) + llm = ChatGoogleGenerativeAI(model = Config.default_google_model, google_api_key=self.google_api_key, temperature=0.7, top_p=0.85, convert_system_message_to_human=True) parser = PydanticOutputParser(pydantic_object=QueryResult) rag_prompt = PromptTemplate(template = Config.template_str, input_variables = ["query", "context"], @@ -283,23 +285,30 @@ def retrieve_and_generate( retriever = vector_store.as_retriever() ranker = RerankerConfig.get_ranker( rerank_model, - model_type, lang, - api_key + api_key, + model_type, ) compressor = ranker.as_langchain_compressor(k=top_k) compression_retriever = ContextualCompressionRetriever( base_compressor=compressor, base_retriever=retriever ) - - rag_chain = ( - {"context": itemgetter("query")| compression_retriever, - "query": itemgetter("query"), - } - | rag_prompt - | llm - | parser - ) - - return rag_chain.invoke({"query": query}) + if pydantic_parser: + rag_chain = ( + {"context": itemgetter("query")| compression_retriever, + "query": itemgetter("query"), + } + | rag_prompt + | llm + | parser + ) + else: + rag_chain = ( + {"context": itemgetter("query")| compression_retriever, + "query": itemgetter("query"), + } + | rag_prompt + | llm + ) + return rag_chain.invoke({"query": query}).content diff --git a/src/tests/googleindex_test.py b/src/tests/googleindex_test.py index 110dddf..774eccb 100644 --- a/src/tests/googleindex_test.py +++ b/src/tests/googleindex_test.py @@ -59,19 +59,19 @@ def test_03_initialize_vectorstore(self): vectorstore = self.indexer.initialize_vectorstore(self.index_name) self.assertIsInstance(vectorstore, PineconeVectorStore) - # def test_04_retrieve_and_generate(self): - # """ - # Test initializing the vector store and assert its type. - # """ - # vector_store = self.indexer.initialize_vectorstore(self.index_name) - # response = self.indexer.retrieve_and_generate( - # query = "tell me something from the context texts", - # vector_store = vector_store, - # top_k = 3, - # rerank_model = "t5" - # ) - # print(response) - # self.assertIsNotNone(response, "The retriever response should not be None.") + def test_04_retrieve_and_generate(self): + """ + Test initializing the vector store and assert its type. + """ + vectorstore = self.indexer.initialize_vectorstore(self.index_name) + response = self.indexer.retrieve_and_generate( + query = "give a short summary of the introduction", + vector_store = vectorstore, + top_k = 3, + pydantic_parser=False, + rerank_model = "flashrank" + ) + self.assertIsNotNone(response, "The retriever response should not be None.") @patch('sys.stdout', new_callable=StringIO) def test_05_delete_index(self, mock_stdout):