Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pydantic_parser parameter to retrieve_and_generate #44

Merged
merged 22 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
c5a0675
fix: run tests on pull_request_target
KevKibe Jun 3, 2024
5f26806
Merge branch 'master' of https://github.com/KevKibe/docindex into fix…
KevKibe Jun 3, 2024
dad166b
fix: add secrets to dependabot configuration
KevKibe Jun 3, 2024
475bd8a
fix: revert back to on pull_request
KevKibe Jun 3, 2024
36fedb3
fix: GH secrets access in GA run
KevKibe Jun 3, 2024
b5c6d4d
Merge branch 'master' of https://github.com/KevKibe/docindex into fix…
KevKibe Jun 3, 2024
9cb1930
chore: add gitignore
KevKibe Jun 3, 2024
b1b968f
Merge branch 'master' of https://github.com/KevKibe/docindex into fix…
KevKibe Jun 3, 2024
8965926
Merf
KevKibe Jun 3, 2024
742c0de
chore(dependencies): rename dependabot test run
KevKibe Jun 3, 2024
5c3cb99
chore(dependabot): add permissions field
KevKibe Jun 3, 2024
3a18e28
chore(dependabot): add test trigger on pull request target master branch
KevKibe Jun 4, 2024
8c5a0c3
Merge branch 'master' into fix-test-workflow
KevKibe Jun 4, 2024
6fdecf2
update: add dictionary to get reranker models in constant time comple…
KevKibe Jun 5, 2024
a0b1cbe
fix: remove api_provider parameter from retrieve_and_generate method
KevKibe Jun 5, 2024
9ed15a9
chore(dependencies): add rerankers package
KevKibe Jun 5, 2024
b244dec
fix: get reranker method and implementation in test
KevKibe Jun 5, 2024
329083c
Merge branch 'master' of https://github.com/KevKibe/docindex into opt…
KevKibe Jun 5, 2024
74cf04d
feat: Add pydantic_parser parameter to retrieve_and_generate
KevKibe Jun 11, 2024
7243904
feat: Add pydantic_parser parameter to retrieve_and_generate
KevKibe Jun 11, 2024
68a6bba
fix: pydantic parser to false
KevKibe Jun 11, 2024
836d950
fix: source documents list
KevKibe Jun 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions src/_cohere/doc_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ def retrieve_and_generate(
query: str,
vector_store: str,
top_k: int =3,
pydantic_parser: bool = True,
rerank_model: str = 'flashrank',
model_type: Optional[str] = None,
lang: Optional[str] = None,
Expand All @@ -251,6 +252,7 @@ def retrieve_and_generate(
query (str): The query from the user.
vector_store (str): The name of the Pinecone index.
top_k (int, optional): The number of documents to retrieve from the index (default is 3).
pydantic_parser (bool, optional): Whether to use Pydantic parsing for the generated response (default is True).
rerank_model (str, optional): The name or path of the model to use for ranking (default is 'flashrank').
model_type (str, optional): The type of the model (e.g., 'cross-encoder', 'flashrank', 't5', etc.).
lang (str, optional): The language for multilingual models.
Expand Down Expand Up @@ -280,13 +282,22 @@ def retrieve_and_generate(
base_retriever=retriever
)

rag_chain = (
{"context": itemgetter("query")| compression_retriever,
"query": itemgetter("query"),
}
| rag_prompt
| llm
| parser
)
if pydantic_parser:
rag_chain = (
{"context": itemgetter("query")| compression_retriever,
"query": itemgetter("query"),
}
| rag_prompt
| llm
| parser
)
else:
rag_chain = (
{"context": itemgetter("query")| compression_retriever,
"query": itemgetter("query"),
}
| rag_prompt
| llm
)

return rag_chain.invoke({"query": query})
29 changes: 19 additions & 10 deletions src/_openai/doc_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ def retrieve_and_generate(
query: str,
vector_store: str,
top_k: int =3,
pydantic_parser: bool = True,
rerank_model: str = 'flashrank',
model_type: Optional[str] = None,
lang: Optional[str] = None,
Expand All @@ -263,6 +264,7 @@ def retrieve_and_generate(
query (str): The query from the user.
vector_store (str): The name of the Pinecone index.
top_k (int, optional): The number of documents to retrieve from the index (default is 3).
pydantic_parser (bool, optional): Whether to use Pydantic parsing for the generated response (default is True).
rerank_model (str, optional): The name or path of the model to use for ranking (default is 'flashrank').
model_type (str, optional): The type of the model (e.g., 'cross-encoder', 'flashrank', 't5', etc.).
lang (str, optional): The language for multilingual models.
Expand Down Expand Up @@ -291,16 +293,23 @@ def retrieve_and_generate(
base_compressor=compressor,
base_retriever=retriever
)

rag_chain = (
{"context": itemgetter("query")| compression_retriever,
"query": itemgetter("query"),
}
| rag_prompt
| llm
| parser
)

if pydantic_parser:
rag_chain = (
{"context": itemgetter("query")| compression_retriever,
"query": itemgetter("query"),
}
| rag_prompt
| llm
| parser
)
else:
rag_chain = (
{"context": itemgetter("query")| compression_retriever,
"query": itemgetter("query"),
}
| rag_prompt
| llm
)
return rag_chain.invoke({"query": query})


Expand Down
27 changes: 14 additions & 13 deletions src/tests/cohereindex_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,19 +54,20 @@ def test_03_initialize_vectorstore(self):
vectorstore = self.indexer.initialize_vectorstore(self.index_name)
self.assertIsInstance(vectorstore, PineconeVectorStore)

# def test_04_retrieve_and_generate(self):
# """
# Test initializing the vector store and assert its type.
# """
# vector_store = self.indexer.initialize_vectorstore(self.index_name)
# response = self.indexer.retrieve_and_generate(
# query = "give a short summary of the introduction",
# vector_store = vector_store,
# top_k = 3,
# rerank_model = "t5"
# )
# print(response)
# self.assertIsNotNone(response, "The retriever response should not be None.")
def test_04_retrieve_and_generate(self):
"""
Test initializing the vector store and assert its type.
"""
vector_store = self.indexer.initialize_vectorstore(self.index_name)
response = self.indexer.retrieve_and_generate(
query = "give a short summary of the introduction",
vector_store = vector_store,
top_k = 1,
pydantic_parser=False,
rerank_model = "flashrank"
)
print(response)
self.assertIsNotNone(response, "The retriever response should not be None.")

@patch('sys.stdout', new_callable=StringIO)
def test_05_delete_index(self, mock_stdout):
Expand Down
4 changes: 3 additions & 1 deletion src/tests/openaiindex_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,10 @@ def test_04_retrieve_and_generate(self):
vector_store = vectorstore,
top_k = 3,
# lang= "en",
rerank_model = "flashrank"
rerank_model = "flashrank",
pydantic_parser=False
)
print(response)
self.assertIsNotNone(response, "The retriever response should not be None.")

@patch('sys.stdout', new_callable=StringIO)
Expand Down
Loading