Skip to content

Commit

Permalink
Update: Optimize get reranker method to run in constant time complexi…
Browse files Browse the repository at this point in the history
…ty (#39)

* fix: run tests on pull_request_target

* fix: add secrets to dependabot configuration

* fix: revert back to on pull_request

* fix: GH secrets access in GA run

* chore: add gitignore

* Merf

* chore(dependencies): rename dependabot test run

* chore(dependabot): add permissions field

* chore(dependabot): add test trigger on pull request target master branch

* update: add dictionary to get reranker models in constant time complexity

* fix: remove api_provider parameter from retrieve_and_generate method

* chore(dependencies): add rerankers package

* fix: get reranker method and implementation in test
  • Loading branch information
KevKibe authored Jun 5, 2024
1 parent bd91c89 commit c7630c8
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 47 deletions.
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ install_requires =
markdown==3.6
langchain-core==0.1.46
langchain-cohere==0.1.4
rerankers[all]==0.2.0
package_dir=
=src

Expand Down
7 changes: 2 additions & 5 deletions src/_cohere/doc_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,7 @@ def retrieve_and_generate(
rerank_model: str = 'flashrank',
model_type: Optional[str] = None,
lang: Optional[str] = None,
api_key: Optional[str] = None,
api_provider: Optional[str] = None,
api_key: Optional[str] = None
) -> QueryResult:
"""
Retrieve documents from the Pinecone index and generate a response.
Expand All @@ -256,7 +255,6 @@ def retrieve_and_generate(
model_type (str, optional): The type of the model (e.g., 'cross-encoder', 'flashrank', 't5', etc.).
lang (str, optional): The language for multilingual models.
api_key (str, optional): The API key for models accessed through an API.
api_provider (str, optional): The provider of the API.
Returns:
QueryResult: A Pydantic model representing the generated response.
Expand All @@ -274,8 +272,7 @@ def retrieve_and_generate(
rerank_model,
model_type,
lang,
api_key,
api_provider
api_key
)
compressor = ranker.as_langchain_compressor(k=top_k)
compression_retriever = ContextualCompressionRetriever(
Expand Down
7 changes: 2 additions & 5 deletions src/_google/doc_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,8 +255,7 @@ def retrieve_and_generate(
rerank_model: str = 'flashrank',
model_type: Optional[str] = None,
lang: Optional[str] = None,
api_key: Optional[str] = None,
api_provider: Optional[str] = None,
api_key: Optional[str] = None
) -> QueryResult:
"""
Retrieve documents from the Pinecone index and generate a response.
Expand All @@ -269,7 +268,6 @@ def retrieve_and_generate(
model_type (str, optional): The type of the model (e.g., 'cross-encoder', 'flashrank', 't5', etc.).
lang (str, optional): The language for multilingual models.
api_key (str, optional): The API key for models accessed through an API.
api_provider (str, optional): The provider of the API.
Returns:
QueryResult: A Pydantic model representing the generated response.
Expand All @@ -287,8 +285,7 @@ def retrieve_and_generate(
rerank_model,
model_type,
lang,
api_key,
api_provider
api_key
)
compressor = ranker.as_langchain_compressor(k=top_k)
compression_retriever = ContextualCompressionRetriever(
Expand Down
7 changes: 2 additions & 5 deletions src/_openai/doc_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,6 @@ def retrieve_and_generate(
model_type: Optional[str] = None,
lang: Optional[str] = None,
api_key: Optional[str] = None,
api_provider: Optional[str] = None,
) -> QueryResult:
"""
Retrieve documents from the Pinecone index and generate a response.
Expand All @@ -268,7 +267,6 @@ def retrieve_and_generate(
model_type (str, optional): The type of the model (e.g., 'cross-encoder', 'flashrank', 't5', etc.).
lang (str, optional): The language for multilingual models.
api_key (str, optional): The API key for models accessed through an API.
api_provider (str, optional): The provider of the API.
Returns:
QueryResult: A Pydantic model representing the generated response.
Expand All @@ -284,10 +282,9 @@ def retrieve_and_generate(
retriever = vector_store.as_retriever()
ranker = RerankerConfig.get_ranker(
rerank_model,
model_type,
lang,
api_key,
api_provider
api_key,
model_type,
)
compressor = ranker.as_langchain_compressor(k=top_k)
compression_retriever = ContextualCompressionRetriever(
Expand Down
3 changes: 2 additions & 1 deletion src/tests/openaiindex_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ def test_04_retrieve_and_generate(self):
query = "give a short summary of the introduction",
vector_store = vectorstore,
top_k = 3,
rerank_model = "t5"
# lang= "en",
rerank_model = "flashrank"
)
self.assertIsNotNone(response, "The retriever response should not be None.")

Expand Down
58 changes: 27 additions & 31 deletions src/utils/rerank.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,45 @@
from rerankers import Reranker

class RerankerConfig:
SUPPORTED_MODELS = {
'cohere': {'lang': True, 'api_key': True},
'jina': {'api_key': True},
'cross-encoder': {},
'flashrank': {},
't5': {},
'rankgpt': {'api_key': True},
'rankgpt3': {'api_key': True},
'colbert': {},
'mixedbread-ai/mxbai-rerank-large-v1': {'model_type': True},
'ce-esci-MiniLM-L12-v2': {'model_type': True},
'unicamp-dl/InRanker-base': {'model_type': True},
}
@staticmethod
def get_ranker(rerank_model: str, model_type: str = None, lang: str = None, api_key: str = None, api_provider: str = None) -> Reranker:
def get_ranker(rerank_model: str, lang: str = None, api_key: str = None, model_type: str = None) -> Reranker:
"""
Returns a Reranker instance based on the provided parameters.
Args:
rerank_model (str): The name or path of the model.
model_type (str, optional): The type of the model. Defaults to None.
lang (str, optional): The language for multilingual models. Defaults to None.
api_key (str, optional): The API key for models accessed through an API. Defaults to None.
api_provider (str, optional): The provider of the API. Defaults to None.
model_type (str, optional): The model type of a reranker, defaults to None.
Returns:
Reranker: An instance of Reranker.
Raises:
ValueError: If unsupported model_type is provided.
ValueError: If unsupported rerank_model is provided.
"""
if rerank_model and rerank_model not in ["cross-encoder", "flashrank", "t5", "rankgpt", "colbert", "mixedbread-ai/mxbai-rerank-large-v1", "ce-esci-MiniLM-L12-v2", "unicamp-dl/InRanker-base", "jina",
"rankgpt", "rankgpt3"]:
raise ValueError("Unsupported model_type provided.")
if rerank_model not in RerankerConfig.SUPPORTED_MODELS:
raise ValueError(f"Unsupported rerank_model provided: {rerank_model}")

model_config = RerankerConfig.SUPPORTED_MODELS[rerank_model]
init_kwargs = {
'lang': lang if model_config.get('lang') else None,
'api_key': api_key if model_config.get('api_key') else None,
'model_type': model_type if model_config.get('model_type') else None
}
init_kwargs = {k: v for k, v in init_kwargs.items() if v is not None}
return Reranker(rerank_model, **init_kwargs)

if rerank_model == 'cohere':
return Reranker(rerank_model, lang=lang, api_key=api_key)
elif rerank_model == 'jina':
return Reranker(rerank_model, api_key=api_key)
elif rerank_model == 'cross-encoder':
return Reranker(rerank_model)
elif rerank_model == 'flashrank':
return Reranker(rerank_model)
elif rerank_model == 't5':
return Reranker(rerank_model)
elif rerank_model == 'rankgpt':
return Reranker(rerank_model, api_key=api_key)
elif rerank_model == 'rankgpt3':
return Reranker(rerank_model, api_key=api_key)
elif rerank_model == 'colbert':
return Reranker(rerank_model)
elif rerank_model == "mixedbread-ai/mxbai-rerank-large-v1":
return Reranker(rerank_model, model_type='cross-encoder')
elif rerank_model == "ce-esci-MiniLM-L12-v2":
return Reranker(rerank_model, model_type='flashrank')
elif rerank_model == "unicamp-dl/InRanker-base":
return Reranker(rerank_model, model_type='t5')
else:
return None

0 comments on commit c7630c8

Please sign in to comment.