-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update: Optimize get reranker method to run in constant time complexi…
…ty (#39) * fix: run tests on pull_request_target * fix: add secrets to dependabot configuration * fix: revert back to on pull_request * fix: GH secrets access in GA run * chore: add gitignore * Merf * chore(dependencies): rename dependabot test run * chore(dependabot): add permissions field * chore(dependabot): add test trigger on pull request target master branch * update: add dictionary to get reranker models in constant time complexity * fix: remove api_provider parameter from retrieve_and_generate method * chore(dependencies): add rerankers package * fix: get reranker method and implementation in test
- Loading branch information
Showing
6 changed files
with
36 additions
and
47 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,49 +1,45 @@ | ||
from rerankers import Reranker | ||
|
||
class RerankerConfig: | ||
SUPPORTED_MODELS = { | ||
'cohere': {'lang': True, 'api_key': True}, | ||
'jina': {'api_key': True}, | ||
'cross-encoder': {}, | ||
'flashrank': {}, | ||
't5': {}, | ||
'rankgpt': {'api_key': True}, | ||
'rankgpt3': {'api_key': True}, | ||
'colbert': {}, | ||
'mixedbread-ai/mxbai-rerank-large-v1': {'model_type': True}, | ||
'ce-esci-MiniLM-L12-v2': {'model_type': True}, | ||
'unicamp-dl/InRanker-base': {'model_type': True}, | ||
} | ||
@staticmethod | ||
def get_ranker(rerank_model: str, model_type: str = None, lang: str = None, api_key: str = None, api_provider: str = None) -> Reranker: | ||
def get_ranker(rerank_model: str, lang: str = None, api_key: str = None, model_type: str = None) -> Reranker: | ||
""" | ||
Returns a Reranker instance based on the provided parameters. | ||
Args: | ||
rerank_model (str): The name or path of the model. | ||
model_type (str, optional): The type of the model. Defaults to None. | ||
lang (str, optional): The language for multilingual models. Defaults to None. | ||
api_key (str, optional): The API key for models accessed through an API. Defaults to None. | ||
api_provider (str, optional): The provider of the API. Defaults to None. | ||
model_type (str, optional): The model type of a reranker, defaults to None. | ||
Returns: | ||
Reranker: An instance of Reranker. | ||
Raises: | ||
ValueError: If unsupported model_type is provided. | ||
ValueError: If unsupported rerank_model is provided. | ||
""" | ||
if rerank_model and rerank_model not in ["cross-encoder", "flashrank", "t5", "rankgpt", "colbert", "mixedbread-ai/mxbai-rerank-large-v1", "ce-esci-MiniLM-L12-v2", "unicamp-dl/InRanker-base", "jina", | ||
"rankgpt", "rankgpt3"]: | ||
raise ValueError("Unsupported model_type provided.") | ||
if rerank_model not in RerankerConfig.SUPPORTED_MODELS: | ||
raise ValueError(f"Unsupported rerank_model provided: {rerank_model}") | ||
|
||
model_config = RerankerConfig.SUPPORTED_MODELS[rerank_model] | ||
init_kwargs = { | ||
'lang': lang if model_config.get('lang') else None, | ||
'api_key': api_key if model_config.get('api_key') else None, | ||
'model_type': model_type if model_config.get('model_type') else None | ||
} | ||
init_kwargs = {k: v for k, v in init_kwargs.items() if v is not None} | ||
return Reranker(rerank_model, **init_kwargs) | ||
|
||
if rerank_model == 'cohere': | ||
return Reranker(rerank_model, lang=lang, api_key=api_key) | ||
elif rerank_model == 'jina': | ||
return Reranker(rerank_model, api_key=api_key) | ||
elif rerank_model == 'cross-encoder': | ||
return Reranker(rerank_model) | ||
elif rerank_model == 'flashrank': | ||
return Reranker(rerank_model) | ||
elif rerank_model == 't5': | ||
return Reranker(rerank_model) | ||
elif rerank_model == 'rankgpt': | ||
return Reranker(rerank_model, api_key=api_key) | ||
elif rerank_model == 'rankgpt3': | ||
return Reranker(rerank_model, api_key=api_key) | ||
elif rerank_model == 'colbert': | ||
return Reranker(rerank_model) | ||
elif rerank_model == "mixedbread-ai/mxbai-rerank-large-v1": | ||
return Reranker(rerank_model, model_type='cross-encoder') | ||
elif rerank_model == "ce-esci-MiniLM-L12-v2": | ||
return Reranker(rerank_model, model_type='flashrank') | ||
elif rerank_model == "unicamp-dl/InRanker-base": | ||
return Reranker(rerank_model, model_type='t5') | ||
else: | ||
return None |