diff --git a/src/utils/rerank.py b/src/utils/rerank.py index 2988e24..6f93131 100644 --- a/src/utils/rerank.py +++ b/src/utils/rerank.py @@ -19,22 +19,31 @@ def get_ranker(rerank_model: str, model_type: str = None, lang: str = None, api_ Raises: ValueError: If unsupported model_type is provided. """ - if model_type and model_type not in ['cross-encoder', 'flashrank', 't5', 'rankgpt', 'colbert']: + if rerank_model and rerank_model not in ["cross-encoder", "flashrank", "t5", "rankgpt", "colbert", "mixedbread-ai/mxbai-rerank-large-v1", "ce-esci-MiniLM-L12-v2", "unicamp-dl/InRanker-base", "jina", + "rankgpt", "rankgpt3"]: raise ValueError("Unsupported model_type provided.") - if model_type == 'cohere': + if rerank_model == 'cohere': return Reranker(rerank_model, lang=lang, api_key=api_key) - elif model_type == 'jina': + elif rerank_model == 'jina': return Reranker(rerank_model, api_key=api_key) - elif model_type == 'cross-encoder': + elif rerank_model == 'cross-encoder': + return Reranker(rerank_model) + elif rerank_model == 'flashrank': + return Reranker(rerank_model) + elif rerank_model == 't5': + return Reranker(rerank_model) + elif rerank_model == 'rankgpt': + return Reranker(rerank_model, api_key=api_key) + elif rerank_model == 'rankgpt3': + return Reranker(rerank_model, api_key=api_key) + elif rerank_model == 'colbert': + return Reranker(rerank_model) + elif rerank_model == "mixedbread-ai/mxbai-rerank-large-v1": return Reranker(rerank_model, model_type='cross-encoder') - elif model_type == 'flashrank': + elif rerank_model == "ce-esci-MiniLM-L12-v2": return Reranker(rerank_model, model_type='flashrank') - elif model_type == 't5': + elif rerank_model == "unicamp-dl/InRanker-base": return Reranker(rerank_model, model_type='t5') - elif model_type == 'rankgpt': - return Reranker(rerank_model, model_type='rankgpt', api_key=api_key) - elif model_type == 'colbert': - return Reranker(rerank_model, model_type='colbert') else: - return Reranker(rerank_model, model_type=model_type, api_key=api_key, api_provider=api_provider) + return None