Skip to content

Commit

Permalink
Bugfix: embedding score & hybrid retrieval & LLM Rerank (#51)
Browse files Browse the repository at this point in the history
* fix bugs for retrieval

* fix bugs for retrieval

* fix bugs for retrieval
  • Loading branch information
wwxxzz authored Jun 5, 2024
1 parent 63d3339 commit 11c802a
Show file tree
Hide file tree
Showing 7 changed files with 149 additions and 15 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/docker.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v4

# Uses the `docker/login-action` action to log in to the Container registry registry using the account and password that will publish the packages. Once published, the packages are scoped to the account defined here.
- name: Login to ACR Beijing region
uses: aliyun/acr-login@v1
Expand Down
2 changes: 1 addition & 1 deletion src/pai_rag/app/web/tabs/chat_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def create_chat_tab() -> Dict[str, Any]:
elem_id="rerank_model",
)
retrieval_mode = gr.Radio(
["Embedding Only", "Keyword Ensembled", "Keyword Only"],
["Embedding Only", "Keyword Only", "Hybrid"],
label="Retrieval Mode",
elem_id="retrieval_mode",
)
Expand Down
4 changes: 2 additions & 2 deletions src/pai_rag/app/web/view_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def sync_app_config(self, config):

self.similarity_top_k = config["retriever"].get("similarity_top_k", 5)
if config["retriever"]["retrieval_mode"] == "hybrid":
self.retrieval_mode = "Keyword Ensembled"
self.retrieval_mode = "Hybrid"
self.BM25_weight = config["retriever"]["BM25_weight"]
self.vector_weight = config["retriever"]["vector_weight"]
self.fusion_mode = config["retriever"]["fusion_mode"]
Expand Down Expand Up @@ -275,7 +275,7 @@ def to_app_config(self):
] = self.milvus_collection_name

config["retriever"]["similarity_top_k"] = self.similarity_top_k
if self.retrieval_mode == "Keyword Ensembled":
if self.retrieval_mode == "Hybrid":
config["retriever"]["retrieval_mode"] = "hybrid"
config["retriever"]["vector_weight"] = self.vector_weight
config["retriever"]["BM25_weight"] = self.BM25_weight
Expand Down
57 changes: 57 additions & 0 deletions src/pai_rag/modules/postprocessor/my_llm_rerank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""LLM reranker."""
from typing import List, Optional

from llama_index.core.schema import NodeWithScore, QueryBundle
from llama_index.core.postprocessor import LLMRerank


class MyLLMRerank(LLMRerank):
"""LLM-based reranker.
Fix the bug : The number of retrieved chunks after LLM Rerank has decreased due to the uncertainty in LLM responses.
"""

def _postprocess_nodes(
self,
nodes: List[NodeWithScore],
query_bundle: Optional[QueryBundle] = None,
) -> List[NodeWithScore]:
if query_bundle is None:
raise ValueError("Query bundle must be provided.")
if len(nodes) == 0:
return []

initial_results: List[NodeWithScore] = []
for idx in range(0, len(nodes), self.choice_batch_size):
nodes_batch = [
node.node for node in nodes[idx : idx + self.choice_batch_size]
]

query_str = query_bundle.query_str
fmt_batch_str = self._format_node_batch_fn(nodes_batch)
# call each batch independently
raw_response = self.llm.predict(
self.choice_select_prompt,
context_str=fmt_batch_str,
query_str=query_str,
)

raw_choices, relevances = self._parse_choice_select_answer_fn(
raw_response, len(nodes_batch)
)
choice_idxs = [int(choice) - 1 for choice in raw_choices]
relevances = relevances
if len(choice_idxs) < len(nodes):
missing_numbers = set(range(1, len(nodes))).difference(choice_idxs)
choice_idxs.extend(missing_numbers)
relevances.extend([1.0 for _ in missing_numbers])
choice_nodes = [nodes_batch[idx] for idx in choice_idxs]
initial_results.extend(
[
NodeWithScore(node=node, score=relevance)
for node, relevance in zip(choice_nodes, relevances)
]
)

return sorted(initial_results, key=lambda x: x.score or 0.0, reverse=True)[
: self.top_n
]
8 changes: 3 additions & 5 deletions src/pai_rag/modules/postprocessor/postprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,12 @@
from typing import Dict, List, Any

# from modules.query.postprocessor.base import BaseNodePostprocessor
from llama_index.core.postprocessor import (
SimilarityPostprocessor,
LLMRerank,
)
from llama_index.core.postprocessor import SimilarityPostprocessor
from llama_index.postprocessor.flag_embedding_reranker import FlagEmbeddingReranker
from pai_rag.utils.constants import DEFAULT_MODEL_DIR
from pai_rag.modules.base.configurable_module import ConfigurableModule
from pai_rag.modules.base.module_constants import MODULE_PARAM_CONFIG
from pai_rag.modules.postprocessor.my_llm_rerank import MyLLMRerank

DEFAULT_RANK_MODEL = "bge-reranker-base"
DEFAULT_RANK_TOP_N = 2
Expand Down Expand Up @@ -43,7 +41,7 @@ def _create_new_instance(self, new_params: Dict[str, Any]):
if rerank_model == "llm-reranker":
top_n = config.get("top_n", DEFAULT_RANK_TOP_N)
logger.info(f"[PostProcessor]: Llm reranker used with top_n {top_n}.")
post_processors.append(LLMRerank(top_n=top_n, llm=llm))
post_processors.append(MyLLMRerank(top_n=top_n, llm=llm))

elif (
rerank_model == "bge-reranker-base" or rerank_model == "bge-reranker-large"
Expand Down
82 changes: 82 additions & 0 deletions src/pai_rag/modules/retriever/my_vector_index_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""Base vector store index query."""

from typing import List, Optional

from llama_index.core.data_structs.data_structs import IndexDict
from llama_index.core.indices.utils import log_vector_store_query_result
from llama_index.core.schema import NodeWithScore, ObjectType
from llama_index.core.vector_stores.types import (
VectorStoreQueryResult,
)
import llama_index.core.instrumentation as instrument

from llama_index.core.retrievers import (
VectorIndexRetriever,
)

dispatcher = instrument.get_dispatcher(__name__)


class MyVectorIndexRetriever(VectorIndexRetriever):
"""PAI-RAG customized vector index retriever.
Refactor the _build_node_list_from_query_result() function
and return the results with the query_result.similarities sorted in descending order.
Args:
index (VectorStoreIndex): vector store index.
similarity_top_k (int): number of top k results to return.
vector_store_query_mode (str): vector store query mode
See reference for VectorStoreQueryMode for full list of supported modes.
filters (Optional[MetadataFilters]): metadata filters, defaults to None
alpha (float): weight for sparse/dense retrieval, only used for
hybrid query mode.
doc_ids (Optional[List[str]]): list of documents to constrain search.
vector_store_kwargs (dict): Additional vector store specific kwargs to pass
through to the vector store at query time.
"""

def _build_node_list_from_query_result(
self, query_result: VectorStoreQueryResult
) -> List[NodeWithScore]:
if query_result.nodes is None:
# NOTE: vector store does not keep text and returns node indices.
# Need to recover all nodes from docstore
if query_result.ids is None:
raise ValueError(
"Vector store query result should return at "
"least one of nodes or ids."
)
assert isinstance(self._index.index_struct, IndexDict)
node_ids = [
self._index.index_struct.nodes_dict[idx] for idx in query_result.ids
]
nodes = self._docstore.get_nodes(node_ids)
query_result.nodes = nodes
else:
# NOTE: vector store keeps text, returns nodes.
# Only need to recover image or index nodes from docstore
for i in range(len(query_result.nodes)):
source_node = query_result.nodes[i].source_node
if (not self._vector_store.stores_text) or (
source_node is not None and source_node.node_type != ObjectType.TEXT
):
node_id = query_result.nodes[i].node_id
if self._docstore.document_exists(node_id):
query_result.nodes[i] = self._docstore.get_node(
node_id
) # type: ignore[index]

log_vector_store_query_result(query_result)

node_with_scores: List[NodeWithScore] = []
query_result.similarities = sorted(query_result.similarities, reverse=True)
for ind, node in enumerate(query_result.nodes):
score: Optional[float] = None
if query_result.similarities is not None:
score = query_result.similarities[ind]
node_with_scores.append(NodeWithScore(node=node, score=score))

return node_with_scores
9 changes: 3 additions & 6 deletions src/pai_rag/modules/retriever/retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@
import jieba
from nltk.corpus import stopwords
from llama_index.core.indices.list.base import SummaryIndex
from llama_index.core.retrievers import (
VectorIndexRetriever,
)
from llama_index.core.retrievers import QueryFusionRetriever
from llama_index.core.tools import RetrieverTool
from llama_index.core.selectors import LLMSingleSelector
Expand All @@ -19,7 +16,7 @@
from pai_rag.modules.base.configurable_module import ConfigurableModule
from pai_rag.modules.base.module_constants import MODULE_PARAM_CONFIG
from pai_rag.utils.prompt_template import QUERY_GEN_PROMPT

from pai_rag.modules.retriever.my_vector_index_retriever import MyVectorIndexRetriever

logger = logging.getLogger(__name__)

Expand All @@ -41,7 +38,7 @@ def _create_new_instance(self, new_params: Dict[str, Any]):

similarity_top_k = config.get("similarity_top_k", 5)
# vector
vector_retriever = VectorIndexRetriever(
vector_retriever = MyVectorIndexRetriever(
index=vector_index, similarity_top_k=similarity_top_k
)

Expand All @@ -53,7 +50,7 @@ def _create_new_instance(self, new_params: Dict[str, Any]):
)

if config["retrieval_mode"] == "embedding":
logger.info(f"VectorIndexRetriever used with top_k {similarity_top_k}.")
logger.info(f"MyVectorIndexRetriever used with top_k {similarity_top_k}.")
return vector_retriever

elif config["retrieval_mode"] == "keyword":
Expand Down

0 comments on commit 11c802a

Please sign in to comment.