From c7630c8d089f7a9d5be0456e996b0fd8a1e21dbc Mon Sep 17 00:00:00 2001 From: Kevin Kibe Date: Wed, 5 Jun 2024 17:56:23 +0300 Subject: [PATCH] Update: Optimize get reranker method to run in constant time complexity (#39) * fix: run tests on pull_request_target * fix: add secrets to dependabot configuration * fix: revert back to on pull_request * fix: GH secrets access in GA run * chore: add gitignore * Merf * chore(dependencies): rename dependabot test run * chore(dependabot): add permissions field * chore(dependabot): add test trigger on pull request target master branch * update: add dictionary to get reranker models in constant time complexity * fix: remove api_provider parameter from retrieve_and_generate method * chore(dependencies): add rerankers package * fix: get reranker method and implementation in test --- setup.cfg | 1 + src/_cohere/doc_index.py | 7 ++--- src/_google/doc_index.py | 7 ++--- src/_openai/doc_index.py | 7 ++--- src/tests/openaiindex_test.py | 3 +- src/utils/rerank.py | 58 ++++++++++++++++------------------- 6 files changed, 36 insertions(+), 47 deletions(-) diff --git a/setup.cfg b/setup.cfg index 3dde6b7..4435274 100644 --- a/setup.cfg +++ b/setup.cfg @@ -28,6 +28,7 @@ install_requires = markdown==3.6 langchain-core==0.1.46 langchain-cohere==0.1.4 + rerankers[all]==0.2.0 package_dir= =src diff --git a/src/_cohere/doc_index.py b/src/_cohere/doc_index.py index 343ed9a..b201b6b 100644 --- a/src/_cohere/doc_index.py +++ b/src/_cohere/doc_index.py @@ -242,8 +242,7 @@ def retrieve_and_generate( rerank_model: str = 'flashrank', model_type: Optional[str] = None, lang: Optional[str] = None, - api_key: Optional[str] = None, - api_provider: Optional[str] = None, + api_key: Optional[str] = None ) -> QueryResult: """ Retrieve documents from the Pinecone index and generate a response. @@ -256,7 +255,6 @@ def retrieve_and_generate( model_type (str, optional): The type of the model (e.g., 'cross-encoder', 'flashrank', 't5', etc.). lang (str, optional): The language for multilingual models. api_key (str, optional): The API key for models accessed through an API. - api_provider (str, optional): The provider of the API. Returns: QueryResult: A Pydantic model representing the generated response. @@ -274,8 +272,7 @@ def retrieve_and_generate( rerank_model, model_type, lang, - api_key, - api_provider + api_key ) compressor = ranker.as_langchain_compressor(k=top_k) compression_retriever = ContextualCompressionRetriever( diff --git a/src/_google/doc_index.py b/src/_google/doc_index.py index c8d07f9..eb32129 100644 --- a/src/_google/doc_index.py +++ b/src/_google/doc_index.py @@ -255,8 +255,7 @@ def retrieve_and_generate( rerank_model: str = 'flashrank', model_type: Optional[str] = None, lang: Optional[str] = None, - api_key: Optional[str] = None, - api_provider: Optional[str] = None, + api_key: Optional[str] = None ) -> QueryResult: """ Retrieve documents from the Pinecone index and generate a response. @@ -269,7 +268,6 @@ def retrieve_and_generate( model_type (str, optional): The type of the model (e.g., 'cross-encoder', 'flashrank', 't5', etc.). lang (str, optional): The language for multilingual models. api_key (str, optional): The API key for models accessed through an API. - api_provider (str, optional): The provider of the API. Returns: QueryResult: A Pydantic model representing the generated response. @@ -287,8 +285,7 @@ def retrieve_and_generate( rerank_model, model_type, lang, - api_key, - api_provider + api_key ) compressor = ranker.as_langchain_compressor(k=top_k) compression_retriever = ContextualCompressionRetriever( diff --git a/src/_openai/doc_index.py b/src/_openai/doc_index.py index f20d056..ef572d5 100644 --- a/src/_openai/doc_index.py +++ b/src/_openai/doc_index.py @@ -255,7 +255,6 @@ def retrieve_and_generate( model_type: Optional[str] = None, lang: Optional[str] = None, api_key: Optional[str] = None, - api_provider: Optional[str] = None, ) -> QueryResult: """ Retrieve documents from the Pinecone index and generate a response. @@ -268,7 +267,6 @@ def retrieve_and_generate( model_type (str, optional): The type of the model (e.g., 'cross-encoder', 'flashrank', 't5', etc.). lang (str, optional): The language for multilingual models. api_key (str, optional): The API key for models accessed through an API. - api_provider (str, optional): The provider of the API. Returns: QueryResult: A Pydantic model representing the generated response. @@ -284,10 +282,9 @@ def retrieve_and_generate( retriever = vector_store.as_retriever() ranker = RerankerConfig.get_ranker( rerank_model, - model_type, lang, - api_key, - api_provider + api_key, + model_type, ) compressor = ranker.as_langchain_compressor(k=top_k) compression_retriever = ContextualCompressionRetriever( diff --git a/src/tests/openaiindex_test.py b/src/tests/openaiindex_test.py index 8ab1c40..cc9f7fd 100644 --- a/src/tests/openaiindex_test.py +++ b/src/tests/openaiindex_test.py @@ -67,7 +67,8 @@ def test_04_retrieve_and_generate(self): query = "give a short summary of the introduction", vector_store = vectorstore, top_k = 3, - rerank_model = "t5" + # lang= "en", + rerank_model = "flashrank" ) self.assertIsNotNone(response, "The retriever response should not be None.") diff --git a/src/utils/rerank.py b/src/utils/rerank.py index 6f93131..f7733a0 100644 --- a/src/utils/rerank.py +++ b/src/utils/rerank.py @@ -1,49 +1,45 @@ from rerankers import Reranker class RerankerConfig: + SUPPORTED_MODELS = { + 'cohere': {'lang': True, 'api_key': True}, + 'jina': {'api_key': True}, + 'cross-encoder': {}, + 'flashrank': {}, + 't5': {}, + 'rankgpt': {'api_key': True}, + 'rankgpt3': {'api_key': True}, + 'colbert': {}, + 'mixedbread-ai/mxbai-rerank-large-v1': {'model_type': True}, + 'ce-esci-MiniLM-L12-v2': {'model_type': True}, + 'unicamp-dl/InRanker-base': {'model_type': True}, + } @staticmethod - def get_ranker(rerank_model: str, model_type: str = None, lang: str = None, api_key: str = None, api_provider: str = None) -> Reranker: + def get_ranker(rerank_model: str, lang: str = None, api_key: str = None, model_type: str = None) -> Reranker: """ Returns a Reranker instance based on the provided parameters. Args: rerank_model (str): The name or path of the model. - model_type (str, optional): The type of the model. Defaults to None. lang (str, optional): The language for multilingual models. Defaults to None. api_key (str, optional): The API key for models accessed through an API. Defaults to None. - api_provider (str, optional): The provider of the API. Defaults to None. + model_type (str, optional): The model type of a reranker, defaults to None. Returns: Reranker: An instance of Reranker. Raises: - ValueError: If unsupported model_type is provided. + ValueError: If unsupported rerank_model is provided. """ - if rerank_model and rerank_model not in ["cross-encoder", "flashrank", "t5", "rankgpt", "colbert", "mixedbread-ai/mxbai-rerank-large-v1", "ce-esci-MiniLM-L12-v2", "unicamp-dl/InRanker-base", "jina", - "rankgpt", "rankgpt3"]: - raise ValueError("Unsupported model_type provided.") + if rerank_model not in RerankerConfig.SUPPORTED_MODELS: + raise ValueError(f"Unsupported rerank_model provided: {rerank_model}") + + model_config = RerankerConfig.SUPPORTED_MODELS[rerank_model] + init_kwargs = { + 'lang': lang if model_config.get('lang') else None, + 'api_key': api_key if model_config.get('api_key') else None, + 'model_type': model_type if model_config.get('model_type') else None + } + init_kwargs = {k: v for k, v in init_kwargs.items() if v is not None} + return Reranker(rerank_model, **init_kwargs) - if rerank_model == 'cohere': - return Reranker(rerank_model, lang=lang, api_key=api_key) - elif rerank_model == 'jina': - return Reranker(rerank_model, api_key=api_key) - elif rerank_model == 'cross-encoder': - return Reranker(rerank_model) - elif rerank_model == 'flashrank': - return Reranker(rerank_model) - elif rerank_model == 't5': - return Reranker(rerank_model) - elif rerank_model == 'rankgpt': - return Reranker(rerank_model, api_key=api_key) - elif rerank_model == 'rankgpt3': - return Reranker(rerank_model, api_key=api_key) - elif rerank_model == 'colbert': - return Reranker(rerank_model) - elif rerank_model == "mixedbread-ai/mxbai-rerank-large-v1": - return Reranker(rerank_model, model_type='cross-encoder') - elif rerank_model == "ce-esci-MiniLM-L12-v2": - return Reranker(rerank_model, model_type='flashrank') - elif rerank_model == "unicamp-dl/InRanker-base": - return Reranker(rerank_model, model_type='t5') - else: - return None