-
Notifications
You must be signed in to change notification settings - Fork 45
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Bugfix: embedding score & hybrid retrieval & LLM Rerank (#51)
* fix bugs for retrieval * fix bugs for retrieval * fix bugs for retrieval
- Loading branch information
Showing
7 changed files
with
149 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
"""LLM reranker.""" | ||
from typing import List, Optional | ||
|
||
from llama_index.core.schema import NodeWithScore, QueryBundle | ||
from llama_index.core.postprocessor import LLMRerank | ||
|
||
|
||
class MyLLMRerank(LLMRerank): | ||
"""LLM-based reranker. | ||
Fix the bug : The number of retrieved chunks after LLM Rerank has decreased due to the uncertainty in LLM responses. | ||
""" | ||
|
||
def _postprocess_nodes( | ||
self, | ||
nodes: List[NodeWithScore], | ||
query_bundle: Optional[QueryBundle] = None, | ||
) -> List[NodeWithScore]: | ||
if query_bundle is None: | ||
raise ValueError("Query bundle must be provided.") | ||
if len(nodes) == 0: | ||
return [] | ||
|
||
initial_results: List[NodeWithScore] = [] | ||
for idx in range(0, len(nodes), self.choice_batch_size): | ||
nodes_batch = [ | ||
node.node for node in nodes[idx : idx + self.choice_batch_size] | ||
] | ||
|
||
query_str = query_bundle.query_str | ||
fmt_batch_str = self._format_node_batch_fn(nodes_batch) | ||
# call each batch independently | ||
raw_response = self.llm.predict( | ||
self.choice_select_prompt, | ||
context_str=fmt_batch_str, | ||
query_str=query_str, | ||
) | ||
|
||
raw_choices, relevances = self._parse_choice_select_answer_fn( | ||
raw_response, len(nodes_batch) | ||
) | ||
choice_idxs = [int(choice) - 1 for choice in raw_choices] | ||
relevances = relevances | ||
if len(choice_idxs) < len(nodes): | ||
missing_numbers = set(range(1, len(nodes))).difference(choice_idxs) | ||
choice_idxs.extend(missing_numbers) | ||
relevances.extend([1.0 for _ in missing_numbers]) | ||
choice_nodes = [nodes_batch[idx] for idx in choice_idxs] | ||
initial_results.extend( | ||
[ | ||
NodeWithScore(node=node, score=relevance) | ||
for node, relevance in zip(choice_nodes, relevances) | ||
] | ||
) | ||
|
||
return sorted(initial_results, key=lambda x: x.score or 0.0, reverse=True)[ | ||
: self.top_n | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
82 changes: 82 additions & 0 deletions
82
src/pai_rag/modules/retriever/my_vector_index_retriever.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
"""Base vector store index query.""" | ||
|
||
from typing import List, Optional | ||
|
||
from llama_index.core.data_structs.data_structs import IndexDict | ||
from llama_index.core.indices.utils import log_vector_store_query_result | ||
from llama_index.core.schema import NodeWithScore, ObjectType | ||
from llama_index.core.vector_stores.types import ( | ||
VectorStoreQueryResult, | ||
) | ||
import llama_index.core.instrumentation as instrument | ||
|
||
from llama_index.core.retrievers import ( | ||
VectorIndexRetriever, | ||
) | ||
|
||
dispatcher = instrument.get_dispatcher(__name__) | ||
|
||
|
||
class MyVectorIndexRetriever(VectorIndexRetriever): | ||
"""PAI-RAG customized vector index retriever. | ||
Refactor the _build_node_list_from_query_result() function | ||
and return the results with the query_result.similarities sorted in descending order. | ||
Args: | ||
index (VectorStoreIndex): vector store index. | ||
similarity_top_k (int): number of top k results to return. | ||
vector_store_query_mode (str): vector store query mode | ||
See reference for VectorStoreQueryMode for full list of supported modes. | ||
filters (Optional[MetadataFilters]): metadata filters, defaults to None | ||
alpha (float): weight for sparse/dense retrieval, only used for | ||
hybrid query mode. | ||
doc_ids (Optional[List[str]]): list of documents to constrain search. | ||
vector_store_kwargs (dict): Additional vector store specific kwargs to pass | ||
through to the vector store at query time. | ||
""" | ||
|
||
def _build_node_list_from_query_result( | ||
self, query_result: VectorStoreQueryResult | ||
) -> List[NodeWithScore]: | ||
if query_result.nodes is None: | ||
# NOTE: vector store does not keep text and returns node indices. | ||
# Need to recover all nodes from docstore | ||
if query_result.ids is None: | ||
raise ValueError( | ||
"Vector store query result should return at " | ||
"least one of nodes or ids." | ||
) | ||
assert isinstance(self._index.index_struct, IndexDict) | ||
node_ids = [ | ||
self._index.index_struct.nodes_dict[idx] for idx in query_result.ids | ||
] | ||
nodes = self._docstore.get_nodes(node_ids) | ||
query_result.nodes = nodes | ||
else: | ||
# NOTE: vector store keeps text, returns nodes. | ||
# Only need to recover image or index nodes from docstore | ||
for i in range(len(query_result.nodes)): | ||
source_node = query_result.nodes[i].source_node | ||
if (not self._vector_store.stores_text) or ( | ||
source_node is not None and source_node.node_type != ObjectType.TEXT | ||
): | ||
node_id = query_result.nodes[i].node_id | ||
if self._docstore.document_exists(node_id): | ||
query_result.nodes[i] = self._docstore.get_node( | ||
node_id | ||
) # type: ignore[index] | ||
|
||
log_vector_store_query_result(query_result) | ||
|
||
node_with_scores: List[NodeWithScore] = [] | ||
query_result.similarities = sorted(query_result.similarities, reverse=True) | ||
for ind, node in enumerate(query_result.nodes): | ||
score: Optional[float] = None | ||
if query_result.similarities is not None: | ||
score = query_result.similarities[ind] | ||
node_with_scores.append(NodeWithScore(node=node, score=score)) | ||
|
||
return node_with_scores |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters