Skip to content

Commit

Permalink
change model_name_or_path to rerank_model
Browse files Browse the repository at this point in the history
  • Loading branch information
KevKibe committed May 14, 2024
1 parent 5250b06 commit 04b1dfa
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 12 deletions.
2 changes: 1 addition & 1 deletion src/tests/cohereindex_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def test_04_retrieve_and_generate(self):
query = "give a short summary of the introduction",
vector_store = vector_store,
top_k = 3,
reranker_model = "t5"
rerank_model = "t5"
)
print(response)
self.assertIsNotNone(response, "The retriever response should not be None.")
Expand Down
2 changes: 1 addition & 1 deletion src/tests/googleindex_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def test_04_retrieve_and_generate(self):
query = "give a short summary of the introduction",
vector_store = vector_store,
top_k = 3,
reranker_model = "t5"
rerank_model = "t5"
)
print(response)
self.assertIsNotNone(response, "The retriever response should not be None.")
Expand Down
2 changes: 1 addition & 1 deletion src/tests/openaiindex_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def test_04_retrieve_and_generate(self):
query = "give a short summary of the introduction",
vector_store = vectorstore,
top_k = 3,
reranker_model = "t5"
rerank_model = "t5"
)
self.assertIsNotNone(response, "The retriever response should not be None.")

Expand Down
18 changes: 9 additions & 9 deletions src/utils/rerank.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

class RerankerConfig:
@staticmethod
def get_ranker(model_name_or_path: str, model_type: str = None, lang: str = None, api_key: str = None, api_provider: str = None) -> Reranker:
def get_ranker(rerank_model: str, model_type: str = None, lang: str = None, api_key: str = None, api_provider: str = None) -> Reranker:
"""
Returns a Reranker instance based on the provided parameters.
Args:
model_name_or_path (str): The name or path of the model.
rerank_model (str): The name or path of the model.
model_type (str, optional): The type of the model. Defaults to None.
lang (str, optional): The language for multilingual models. Defaults to None.
api_key (str, optional): The API key for models accessed through an API. Defaults to None.
Expand All @@ -23,18 +23,18 @@ def get_ranker(model_name_or_path: str, model_type: str = None, lang: str = None
raise ValueError("Unsupported model_type provided.")

if model_type == 'cohere':
return Reranker(model_name_or_path, lang=lang, api_key=api_key)
return Reranker(rerank_model, lang=lang, api_key=api_key)
elif model_type == 'jina':
return Reranker(model_name_or_path, api_key=api_key)
return Reranker(rerank_model, api_key=api_key)
elif model_type == 'cross-encoder':
return Reranker(model_name_or_path, model_type='cross-encoder')
return Reranker(rerank_model, model_type='cross-encoder')
elif model_type == 'flashrank':
return Reranker(model_name_or_path, model_type='flashrank')
return Reranker(rerank_model, model_type='flashrank')
elif model_type == 't5':
return Reranker(model_name_or_path, model_type='t5')
return Reranker(rerank_model, model_type='t5')
elif model_type == 'rankgpt':
return Reranker(model_name_or_path, model_type='rankgpt', api_key=api_key)
return Reranker(rerank_model, model_type='rankgpt', api_key=api_key)
elif model_type == 'colbert':
return Reranker(model_name_or_path, model_type='colbert')
return Reranker(rerank_model, model_type='colbert')
else:
return Reranker(model_name_or_path, model_type=model_type, api_key=api_key, api_provider=api_provider)

0 comments on commit 04b1dfa

Please sign in to comment.