Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add integration and unit tests for missing RRF coverage #997

Open
wants to merge 11 commits into
base: feature/rrf-score-normalization-v2
Choose a base branch
from
5 changes: 3 additions & 2 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ jobs:

- name: Upload Coverage Report
if: ${{ !cancelled() }}
uses: codecov/codecov-action@v4
uses: codecov/codecov-action@v3
with:
token: ${{ secrets.CODECOV_TOKEN }}

Expand Down Expand Up @@ -132,7 +132,7 @@ jobs:

- name: Upload Coverage Report
if: ${{ !cancelled() }}
uses: codecov/codecov-action@v4
uses: codecov/codecov-action@v3
with:
token: ${{ secrets.CODECOV_TOKEN }}

Expand Down Expand Up @@ -163,3 +163,4 @@ jobs:
- name: Run build
run: |
./gradlew precommit --parallel

1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.18...2.x)
### Features
- Implement Reciprocal Rank Fusion score normalization/combination technique in hybrid query ([#874](https://github.com/opensearch-project/neural-search/pull/874))
### Enhancements
### Bug Fixes
### Infrastructure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,22 @@
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.NeuralQueryEnricherProcessor;
import org.opensearch.neuralsearch.processor.NeuralSparseTwoPhaseProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow;
import org.opensearch.neuralsearch.processor.SparseEncodingProcessor;
import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor;
import org.opensearch.neuralsearch.processor.TextChunkingProcessor;
import org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor;
import org.opensearch.neuralsearch.processor.RRFProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessor;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory;
import org.opensearch.neuralsearch.processor.combination.ScoreCombiner;
import org.opensearch.neuralsearch.processor.factory.TextChunkingProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.RerankProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.SparseEncodingProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.TextImageEmbeddingProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.RRFProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer;
import org.opensearch.neuralsearch.processor.rerank.RerankProcessor;
Expand Down Expand Up @@ -154,7 +156,9 @@ public Map<String, org.opensearch.search.pipeline.Processor.Factory<SearchPhaseR
) {
return Map.of(
NormalizationProcessor.TYPE,
new NormalizationProcessorFactory(normalizationProcessorWorkflow, scoreNormalizationFactory, scoreCombinationFactory)
new NormalizationProcessorFactory(normalizationProcessorWorkflow, scoreNormalizationFactory, scoreCombinationFactory),
RRFProcessor.TYPE,
new RRFProcessorFactory(normalizationProcessorWorkflow, scoreNormalizationFactory, scoreCombinationFactory)
);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.processor;

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Getter;
import lombok.NonNull;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.search.fetch.FetchSearchResult;
import org.opensearch.search.query.QuerySearchResult;

import java.util.List;
import java.util.Optional;

/**
* DTO object to hold data in NormalizationProcessorWorkflow class
* in NormalizationProcessorWorkflow.
*/
@AllArgsConstructor
@Builder
@Getter
public class NormalizationExecuteDTO {
@NonNull
private List<QuerySearchResult> querySearchResults;
@NonNull
private Optional<FetchSearchResult> fetchSearchResultOptional;
@NonNull
private ScoreNormalizationTechnique normalizationTechnique;
@NonNull
private ScoreCombinationTechnique combinationTechnique;
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,15 @@ public <Result extends SearchPhaseResult> void process(
}
List<QuerySearchResult> querySearchResults = getQueryPhaseSearchResults(searchPhaseResult);
Optional<FetchSearchResult> fetchSearchResult = getFetchSearchResults(searchPhaseResult);
normalizationWorkflow.execute(querySearchResults, fetchSearchResult, normalizationTechnique, combinationTechnique);
// Builds data transfer object to pass into execute, DTO has nullable field for rankConstant which
// is only used for RRF technique
NormalizationExecuteDTO normalizationExecuteDTO = NormalizationExecuteDTO.builder()
.querySearchResults(querySearchResults)
.fetchSearchResultOptional(fetchSearchResult)
.normalizationTechnique(normalizationTechnique)
.combinationTechnique(combinationTechnique)
.build();
normalizationWorkflow.execute(normalizationExecuteDTO);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,26 +47,31 @@ public class NormalizationProcessorWorkflow {

/**
* Start execution of this workflow
* @param querySearchResults input data with QuerySearchResult from multiple shards
* @param normalizationTechnique technique for score normalization
* @param combinationTechnique technique for score combination
* @param normalizationExecuteDTO contains querySearchResults input data with QuerySearchResult
* from multiple shards, fetchSearchResultOptional, normalizationTechnique technique for score normalization
* combinationTechnique technique for score combination, and nullable rankConstant only used in RRF technique
*/
public void execute(
final List<QuerySearchResult> querySearchResults,
final Optional<FetchSearchResult> fetchSearchResultOptional,
final ScoreNormalizationTechnique normalizationTechnique,
final ScoreCombinationTechnique combinationTechnique
) {
public void execute(final NormalizationExecuteDTO normalizationExecuteDTO) {
final List<QuerySearchResult> querySearchResults = normalizationExecuteDTO.getQuerySearchResults();
final Optional<FetchSearchResult> fetchSearchResultOptional = normalizationExecuteDTO.getFetchSearchResultOptional();
final ScoreNormalizationTechnique normalizationTechnique = normalizationExecuteDTO.getNormalizationTechnique();
final ScoreCombinationTechnique combinationTechnique = normalizationExecuteDTO.getCombinationTechnique();
// save original state
List<Integer> unprocessedDocIds = unprocessedDocIds(querySearchResults);

// pre-process data
log.debug("Pre-process query results");
List<CompoundTopDocs> queryTopDocs = getQueryTopDocs(querySearchResults);

// Data transfer object for score normalization used to pass nullable rankConstant which is only used in RRF
NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder()
.queryTopDocs(queryTopDocs)
.normalizationTechnique(normalizationTechnique)
.build();

// normalize
log.debug("Do score normalization");
scoreNormalizer.normalizeScores(queryTopDocs, normalizationTechnique);
scoreNormalizer.normalizeScores(normalizeScoresDTO);

CombineScoresDto combineScoresDTO = CombineScoresDto.builder()
.queryTopDocs(queryTopDocs)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.processor;

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Getter;
import lombok.NonNull;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;

import java.util.List;

/**
* DTO object to hold data required for score normalization.
*/
@AllArgsConstructor
@Builder
@Getter
public class NormalizeScoresDTO {
@NonNull
private List<CompoundTopDocs> queryTopDocs;
@NonNull
private ScoreNormalizationTechnique normalizationTechnique;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.processor;

import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryStartStopElement;

import java.util.stream.Collectors;

import java.util.List;
import java.util.Objects;
import java.util.Optional;

import lombok.Getter;
import org.opensearch.ml.repackage.com.google.common.annotations.VisibleForTesting;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.search.fetch.FetchSearchResult;
import org.opensearch.search.query.QuerySearchResult;

import org.opensearch.action.search.QueryPhaseResultConsumer;
import org.opensearch.action.search.SearchPhaseContext;
import org.opensearch.action.search.SearchPhaseName;
import org.opensearch.action.search.SearchPhaseResults;
import org.opensearch.search.SearchPhaseResult;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.pipeline.SearchPhaseResultsProcessor;

import lombok.AllArgsConstructor;
import lombok.extern.log4j.Log4j2;

/**
* Processor for implementing reciprocal rank fusion technique on post
* query search results. Updates query results with
* normalized and combined scores for next phase (typically it's FETCH)
* by using ranks from individual subqueries to calculate 'normalized'
* scores before combining results from subqueries into final results
*/
@Log4j2
@AllArgsConstructor
public class RRFProcessor implements SearchPhaseResultsProcessor {
public static final String TYPE = "score-ranker-processor";

@Getter
private final String tag;
@Getter
private final String description;
private final ScoreNormalizationTechnique normalizationTechnique;
private final ScoreCombinationTechnique combinationTechnique;
private final NormalizationProcessorWorkflow normalizationWorkflow;

/**
* Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage
* are set as part of class constructor
* @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution
* @param searchPhaseContext {@link SearchContext}
*/
@Override
public <Result extends SearchPhaseResult> void process(
final SearchPhaseResults<Result> searchPhaseResult,
final SearchPhaseContext searchPhaseContext
) {
if (shouldSkipProcessor(searchPhaseResult)) {
log.debug("Query results are not compatible with RRF processor");
return;
}
List<QuerySearchResult> querySearchResults = getQueryPhaseSearchResults(searchPhaseResult);
Optional<FetchSearchResult> fetchSearchResult = getFetchSearchResults(searchPhaseResult);

// make data transfer object to pass in, execute will get object with 4 or 5 fields, depending
// on coming from NormalizationProcessor or RRFProcessor
NormalizationExecuteDTO normalizationExecuteDTO = NormalizationExecuteDTO.builder()
.querySearchResults(querySearchResults)
.fetchSearchResultOptional(fetchSearchResult)
.normalizationTechnique(normalizationTechnique)
.combinationTechnique(combinationTechnique)
.build();
normalizationWorkflow.execute(normalizationExecuteDTO);
}

@Override
public SearchPhaseName getBeforePhase() {
return SearchPhaseName.QUERY;
}

@Override
public SearchPhaseName getAfterPhase() {
return SearchPhaseName.FETCH;
}

@Override
public String getType() {
return TYPE;
}

@Override
public boolean isIgnoreFailure() {
return false;
}

@VisibleForTesting
<Result extends SearchPhaseResult> boolean shouldSkipProcessor(SearchPhaseResults<Result> searchPhaseResult) {
if (Objects.isNull(searchPhaseResult) || !(searchPhaseResult instanceof QueryPhaseResultConsumer queryPhaseResultConsumer)) {
return true;
}

return queryPhaseResultConsumer.getAtomicArray().asList().stream().filter(Objects::nonNull).noneMatch(this::isHybridQuery);
}

/**
* Return true if results are from hybrid query.
* @param searchPhaseResult
* @return true if results are from hybrid query
*/
@VisibleForTesting
boolean isHybridQuery(final SearchPhaseResult searchPhaseResult) {
// check for delimiter at the end of the score docs.
return Objects.nonNull(searchPhaseResult.queryResult())
&& Objects.nonNull(searchPhaseResult.queryResult().topDocs())
&& Objects.nonNull(searchPhaseResult.queryResult().topDocs().topDocs.scoreDocs)
&& searchPhaseResult.queryResult().topDocs().topDocs.scoreDocs.length > 0
&& isHybridQueryStartStopElement(searchPhaseResult.queryResult().topDocs().topDocs.scoreDocs[0]);
}

@VisibleForTesting
<Result extends SearchPhaseResult> List<QuerySearchResult> getQueryPhaseSearchResults(final SearchPhaseResults<Result> results) {
return results.getAtomicArray()
.asList()
.stream()
.map(result -> result == null ? null : result.queryResult())
.collect(Collectors.toList());
}

@VisibleForTesting
<Result extends SearchPhaseResult> Optional<FetchSearchResult> getFetchSearchResults(
final SearchPhaseResults<Result> searchPhaseResults
) {
Optional<Result> optionalFirstSearchPhaseResult = searchPhaseResults.getAtomicArray().asList().stream().findFirst();
return optionalFirstSearchPhaseResult.map(SearchPhaseResult::fetchResult);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.processor.combination;

import lombok.ToString;
import lombok.extern.log4j.Log4j2;

import java.util.Map;

@Log4j2
/**
* Abstracts combination of scores based on reciprocal rank fusion algorithm
*/
@ToString(onlyExplicitlyIncluded = true)
public class RRFScoreCombinationTechnique implements ScoreCombinationTechnique {
@ToString.Include
public static final String TECHNIQUE_NAME = "rrf";

// Not currently using weights for RRF, no need to modify or verify these params
public RRFScoreCombinationTechnique(final Map<String, Object> params, final ScoreCombinationUtil combinationUtil) {}

@Override
public float combine(final float[] scores) {
float sumScores = 0.0f;
for (float score : scores) {
sumScores += score;
}
return sumScores;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ public class ScoreCombinationFactory {
HarmonicMeanScoreCombinationTechnique.TECHNIQUE_NAME,
params -> new HarmonicMeanScoreCombinationTechnique(params, scoreCombinationUtil),
GeometricMeanScoreCombinationTechnique.TECHNIQUE_NAME,
params -> new GeometricMeanScoreCombinationTechnique(params, scoreCombinationUtil)
params -> new GeometricMeanScoreCombinationTechnique(params, scoreCombinationUtil),
RRFScoreCombinationTechnique.TECHNIQUE_NAME,
params -> new RRFScoreCombinationTechnique(params, scoreCombinationUtil)
);

/**
Expand Down
Loading
Loading