Skip to content

Commit

Permalink
fix: RerankerConfig.get_ranker method
Browse files Browse the repository at this point in the history
  • Loading branch information
KevKibe committed May 14, 2024
1 parent 44cbee2 commit 87bc85d
Showing 1 changed file with 20 additions and 11 deletions.
31 changes: 20 additions & 11 deletions src/utils/rerank.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,31 @@ def get_ranker(rerank_model: str, model_type: str = None, lang: str = None, api_
Raises:
ValueError: If unsupported model_type is provided.
"""
if model_type and model_type not in ['cross-encoder', 'flashrank', 't5', 'rankgpt', 'colbert']:
if rerank_model and rerank_model not in ["cross-encoder", "flashrank", "t5", "rankgpt", "colbert", "mixedbread-ai/mxbai-rerank-large-v1", "ce-esci-MiniLM-L12-v2", "unicamp-dl/InRanker-base", "jina",
"rankgpt", "rankgpt3"]:
raise ValueError("Unsupported model_type provided.")

if model_type == 'cohere':
if rerank_model == 'cohere':
return Reranker(rerank_model, lang=lang, api_key=api_key)
elif model_type == 'jina':
elif rerank_model == 'jina':
return Reranker(rerank_model, api_key=api_key)
elif model_type == 'cross-encoder':
elif rerank_model == 'cross-encoder':
return Reranker(rerank_model)
elif rerank_model == 'flashrank':
return Reranker(rerank_model)
elif rerank_model == 't5':
return Reranker(rerank_model)
elif rerank_model == 'rankgpt':
return Reranker(rerank_model, api_key=api_key)
elif rerank_model == 'rankgpt3':
return Reranker(rerank_model, api_key=api_key)
elif rerank_model == 'colbert':
return Reranker(rerank_model)
elif rerank_model == "mixedbread-ai/mxbai-rerank-large-v1":
return Reranker(rerank_model, model_type='cross-encoder')
elif model_type == 'flashrank':
elif rerank_model == "ce-esci-MiniLM-L12-v2":
return Reranker(rerank_model, model_type='flashrank')
elif model_type == 't5':
elif rerank_model == "unicamp-dl/InRanker-base":
return Reranker(rerank_model, model_type='t5')
elif model_type == 'rankgpt':
return Reranker(rerank_model, model_type='rankgpt', api_key=api_key)
elif model_type == 'colbert':
return Reranker(rerank_model, model_type='colbert')
else:
return Reranker(rerank_model, model_type=model_type, api_key=api_key, api_provider=api_provider)
return None

0 comments on commit 87bc85d

Please sign in to comment.