Skip to content

Commit

Permalink
fix: get reranker method and implementation in test
Browse files Browse the repository at this point in the history
  • Loading branch information
KevKibe committed Jun 5, 2024
1 parent 9ed15a9 commit b244dec
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 8 deletions.
3 changes: 2 additions & 1 deletion src/tests/openaiindex_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ def test_04_retrieve_and_generate(self):
query = "give a short summary of the introduction",
vector_store = vectorstore,
top_k = 3,
rerank_model = "t5"
# lang= "en",
rerank_model = "flashrank"
)
self.assertIsNotNone(response, "The retriever response should not be None.")

Expand Down
19 changes: 12 additions & 7 deletions src/utils/rerank.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@ class RerankerConfig:
SUPPORTED_MODELS = {
'cohere': {'lang': True, 'api_key': True},
'jina': {'api_key': True},
'cross-encoder': {'api_key': False},
'flashrank': {'api_key': False},
't5': {'api_key': False},
'cross-encoder': {},
'flashrank': {},
't5': {},
'rankgpt': {'api_key': True},
'rankgpt3': {'api_key': True},
'colbert': {'api_key': False},
'colbert': {},
'mixedbread-ai/mxbai-rerank-large-v1': {'model_type': True},
'ce-esci-MiniLM-L12-v2': {'model_type': True},
'unicamp-dl/InRanker-base': {'model_type': True},
Expand All @@ -35,6 +35,11 @@ def get_ranker(rerank_model: str, lang: str = None, api_key: str = None, model_t
raise ValueError(f"Unsupported rerank_model provided: {rerank_model}")

model_config = RerankerConfig.SUPPORTED_MODELS[rerank_model]
return Reranker(rerank_model, lang=lang if model_config.get('lang') else None,
api_key=api_key if model_config.get('api_key') else None,
model_type=model_type if model_config.get('model_type') else None)
init_kwargs = {
'lang': lang if model_config.get('lang') else None,
'api_key': api_key if model_config.get('api_key') else None,
'model_type': model_type if model_config.get('model_type') else None
}
init_kwargs = {k: v for k, v in init_kwargs.items() if v is not None}
return Reranker(rerank_model, **init_kwargs)

0 comments on commit b244dec

Please sign in to comment.