Skip to content

Commit

Permalink
Add LongContextReorder (run-llama#7719)
Browse files Browse the repository at this point in the history
  • Loading branch information
netoferraz authored Sep 20, 2023
1 parent fe905dc commit fdf0d7f
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
### New Features
- Add support for `gpt-3.5-turbo-instruct` (#7729)
- Add support for `TimescaleVectorStore` (#7727)
- Added `LongContextReorder` for lost-in-the-middle issues (#7719)

### Bug Fixes / Nits
- Added node post-processors to async context chat engine (#7731)
Expand Down
33 changes: 33 additions & 0 deletions llama_index/indices/postprocessor/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,3 +349,36 @@ def postprocess_nodes(

sorted_nodes = sorted(all_nodes.values(), key=lambda x: x.node.node_id)
return list(sorted_nodes)


class LongContextReorder(BaseNodePostprocessor):
"""
Models struggle to access significant details found
in the center of extended contexts. A study
(https://arxiv.org/abs/2307.03172) observed that the best
performance typically arises when crucial data is positioned
at the start or conclusion of the input context. Additionally,
as the input context lengthens, performance drops notably, even
in models designed for long contexts."
"""

@classmethod
def class_name(cls) -> str:
return "LongContextReorder"

def postprocess_nodes(
self,
nodes: List[NodeWithScore],
query_bundle: Optional[QueryBundle] = None,
) -> List[NodeWithScore]:
"""Postprocess nodes."""
reordered_nodes: List[NodeWithScore] = []
ordered_nodes: List[NodeWithScore] = sorted(
nodes, key=lambda x: x.score if x.score is not None else 0
)
for i, node in enumerate(ordered_nodes):
if i % 2 == 0:
reordered_nodes.insert(0, node)
else:
reordered_nodes.append(node)
return reordered_nodes
28 changes: 28 additions & 0 deletions tests/indices/postprocessor/test_longcontext_reorder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import List

from llama_index.indices.postprocessor.node import LongContextReorder
from llama_index.schema import Node, NodeWithScore


def test_long_context_reorder() -> None:
nodes = [
NodeWithScore(node=Node(text="text"), score=0.7),
NodeWithScore(node=Node(text="text"), score=0.8),
NodeWithScore(node=Node(text="text"), score=1.0),
NodeWithScore(node=Node(text="text"), score=0.2),
NodeWithScore(node=Node(text="text"), score=0.9),
NodeWithScore(node=Node(text="text"), score=1.5),
NodeWithScore(node=Node(text="text"), score=0.1),
NodeWithScore(node=Node(text="text"), score=1.6),
NodeWithScore(node=Node(text="text"), score=3.0),
NodeWithScore(node=Node(text="text"), score=0.4),
]
ordered_nodes: List[NodeWithScore] = sorted(
nodes, key=lambda x: x.score if x.score is not None else 0, reverse=True
)
expected_scores_at_tails = [n.score for n in ordered_nodes[:4]]
lcr = LongContextReorder()
filtered_nodes = lcr.postprocess_nodes(nodes)
nodes_lost_in_the_middle = [n.score for n in filtered_nodes[3:-2]]
assert set(expected_scores_at_tails).intersection(nodes_lost_in_the_middle) == set()
return None

0 comments on commit fdf0d7f

Please sign in to comment.