Skip to content

Commit

Permalink
Fix: GooglePineconeIndexer class (#45)
Browse files Browse the repository at this point in the history
* fix: run tests on pull_request_target

* fix: add secrets to dependabot configuration

* fix: revert back to on pull_request

* fix: GH secrets access in GA run

* chore: add gitignore

* Merf

* chore(dependencies): rename dependabot test run

* chore(dependabot): add permissions field

* chore(dependabot): add test trigger on pull request target master branch

* update: add dictionary to get reranker models in constant time complexity

* fix: remove api_provider parameter from retrieve_and_generate method

* chore(dependencies): add rerankers package

* fix: get reranker method and implementation in test

* feat: Add pydantic_parser parameter to retrieve_and_generate

* feat: Add pydantic_parser parameter to retrieve_and_generate

* fix: pydantic parser to false

* fix: source documents list

* fix: google rag pipeline

* update: pydantic_parser parameter
  • Loading branch information
KevKibe authored Jun 11, 2024
1 parent 103deb6 commit b5e6dd3
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 28 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ response = pinecone_indexer.retrieve_and_generate(
vector_store = vectorstore,
query = query,
top_k = "number of sources to retrieve", # Default is 3
pydantic_parser=True # Whether to use Pydantic parsing for the generated response (default is True)
rerank_model = "reranking model" # Default is 'flashrank' Other models available Docs:https://github.com/AnswerDotAI/rerankers
)
response
Expand Down
39 changes: 24 additions & 15 deletions src/_google/doc_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def initialize_vectorstore(self, index_name: str) -> PineconeVectorStore:
model="models/embedding-001",
google_api_key=self.google_api_key
)
vectorstore = PineconeVectorStore(index, embed, "text")
vectorstore = PineconeVectorStore(index=index, index_name=index_name, embedding=embed)
return vectorstore


Expand All @@ -252,6 +252,7 @@ def retrieve_and_generate(
query: str,
vector_store: str,
top_k: int =3,
pydantic_parser: bool = True,
rerank_model: str = 'flashrank',
model_type: Optional[str] = None,
lang: Optional[str] = None,
Expand All @@ -264,6 +265,7 @@ def retrieve_and_generate(
query (str): The query from the user.
vector_store (str): The name of the Pinecone index.
top_k (int, optional): The number of documents to retrieve from the index (default is 3).
pydantic_parser (bool, optional): Whether to use Pydantic parsing for the generated response (default is True).
rerank_model (str, optional): The name or path of the model to use for ranking (default is 'flashrank').
model_type (str, optional): The type of the model (e.g., 'cross-encoder', 'flashrank', 't5', etc.).
lang (str, optional): The language for multilingual models.
Expand All @@ -275,31 +277,38 @@ def retrieve_and_generate(
Raises:
ValueError: If an unsupported model_type is provided.
"""
llm = ChatGoogleGenerativeAI(model = Config.default_google_model, google_api_key=self.google_api_key)
llm = ChatGoogleGenerativeAI(model = Config.default_google_model, google_api_key=self.google_api_key, temperature=0.7, top_p=0.85, convert_system_message_to_human=True)
parser = PydanticOutputParser(pydantic_object=QueryResult)
rag_prompt = PromptTemplate(template = Config.template_str,
input_variables = ["query", "context"],
partial_variables={"format_instructions": parser.get_format_instructions()})
retriever = vector_store.as_retriever()
ranker = RerankerConfig.get_ranker(
rerank_model,
model_type,
lang,
api_key
api_key,
model_type,
)
compressor = ranker.as_langchain_compressor(k=top_k)
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor,
base_retriever=retriever
)

rag_chain = (
{"context": itemgetter("query")| compression_retriever,
"query": itemgetter("query"),
}
| rag_prompt
| llm
| parser
)

return rag_chain.invoke({"query": query})
if pydantic_parser:
rag_chain = (
{"context": itemgetter("query")| compression_retriever,
"query": itemgetter("query"),
}
| rag_prompt
| llm
| parser
)
else:
rag_chain = (
{"context": itemgetter("query")| compression_retriever,
"query": itemgetter("query"),
}
| rag_prompt
| llm
)
return rag_chain.invoke({"query": query}).content
26 changes: 13 additions & 13 deletions src/tests/googleindex_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,19 +59,19 @@ def test_03_initialize_vectorstore(self):
vectorstore = self.indexer.initialize_vectorstore(self.index_name)
self.assertIsInstance(vectorstore, PineconeVectorStore)

# def test_04_retrieve_and_generate(self):
# """
# Test initializing the vector store and assert its type.
# """
# vector_store = self.indexer.initialize_vectorstore(self.index_name)
# response = self.indexer.retrieve_and_generate(
# query = "tell me something from the context texts",
# vector_store = vector_store,
# top_k = 3,
# rerank_model = "t5"
# )
# print(response)
# self.assertIsNotNone(response, "The retriever response should not be None.")
def test_04_retrieve_and_generate(self):
"""
Test initializing the vector store and assert its type.
"""
vectorstore = self.indexer.initialize_vectorstore(self.index_name)
response = self.indexer.retrieve_and_generate(
query = "give a short summary of the introduction",
vector_store = vectorstore,
top_k = 3,
pydantic_parser=False,
rerank_model = "flashrank"
)
self.assertIsNotNone(response, "The retriever response should not be None.")

@patch('sys.stdout', new_callable=StringIO)
def test_05_delete_index(self, mock_stdout):
Expand Down

0 comments on commit b5e6dd3

Please sign in to comment.