From 04b1dfab5d5d5fe70025455882b7774164ecc637 Mon Sep 17 00:00:00 2001 From: KevKibe Date: Tue, 14 May 2024 13:45:36 +0300 Subject: [PATCH] change model_name_or_path to rerank_model --- src/tests/cohereindex_test.py | 2 +- src/tests/googleindex_test.py | 2 +- src/tests/openaiindex_test.py | 2 +- src/utils/rerank.py | 18 +++++++++--------- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/tests/cohereindex_test.py b/src/tests/cohereindex_test.py index 5aaf444..7ecc41a 100644 --- a/src/tests/cohereindex_test.py +++ b/src/tests/cohereindex_test.py @@ -63,7 +63,7 @@ def test_04_retrieve_and_generate(self): query = "give a short summary of the introduction", vector_store = vector_store, top_k = 3, - reranker_model = "t5" + rerank_model = "t5" ) print(response) self.assertIsNotNone(response, "The retriever response should not be None.") diff --git a/src/tests/googleindex_test.py b/src/tests/googleindex_test.py index 56efb7e..3abdf90 100644 --- a/src/tests/googleindex_test.py +++ b/src/tests/googleindex_test.py @@ -68,7 +68,7 @@ def test_04_retrieve_and_generate(self): query = "give a short summary of the introduction", vector_store = vector_store, top_k = 3, - reranker_model = "t5" + rerank_model = "t5" ) print(response) self.assertIsNotNone(response, "The retriever response should not be None.") diff --git a/src/tests/openaiindex_test.py b/src/tests/openaiindex_test.py index c0d6e9a..8ab1c40 100644 --- a/src/tests/openaiindex_test.py +++ b/src/tests/openaiindex_test.py @@ -67,7 +67,7 @@ def test_04_retrieve_and_generate(self): query = "give a short summary of the introduction", vector_store = vectorstore, top_k = 3, - reranker_model = "t5" + rerank_model = "t5" ) self.assertIsNotNone(response, "The retriever response should not be None.") diff --git a/src/utils/rerank.py b/src/utils/rerank.py index 382b3b9..b191573 100644 --- a/src/utils/rerank.py +++ b/src/utils/rerank.py @@ -2,12 +2,12 @@ class RerankerConfig: @staticmethod - def get_ranker(model_name_or_path: str, model_type: str = None, lang: str = None, api_key: str = None, api_provider: str = None) -> Reranker: + def get_ranker(rerank_model: str, model_type: str = None, lang: str = None, api_key: str = None, api_provider: str = None) -> Reranker: """ Returns a Reranker instance based on the provided parameters. Args: - model_name_or_path (str): The name or path of the model. + rerank_model (str): The name or path of the model. model_type (str, optional): The type of the model. Defaults to None. lang (str, optional): The language for multilingual models. Defaults to None. api_key (str, optional): The API key for models accessed through an API. Defaults to None. @@ -23,18 +23,18 @@ def get_ranker(model_name_or_path: str, model_type: str = None, lang: str = None raise ValueError("Unsupported model_type provided.") if model_type == 'cohere': - return Reranker(model_name_or_path, lang=lang, api_key=api_key) + return Reranker(rerank_model, lang=lang, api_key=api_key) elif model_type == 'jina': - return Reranker(model_name_or_path, api_key=api_key) + return Reranker(rerank_model, api_key=api_key) elif model_type == 'cross-encoder': - return Reranker(model_name_or_path, model_type='cross-encoder') + return Reranker(rerank_model, model_type='cross-encoder') elif model_type == 'flashrank': - return Reranker(model_name_or_path, model_type='flashrank') + return Reranker(rerank_model, model_type='flashrank') elif model_type == 't5': - return Reranker(model_name_or_path, model_type='t5') + return Reranker(rerank_model, model_type='t5') elif model_type == 'rankgpt': - return Reranker(model_name_or_path, model_type='rankgpt', api_key=api_key) + return Reranker(rerank_model, model_type='rankgpt', api_key=api_key) elif model_type == 'colbert': - return Reranker(model_name_or_path, model_type='colbert') + return Reranker(rerank_model, model_type='colbert') else: return Reranker(model_name_or_path, model_type=model_type, api_key=api_key, api_provider=api_provider)