Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Pagination in hybrid query #963

Open
wants to merge 13 commits into
base: Pagination_in_hybridQuery
Choose a base branch
from
2 changes: 1 addition & 1 deletion .github/CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# This should match the owning team set up in https://github.com/orgs/opensearch-project/teams
* @heemin32 @navneet1v @VijayanB @vamshin @jmazanec15 @naveentatikonda @junqiu-lei @martin-gaievski @sean-zheng-amazon @model-collapse @zane-neo @ylwu-amzn @jngz-es @vibrantvarun @zhichao-aws @yuye-aws
* @heemin32 @navneet1v @VijayanB @vamshin @jmazanec15 @naveentatikonda @junqiu-lei @martin-gaievski @sean-zheng-amazon @model-collapse @zane-neo @vibrantvarun @zhichao-aws @yuye-aws @minalsha
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks like artifacts of improper rebase, can you please rebase on main properly

1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.18...2.x)
### Features
- Pagination in Hybrid query ([#963](https://github.com/opensearch-project/neural-search/pull/963))
### Enhancements
### Bug Fixes
### Infrastructure
Expand Down
6 changes: 4 additions & 2 deletions MAINTAINERS.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ This document contains a list of maintainers in this repo. See [opensearch-proje
| Charlie Yang | [model-collapse](https://github.com/model-collapse) | Amazon |
| Navneet Verma | [navneet1v](https://github.com/navneet1v) | Amazon |
| Zan Niu | [zane-neo](https://github.com/zane-neo) | Amazon |
| Yaliang Wu | [ylwu-amzn](https://github.com/ylwu-amzn) | Amazon |
| Jing Zhang | [jngz-es](https://github.com/jngz-es) | Amazon |
| Heemin Kim | [heemin32](https://github.com/heemin32) | Amazon |
| Junqiu Lei | [junqiu-lei](https://github.com/junqiu-lei) | Amazon |
| Martin Gaievski | [martin-gaievski](https://github.com/martin-gaievski) | Amazon |
Expand All @@ -22,9 +20,13 @@ This document contains a list of maintainers in this repo. See [opensearch-proje
| Varun Jain | [vibrantvarun](https://github.com/vibrantvarun) | Amazon |
| Zhichao Geng | [zhichao-aws](https://github.com/zhichao-aws) | Amazon |
| Yuye Zhu | [yuye-aws](https://github.com/yuye-aws) | Amazon |
| Minal Shah | [minalsha](https://github.com/minalsha) | Amazon |


## Emeritus

| Maintainer | GitHub ID | Affiliation |
|-------------------------|---------------------------------------------|-------------|
| Junshen Wu | [wujunshen](https://github.com/wujunshen) | Independent |
| Yaliang Wu | [ylwu-amzn](https://github.com/ylwu-amzn) | Amazon |
| Jing Zhang | [jngz-es](https://github.com/jngz-es) | Amazon |
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ public final class MinClusterVersionUtil {

private static final Version MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID = Version.V_2_11_0;
private static final Version MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH = Version.V_2_14_0;
private static final Version MINIMAL_SUPPORTED_VERSION_PAGINATION_IN_HYBRID_QUERY = Version.V_2_19_0;

// Note this minimal version will act as a override
private static final Map<String, Version> MINIMAL_VERSION_NEURAL = ImmutableMap.<String, Version>builder()
Expand All @@ -38,6 +39,10 @@ public static boolean isClusterOnOrAfterMinReqVersionForRadialSearch() {
return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH);
}

public static boolean isClusterOnOrAfterMinReqVersionForPaginationInHybridQuery() {
return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_PAGINATION_IN_HYBRID_QUERY);
}

public static boolean isClusterOnOrAfterMinReqVersion(String key) {
Version version;
if (MINIMAL_VERSION_NEURAL.containsKey(key)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.opensearch.action.search.SearchPhaseName;
import org.opensearch.action.search.SearchPhaseResults;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.dto.NormalizationExecuteDto;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.search.SearchPhaseResult;
import org.opensearch.search.fetch.FetchSearchResult;
Expand Down Expand Up @@ -58,7 +59,16 @@ public <Result extends SearchPhaseResult> void process(
}
List<QuerySearchResult> querySearchResults = getQueryPhaseSearchResults(searchPhaseResult);
Optional<FetchSearchResult> fetchSearchResult = getFetchSearchResults(searchPhaseResult);
normalizationWorkflow.execute(querySearchResults, fetchSearchResult, normalizationTechnique, combinationTechnique);
// Builds data transfer object to pass into execute
NormalizationExecuteDto normalizationExecuteDto = NormalizationExecuteDto.builder()
.querySearchResults(querySearchResults)
.fetchSearchResultOptional(fetchSearchResult)
.normalizationTechnique(normalizationTechnique)
.combinationTechnique(combinationTechnique)
.searchPhaseContext(searchPhaseContext)
.build();

normalizationWorkflow.execute(normalizationExecuteDto);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.TopFieldDocs;
import org.apache.lucene.search.FieldDoc;
import org.opensearch.action.search.SearchPhaseContext;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.neuralsearch.processor.combination.CombineScoresDto;
import org.opensearch.neuralsearch.processor.dto.CombineScoresDto;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.combination.ScoreCombiner;
import org.opensearch.neuralsearch.processor.dto.NormalizationExecuteDto;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer;
import org.opensearch.search.SearchHit;
Expand All @@ -47,16 +49,17 @@ public class NormalizationProcessorWorkflow {

/**
* Start execution of this workflow
* @param querySearchResults input data with QuerySearchResult from multiple shards
* @param normalizationTechnique technique for score normalization
* @param combinationTechnique technique for score combination
* @param normalizationExecuteDto contains querySearchResults input data with QuerySearchResult
* from multiple shards, fetchSearchResultOptional, normalizationTechnique technique for score normalization
* combinationTechnique technique for score combination, searchPhaseContext.
*/
public void execute(
final List<QuerySearchResult> querySearchResults,
final Optional<FetchSearchResult> fetchSearchResultOptional,
final ScoreNormalizationTechnique normalizationTechnique,
final ScoreCombinationTechnique combinationTechnique
) {
public void execute(final NormalizationExecuteDto normalizationExecuteDto) {
final List<QuerySearchResult> querySearchResults = normalizationExecuteDto.getQuerySearchResults();
final Optional<FetchSearchResult> fetchSearchResultOptional = normalizationExecuteDto.getFetchSearchResultOptional();
final ScoreNormalizationTechnique normalizationTechnique = normalizationExecuteDto.getNormalizationTechnique();
final ScoreCombinationTechnique combinationTechnique = normalizationExecuteDto.getCombinationTechnique();
final SearchPhaseContext searchPhaseContext = normalizationExecuteDto.getSearchPhaseContext();

// save original state
List<Integer> unprocessedDocIds = unprocessedDocIds(querySearchResults);

Expand All @@ -73,6 +76,8 @@ public void execute(
.scoreCombinationTechnique(combinationTechnique)
.querySearchResults(querySearchResults)
.sort(evaluateSortCriteria(querySearchResults, queryTopDocs))
.fromValueForSingleShard(searchPhaseContext.getRequest().source().from())
.isFetchResultsPresent(fetchSearchResultOptional.isPresent())
.build();

// combine
Expand All @@ -82,7 +87,12 @@ public void execute(
// post-process data
log.debug("Post-process query results after score normalization and combination");
updateOriginalQueryResults(combineScoresDTO);
updateOriginalFetchResults(querySearchResults, fetchSearchResultOptional, unprocessedDocIds);
updateOriginalFetchResults(
querySearchResults,
fetchSearchResultOptional,
unprocessedDocIds,
combineScoresDTO.getFromValueForSingleShard()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can just pass value from searchPhaseContext.getRequest().source(), without adding it to DTO

);
}

/**
Expand Down Expand Up @@ -113,15 +123,29 @@ private void updateOriginalQueryResults(final CombineScoresDto combineScoresDTO)
final List<QuerySearchResult> querySearchResults = combineScoresDTO.getQuerySearchResults();
final List<CompoundTopDocs> queryTopDocs = getCompoundTopDocs(combineScoresDTO, querySearchResults);
final Sort sort = combineScoresDTO.getSort();
int totalScoreDocsCount = 0;
for (int index = 0; index < querySearchResults.size(); index++) {
QuerySearchResult querySearchResult = querySearchResults.get(index);
CompoundTopDocs updatedTopDocs = queryTopDocs.get(index);
totalScoreDocsCount += updatedTopDocs.getScoreDocs().size();
TopDocsAndMaxScore updatedTopDocsAndMaxScore = new TopDocsAndMaxScore(
buildTopDocs(updatedTopDocs, sort),
maxScoreForShard(updatedTopDocs, sort != null)
);
// Fetch Phase had ran before the normalization phase, therefore update the from value in result of each shard.
// This will ensure the trimming of the results.
if (combineScoresDTO.isFetchResultsPresent()) {
querySearchResult.from(combineScoresDTO.getFromValueForSingleShard());
}
querySearchResult.topDocs(updatedTopDocsAndMaxScore, querySearchResult.sortValueFormats());
}

final int from = querySearchResults.get(0).from();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need final?

if (from > 0 && from > totalScoreDocsCount) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

first check looks redundant, can't we rely only on from > totalScoreDocsCount?

throw new IllegalArgumentException(
String.format(Locale.ROOT, "Reached end of search result, increase pagination_depth value to see more results")
);
}
}

private List<CompoundTopDocs> getCompoundTopDocs(CombineScoresDto combineScoresDTO, List<QuerySearchResult> querySearchResults) {
Expand Down Expand Up @@ -180,7 +204,8 @@ private TopDocs buildTopDocs(CompoundTopDocs updatedTopDocs, Sort sort) {
private void updateOriginalFetchResults(
final List<QuerySearchResult> querySearchResults,
final Optional<FetchSearchResult> fetchSearchResultOptional,
final List<Integer> docIds
final List<Integer> docIds,
final int fromValueForSingleShard
) {
if (fetchSearchResultOptional.isEmpty()) {
return;
Expand Down Expand Up @@ -212,14 +237,21 @@ private void updateOriginalFetchResults(

QuerySearchResult querySearchResult = querySearchResults.get(0);
TopDocs topDocs = querySearchResult.topDocs().topDocs;

// When normalization process will execute before the fetch phase, then from =0 is applicable.
// When normalization process runs after fetch phase, then search hits already fetched. Therefore, use the from value sent in the
// search request.
// iterate over the normalized/combined scores, that solves (1) and (3)
SearchHit[] updatedSearchHitArray = Arrays.stream(topDocs.scoreDocs).map(scoreDoc -> {
SearchHit[] updatedSearchHitArray = new SearchHit[topDocs.scoreDocs.length - fromValueForSingleShard];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please pull topDocs.scoreDocs.length - fromValueForSingleShard expression to a variable and give it meaningful name

for (int i = fromValueForSingleShard; i < topDocs.scoreDocs.length; i++) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please change the semantic here, start from 0 and do (i + offset) when you're reading from topDocs

ScoreDoc scoreDoc = topDocs.scoreDocs[i];
// get fetched hit content by doc_id
SearchHit searchHit = docIdToSearchHit.get(scoreDoc.doc);
// update score to normalized/combined value (3)
searchHit.score(scoreDoc.score);
return searchHit;
}).toArray(SearchHit[]::new);
updatedSearchHitArray[i - fromValueForSingleShard] = searchHit;
}

SearchHits updatedSearchHits = new SearchHits(
updatedSearchHitArray,
querySearchResult.getTotalHits(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.opensearch.neuralsearch.processor.CompoundTopDocs;

import lombok.extern.log4j.Log4j2;
import org.opensearch.neuralsearch.processor.dto.CombineScoresDto;

/**
* Abstracts combination of scores in query search results.
Expand Down Expand Up @@ -65,14 +66,10 @@ public class ScoreCombiner {
public void combineScores(final CombineScoresDto combineScoresDTO) {
// iterate over results from each shard. Every CompoundTopDocs object has results from
// multiple sub queries, doc ids may repeat for each sub query results
ScoreCombinationTechnique scoreCombinationTechnique = combineScoresDTO.getScoreCombinationTechnique();
Sort sort = combineScoresDTO.getSort();
combineScoresDTO.getQueryTopDocs()
.forEach(
compoundQueryTopDocs -> combineShardScores(
combineScoresDTO.getScoreCombinationTechnique(),
compoundQueryTopDocs,
combineScoresDTO.getSort()
)
);
.forEach(compoundQueryTopDocs -> combineShardScores(scoreCombinationTechnique, compoundQueryTopDocs, sort));
}

private void combineShardScores(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,18 @@
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.processor.combination;
package org.opensearch.neuralsearch.processor.dto;

import java.util.List;

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Getter;
import lombok.NonNull;
import org.apache.lucene.search.Sort;
import org.opensearch.common.Nullable;
import org.opensearch.neuralsearch.processor.CompoundTopDocs;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.search.query.QuerySearchResult;

/**
Expand All @@ -29,4 +31,6 @@ public class CombineScoresDto {
private List<QuerySearchResult> querySearchResults;
@Nullable
private Sort sort;
private int fromValueForSingleShard;
private boolean isFetchResultsPresent;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it doesn't look right to put this field here, it's not related to combination. please find alternative solution

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.processor.dto;

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Getter;
import lombok.NonNull;
import org.opensearch.action.search.SearchPhaseContext;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.search.fetch.FetchSearchResult;
import org.opensearch.search.query.QuerySearchResult;

import java.util.List;
import java.util.Optional;

/**
* DTO object to hold data in NormalizationProcessorWorkflow class
* in NormalizationProcessorWorkflow.
*/
@AllArgsConstructor
@Builder
@Getter
public class NormalizationExecuteDto {
@NonNull
private List<QuerySearchResult> querySearchResults;
@NonNull
private Optional<FetchSearchResult> fetchSearchResultOptional;
@NonNull
private ScoreNormalizationTechnique normalizationTechnique;
@NonNull
private ScoreCombinationTechnique combinationTechnique;
@NonNull
private SearchPhaseContext searchPhaseContext;
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import java.util.Objects;
import java.util.concurrent.Callable;

import lombok.Getter;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.IndexSearcher;
Expand All @@ -31,20 +32,25 @@
* Implementation of Query interface for type "hybrid". It allows execution of multiple sub-queries and collect individual
* scores for each sub-query.
*/
@Getter
public final class HybridQuery extends Query implements Iterable<Query> {

private final List<Query> subQueries;
private Integer paginationDepth;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this is not primitive int? Operating with wrapper class is potentially error prone when boxing/unboxing a null value.


/**
* Create new instance of hybrid query object based on collection of sub queries and filter query
* @param subQueries collection of queries that are executed individually and contribute to a final list of combined scores
* @param filterQueries list of filters that will be applied to each sub query. Each filter from the list is added as bool "filter" clause. If this is null sub queries will be executed as is
*/
public HybridQuery(final Collection<Query> subQueries, final List<Query> filterQueries) {
public HybridQuery(final Collection<Query> subQueries, final List<Query> filterQueries, Integer paginationDepth) {
Objects.requireNonNull(subQueries, "collection of queries must not be null");
if (subQueries.isEmpty()) {
throw new IllegalArgumentException("collection of queries must not be empty");
}
if (paginationDepth != null && paginationDepth == 0) {
throw new IllegalArgumentException("pagination depth must not be zero");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change to pagination_depth in error message to signify that this is a parameter

}
if (Objects.isNull(filterQueries) || filterQueries.isEmpty()) {
this.subQueries = new ArrayList<>(subQueries);
} else {
Expand All @@ -57,10 +63,11 @@ public HybridQuery(final Collection<Query> subQueries, final List<Query> filterQ
}
this.subQueries = modifiedSubQueries;
}
this.paginationDepth = paginationDepth;
}

public HybridQuery(final Collection<Query> subQueries) {
this(subQueries, List.of());
public HybridQuery(final Collection<Query> subQueries, final Integer paginationDepth) {
this(subQueries, List.of(), paginationDepth);
}

/**
Expand Down Expand Up @@ -128,7 +135,7 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
return super.rewrite(indexSearcher);
}
final List<Query> rewrittenSubQueries = manager.getQueriesAfterRewrite(collectors);
return new HybridQuery(rewrittenSubQueries);
return new HybridQuery(rewrittenSubQueries, paginationDepth);
}

private Void rewriteQuery(Query query, HybridQueryExecutorCollector<IndexSearcher, Map.Entry<Query, Boolean>> collector) {
Expand Down
Loading
Loading