diff --git a/src/_cohere/doc_index.py b/src/_cohere/doc_index.py index b201b6b..c5685f5 100644 --- a/src/_cohere/doc_index.py +++ b/src/_cohere/doc_index.py @@ -239,6 +239,7 @@ def retrieve_and_generate( query: str, vector_store: str, top_k: int =3, + pydantic_parser: bool = True, rerank_model: str = 'flashrank', model_type: Optional[str] = None, lang: Optional[str] = None, @@ -251,6 +252,7 @@ def retrieve_and_generate( query (str): The query from the user. vector_store (str): The name of the Pinecone index. top_k (int, optional): The number of documents to retrieve from the index (default is 3). + pydantic_parser (bool, optional): Whether to use Pydantic parsing for the generated response (default is True). rerank_model (str, optional): The name or path of the model to use for ranking (default is 'flashrank'). model_type (str, optional): The type of the model (e.g., 'cross-encoder', 'flashrank', 't5', etc.). lang (str, optional): The language for multilingual models. @@ -280,13 +282,22 @@ def retrieve_and_generate( base_retriever=retriever ) - rag_chain = ( - {"context": itemgetter("query")| compression_retriever, - "query": itemgetter("query"), - } - | rag_prompt - | llm - | parser - ) + if pydantic_parser: + rag_chain = ( + {"context": itemgetter("query")| compression_retriever, + "query": itemgetter("query"), + } + | rag_prompt + | llm + | parser + ) + else: + rag_chain = ( + {"context": itemgetter("query")| compression_retriever, + "query": itemgetter("query"), + } + | rag_prompt + | llm + ) return rag_chain.invoke({"query": query}) diff --git a/src/_openai/doc_index.py b/src/_openai/doc_index.py index ef572d5..7f041a1 100644 --- a/src/_openai/doc_index.py +++ b/src/_openai/doc_index.py @@ -251,6 +251,7 @@ def retrieve_and_generate( query: str, vector_store: str, top_k: int =3, + pydantic_parser: bool = True, rerank_model: str = 'flashrank', model_type: Optional[str] = None, lang: Optional[str] = None, @@ -263,6 +264,7 @@ def retrieve_and_generate( query (str): The query from the user. vector_store (str): The name of the Pinecone index. top_k (int, optional): The number of documents to retrieve from the index (default is 3). + pydantic_parser (bool, optional): Whether to use Pydantic parsing for the generated response (default is True). rerank_model (str, optional): The name or path of the model to use for ranking (default is 'flashrank'). model_type (str, optional): The type of the model (e.g., 'cross-encoder', 'flashrank', 't5', etc.). lang (str, optional): The language for multilingual models. @@ -291,16 +293,23 @@ def retrieve_and_generate( base_compressor=compressor, base_retriever=retriever ) - - rag_chain = ( - {"context": itemgetter("query")| compression_retriever, - "query": itemgetter("query"), - } - | rag_prompt - | llm - | parser - ) - + if pydantic_parser: + rag_chain = ( + {"context": itemgetter("query")| compression_retriever, + "query": itemgetter("query"), + } + | rag_prompt + | llm + | parser + ) + else: + rag_chain = ( + {"context": itemgetter("query")| compression_retriever, + "query": itemgetter("query"), + } + | rag_prompt + | llm + ) return rag_chain.invoke({"query": query}) diff --git a/src/tests/cohereindex_test.py b/src/tests/cohereindex_test.py index b284fa4..5b8e67c 100644 --- a/src/tests/cohereindex_test.py +++ b/src/tests/cohereindex_test.py @@ -54,19 +54,20 @@ def test_03_initialize_vectorstore(self): vectorstore = self.indexer.initialize_vectorstore(self.index_name) self.assertIsInstance(vectorstore, PineconeVectorStore) - # def test_04_retrieve_and_generate(self): - # """ - # Test initializing the vector store and assert its type. - # """ - # vector_store = self.indexer.initialize_vectorstore(self.index_name) - # response = self.indexer.retrieve_and_generate( - # query = "give a short summary of the introduction", - # vector_store = vector_store, - # top_k = 3, - # rerank_model = "t5" - # ) - # print(response) - # self.assertIsNotNone(response, "The retriever response should not be None.") + def test_04_retrieve_and_generate(self): + """ + Test initializing the vector store and assert its type. + """ + vector_store = self.indexer.initialize_vectorstore(self.index_name) + response = self.indexer.retrieve_and_generate( + query = "give a short summary of the introduction", + vector_store = vector_store, + top_k = 1, + pydantic_parser=False, + rerank_model = "flashrank" + ) + print(response) + self.assertIsNotNone(response, "The retriever response should not be None.") @patch('sys.stdout', new_callable=StringIO) def test_05_delete_index(self, mock_stdout): diff --git a/src/tests/openaiindex_test.py b/src/tests/openaiindex_test.py index cc9f7fd..5a1a92f 100644 --- a/src/tests/openaiindex_test.py +++ b/src/tests/openaiindex_test.py @@ -68,8 +68,10 @@ def test_04_retrieve_and_generate(self): vector_store = vectorstore, top_k = 3, # lang= "en", - rerank_model = "flashrank" + rerank_model = "flashrank", + pydantic_parser=False ) + print(response) self.assertIsNotNone(response, "The retriever response should not be None.") @patch('sys.stdout', new_callable=StringIO)