From 951d11a19ded8b415691b2b6f64ed6e56e440a61 Mon Sep 17 00:00:00 2001 From: Isaac Johnson <114550967+Johnsonisaacn@users.noreply.github.com> Date: Fri, 18 Oct 2024 09:44:07 -0700 Subject: [PATCH 01/10] Reciprocal Rank Fusion (RRF) normalization technique in hybrid query (#874) * initial commit of RRF Signed-off-by: Isaac Johnson Co-authored-by: Varun Jain Signed-off-by: Martin Gaievski --- CHANGELOG.md | 1 + .../neuralsearch/plugin/NeuralSearch.java | 10 +- .../processor/NormalizationExecuteDTO.java | 35 +++ .../processor/NormalizationProcessor.java | 10 +- .../NormalizationProcessorWorkflow.java | 25 +- .../processor/NormalizeScoresDTO.java | 26 ++ .../neuralsearch/processor/RRFProcessor.java | 139 +++++++++++ .../RRFScoreCombinationTechnique.java | 32 +++ .../combination/ScoreCombinationFactory.java | 4 +- .../combination/ScoreCombinationUtil.java | 6 +- .../factory/RRFProcessorFactory.java | 79 ++++++ .../L2ScoreNormalizationTechnique.java | 4 +- .../MinMaxScoreNormalizationTechnique.java | 5 +- .../RRFNormalizationTechnique.java | 106 ++++++++ .../ScoreNormalizationFactory.java | 18 +- .../ScoreNormalizationTechnique.java | 14 +- .../normalization/ScoreNormalizationUtil.java | 57 +++++ .../normalization/ScoreNormalizer.java | 15 +- .../plugin/NeuralSearchTests.java | 11 +- .../NormalizationProcessorTests.java | 7 +- .../NormalizationProcessorWorkflowTests.java | 84 ++++--- .../ScoreNormalizationTechniqueTests.java | 31 ++- .../RRFScoreCombinationTechniqueTests.java | 35 +++ .../ScoreCombinationFactoryTests.java | 8 + ....java => ScoreNormalizationUtilTests.java} | 2 +- .../factory/RRFProcessorFactoryTests.java | 214 ++++++++++++++++ .../L2ScoreNormalizationTechniqueTests.java | 21 +- ...inMaxScoreNormalizationTechniqueTests.java | 19 +- .../RRFNormalizationTechniqueTests.java | 232 ++++++++++++++++++ .../ScoreNormalizationFactoryTests.java | 8 + .../query/OpenSearchQueryTestCase.java | 2 + .../query/HybridCollectorManagerTests.java | 1 - .../HybridQueryScoreDocsMergerTests.java | 2 - .../search/query/TopDocsMergerTests.java | 2 - 34 files changed, 1168 insertions(+), 97 deletions(-) create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/NormalizationExecuteDTO.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/NormalizeScoresDTO.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/RRFProcessor.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/combination/RRFScoreCombinationTechnique.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/factory/RRFProcessorFactory.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechnique.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationUtil.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/combination/RRFScoreCombinationTechniqueTests.java rename src/test/java/org/opensearch/neuralsearch/processor/combination/{ScoreCombinationUtilTests.java => ScoreNormalizationUtilTests.java} (97%) create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/factory/RRFProcessorFactoryTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechniqueTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 595ea7dd4..e8621d2a4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.18...2.x) ### Features +- Implement Reciprocal Rank Fusion score normalization/combination technique in hybrid query ([#874](https://github.com/opensearch-project/neural-search/pull/874)) ### Enhancements ### Bug Fixes ### Infrastructure diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index 8b173ba81..8b9016323 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -30,20 +30,22 @@ import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.processor.NeuralQueryEnricherProcessor; import org.opensearch.neuralsearch.processor.NeuralSparseTwoPhaseProcessor; -import org.opensearch.neuralsearch.processor.NormalizationProcessor; import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow; import org.opensearch.neuralsearch.processor.SparseEncodingProcessor; import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor; import org.opensearch.neuralsearch.processor.TextChunkingProcessor; import org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor; +import org.opensearch.neuralsearch.processor.RRFProcessor; +import org.opensearch.neuralsearch.processor.NormalizationProcessor; import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory; import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; import org.opensearch.neuralsearch.processor.factory.TextChunkingProcessorFactory; -import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory; import org.opensearch.neuralsearch.processor.factory.RerankProcessorFactory; import org.opensearch.neuralsearch.processor.factory.SparseEncodingProcessorFactory; import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory; import org.opensearch.neuralsearch.processor.factory.TextImageEmbeddingProcessorFactory; +import org.opensearch.neuralsearch.processor.factory.RRFProcessorFactory; +import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer; import org.opensearch.neuralsearch.processor.rerank.RerankProcessor; @@ -154,7 +156,9 @@ public Map querySearchResults; + @NonNull + private Optional fetchSearchResultOptional; + @NonNull + private ScoreNormalizationTechnique normalizationTechnique; + @NonNull + private ScoreCombinationTechnique combinationTechnique; +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java index 0563c92a0..231749f33 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java @@ -58,7 +58,15 @@ public void process( } List querySearchResults = getQueryPhaseSearchResults(searchPhaseResult); Optional fetchSearchResult = getFetchSearchResults(searchPhaseResult); - normalizationWorkflow.execute(querySearchResults, fetchSearchResult, normalizationTechnique, combinationTechnique); + // Builds data transfer object to pass into execute, DTO has nullable field for rankConstant which + // is only used for RRF technique + NormalizationExecuteDTO normalizationExecuteDTO = NormalizationExecuteDTO.builder() + .querySearchResults(querySearchResults) + .fetchSearchResultOptional(fetchSearchResult) + .normalizationTechnique(normalizationTechnique) + .combinationTechnique(combinationTechnique) + .build(); + normalizationWorkflow.execute(normalizationExecuteDTO); } @Override diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java index c64f1c1f4..6507e3bd9 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java @@ -47,16 +47,15 @@ public class NormalizationProcessorWorkflow { /** * Start execution of this workflow - * @param querySearchResults input data with QuerySearchResult from multiple shards - * @param normalizationTechnique technique for score normalization - * @param combinationTechnique technique for score combination + * @param normalizationExecuteDTO contains querySearchResults input data with QuerySearchResult + * from multiple shards, fetchSearchResultOptional, normalizationTechnique technique for score normalization + * combinationTechnique technique for score combination, and nullable rankConstant only used in RRF technique */ - public void execute( - final List querySearchResults, - final Optional fetchSearchResultOptional, - final ScoreNormalizationTechnique normalizationTechnique, - final ScoreCombinationTechnique combinationTechnique - ) { + public void execute(final NormalizationExecuteDTO normalizationExecuteDTO) { + final List querySearchResults = normalizationExecuteDTO.getQuerySearchResults(); + final Optional fetchSearchResultOptional = normalizationExecuteDTO.getFetchSearchResultOptional(); + final ScoreNormalizationTechnique normalizationTechnique = normalizationExecuteDTO.getNormalizationTechnique(); + final ScoreCombinationTechnique combinationTechnique = normalizationExecuteDTO.getCombinationTechnique(); // save original state List unprocessedDocIds = unprocessedDocIds(querySearchResults); @@ -64,9 +63,15 @@ public void execute( log.debug("Pre-process query results"); List queryTopDocs = getQueryTopDocs(querySearchResults); + // Data transfer object for score normalization used to pass nullable rankConstant which is only used in RRF + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(queryTopDocs) + .normalizationTechnique(normalizationTechnique) + .build(); + // normalize log.debug("Do score normalization"); - scoreNormalizer.normalizeScores(queryTopDocs, normalizationTechnique); + scoreNormalizer.normalizeScores(normalizeScoresDTO); CombineScoresDto combineScoresDTO = CombineScoresDto.builder() .queryTopDocs(queryTopDocs) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizeScoresDTO.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizeScoresDTO.java new file mode 100644 index 000000000..c932a157d --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizeScoresDTO.java @@ -0,0 +1,26 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Getter; +import lombok.NonNull; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; + +import java.util.List; + +/** + * DTO object to hold data required for score normalization. + */ +@AllArgsConstructor +@Builder +@Getter +public class NormalizeScoresDTO { + @NonNull + private List queryTopDocs; + @NonNull + private ScoreNormalizationTechnique normalizationTechnique; +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/RRFProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/RRFProcessor.java new file mode 100644 index 000000000..207af156c --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/RRFProcessor.java @@ -0,0 +1,139 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor; + +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryStartStopElement; + +import java.util.stream.Collectors; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import lombok.Getter; +import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; +import org.opensearch.search.fetch.FetchSearchResult; +import org.opensearch.search.query.QuerySearchResult; + +import org.opensearch.action.search.QueryPhaseResultConsumer; +import org.opensearch.action.search.SearchPhaseContext; +import org.opensearch.action.search.SearchPhaseName; +import org.opensearch.action.search.SearchPhaseResults; +import org.opensearch.search.SearchPhaseResult; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.pipeline.SearchPhaseResultsProcessor; + +import lombok.AllArgsConstructor; +import lombok.extern.log4j.Log4j2; + +/** + * Processor for implementing reciprocal rank fusion technique on post + * query search results. Updates query results with + * normalized and combined scores for next phase (typically it's FETCH) + * by using ranks from individual subqueries to calculate 'normalized' + * scores before combining results from subqueries into final results + */ +@Log4j2 +@AllArgsConstructor +public class RRFProcessor implements SearchPhaseResultsProcessor { + public static final String TYPE = "score-ranker-processor"; + + @Getter + private final String tag; + @Getter + private final String description; + private final ScoreNormalizationTechnique normalizationTechnique; + private final ScoreCombinationTechnique combinationTechnique; + private final NormalizationProcessorWorkflow normalizationWorkflow; + + /** + * Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage + * are set as part of class constructor + * @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution + * @param searchPhaseContext {@link SearchContext} + */ + @Override + public void process( + final SearchPhaseResults searchPhaseResult, + final SearchPhaseContext searchPhaseContext + ) { + if (shouldSkipProcessor(searchPhaseResult)) { + log.debug("Query results are not compatible with RRF processor"); + return; + } + List querySearchResults = getQueryPhaseSearchResults(searchPhaseResult); + Optional fetchSearchResult = getFetchSearchResults(searchPhaseResult); + + // make data transfer object to pass in, execute will get object with 4 or 5 fields, depending + // on coming from NormalizationProcessor or RRFProcessor + NormalizationExecuteDTO normalizationExecuteDTO = NormalizationExecuteDTO.builder() + .querySearchResults(querySearchResults) + .fetchSearchResultOptional(fetchSearchResult) + .normalizationTechnique(normalizationTechnique) + .combinationTechnique(combinationTechnique) + .build(); + normalizationWorkflow.execute(normalizationExecuteDTO); + } + + @Override + public SearchPhaseName getBeforePhase() { + return SearchPhaseName.QUERY; + } + + @Override + public SearchPhaseName getAfterPhase() { + return SearchPhaseName.FETCH; + } + + @Override + public String getType() { + return TYPE; + } + + @Override + public boolean isIgnoreFailure() { + return false; + } + + private boolean shouldSkipProcessor(SearchPhaseResults searchPhaseResult) { + if (Objects.isNull(searchPhaseResult) || !(searchPhaseResult instanceof QueryPhaseResultConsumer queryPhaseResultConsumer)) { + return true; + } + + return queryPhaseResultConsumer.getAtomicArray().asList().stream().filter(Objects::nonNull).noneMatch(this::isHybridQuery); + } + + /** + * Return true if results are from hybrid query. + * @param searchPhaseResult + * @return true if results are from hybrid query + */ + private boolean isHybridQuery(final SearchPhaseResult searchPhaseResult) { + // check for delimiter at the end of the score docs. + return Objects.nonNull(searchPhaseResult.queryResult()) + && Objects.nonNull(searchPhaseResult.queryResult().topDocs()) + && Objects.nonNull(searchPhaseResult.queryResult().topDocs().topDocs.scoreDocs) + && searchPhaseResult.queryResult().topDocs().topDocs.scoreDocs.length > 0 + && isHybridQueryStartStopElement(searchPhaseResult.queryResult().topDocs().topDocs.scoreDocs[0]); + } + + private List getQueryPhaseSearchResults( + final SearchPhaseResults results + ) { + return results.getAtomicArray() + .asList() + .stream() + .map(result -> result == null ? null : result.queryResult()) + .collect(Collectors.toList()); + } + + private Optional getFetchSearchResults( + final SearchPhaseResults searchPhaseResults + ) { + Optional optionalFirstSearchPhaseResult = searchPhaseResults.getAtomicArray().asList().stream().findFirst(); + return optionalFirstSearchPhaseResult.map(SearchPhaseResult::fetchResult); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/RRFScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/RRFScoreCombinationTechnique.java new file mode 100644 index 000000000..befe14dda --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/RRFScoreCombinationTechnique.java @@ -0,0 +1,32 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.combination; + +import lombok.ToString; +import lombok.extern.log4j.Log4j2; + +import java.util.Map; + +@Log4j2 +/** + * Abstracts combination of scores based on reciprocal rank fusion algorithm + */ +@ToString(onlyExplicitlyIncluded = true) +public class RRFScoreCombinationTechnique implements ScoreCombinationTechnique { + @ToString.Include + public static final String TECHNIQUE_NAME = "rrf"; + + // Not currently using weights for RRF, no need to modify or verify these params + public RRFScoreCombinationTechnique(final Map params, final ScoreCombinationUtil combinationUtil) {} + + @Override + public float combine(final float[] scores) { + float sumScores = 0.0f; + for (float score : scores) { + sumScores += score; + } + return sumScores; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactory.java index 23d8e01be..1e560342a 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactory.java @@ -25,7 +25,9 @@ public class ScoreCombinationFactory { HarmonicMeanScoreCombinationTechnique.TECHNIQUE_NAME, params -> new HarmonicMeanScoreCombinationTechnique(params, scoreCombinationUtil), GeometricMeanScoreCombinationTechnique.TECHNIQUE_NAME, - params -> new GeometricMeanScoreCombinationTechnique(params, scoreCombinationUtil) + params -> new GeometricMeanScoreCombinationTechnique(params, scoreCombinationUtil), + RRFScoreCombinationTechnique.TECHNIQUE_NAME, + params -> new RRFScoreCombinationTechnique(params, scoreCombinationUtil) ); /** diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtil.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtil.java index a915057df..b7a07395f 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtil.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtil.java @@ -25,7 +25,7 @@ @Log4j2 class ScoreCombinationUtil { private static final String PARAM_NAME_WEIGHTS = "weights"; - private static final float DELTA_FOR_SCORE_ASSERTION = 0.01f; + private static final float DELTA_FOR_WEIGHTS_ASSERTION = 0.01f; /** * Get collection of weights based on user provided config @@ -117,7 +117,7 @@ protected void validateIfWeightsMatchScores(final float[] scores, final List weightsList) { - boolean isOutOfRange = weightsList.stream().anyMatch(weight -> !Range.between(0.0f, 1.0f).contains(weight)); + boolean isOutOfRange = weightsList.stream().anyMatch(weight -> !Range.of(0.0f, 1.0f).contains(weight)); if (isOutOfRange) { throw new IllegalArgumentException( String.format( @@ -128,7 +128,7 @@ private void validateWeights(final List weightsList) { ); } float sumOfWeights = weightsList.stream().reduce(0.0f, Float::sum); - if (!DoubleMath.fuzzyEquals(1.0f, sumOfWeights, DELTA_FOR_SCORE_ASSERTION)) { + if (!DoubleMath.fuzzyEquals(1.0f, sumOfWeights, DELTA_FOR_WEIGHTS_ASSERTION)) { throw new IllegalArgumentException( String.format( Locale.ROOT, diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/RRFProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/RRFProcessorFactory.java new file mode 100644 index 000000000..fa4f39942 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/RRFProcessorFactory.java @@ -0,0 +1,79 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.factory; + +import java.util.Map; +import java.util.Objects; + +import org.opensearch.neuralsearch.processor.RRFProcessor; +import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow; +import org.opensearch.neuralsearch.processor.combination.RRFScoreCombinationTechnique; +import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory; +import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; +import org.opensearch.neuralsearch.processor.normalization.RRFNormalizationTechnique; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; +import org.opensearch.search.pipeline.Processor; +import org.opensearch.search.pipeline.SearchPhaseResultsProcessor; + +import lombok.AllArgsConstructor; +import lombok.extern.log4j.Log4j2; + +import static org.opensearch.ingest.ConfigurationUtils.readOptionalMap; +import static org.opensearch.ingest.ConfigurationUtils.readStringProperty; + +/** + * Factory class to instantiate RRF processor based on user provided input. + */ +@AllArgsConstructor +@Log4j2 +public class RRFProcessorFactory implements Processor.Factory { + public static final String COMBINATION_CLAUSE = "combination"; + public static final String TECHNIQUE = "technique"; + public static final String PARAMETERS = "parameters"; + + private final NormalizationProcessorWorkflow normalizationProcessorWorkflow; + private ScoreNormalizationFactory scoreNormalizationFactory; + private ScoreCombinationFactory scoreCombinationFactory; + + @Override + public SearchPhaseResultsProcessor create( + final Map> processorFactories, + final String tag, + final String description, + final boolean ignoreFailure, + final Map config, + final Processor.PipelineContext pipelineContext + ) throws Exception { + // assign defaults + ScoreNormalizationTechnique normalizationTechnique = scoreNormalizationFactory.createNormalization( + RRFNormalizationTechnique.TECHNIQUE_NAME + ); + ScoreCombinationTechnique scoreCombinationTechnique = scoreCombinationFactory.createCombination( + RRFScoreCombinationTechnique.TECHNIQUE_NAME + ); + Map combinationClause = readOptionalMap(RRFProcessor.TYPE, tag, config, COMBINATION_CLAUSE); + if (Objects.nonNull(combinationClause)) { + String combinationTechnique = readStringProperty( + RRFProcessor.TYPE, + tag, + combinationClause, + TECHNIQUE, + RRFScoreCombinationTechnique.TECHNIQUE_NAME + ); + // check for optional combination params + Map params = readOptionalMap(RRFProcessor.TYPE, tag, combinationClause, PARAMETERS); + normalizationTechnique = scoreNormalizationFactory.createNormalization(RRFNormalizationTechnique.TECHNIQUE_NAME, params); + scoreCombinationTechnique = scoreCombinationFactory.createCombination(combinationTechnique); + } + log.info( + "Creating search phase results processor of type [{}] with normalization [{}] and combination [{}]", + RRFProcessor.TYPE, + normalizationTechnique, + scoreCombinationTechnique + ); + return new RRFProcessor(tag, description, normalizationTechnique, scoreCombinationTechnique, normalizationProcessorWorkflow); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java index 2bb6bbed7..4acaf9626 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java @@ -11,6 +11,7 @@ import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; import org.opensearch.neuralsearch.processor.CompoundTopDocs; +import org.opensearch.neuralsearch.processor.NormalizeScoresDTO; import lombok.ToString; @@ -31,7 +32,8 @@ public class L2ScoreNormalizationTechnique implements ScoreNormalizationTechniqu * - iterate over each result and update score as per formula above where "score" is raw score returned by Hybrid query */ @Override - public void normalize(final List queryTopDocs) { + public void normalize(final NormalizeScoresDTO normalizeScoresDTO) { + List queryTopDocs = normalizeScoresDTO.getQueryTopDocs(); // get l2 norms for each sub-query List normsPerSubquery = getL2Norm(queryTopDocs); diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java index 4fdf3c0a6..dcaae402e 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java @@ -11,10 +11,12 @@ import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; import org.opensearch.neuralsearch.processor.CompoundTopDocs; +import org.opensearch.neuralsearch.processor.NormalizeScoresDTO; import com.google.common.primitives.Floats; import lombok.ToString; +import org.opensearch.neuralsearch.processor.NormalizeScoresDTO; /** * Abstracts normalization of scores based on min-max method @@ -34,7 +36,8 @@ public class MinMaxScoreNormalizationTechnique implements ScoreNormalizationTech * - iterate over each result and update score as per formula above where "score" is raw score returned by Hybrid query */ @Override - public void normalize(final List queryTopDocs) { + public void normalize(final NormalizeScoresDTO normalizeScoresDTO) { + final List queryTopDocs = normalizeScoresDTO.getQueryTopDocs(); int numOfSubqueries = queryTopDocs.stream() .filter(Objects::nonNull) .filter(topDocs -> topDocs.getTopDocs().size() > 0) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechnique.java new file mode 100644 index 000000000..16ef83d05 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechnique.java @@ -0,0 +1,106 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.normalization; + +import java.math.BigDecimal; +import java.math.RoundingMode; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Locale; +import java.util.Set; + +import org.apache.commons.lang3.Range; +import org.apache.commons.lang3.math.NumberUtils; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.opensearch.neuralsearch.processor.CompoundTopDocs; + +import lombok.ToString; +import org.opensearch.neuralsearch.processor.NormalizeScoresDTO; + +/** + * Abstracts calculation of rank scores for each document returned as part of + * reciprocal rank fusion. Rank scores are summed across subqueries in combination classes. + */ +@ToString(onlyExplicitlyIncluded = true) +public class RRFNormalizationTechnique implements ScoreNormalizationTechnique { + @ToString.Include + public static final String TECHNIQUE_NAME = "rrf"; + public static final int DEFAULT_RANK_CONSTANT = 60; + public static final String PARAM_NAME_RANK_CONSTANT = "rank_constant"; + private static final Set SUPPORTED_PARAMS = Set.of(PARAM_NAME_RANK_CONSTANT); + private static final int MIN_RANK_CONSTANT = 1; + private static final int MAX_RANK_CONSTANT = 10_000; + private static final Range RANK_CONSTANT_RANGE = Range.of(MIN_RANK_CONSTANT, MAX_RANK_CONSTANT); + @ToString.Include + private final int rankConstant; + + public RRFNormalizationTechnique(final Map params, final ScoreNormalizationUtil scoreNormalizationUtil) { + scoreNormalizationUtil.validateParams(params, SUPPORTED_PARAMS); + rankConstant = getRankConstant(params); + } + + /** + * Reciprocal Rank Fusion normalization technique + * @param normalizeScoresDTO is a data transfer object that contains queryTopDocs + * original query results from multiple shards and multiple sub-queries, ScoreNormalizationTechnique, + * and nullable rankConstant, which has a default value of 60 if not specified by user + * algorithm as follows, where document_n_score is the new score for each document in queryTopDocs + * and subquery_result_rank is the position in the array of documents returned for each subquery + * (j + 1 is used to adjust for 0 indexing) + * document_n_score = 1 / (rankConstant + subquery_result_rank) + * document scores are summed in combination step + */ + @Override + public void normalize(final NormalizeScoresDTO normalizeScoresDTO) { + final List queryTopDocs = normalizeScoresDTO.getQueryTopDocs(); + for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { + if (Objects.isNull(compoundQueryTopDocs)) { + continue; + } + List topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs(); + for (TopDocs topDocs : topDocsPerSubQuery) { + int docsCountPerSubQuery = topDocs.scoreDocs.length; + ScoreDoc[] scoreDocs = topDocs.scoreDocs; + for (int j = 0; j < docsCountPerSubQuery; j++) { + // using big decimal approach to minimize error caused by floating point ops + // score = 1.f / (float) (rankConstant + j + 1)) + scoreDocs[j].score = BigDecimal.ONE.divide(BigDecimal.valueOf(rankConstant + j + 1), 10, RoundingMode.HALF_UP) + .floatValue(); + } + } + } + } + + private int getRankConstant(final Map params) { + if (Objects.isNull(params) || !params.containsKey(PARAM_NAME_RANK_CONSTANT)) { + return DEFAULT_RANK_CONSTANT; + } + int rankConstant = getParamAsInteger(params, PARAM_NAME_RANK_CONSTANT); + validateRankConstant(rankConstant); + return rankConstant; + } + + private void validateRankConstant(final int rankConstant) { + if (!RANK_CONSTANT_RANGE.contains(rankConstant)) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "rank constant must be in the interval between 1 and 10000, submitted rank constant: %d", + rankConstant + ) + ); + } + } + + public static int getParamAsInteger(final Map parameters, final String fieldName) { + try { + return NumberUtils.createInteger(String.valueOf(parameters.get(fieldName))); + } catch (NumberFormatException e) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "parameter [%s] must be an integer", fieldName)); + } + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactory.java index ca6ad20d6..7c62893a5 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactory.java @@ -6,19 +6,24 @@ import java.util.Map; import java.util.Optional; +import java.util.function.Function; /** * Abstracts creation of exact score normalization method based on technique name */ public class ScoreNormalizationFactory { + private static final ScoreNormalizationUtil scoreNormalizationUtil = new ScoreNormalizationUtil(); + public static final ScoreNormalizationTechnique DEFAULT_METHOD = new MinMaxScoreNormalizationTechnique(); - private final Map scoreNormalizationMethodsMap = Map.of( + private final Map, ScoreNormalizationTechnique>> scoreNormalizationMethodsMap = Map.of( MinMaxScoreNormalizationTechnique.TECHNIQUE_NAME, - new MinMaxScoreNormalizationTechnique(), + params -> new MinMaxScoreNormalizationTechnique(), L2ScoreNormalizationTechnique.TECHNIQUE_NAME, - new L2ScoreNormalizationTechnique() + params -> new L2ScoreNormalizationTechnique(), + RRFNormalizationTechnique.TECHNIQUE_NAME, + params -> new RRFNormalizationTechnique(params, scoreNormalizationUtil) ); /** @@ -27,7 +32,12 @@ public class ScoreNormalizationFactory { * @return instance of ScoreNormalizationMethod for technique name */ public ScoreNormalizationTechnique createNormalization(final String technique) { + return createNormalization(technique, Map.of()); + } + + public ScoreNormalizationTechnique createNormalization(final String technique, final Map params) { return Optional.ofNullable(scoreNormalizationMethodsMap.get(technique)) - .orElseThrow(() -> new IllegalArgumentException("provided normalization technique is not supported")); + .orElseThrow(() -> new IllegalArgumentException("provided normalization technique is not supported")) + .apply(params); } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationTechnique.java index 0b784c678..f8190a728 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationTechnique.java @@ -4,9 +4,7 @@ */ package org.opensearch.neuralsearch.processor.normalization; -import java.util.List; - -import org.opensearch.neuralsearch.processor.CompoundTopDocs; +import org.opensearch.neuralsearch.processor.NormalizeScoresDTO; /** * Abstracts normalization of scores in query search results. @@ -14,8 +12,12 @@ public interface ScoreNormalizationTechnique { /** - * Performs score normalization based on input normalization technique. Mutates input object by updating normalized scores. - * @param queryTopDocs original query results from multiple shards and multiple sub-queries + * Performs score normalization based on input normalization technique. + * Mutates input object by updating normalized scores. + * @param normalizeScoresDTO is a data transfer object that contains queryTopDocs + * original query results from multiple shards and multiple sub-queries, ScoreNormalizationTechnique, + * and nullable rankConstant that is only used in RRF technique */ - void normalize(final List queryTopDocs); + void normalize(final NormalizeScoresDTO normalizeScoresDTO); + } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationUtil.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationUtil.java new file mode 100644 index 000000000..ad24b0aaa --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationUtil.java @@ -0,0 +1,57 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.normalization; + +import lombok.extern.log4j.Log4j2; + +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; + +/** + * Collection of utility methods for score combination technique classes + */ +@Log4j2 +class ScoreNormalizationUtil { + private static final String PARAM_NAME_WEIGHTS = "weights"; + private static final float DELTA_FOR_SCORE_ASSERTION = 0.01f; + + /** + * Validate config parameters for this technique + * @param actualParams map of parameters in form of name-value + * @param supportedParams collection of parameters that we should validate against, typically that's what is supported by exact technique + */ + public void validateParams(final Map actualParams, final Set supportedParams) { + if (Objects.isNull(actualParams) || actualParams.isEmpty()) { + return; + } + // check if only supported params are passed + Optional optionalNotSupportedParam = actualParams.keySet() + .stream() + .filter(paramName -> !supportedParams.contains(paramName)) + .findFirst(); + if (optionalNotSupportedParam.isPresent()) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "provided parameter for combination technique is not supported. supported parameters are [%s]", + String.join(",", supportedParams) + ) + ); + } + + // check param types + if (actualParams.keySet().stream().anyMatch(PARAM_NAME_WEIGHTS::equalsIgnoreCase)) { + if (!(actualParams.get(PARAM_NAME_WEIGHTS) instanceof List)) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "parameter [%s] must be a collection of numbers", PARAM_NAME_WEIGHTS) + ); + } + } + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizer.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizer.java index 263115f8f..2ce131bf9 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizer.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizer.java @@ -8,17 +8,22 @@ import java.util.Objects; import org.opensearch.neuralsearch.processor.CompoundTopDocs; +import org.opensearch.neuralsearch.processor.NormalizeScoresDTO; public class ScoreNormalizer { /** - * Performs score normalization based on input normalization technique. Mutates input object by updating normalized scores. - * @param queryTopDocs original query results from multiple shards and multiple sub-queries - * @param scoreNormalizationTechnique exact normalization technique that should be applied + * Performs score normalization based on input normalization technique. + * Mutates input object by updating normalized scores. + * @param normalizeScoresDTO used as data transfer object to pass in queryTopDocs, original query results + * from multiple shards and multiple sub-queries, scoreNormalizationTechnique exact normalization technique + * that should be applied, and nullable rankConstant that is only used in RRF technique */ - public void normalizeScores(final List queryTopDocs, final ScoreNormalizationTechnique scoreNormalizationTechnique) { + public void normalizeScores(final NormalizeScoresDTO normalizeScoresDTO) { + final List queryTopDocs = normalizeScoresDTO.getQueryTopDocs(); + final ScoreNormalizationTechnique scoreNormalizationTechnique = normalizeScoresDTO.getNormalizationTechnique(); if (canQueryResultsBeNormalized(queryTopDocs)) { - scoreNormalizationTechnique.normalize(queryTopDocs); + scoreNormalizationTechnique.normalize(normalizeScoresDTO); } } diff --git a/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java b/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java index 9a969e71b..a4ad9f2d4 100644 --- a/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java +++ b/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java @@ -27,8 +27,10 @@ import org.opensearch.neuralsearch.processor.NeuralQueryEnricherProcessor; import org.opensearch.neuralsearch.processor.NeuralSparseTwoPhaseProcessor; import org.opensearch.neuralsearch.processor.NormalizationProcessor; +import org.opensearch.neuralsearch.processor.RRFProcessor; import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor; import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory; +import org.opensearch.neuralsearch.processor.factory.RRFProcessorFactory; import org.opensearch.neuralsearch.processor.rerank.RerankProcessor; import org.opensearch.neuralsearch.query.HybridQueryBuilder; import org.opensearch.neuralsearch.query.NeuralQueryBuilder; @@ -143,12 +145,19 @@ public void testSearchPhaseResultsProcessors() { Map> searchPhaseResultsProcessors = plugin .getSearchPhaseResultsProcessors(searchParameters); assertNotNull(searchPhaseResultsProcessors); - assertEquals(1, searchPhaseResultsProcessors.size()); + assertEquals(2, searchPhaseResultsProcessors.size()); + // assert normalization processor conditions assertTrue(searchPhaseResultsProcessors.containsKey("normalization-processor")); org.opensearch.search.pipeline.Processor.Factory scoringProcessor = searchPhaseResultsProcessors.get( NormalizationProcessor.TYPE ); assertTrue(scoringProcessor instanceof NormalizationProcessorFactory); + // assert rrf processor conditions + assertTrue(searchPhaseResultsProcessors.containsKey("score-ranker-processor")); + org.opensearch.search.pipeline.Processor.Factory rankingProcessor = searchPhaseResultsProcessors.get( + RRFProcessor.TYPE + ); + assertTrue(rankingProcessor instanceof RRFProcessorFactory); } public void testGetSettings() { diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java index e93c9b9ec..4b34f7fe1 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java @@ -271,8 +271,7 @@ public void testEmptySearchResults_whenEmptySearchResults_thenDoNotExecuteWorkfl ); SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); normalizationProcessor.process(null, searchPhaseContext); - - verify(normalizationProcessorWorkflow, never()).execute(any(), any(), any(), any()); + verify(normalizationProcessorWorkflow, never()).execute(any()); } public void testNotHybridSearchResult_whenResultsNotEmptyAndNotHybridSearchResult_thenDoNotExecuteWorkflow() { @@ -327,8 +326,8 @@ public void testNotHybridSearchResult_whenResultsNotEmptyAndNotHybridSearchResul SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); when(searchPhaseContext.getNumShards()).thenReturn(numberOfShards); normalizationProcessor.process(queryPhaseResultConsumer, searchPhaseContext); + verify(normalizationProcessorWorkflow, never()).execute(any()); - verify(normalizationProcessorWorkflow, never()).execute(any(), any(), any(), any()); } public void testResultTypes_whenQueryAndFetchPresentAndSizeSame_thenCallNormalization() { @@ -417,7 +416,7 @@ public void testResultTypes_whenQueryAndFetchPresentAndSizeSame_thenCallNormaliz .collect(Collectors.toList()); TestUtils.assertQueryResultScores(querySearchResults); - verify(normalizationProcessorWorkflow).execute(any(), any(), any(), any()); + verify(normalizationProcessorWorkflow).execute(any()); } public void testResultTypes_whenQueryAndFetchPresentButSizeDifferent_thenFail() { diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java index 59fb51563..9f7e7e785 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java @@ -71,13 +71,14 @@ public void testSearchResultTypes_whenResultsOfHybridSearch_thenDoNormalizationC querySearchResult.setShardIndex(shardId); querySearchResults.add(querySearchResult); } + NormalizationExecuteDTO normalizationExecuteDTO = NormalizationExecuteDTO.builder() + .querySearchResults(querySearchResults) + .fetchSearchResultOptional(Optional.empty()) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .combinationTechnique(ScoreCombinationFactory.DEFAULT_METHOD) + .build(); - normalizationProcessorWorkflow.execute( - querySearchResults, - Optional.empty(), - ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD - ); + normalizationProcessorWorkflow.execute(normalizationExecuteDTO); TestUtils.assertQueryResultScores(querySearchResults); } @@ -114,12 +115,14 @@ public void testSearchResultTypes_whenNoMatches_thenReturnZeroResults() { querySearchResults.add(querySearchResult); } - normalizationProcessorWorkflow.execute( - querySearchResults, - Optional.empty(), - ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD - ); + NormalizationExecuteDTO normalizationExecuteDTO = NormalizationExecuteDTO.builder() + .querySearchResults(querySearchResults) + .fetchSearchResultOptional(Optional.empty()) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .combinationTechnique(ScoreCombinationFactory.DEFAULT_METHOD) + .build(); + + normalizationProcessorWorkflow.execute(normalizationExecuteDTO); TestUtils.assertQueryResultScoresWithNoMatches(querySearchResults); } @@ -173,12 +176,14 @@ public void testFetchResults_whenOneShardAndQueryAndFetchResultsPresent_thenDoNo SearchHits searchHits = new SearchHits(searchHitArray, new TotalHits(7, TotalHits.Relation.EQUAL_TO), 10); fetchSearchResult.hits(searchHits); - normalizationProcessorWorkflow.execute( - querySearchResults, - Optional.of(fetchSearchResult), - ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD - ); + NormalizationExecuteDTO normalizationExecuteDTO = NormalizationExecuteDTO.builder() + .querySearchResults(querySearchResults) + .fetchSearchResultOptional(Optional.of(fetchSearchResult)) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .combinationTechnique(ScoreCombinationFactory.DEFAULT_METHOD) + .build(); + + normalizationProcessorWorkflow.execute(normalizationExecuteDTO); TestUtils.assertQueryResultScores(querySearchResults); TestUtils.assertFetchResultScores(fetchSearchResult, 4); @@ -233,12 +238,14 @@ public void testFetchResults_whenOneShardAndMultipleNodes_thenDoNormalizationCom SearchHits searchHits = new SearchHits(searchHitArray, new TotalHits(7, TotalHits.Relation.EQUAL_TO), 10); fetchSearchResult.hits(searchHits); - normalizationProcessorWorkflow.execute( - querySearchResults, - Optional.of(fetchSearchResult), - ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD - ); + NormalizationExecuteDTO normalizationExecuteDTO = NormalizationExecuteDTO.builder() + .querySearchResults(querySearchResults) + .fetchSearchResultOptional(Optional.of(fetchSearchResult)) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .combinationTechnique(ScoreCombinationFactory.DEFAULT_METHOD) + .build(); + + normalizationProcessorWorkflow.execute(normalizationExecuteDTO); TestUtils.assertQueryResultScores(querySearchResults); TestUtils.assertFetchResultScores(fetchSearchResult, 4); @@ -284,16 +291,14 @@ public void testFetchResultsAndNoCache_whenOneShardAndMultipleNodesAndMismatchRe querySearchResults.add(querySearchResult); SearchHits searchHits = getSearchHits(); fetchSearchResult.hits(searchHits); + NormalizationExecuteDTO normalizationExecuteDTO = NormalizationExecuteDTO.builder() + .querySearchResults(querySearchResults) + .fetchSearchResultOptional(Optional.of(fetchSearchResult)) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .combinationTechnique(ScoreCombinationFactory.DEFAULT_METHOD) + .build(); - expectThrows( - IllegalStateException.class, - () -> normalizationProcessorWorkflow.execute( - querySearchResults, - Optional.of(fetchSearchResult), - ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD - ) - ); + expectThrows(IllegalStateException.class, () -> normalizationProcessorWorkflow.execute(normalizationExecuteDTO)); } public void testFetchResultsAndCache_whenOneShardAndMultipleNodesAndMismatchResults_thenSuccessful() { @@ -336,13 +341,14 @@ public void testFetchResultsAndCache_whenOneShardAndMultipleNodesAndMismatchResu querySearchResults.add(querySearchResult); SearchHits searchHits = getSearchHits(); fetchSearchResult.hits(searchHits); + NormalizationExecuteDTO normalizationExecuteDTO = NormalizationExecuteDTO.builder() + .querySearchResults(querySearchResults) + .fetchSearchResultOptional(Optional.of(fetchSearchResult)) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .combinationTechnique(ScoreCombinationFactory.DEFAULT_METHOD) + .build(); - normalizationProcessorWorkflow.execute( - querySearchResults, - Optional.of(fetchSearchResult), - ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD - ); + normalizationProcessorWorkflow.execute(normalizationExecuteDTO); TestUtils.assertQueryResultScores(querySearchResults); TestUtils.assertFetchResultScores(fetchSearchResult, 4); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationTechniqueTests.java index 67abd552f..9f0be0300 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationTechniqueTests.java @@ -12,6 +12,7 @@ import org.apache.lucene.search.TotalHits; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer; + import org.opensearch.test.OpenSearchTestCase; import lombok.SneakyThrows; @@ -20,7 +21,11 @@ public class ScoreNormalizationTechniqueTests extends OpenSearchTestCase { public void testEmptyResults_whenEmptyResultsAndDefaultMethod_thenNoProcessing() { ScoreNormalizer scoreNormalizationMethod = new ScoreNormalizer(); - scoreNormalizationMethod.normalizeScores(List.of(), ScoreNormalizationFactory.DEFAULT_METHOD); + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(List.of()) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .build(); + scoreNormalizationMethod.normalizeScores(normalizeScoresDTO); } @SneakyThrows @@ -33,7 +38,11 @@ public void testNormalization_whenOneSubqueryAndOneShardAndDefaultMethod_thenSco false ) ); - scoreNormalizationMethod.normalizeScores(queryTopDocs, ScoreNormalizationFactory.DEFAULT_METHOD); + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(queryTopDocs) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .build(); + scoreNormalizationMethod.normalizeScores(normalizeScoresDTO); assertNotNull(queryTopDocs); assertEquals(1, queryTopDocs.size()); CompoundTopDocs resultDoc = queryTopDocs.get(0); @@ -64,7 +73,11 @@ public void testNormalization_whenOneSubqueryMultipleHitsAndOneShardAndDefaultMe false ) ); - scoreNormalizationMethod.normalizeScores(queryTopDocs, ScoreNormalizationFactory.DEFAULT_METHOD); + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(queryTopDocs) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .build(); + scoreNormalizationMethod.normalizeScores(normalizeScoresDTO); assertNotNull(queryTopDocs); assertEquals(1, queryTopDocs.size()); CompoundTopDocs resultDoc = queryTopDocs.get(0); @@ -101,7 +114,11 @@ public void testNormalization_whenMultipleSubqueriesMultipleHitsAndOneShardAndDe false ) ); - scoreNormalizationMethod.normalizeScores(queryTopDocs, ScoreNormalizationFactory.DEFAULT_METHOD); + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(queryTopDocs) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .build(); + scoreNormalizationMethod.normalizeScores(normalizeScoresDTO); assertNotNull(queryTopDocs); assertEquals(1, queryTopDocs.size()); @@ -169,7 +186,11 @@ public void testNormalization_whenMultipleSubqueriesMultipleHitsMultipleShardsAn false ) ); - scoreNormalizationMethod.normalizeScores(queryTopDocs, ScoreNormalizationFactory.DEFAULT_METHOD); + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(queryTopDocs) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .build(); + scoreNormalizationMethod.normalizeScores(normalizeScoresDTO); assertNotNull(queryTopDocs); assertEquals(3, queryTopDocs.size()); // shard one diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/RRFScoreCombinationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/RRFScoreCombinationTechniqueTests.java new file mode 100644 index 000000000..daed466d3 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/RRFScoreCombinationTechniqueTests.java @@ -0,0 +1,35 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.combination; + +import java.util.List; +import java.util.Map; + +public class RRFScoreCombinationTechniqueTests extends BaseScoreCombinationTechniqueTests { + + private ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil(); + + public RRFScoreCombinationTechniqueTests() { + this.expectedScoreFunction = (scores, weights) -> RRF(scores, weights); + } + + public void testLogic_whenAllScoresPresentAndNoWeights_thenCorrectScores() { + ScoreCombinationTechnique technique = new RRFScoreCombinationTechnique(Map.of(), scoreCombinationUtil); + testLogic_whenAllScoresPresentAndNoWeights_thenCorrectScores(technique); + } + + public void testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores() { + ScoreCombinationTechnique technique = new RRFScoreCombinationTechnique(Map.of(), scoreCombinationUtil); + testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores(technique); + } + + private float RRF(List scores, List weights) { + float sumScores = 0.0f; + for (float score : scores) { + sumScores += score; + } + return sumScores; + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactoryTests.java index b36a6b492..5ca534dac 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactoryTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactoryTests.java @@ -34,6 +34,14 @@ public void testGeometricWeightedMean_whenCreatingByName_thenReturnCorrectInstan assertTrue(scoreCombinationTechnique instanceof GeometricMeanScoreCombinationTechnique); } + public void testRRF_whenCreatingByName_thenReturnCorrectInstance() { + ScoreCombinationFactory scoreCombinationFactory = new ScoreCombinationFactory(); + ScoreCombinationTechnique scoreCombinationTechnique = scoreCombinationFactory.createCombination("rrf"); + + assertNotNull(scoreCombinationTechnique); + assertTrue(scoreCombinationTechnique instanceof RRFScoreCombinationTechnique); + } + public void testUnsupportedTechnique_whenPassingInvalidName_thenFail() { ScoreCombinationFactory scoreCombinationFactory = new ScoreCombinationFactory(); IllegalArgumentException illegalArgumentException = expectThrows( diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtilTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreNormalizationUtilTests.java similarity index 97% rename from src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtilTests.java rename to src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreNormalizationUtilTests.java index 9e00e3833..009681116 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtilTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreNormalizationUtilTests.java @@ -12,7 +12,7 @@ import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; -public class ScoreCombinationUtilTests extends OpenSearchQueryTestCase { +public class ScoreNormalizationUtilTests extends OpenSearchQueryTestCase { public void testCombinationWeights_whenEmptyInputPassed_thenCreateEmptyWeightCollection() { ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil(); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/factory/RRFProcessorFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/factory/RRFProcessorFactoryTests.java new file mode 100644 index 000000000..3097402a0 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/factory/RRFProcessorFactoryTests.java @@ -0,0 +1,214 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.factory; + +import lombok.SneakyThrows; +import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow; +import org.opensearch.neuralsearch.processor.RRFProcessor; +import org.opensearch.neuralsearch.processor.combination.ArithmeticMeanScoreCombinationTechnique; +import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory; +import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer; +import org.opensearch.search.pipeline.Processor; +import org.opensearch.search.pipeline.SearchPhaseResultsProcessor; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.HashMap; +import java.util.Map; + +import static org.mockito.Mockito.mock; + +import static org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory.NORMALIZATION_CLAUSE; +import static org.opensearch.neuralsearch.processor.factory.RRFProcessorFactory.PARAMETERS; +import static org.opensearch.neuralsearch.processor.factory.RRFProcessorFactory.COMBINATION_CLAUSE; +import static org.opensearch.neuralsearch.processor.factory.RRFProcessorFactory.TECHNIQUE; + +public class RRFProcessorFactoryTests extends OpenSearchTestCase { + + @SneakyThrows + public void testDefaults_whenNoValuesPassed_thenSuccessful() { + RRFProcessorFactory rrfProcessorFactory = new RRFProcessorFactory( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()), + new ScoreNormalizationFactory(), + new ScoreCombinationFactory() + ); + final Map> processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + Map config = new HashMap<>(); + Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); + SearchPhaseResultsProcessor searchPhaseResultsProcessor = rrfProcessorFactory.create( + processorFactories, + tag, + description, + ignoreFailure, + config, + pipelineContext + ); + assertRRFProcessor(searchPhaseResultsProcessor); + } + + @SneakyThrows + public void testCombinationParams_whenValidValues_thenSuccessful() { + RRFProcessorFactory rrfProcessorFactory = new RRFProcessorFactory( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()), + new ScoreNormalizationFactory(), + new ScoreCombinationFactory() + ); + final Map> processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + Map config = new HashMap<>(); + config.put(COMBINATION_CLAUSE, new HashMap<>(Map.of(TECHNIQUE, "rrf", PARAMETERS, new HashMap<>(Map.of("rank_constant", 100))))); + Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); + SearchPhaseResultsProcessor searchPhaseResultsProcessor = rrfProcessorFactory.create( + processorFactories, + tag, + description, + ignoreFailure, + config, + pipelineContext + ); + assertRRFProcessor(searchPhaseResultsProcessor); + } + + @SneakyThrows + public void testInvalidCombinationParams_whenRankIsNegative_thenFail() { + RRFProcessorFactory rrfProcessorFactory = new RRFProcessorFactory( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()), + new ScoreNormalizationFactory(), + new ScoreCombinationFactory() + ); + final Map> processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + + Map config = new HashMap<>(); + config.put(COMBINATION_CLAUSE, new HashMap<>(Map.of(TECHNIQUE, "rrf", PARAMETERS, new HashMap<>(Map.of("rank_constant", -1))))); + Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> rrfProcessorFactory.create(processorFactories, tag, description, ignoreFailure, config, pipelineContext) + ); + assertTrue( + exception.getMessage().contains("rank constant must be in the interval between 1 and 10000, submitted rank constant: -1") + ); + } + + @SneakyThrows + public void testInvalidCombinationParams_whenRankIsTooLarge_thenFail() { + RRFProcessorFactory rrfProcessorFactory = new RRFProcessorFactory( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()), + new ScoreNormalizationFactory(), + new ScoreCombinationFactory() + ); + final Map> processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + + Map config = new HashMap<>(); + config.put(COMBINATION_CLAUSE, new HashMap<>(Map.of(TECHNIQUE, "rrf", PARAMETERS, new HashMap<>(Map.of("rank_constant", 50_000))))); + Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> rrfProcessorFactory.create(processorFactories, tag, description, ignoreFailure, config, pipelineContext) + ); + assertTrue( + exception.getMessage().contains("rank constant must be in the interval between 1 and 10000, submitted rank constant: 50000") + ); + } + + @SneakyThrows + public void testInvalidCombinationParams_whenRankIsNotNumeric_thenFail() { + RRFProcessorFactory rrfProcessorFactory = new RRFProcessorFactory( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()), + new ScoreNormalizationFactory(), + new ScoreCombinationFactory() + ); + final Map> processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + + Map config = new HashMap<>(); + config.put( + COMBINATION_CLAUSE, + new HashMap<>(Map.of(TECHNIQUE, "rrf", PARAMETERS, new HashMap<>(Map.of("rank_constant", "string")))) + ); + Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> rrfProcessorFactory.create(processorFactories, tag, description, ignoreFailure, config, pipelineContext) + ); + assertTrue(exception.getMessage().contains("parameter [rank_constant] must be an integer")); + } + + @SneakyThrows + public void testInvalidCombinationName_whenUnsupportedFunction_thenFail() { + RRFProcessorFactory rrfProcessorFactory = new RRFProcessorFactory( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()), + new ScoreNormalizationFactory(), + new ScoreCombinationFactory() + ); + final Map> processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + + Map config = new HashMap<>(); + config.put( + COMBINATION_CLAUSE, + new HashMap<>(Map.of(TECHNIQUE, "my_function", PARAMETERS, new HashMap<>(Map.of("rank_constant", 100)))) + ); + Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> rrfProcessorFactory.create(processorFactories, tag, description, ignoreFailure, config, pipelineContext) + ); + assertTrue(exception.getMessage().contains("provided combination technique is not supported")); + } + + @SneakyThrows + public void testInvalidTechniqueType_whenPassingNormalization_thenSuccessful() { + RRFProcessorFactory rrfProcessorFactory = new RRFProcessorFactory( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()), + new ScoreNormalizationFactory(), + new ScoreCombinationFactory() + ); + final Map> processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + + Map config = new HashMap<>(); + config.put(COMBINATION_CLAUSE, new HashMap<>(Map.of(TECHNIQUE, "rrf", PARAMETERS, new HashMap<>(Map.of("rank_constant", 100))))); + config.put( + NORMALIZATION_CLAUSE, + new HashMap<>(Map.of(TECHNIQUE, ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME, PARAMETERS, new HashMap<>(Map.of()))) + ); + Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); + SearchPhaseResultsProcessor searchPhaseResultsProcessor = rrfProcessorFactory.create( + processorFactories, + tag, + description, + ignoreFailure, + config, + pipelineContext + ); + assertRRFProcessor(searchPhaseResultsProcessor); + } + + private static void assertRRFProcessor(SearchPhaseResultsProcessor searchPhaseResultsProcessor) { + assertNotNull(searchPhaseResultsProcessor); + assertTrue(searchPhaseResultsProcessor instanceof RRFProcessor); + RRFProcessor rrfProcessor = (RRFProcessor) searchPhaseResultsProcessor; + assertEquals("score-ranker-processor", rrfProcessor.getType()); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechniqueTests.java index ba4bfee0d..29fdb735f 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechniqueTests.java @@ -12,9 +12,10 @@ import org.apache.lucene.search.TotalHits; import org.opensearch.neuralsearch.processor.CompoundTopDocs; import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; +import org.opensearch.neuralsearch.processor.NormalizeScoresDTO; /** - * Abstracts normalization of scores based on min-max method + * Abstracts normalization of scores based on L2 method */ public class L2ScoreNormalizationTechniqueTests extends OpenSearchQueryTestCase { private static final float DELTA_FOR_ASSERTION = 0.0001f; @@ -34,7 +35,11 @@ public void testNormalization_whenResultFromOneShardOneSubQuery_thenSuccessful() false ) ); - normalizationTechnique.normalize(compoundTopDocs); + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(compoundTopDocs) + .normalizationTechnique(normalizationTechnique) + .build(); + normalizationTechnique.normalize(normalizeScoresDTO); CompoundTopDocs expectedCompoundDocs = new CompoundTopDocs( new TotalHits(2, TotalHits.Relation.EQUAL_TO), @@ -81,7 +86,11 @@ public void testNormalization_whenResultFromOneShardMultipleSubQueries_thenSucce false ) ); - normalizationTechnique.normalize(compoundTopDocs); + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(compoundTopDocs) + .normalizationTechnique(normalizationTechnique) + .build(); + normalizationTechnique.normalize(normalizeScoresDTO); CompoundTopDocs expectedCompoundDocs = new CompoundTopDocs( new TotalHits(3, TotalHits.Relation.EQUAL_TO), @@ -155,7 +164,11 @@ public void testNormalization_whenResultFromMultipleShardsMultipleSubQueries_the false ) ); - normalizationTechnique.normalize(compoundTopDocs); + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(compoundTopDocs) + .normalizationTechnique(normalizationTechnique) + .build(); + normalizationTechnique.normalize(normalizeScoresDTO); CompoundTopDocs expectedCompoundDocsShard1 = new CompoundTopDocs( new TotalHits(3, TotalHits.Relation.EQUAL_TO), diff --git a/src/test/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechniqueTests.java index d0445f0ca..239496355 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechniqueTests.java @@ -10,6 +10,7 @@ import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TotalHits; import org.opensearch.neuralsearch.processor.CompoundTopDocs; +import org.opensearch.neuralsearch.processor.NormalizeScoresDTO; import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; /** @@ -32,7 +33,11 @@ public void testNormalization_whenResultFromOneShardOneSubQuery_thenSuccessful() false ) ); - normalizationTechnique.normalize(compoundTopDocs); + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(compoundTopDocs) + .normalizationTechnique(normalizationTechnique) + .build(); + normalizationTechnique.normalize(normalizeScoresDTO); CompoundTopDocs expectedCompoundDocs = new CompoundTopDocs( new TotalHits(2, TotalHits.Relation.EQUAL_TO), @@ -72,7 +77,11 @@ public void testNormalization_whenResultFromOneShardMultipleSubQueries_thenSucce false ) ); - normalizationTechnique.normalize(compoundTopDocs); + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(compoundTopDocs) + .normalizationTechnique(normalizationTechnique) + .build(); + normalizationTechnique.normalize(normalizeScoresDTO); CompoundTopDocs expectedCompoundDocs = new CompoundTopDocs( new TotalHits(3, TotalHits.Relation.EQUAL_TO), @@ -127,7 +136,11 @@ public void testNormalization_whenResultFromMultipleShardsMultipleSubQueries_the false ) ); - normalizationTechnique.normalize(compoundTopDocs); + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(compoundTopDocs) + .normalizationTechnique(normalizationTechnique) + .build(); + normalizationTechnique.normalize(normalizeScoresDTO); CompoundTopDocs expectedCompoundDocsShard1 = new CompoundTopDocs( new TotalHits(3, TotalHits.Relation.EQUAL_TO), diff --git a/src/test/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechniqueTests.java new file mode 100644 index 000000000..00ec13b73 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechniqueTests.java @@ -0,0 +1,232 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.normalization; + +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.opensearch.neuralsearch.processor.CompoundTopDocs; +import org.opensearch.neuralsearch.processor.NormalizeScoresDTO; +import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; + +import java.math.BigDecimal; +import java.math.RoundingMode; +import java.util.List; +import java.util.Map; + +/** + * Abstracts testing of normalization of scores based on RRF method + */ +public class RRFNormalizationTechniqueTests extends OpenSearchQueryTestCase { + static final int RANK_CONSTANT = 60; + private ScoreNormalizationUtil scoreNormalizationUtil = new ScoreNormalizationUtil(); + + public void testNormalization_whenResultFromOneShardOneSubQuery_thenSuccessful() { + RRFNormalizationTechnique normalizationTechnique = new RRFNormalizationTechnique(Map.of(), scoreNormalizationUtil); + float[] scores = { 0.5f, 0.2f }; + List compoundTopDocs = List.of( + new CompoundTopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, scores[0]), new ScoreDoc(4, scores[1]) } + ) + ), + false + ) + ); + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(compoundTopDocs) + .normalizationTechnique(normalizationTechnique) + .build(); + normalizationTechnique.normalize(normalizeScoresDTO); + + CompoundTopDocs expectedCompoundDocs = new CompoundTopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, rrfNorm(0)), new ScoreDoc(4, rrfNorm(1)) } + ) + ), + false + ); + assertNotNull(compoundTopDocs); + assertEquals(1, compoundTopDocs.size()); + assertNotNull(compoundTopDocs.get(0).getTopDocs()); + assertCompoundTopDocs( + new TopDocs(expectedCompoundDocs.getTotalHits(), expectedCompoundDocs.getScoreDocs().toArray(new ScoreDoc[0])), + compoundTopDocs.get(0).getTopDocs().get(0) + ); + } + + public void testNormalization_whenResultFromOneShardMultipleSubQueries_thenSuccessful() { + RRFNormalizationTechnique normalizationTechnique = new RRFNormalizationTechnique(Map.of(), scoreNormalizationUtil); + float[] scoresQuery1 = { 0.5f, 0.2f }; + float[] scoresQuery2 = { 0.9f, 0.7f, 0.1f }; + List compoundTopDocs = List.of( + new CompoundTopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, scoresQuery1[0]), new ScoreDoc(4, scoresQuery1[1]) } + ), + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { + new ScoreDoc(3, scoresQuery2[0]), + new ScoreDoc(4, scoresQuery2[1]), + new ScoreDoc(2, scoresQuery2[2]) } + ) + ), + false + ) + ); + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(compoundTopDocs) + .normalizationTechnique(normalizationTechnique) + .build(); + normalizationTechnique.normalize(normalizeScoresDTO); + + CompoundTopDocs expectedCompoundDocs = new CompoundTopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, rrfNorm(0)), new ScoreDoc(4, rrfNorm(1)) } + ), + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(3, rrfNorm(0)), new ScoreDoc(4, rrfNorm(1)), new ScoreDoc(2, rrfNorm(2)) } + ) + ), + false + ); + assertNotNull(compoundTopDocs); + assertEquals(1, compoundTopDocs.size()); + assertNotNull(compoundTopDocs.get(0).getTopDocs()); + for (int i = 0; i < expectedCompoundDocs.getTopDocs().size(); i++) { + assertCompoundTopDocs(expectedCompoundDocs.getTopDocs().get(i), compoundTopDocs.get(0).getTopDocs().get(i)); + } + } + + public void testNormalization_whenResultFromMultipleShardsMultipleSubQueries_thenSuccessful() { + RRFNormalizationTechnique normalizationTechnique = new RRFNormalizationTechnique(Map.of(), scoreNormalizationUtil); + float[] scoresShard1Query1 = { 0.5f, 0.2f }; + float[] scoresShard1and2Query3 = { 0.9f, 0.7f, 0.1f, 0.8f, 0.7f, 0.6f, 0.5f }; + float[] scoresShard2Query2 = { 2.9f, 0.7f }; + List compoundTopDocs = List.of( + new CompoundTopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, scoresShard1Query1[0]), new ScoreDoc(4, scoresShard1Query1[1]) } + ), + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { + new ScoreDoc(3, scoresShard1and2Query3[0]), + new ScoreDoc(4, scoresShard1and2Query3[1]), + new ScoreDoc(2, scoresShard1and2Query3[2]) } + ) + ), + false + ), + new CompoundTopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(7, scoresShard2Query2[0]), new ScoreDoc(9, scoresShard2Query2[1]) } + ), + new TopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { + new ScoreDoc(3, scoresShard1and2Query3[3]), + new ScoreDoc(9, scoresShard1and2Query3[4]), + new ScoreDoc(10, scoresShard1and2Query3[5]), + new ScoreDoc(15, scoresShard1and2Query3[6]) } + ) + ), + false + ) + ); + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(compoundTopDocs) + .normalizationTechnique(normalizationTechnique) + .build(); + normalizationTechnique.normalize(normalizeScoresDTO); + + CompoundTopDocs expectedCompoundDocsShard1 = new CompoundTopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, rrfNorm(0)), new ScoreDoc(4, rrfNorm(1)) } + ), + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(3, rrfNorm(0)), new ScoreDoc(4, rrfNorm(1)), new ScoreDoc(2, rrfNorm(2)) } + ) + ), + false + ); + + CompoundTopDocs expectedCompoundDocsShard2 = new CompoundTopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(7, rrfNorm(0)), new ScoreDoc(9, rrfNorm(1)) } + ), + new TopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { + new ScoreDoc(3, rrfNorm(3)), + new ScoreDoc(9, rrfNorm(4)), + new ScoreDoc(10, rrfNorm(5)), + new ScoreDoc(15, rrfNorm(6)) } + ) + ), + false + ); + + assertNotNull(compoundTopDocs); + assertEquals(2, compoundTopDocs.size()); + assertNotNull(compoundTopDocs.get(0).getTopDocs()); + for (int i = 0; i < expectedCompoundDocsShard1.getTopDocs().size(); i++) { + assertCompoundTopDocs(expectedCompoundDocsShard1.getTopDocs().get(i), compoundTopDocs.get(0).getTopDocs().get(i)); + } + assertNotNull(compoundTopDocs.get(1).getTopDocs()); + for (int i = 0; i < expectedCompoundDocsShard2.getTopDocs().size(); i++) { + assertCompoundTopDocs(expectedCompoundDocsShard2.getTopDocs().get(i), compoundTopDocs.get(1).getTopDocs().get(i)); + } + } + + private float rrfNorm(int rank) { + // 1.0f / (float) (rank + RANK_CONSTANT + 1); + return BigDecimal.ONE.divide(BigDecimal.valueOf(rank + RANK_CONSTANT + 1), 10, RoundingMode.HALF_UP).floatValue(); + } + + private void assertCompoundTopDocs(TopDocs expected, TopDocs actual) { + assertEquals(expected.totalHits.value, actual.totalHits.value); + assertEquals(expected.totalHits.relation, actual.totalHits.relation); + assertEquals(expected.scoreDocs.length, actual.scoreDocs.length); + for (int i = 0; i < expected.scoreDocs.length; i++) { + assertEquals(expected.scoreDocs[i].score, actual.scoreDocs[i].score, DELTA_FOR_ASSERTION); + assertEquals(expected.scoreDocs[i].doc, actual.scoreDocs[i].doc); + assertEquals(expected.scoreDocs[i].shardIndex, actual.scoreDocs[i].shardIndex); + } + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactoryTests.java index d9dcd5540..cecdf8779 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactoryTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactoryTests.java @@ -26,6 +26,14 @@ public void testL2Norm_whenCreatingByName_thenReturnCorrectInstance() { assertTrue(scoreNormalizationTechnique instanceof L2ScoreNormalizationTechnique); } + public void testRRFNorm_whenCreatingByName_thenReturnCorrectInstance() { + ScoreNormalizationFactory scoreNormalizationFactory = new ScoreNormalizationFactory(); + ScoreNormalizationTechnique scoreNormalizationTechnique = scoreNormalizationFactory.createNormalization("rrf"); + + assertNotNull(scoreNormalizationTechnique); + assertTrue(scoreNormalizationTechnique instanceof RRFNormalizationTechnique); + } + public void testUnsupportedTechnique_whenPassingInvalidName_thenFail() { ScoreNormalizationFactory scoreNormalizationFactory = new ScoreNormalizationFactory(); IllegalArgumentException illegalArgumentException = expectThrows( diff --git a/src/test/java/org/opensearch/neuralsearch/query/OpenSearchQueryTestCase.java b/src/test/java/org/opensearch/neuralsearch/query/OpenSearchQueryTestCase.java index a1e8210e6..9c162ce11 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/OpenSearchQueryTestCase.java +++ b/src/test/java/org/opensearch/neuralsearch/query/OpenSearchQueryTestCase.java @@ -53,6 +53,8 @@ public abstract class OpenSearchQueryTestCase extends OpenSearchTestCase { + protected static final float DELTA_FOR_ASSERTION = 0.001f; + protected final MapperService createMapperService(Version version, XContentBuilder mapping) throws IOException { IndexMetadata meta = IndexMetadata.builder("index") .settings(Settings.builder().put("index.version.created", version)) diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java index 24ebebe5b..891ad7d28 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java @@ -81,7 +81,6 @@ public class HybridCollectorManagerTests extends OpenSearchQueryTestCase { private static final String TEST_DOC_TEXT3 = "We would like to welcome everyone"; private static final String QUERY1 = "hello"; private static final String QUERY2 = "hi"; - private static final float DELTA_FOR_ASSERTION = 0.001f; protected static final String QUERY3 = "everyone"; @SneakyThrows diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryScoreDocsMergerTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryScoreDocsMergerTests.java index 196014220..f91dae327 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryScoreDocsMergerTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryScoreDocsMergerTests.java @@ -21,8 +21,6 @@ public class HybridQueryScoreDocsMergerTests extends OpenSearchQueryTestCase { - private static final float DELTA_FOR_ASSERTION = 0.001f; - public void testIncorrectInput_whenScoreDocsAreNullOrNotEnoughElements_thenFail() { HybridQueryScoreDocsMerger scoreDocsMerger = new HybridQueryScoreDocsMerger<>(); TopDocsMerger topDocsMerger = new TopDocsMerger(null); diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/TopDocsMergerTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/TopDocsMergerTests.java index 9c2718687..2e064913f 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/TopDocsMergerTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/TopDocsMergerTests.java @@ -27,8 +27,6 @@ public class TopDocsMergerTests extends OpenSearchQueryTestCase { - private static final float DELTA_FOR_ASSERTION = 0.001f; - @SneakyThrows public void testMergeScoreDocs_whenBothTopDocsHasHits_thenSuccessful() { TopDocsMerger topDocsMerger = new TopDocsMerger(null); From 74c99c53dae22d013967edbdb9a1add3076e6318 Mon Sep 17 00:00:00 2001 From: Isaac Johnson <114550967+Johnsonisaacn@users.noreply.github.com> Date: Fri, 18 Oct 2024 09:44:07 -0700 Subject: [PATCH 02/10] Reciprocal Rank Fusion (RRF) normalization technique in hybrid query (#874) * initial commit of RRF Signed-off-by: Isaac Johnson Co-authored-by: Varun Jain Signed-off-by: Martin Gaievski --- .github/workflows/CI.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index f1b997db7..5bcc517bc 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -163,3 +163,4 @@ jobs: - name: Run build run: | ./gradlew precommit --parallel + From a957fd3925e7857a6b4c1900d719d10718bc35cc Mon Sep 17 00:00:00 2001 From: Peter Zhu Date: Thu, 21 Nov 2024 15:55:19 -0500 Subject: [PATCH 03/10] Switch codecov to v3 as it requires python3.11 instead of node20 (#999) Signed-off-by: Peter Zhu --- .github/workflows/CI.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 5bcc517bc..1fc751e90 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -59,7 +59,7 @@ jobs: - name: Upload Coverage Report if: ${{ !cancelled() }} - uses: codecov/codecov-action@v4 + uses: codecov/codecov-action@v3 with: token: ${{ secrets.CODECOV_TOKEN }} @@ -132,7 +132,7 @@ jobs: - name: Upload Coverage Report if: ${{ !cancelled() }} - uses: codecov/codecov-action@v4 + uses: codecov/codecov-action@v3 with: token: ${{ secrets.CODECOV_TOKEN }} From 0a0181fef3ff319febdab164bff22a4a5b527eb1 Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Thu, 14 Nov 2024 14:01:49 -0800 Subject: [PATCH 04/10] Add RRF integ test Signed-off-by: Ryan Bogan --- .../neuralsearch/processor/RRFSearchIT.java | 171 ++++++++++++++++++ 1 file changed, 171 insertions(+) create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/RRFSearchIT.java diff --git a/src/test/java/org/opensearch/neuralsearch/processor/RRFSearchIT.java b/src/test/java/org/opensearch/neuralsearch/processor/RRFSearchIT.java new file mode 100644 index 000000000..5cf0f0031 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/RRFSearchIT.java @@ -0,0 +1,171 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor; + +import com.google.common.collect.ImmutableList; +import lombok.SneakyThrows; +import org.apache.hc.core5.http.HttpHeaders; +import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.apache.hc.core5.http.message.BasicHeader; +import org.opensearch.client.Request; +import org.opensearch.client.Response; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.neuralsearch.BaseNeuralSearchIT; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Map; + +import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_USER_AGENT; +import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_SCORE_ASSERTION; + +public class RRFSearchIT extends BaseNeuralSearchIT { + + private int currentDoc = 1; + private static final String RRF_INDEX_NAME = "rrf-index"; + private static final String RRF_SEARCH_PIPELINE = "rrf-search-pipeline"; + + @SneakyThrows + public void testRRF() { + String modelId = prepareModel(); + String ingestPipelineName = "rrf-ingest-pipeline"; + createPipelineProcessor(modelId, ingestPipelineName, ProcessorType.TEXT_EMBEDDING); + Settings indexSettings = Settings.builder().put("index.knn", true).put("default_pipeline", ingestPipelineName).build(); + String indexMappings = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject("id") + .field("type", "text") + .endObject() + .startObject("passage_embedding") + .field("type", "knn_vector") + .field("dimension", "768") + .startObject("method") + .field("engine", "lucene") + .field("space_type", "l2") + .field("name", "hnsw") + .endObject() + .endObject() + .startObject("text") + .field("type", "text") + .endObject() + .endObject() + .endObject() + .toString(); + // Removes the {} around the string, since they are already included with createIndex + indexMappings = indexMappings.substring(1, indexMappings.length() - 1); + String indexName = "rrf-index"; + createIndex(indexName, indexSettings, indexMappings, null); + addRRFDocuments(); + createDefaultRRFSearchPipeline(); + + Map results = searchRRF(modelId); + Map hits = (Map) results.get("hits"); + ArrayList> hitsList = (ArrayList>) hits.get("hits"); + assertEquals(3, hitsList.size()); + assertEquals(0.016393442, (Double) hitsList.getFirst().get("_score"), DELTA_FOR_SCORE_ASSERTION); + assertEquals(0.016129032, (Double) hitsList.get(1).get("_score"), DELTA_FOR_SCORE_ASSERTION); + assertEquals(0.015873017, (Double) hitsList.getLast().get("_score"), DELTA_FOR_SCORE_ASSERTION); + } + + @SneakyThrows + private void addRRFDocuments() { + addRRFDocument( + "A West Virginia university women 's basketball team , officials , and a small gathering of fans are in a West Virginia arena .", + "4319130149.jpg" + ); + addRRFDocument("A wild animal races across an uncut field with a minimal amount of trees .", "1775029934.jpg"); + addRRFDocument( + "People line the stands which advertise Freemont 's orthopedics , a cowboy rides a light brown bucking bronco .", + "2664027527.jpg" + ); + addRRFDocument("A man who is riding a wild horse in the rodeo is very near to falling off .", "4427058951.jpg"); + addRRFDocument("A rodeo cowboy , wearing a cowboy hat , is being thrown off of a wild white horse .", "2691147709.jpg"); + } + + @SneakyThrows + private void addRRFDocument(String description, String imageText) { + addDocument(RRF_INDEX_NAME, String.valueOf(currentDoc++), "text", description, "image_text", imageText); + } + + @SneakyThrows + private void createDefaultRRFSearchPipeline() { + String requestBody = XContentFactory.jsonBuilder() + .startObject() + .field("description", "Post processor for hybrid search") + .startArray("phase_results_processors") + .startObject() + .startObject("score-ranker-processor") + .startObject("combination") + .field("technique", "rrf") + .startObject("parameters") + .field("rank_constant", 60) + .endObject() + .endObject() + .endObject() + .endObject() + .endArray() + .endObject() + .toString(); + + makeRequest( + client(), + "PUT", + String.format(LOCALE, "/_search/pipeline/%s", RRF_SEARCH_PIPELINE), + null, + toHttpEntity(String.format(LOCALE, requestBody)), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + } + + @SneakyThrows + private Map searchRRF(String modelId) { + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("_source") + .startArray("exclude") + .value("passage_embedding") + .endArray() + .endObject() + .startObject("query") + .startObject("hybrid") + .startArray("queries") + .startObject() + .startObject("match") + .startObject("text") + .field("query", "cowboy rodeo bronco") + .endObject() + .endObject() + .endObject() + .startObject() + .startObject("neural") + .startObject("passage_embedding") + .field("query_text", "wild west") + .field("model_id", modelId) + .field("k", 5) + .endObject() + .endObject() + .endObject() + .endArray() + .endObject() + .endObject() + .endObject(); + + Request request = new Request("GET", "/" + RRF_INDEX_NAME + "/_search?timeout=1000s&search_pipeline=" + RRF_SEARCH_PIPELINE); + logger.info("Sorting request " + builder); + request.setJsonEntity(builder.toString()); + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + + String responseBody = EntityUtils.toString(response.getEntity()); + logger.info("Response " + responseBody); + return XContentHelper.convertToMap(XContentType.JSON.xContent(), responseBody, false); + } +} From 85ac5af2efd6b6c7046664039fb46ef461b01ff1 Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Tue, 19 Nov 2024 12:33:55 -0800 Subject: [PATCH 05/10] Initial unit test implementation Signed-off-by: Ryan Bogan --- .../neuralsearch/processor/RRFProcessor.java | 15 +- .../processor/RRFProcessorTests.java | 230 ++++++++++++++++++ 2 files changed, 239 insertions(+), 6 deletions(-) create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorTests.java diff --git a/src/main/java/org/opensearch/neuralsearch/processor/RRFProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/RRFProcessor.java index 207af156c..a083fa0b7 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/RRFProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/RRFProcessor.java @@ -13,6 +13,7 @@ import java.util.Optional; import lombok.Getter; +import org.opensearch.ml.repackage.com.google.common.annotations.VisibleForTesting; import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; import org.opensearch.search.fetch.FetchSearchResult; @@ -98,7 +99,8 @@ public boolean isIgnoreFailure() { return false; } - private boolean shouldSkipProcessor(SearchPhaseResults searchPhaseResult) { + @VisibleForTesting + boolean shouldSkipProcessor(SearchPhaseResults searchPhaseResult) { if (Objects.isNull(searchPhaseResult) || !(searchPhaseResult instanceof QueryPhaseResultConsumer queryPhaseResultConsumer)) { return true; } @@ -111,7 +113,8 @@ private boolean shouldSkipProcessor(SearchPha * @param searchPhaseResult * @return true if results are from hybrid query */ - private boolean isHybridQuery(final SearchPhaseResult searchPhaseResult) { + @VisibleForTesting + boolean isHybridQuery(final SearchPhaseResult searchPhaseResult) { // check for delimiter at the end of the score docs. return Objects.nonNull(searchPhaseResult.queryResult()) && Objects.nonNull(searchPhaseResult.queryResult().topDocs()) @@ -120,9 +123,8 @@ private boolean isHybridQuery(final SearchPhaseResult searchPhaseResult) { && isHybridQueryStartStopElement(searchPhaseResult.queryResult().topDocs().topDocs.scoreDocs[0]); } - private List getQueryPhaseSearchResults( - final SearchPhaseResults results - ) { + @VisibleForTesting + List getQueryPhaseSearchResults(final SearchPhaseResults results) { return results.getAtomicArray() .asList() .stream() @@ -130,7 +132,8 @@ private List getQueryPhase .collect(Collectors.toList()); } - private Optional getFetchSearchResults( + @VisibleForTesting + Optional getFetchSearchResults( final SearchPhaseResults searchPhaseResults ) { Optional optionalFirstSearchPhaseResult = searchPhaseResults.getAtomicArray().asList().stream().findFirst(); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorTests.java new file mode 100644 index 000000000..70dcb7aee --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorTests.java @@ -0,0 +1,230 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor; + +import lombok.SneakyThrows; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.junit.Before; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.OriginalIndices; +import org.opensearch.action.search.QueryPhaseResultConsumer; +import org.opensearch.action.search.SearchPhaseContext; +import org.opensearch.action.search.SearchPhaseName; +import org.opensearch.action.search.SearchPhaseResults; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.support.IndicesOptions; +import org.opensearch.common.lucene.search.TopDocsAndMaxScore; +import org.opensearch.common.util.concurrent.AtomicArray; +import org.opensearch.core.common.Strings; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; +import org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil; +import org.opensearch.search.DocValueFormat; +import org.opensearch.search.SearchPhaseResult; +import org.opensearch.search.SearchShardTarget; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.fetch.FetchSearchResult; +import org.opensearch.search.internal.AliasFilter; +import org.opensearch.search.internal.ShardSearchContextId; +import org.opensearch.search.internal.ShardSearchRequest; +import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.List; +import java.util.Optional; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class RRFProcessorTests extends OpenSearchTestCase { + + @Mock + private ScoreNormalizationTechnique mockNormalizationTechnique; + @Mock + private ScoreCombinationTechnique mockCombinationTechnique; + @Mock + private NormalizationProcessorWorkflow mockNormalizationWorkflow; + @Mock + private SearchPhaseResults mockSearchPhaseResults; + @Mock + private SearchPhaseContext mockSearchPhaseContext; + @Mock + private QueryPhaseResultConsumer mockQueryPhaseResultConsumer; + + private RRFProcessor rrfProcessor; + + @Before + @SneakyThrows + public void setUp() { + super.setUp(); + MockitoAnnotations.openMocks(this); + rrfProcessor = new RRFProcessor( + "tag", + "description", + mockNormalizationTechnique, + mockCombinationTechnique, + mockNormalizationWorkflow + ); + } + + @SneakyThrows + public void testGetType() { + assertEquals("score-ranker-processor", rrfProcessor.getType()); + } + + @SneakyThrows + public void testGetBeforePhase() { + assertEquals(SearchPhaseName.QUERY, rrfProcessor.getBeforePhase()); + } + + @SneakyThrows + public void testGetAfterPhase() { + assertEquals(SearchPhaseName.FETCH, rrfProcessor.getAfterPhase()); + } + + @SneakyThrows + public void testIsIgnoreFailure() { + assertFalse(rrfProcessor.isIgnoreFailure()); + } + + @SneakyThrows + public void testProcessWithNullSearchPhaseResult() { + rrfProcessor.process(null, mockSearchPhaseContext); + verify(mockNormalizationWorkflow, never()).execute(any()); + } + + @SneakyThrows + public void testProcessWithNonQueryPhaseResultConsumer() { + rrfProcessor.process(mockSearchPhaseResults, mockSearchPhaseContext); + verify(mockNormalizationWorkflow, never()).execute(any()); + } + + @SneakyThrows + public void testProcessWithValidHybridInput() { + QuerySearchResult result = createQuerySearchResult(true); + AtomicArray atomicArray = new AtomicArray<>(1); + atomicArray.set(0, result); + + when(mockQueryPhaseResultConsumer.getAtomicArray()).thenReturn(atomicArray); + + rrfProcessor.process(mockQueryPhaseResultConsumer, mockSearchPhaseContext); + + verify(mockNormalizationWorkflow).execute(any(NormalizationExecuteDTO.class)); + } + + @SneakyThrows + public void testProcessWithValidNonHybridInput() { + QuerySearchResult result = createQuerySearchResult(false); + AtomicArray atomicArray = new AtomicArray<>(1); + atomicArray.set(0, result); + + when(mockQueryPhaseResultConsumer.getAtomicArray()).thenReturn(atomicArray); + + rrfProcessor.process(mockQueryPhaseResultConsumer, mockSearchPhaseContext); + + verify(mockNormalizationWorkflow, never()).execute(any(NormalizationExecuteDTO.class)); + } + + @SneakyThrows + public void testGetTag() { + assertEquals("tag", rrfProcessor.getTag()); + } + + @SneakyThrows + public void testGetDescription() { + assertEquals("description", rrfProcessor.getDescription()); + } + + @SneakyThrows + public void testShouldSkipProcessor() { + assertTrue(rrfProcessor.shouldSkipProcessor(null)); + assertTrue(rrfProcessor.shouldSkipProcessor(mockSearchPhaseResults)); + + AtomicArray atomicArray = new AtomicArray<>(1); + atomicArray.set(0, createQuerySearchResult(false)); + when(mockQueryPhaseResultConsumer.getAtomicArray()).thenReturn(atomicArray); + + assertTrue(rrfProcessor.shouldSkipProcessor(mockQueryPhaseResultConsumer)); + + atomicArray.set(0, createQuerySearchResult(true)); + assertFalse(rrfProcessor.shouldSkipProcessor(mockQueryPhaseResultConsumer)); + } + + @SneakyThrows + public void testGetQueryPhaseSearchResults() { + AtomicArray atomicArray = new AtomicArray<>(2); + atomicArray.set(0, createQuerySearchResult(true)); + atomicArray.set(1, createQuerySearchResult(false)); + when(mockQueryPhaseResultConsumer.getAtomicArray()).thenReturn(atomicArray); + + List results = rrfProcessor.getQueryPhaseSearchResults(mockQueryPhaseResultConsumer); + assertEquals(2, results.size()); + assertNotNull(results.get(0)); + assertNotNull(results.get(1)); + } + + @SneakyThrows + public void testGetFetchSearchResults() { + AtomicArray atomicArray = new AtomicArray<>(1); + atomicArray.set(0, createQuerySearchResult(true)); + when(mockQueryPhaseResultConsumer.getAtomicArray()).thenReturn(atomicArray); + + Optional result = rrfProcessor.getFetchSearchResults(mockQueryPhaseResultConsumer); + assertFalse(result.isPresent()); + } + + private QuerySearchResult createQuerySearchResult(boolean isHybrid) { + ShardId shardId = new ShardId("index", "uuid", 0); + OriginalIndices originalIndices = new OriginalIndices(new String[] { "index" }, IndicesOptions.strictExpandOpenAndForbidClosed()); + SearchRequest searchRequest = new SearchRequest("index"); + searchRequest.source(new SearchSourceBuilder()); + searchRequest.allowPartialSearchResults(true); + + int numberOfShards = 1; + AliasFilter aliasFilter = new AliasFilter(null, Strings.EMPTY_ARRAY); + float indexBoost = 1.0f; + long nowInMillis = System.currentTimeMillis(); + String clusterAlias = null; + String[] indexRoutings = Strings.EMPTY_ARRAY; + + ShardSearchRequest shardSearchRequest = new ShardSearchRequest( + originalIndices, + searchRequest, + shardId, + numberOfShards, + aliasFilter, + indexBoost, + nowInMillis, + clusterAlias, + indexRoutings + ); + + QuerySearchResult result = new QuerySearchResult( + new ShardSearchContextId("test", 1), + new SearchShardTarget("node1", shardId, clusterAlias, originalIndices), + shardSearchRequest + ); + result.from(0).size(10); + + ScoreDoc[] scoreDocs; + if (isHybrid) { + scoreDocs = new ScoreDoc[] { HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults(0) }; + } else { + scoreDocs = new ScoreDoc[] { new ScoreDoc(0, 1.0f) }; + } + + TopDocs topDocs = new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), scoreDocs); + TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(topDocs, 1.0f); + result.topDocs(topDocsAndMaxScore, new DocValueFormat[0]); + + return result; + } +} From 371d77217d91edb0495f7ab1e61f20dc5f2dd58a Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Thu, 21 Nov 2024 11:58:55 -0800 Subject: [PATCH 06/10] Address feedback on integration test Signed-off-by: Ryan Bogan --- .../{RRFSearchIT.java => RRFProcessorIT.java} | 25 ++++++++----------- 1 file changed, 10 insertions(+), 15 deletions(-) rename src/test/java/org/opensearch/neuralsearch/processor/{RRFSearchIT.java => RRFProcessorIT.java} (88%) diff --git a/src/test/java/org/opensearch/neuralsearch/processor/RRFSearchIT.java b/src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorIT.java similarity index 88% rename from src/test/java/org/opensearch/neuralsearch/processor/RRFSearchIT.java rename to src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorIT.java index 5cf0f0031..8379972a0 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/RRFSearchIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorIT.java @@ -26,14 +26,14 @@ import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_USER_AGENT; import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_SCORE_ASSERTION; -public class RRFSearchIT extends BaseNeuralSearchIT { +public class RRFProcessorIT extends BaseNeuralSearchIT { private int currentDoc = 1; private static final String RRF_INDEX_NAME = "rrf-index"; private static final String RRF_SEARCH_PIPELINE = "rrf-search-pipeline"; @SneakyThrows - public void testRRF() { + public void testRRF_whenValidInput_thenSucceed() { String modelId = prepareModel(); String ingestPipelineName = "rrf-ingest-pipeline"; createPipelineProcessor(modelId, ingestPipelineName, ProcessorType.TEXT_EMBEDDING); @@ -63,7 +63,7 @@ public void testRRF() { indexMappings = indexMappings.substring(1, indexMappings.length() - 1); String indexName = "rrf-index"; createIndex(indexName, indexSettings, indexMappings, null); - addRRFDocuments(); + addDocuments(); createDefaultRRFSearchPipeline(); Map results = searchRRF(modelId); @@ -76,22 +76,22 @@ public void testRRF() { } @SneakyThrows - private void addRRFDocuments() { - addRRFDocument( + private void addDocuments() { + addDocument( "A West Virginia university women 's basketball team , officials , and a small gathering of fans are in a West Virginia arena .", "4319130149.jpg" ); - addRRFDocument("A wild animal races across an uncut field with a minimal amount of trees .", "1775029934.jpg"); - addRRFDocument( + addDocument("A wild animal races across an uncut field with a minimal amount of trees .", "1775029934.jpg"); + addDocument( "People line the stands which advertise Freemont 's orthopedics , a cowboy rides a light brown bucking bronco .", "2664027527.jpg" ); - addRRFDocument("A man who is riding a wild horse in the rodeo is very near to falling off .", "4427058951.jpg"); - addRRFDocument("A rodeo cowboy , wearing a cowboy hat , is being thrown off of a wild white horse .", "2691147709.jpg"); + addDocument("A man who is riding a wild horse in the rodeo is very near to falling off .", "4427058951.jpg"); + addDocument("A rodeo cowboy , wearing a cowboy hat , is being thrown off of a wild white horse .", "2691147709.jpg"); } @SneakyThrows - private void addRRFDocument(String description, String imageText) { + private void addDocument(String description, String imageText) { addDocument(RRF_INDEX_NAME, String.valueOf(currentDoc++), "text", description, "image_text", imageText); } @@ -105,9 +105,6 @@ private void createDefaultRRFSearchPipeline() { .startObject("score-ranker-processor") .startObject("combination") .field("technique", "rrf") - .startObject("parameters") - .field("rank_constant", 60) - .endObject() .endObject() .endObject() .endObject() @@ -159,13 +156,11 @@ private Map searchRRF(String modelId) { .endObject(); Request request = new Request("GET", "/" + RRF_INDEX_NAME + "/_search?timeout=1000s&search_pipeline=" + RRF_SEARCH_PIPELINE); - logger.info("Sorting request " + builder); request.setJsonEntity(builder.toString()); Response response = client().performRequest(request); assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); String responseBody = EntityUtils.toString(response.getEntity()); - logger.info("Response " + responseBody); return XContentHelper.convertToMap(XContentType.JSON.xContent(), responseBody, false); } } From 4a1a20db9c80bb6aaec15fae60f241c60b80b50e Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Thu, 21 Nov 2024 12:23:54 -0800 Subject: [PATCH 07/10] Address remaining feedback for integ test Signed-off-by: Ryan Bogan --- .../processor/RRFProcessorIT.java | 193 +++++++----------- .../neuralsearch/BaseNeuralSearchIT.java | 28 +++ 2 files changed, 103 insertions(+), 118 deletions(-) diff --git a/src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorIT.java index 8379972a0..69ad63a74 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorIT.java @@ -4,75 +4,101 @@ */ package org.opensearch.neuralsearch.processor; -import com.google.common.collect.ImmutableList; import lombok.SneakyThrows; -import org.apache.hc.core5.http.HttpHeaders; -import org.apache.hc.core5.http.io.entity.EntityUtils; -import org.apache.hc.core5.http.message.BasicHeader; -import org.opensearch.client.Request; -import org.opensearch.client.Response; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.XContentFactory; -import org.opensearch.common.xcontent.XContentHelper; -import org.opensearch.common.xcontent.XContentType; -import org.opensearch.core.rest.RestStatus; -import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.query.MatchQueryBuilder; import org.opensearch.neuralsearch.BaseNeuralSearchIT; +import org.opensearch.neuralsearch.query.HybridQueryBuilder; +import org.opensearch.neuralsearch.query.NeuralQueryBuilder; import java.util.ArrayList; import java.util.HashMap; import java.util.Map; -import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_USER_AGENT; import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_SCORE_ASSERTION; public class RRFProcessorIT extends BaseNeuralSearchIT { private int currentDoc = 1; private static final String RRF_INDEX_NAME = "rrf-index"; - private static final String RRF_SEARCH_PIPELINE = "rrf-search-pipeline"; @SneakyThrows public void testRRF_whenValidInput_thenSucceed() { - String modelId = prepareModel(); String ingestPipelineName = "rrf-ingest-pipeline"; - createPipelineProcessor(modelId, ingestPipelineName, ProcessorType.TEXT_EMBEDDING); - Settings indexSettings = Settings.builder().put("index.knn", true).put("default_pipeline", ingestPipelineName).build(); - String indexMappings = XContentFactory.jsonBuilder() - .startObject() - .startObject("properties") - .startObject("id") - .field("type", "text") - .endObject() - .startObject("passage_embedding") - .field("type", "knn_vector") - .field("dimension", "768") - .startObject("method") - .field("engine", "lucene") - .field("space_type", "l2") - .field("name", "hnsw") - .endObject() - .endObject() - .startObject("text") - .field("type", "text") - .endObject() - .endObject() - .endObject() - .toString(); - // Removes the {} around the string, since they are already included with createIndex - indexMappings = indexMappings.substring(1, indexMappings.length() - 1); - String indexName = "rrf-index"; - createIndex(indexName, indexSettings, indexMappings, null); - addDocuments(); - createDefaultRRFSearchPipeline(); + String modelId = null; + try { + modelId = prepareModel(); + createPipelineProcessor(modelId, ingestPipelineName, ProcessorType.TEXT_EMBEDDING); + Settings indexSettings = Settings.builder().put("index.knn", true).put("default_pipeline", ingestPipelineName).build(); + String indexMappings = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject("id") + .field("type", "text") + .endObject() + .startObject("passage_embedding") + .field("type", "knn_vector") + .field("dimension", "768") + .startObject("method") + .field("engine", "lucene") + .field("space_type", "l2") + .field("name", "hnsw") + .endObject() + .endObject() + .startObject("text") + .field("type", "text") + .endObject() + .endObject() + .endObject() + .toString(); + // Removes the {} around the string, since they are already included with createIndex + indexMappings = indexMappings.substring(1, indexMappings.length() - 1); + createIndex(RRF_INDEX_NAME, indexSettings, indexMappings, null); + addDocuments(); + createDefaultRRFSearchPipeline(); - Map results = searchRRF(modelId); - Map hits = (Map) results.get("hits"); - ArrayList> hitsList = (ArrayList>) hits.get("hits"); - assertEquals(3, hitsList.size()); - assertEquals(0.016393442, (Double) hitsList.getFirst().get("_score"), DELTA_FOR_SCORE_ASSERTION); - assertEquals(0.016129032, (Double) hitsList.get(1).get("_score"), DELTA_FOR_SCORE_ASSERTION); - assertEquals(0.015873017, (Double) hitsList.getLast().get("_score"), DELTA_FOR_SCORE_ASSERTION); + HybridQueryBuilder hybridQueryBuilder = getHybridQueryBuilder(modelId); + + Map results = search( + RRF_INDEX_NAME, + hybridQueryBuilder, + null, + 5, + Map.of("search_pipeline", RRF_SEARCH_PIPELINE) + ); + Map hits = (Map) results.get("hits"); + ArrayList> hitsList = (ArrayList>) hits.get("hits"); + assertEquals(3, hitsList.size()); + assertEquals(0.016393442, (Double) hitsList.getFirst().get("_score"), DELTA_FOR_SCORE_ASSERTION); + assertEquals(0.016129032, (Double) hitsList.get(1).get("_score"), DELTA_FOR_SCORE_ASSERTION); + assertEquals(0.015873017, (Double) hitsList.getLast().get("_score"), DELTA_FOR_SCORE_ASSERTION); + } finally { + wipeOfTestResources(RRF_INDEX_NAME, ingestPipelineName, modelId, RRF_SEARCH_PIPELINE); + } + } + + private HybridQueryBuilder getHybridQueryBuilder(String modelId) { + MatchQueryBuilder matchQueryBuilder = new MatchQueryBuilder("text", "cowboy rodeo bronco"); + + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder( + "passage_embedding", + "wild_west", + "", + modelId, + 5, + null, + null, + null, + null, + null, + null + ); + + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(matchQueryBuilder); + hybridQueryBuilder.add(neuralQueryBuilder); + return hybridQueryBuilder; } @SneakyThrows @@ -94,73 +120,4 @@ private void addDocuments() { private void addDocument(String description, String imageText) { addDocument(RRF_INDEX_NAME, String.valueOf(currentDoc++), "text", description, "image_text", imageText); } - - @SneakyThrows - private void createDefaultRRFSearchPipeline() { - String requestBody = XContentFactory.jsonBuilder() - .startObject() - .field("description", "Post processor for hybrid search") - .startArray("phase_results_processors") - .startObject() - .startObject("score-ranker-processor") - .startObject("combination") - .field("technique", "rrf") - .endObject() - .endObject() - .endObject() - .endArray() - .endObject() - .toString(); - - makeRequest( - client(), - "PUT", - String.format(LOCALE, "/_search/pipeline/%s", RRF_SEARCH_PIPELINE), - null, - toHttpEntity(String.format(LOCALE, requestBody)), - ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) - ); - } - - @SneakyThrows - private Map searchRRF(String modelId) { - XContentBuilder builder = XContentFactory.jsonBuilder() - .startObject() - .startObject("_source") - .startArray("exclude") - .value("passage_embedding") - .endArray() - .endObject() - .startObject("query") - .startObject("hybrid") - .startArray("queries") - .startObject() - .startObject("match") - .startObject("text") - .field("query", "cowboy rodeo bronco") - .endObject() - .endObject() - .endObject() - .startObject() - .startObject("neural") - .startObject("passage_embedding") - .field("query_text", "wild west") - .field("model_id", modelId) - .field("k", 5) - .endObject() - .endObject() - .endObject() - .endArray() - .endObject() - .endObject() - .endObject(); - - Request request = new Request("GET", "/" + RRF_INDEX_NAME + "/_search?timeout=1000s&search_pipeline=" + RRF_SEARCH_PIPELINE); - request.setJsonEntity(builder.toString()); - Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); - - String responseBody = EntityUtils.toString(response.getEntity()); - return XContentHelper.convertToMap(XContentType.JSON.xContent(), responseBody, false); - } } diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java index afc545447..f8021b08e 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java @@ -89,6 +89,7 @@ public abstract class BaseNeuralSearchIT extends OpenSearchSecureRestTestCase { ); private static final Set SUCCESS_STATUSES = Set.of(RestStatus.CREATED, RestStatus.OK); protected static final String CONCURRENT_SEGMENT_SEARCH_ENABLED = "search.concurrent_segment_search.enabled"; + protected static final String RRF_SEARCH_PIPELINE = "rrf-search-pipeline"; protected final ClassLoader classLoader = this.getClass().getClassLoader(); @@ -1441,4 +1442,31 @@ protected enum ProcessorType { TEXT_IMAGE_EMBEDDING, SPARSE_ENCODING } + + @SneakyThrows + protected void createDefaultRRFSearchPipeline() { + String requestBody = XContentFactory.jsonBuilder() + .startObject() + .field("description", "Post processor for hybrid search") + .startArray("phase_results_processors") + .startObject() + .startObject("score-ranker-processor") + .startObject("combination") + .field("technique", "rrf") + .endObject() + .endObject() + .endObject() + .endArray() + .endObject() + .toString(); + + makeRequest( + client(), + "PUT", + String.format(LOCALE, "/_search/pipeline/%s", RRF_SEARCH_PIPELINE), + null, + toHttpEntity(String.format(LOCALE, requestBody)), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + } } From dcfa37412912a93622726cddbf3b3425b06c126a Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Thu, 21 Nov 2024 12:49:35 -0800 Subject: [PATCH 08/10] Adjust unit test names Signed-off-by: Ryan Bogan --- .../neuralsearch/processor/RRFProcessorTests.java | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorTests.java index 70dcb7aee..76d7b008d 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorTests.java @@ -96,19 +96,19 @@ public void testIsIgnoreFailure() { } @SneakyThrows - public void testProcessWithNullSearchPhaseResult() { + public void testProcess_whenNullSearchPhaseResult_thenSkipWorkflow() { rrfProcessor.process(null, mockSearchPhaseContext); verify(mockNormalizationWorkflow, never()).execute(any()); } @SneakyThrows - public void testProcessWithNonQueryPhaseResultConsumer() { + public void testProcess_whenNonQueryPhaseResultConsumer_thenSkipWorkflow() { rrfProcessor.process(mockSearchPhaseResults, mockSearchPhaseContext); verify(mockNormalizationWorkflow, never()).execute(any()); } @SneakyThrows - public void testProcessWithValidHybridInput() { + public void testProcess_whenValidHybridInput_thenSucceed() { QuerySearchResult result = createQuerySearchResult(true); AtomicArray atomicArray = new AtomicArray<>(1); atomicArray.set(0, result); @@ -121,7 +121,7 @@ public void testProcessWithValidHybridInput() { } @SneakyThrows - public void testProcessWithValidNonHybridInput() { + public void testProcess_whenValidNonHybridInput_thenSucceed() { QuerySearchResult result = createQuerySearchResult(false); AtomicArray atomicArray = new AtomicArray<>(1); atomicArray.set(0, result); From 24c7fa8e2ba4249974f8ad51e0edb632c5757463 Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Thu, 21 Nov 2024 13:35:17 -0800 Subject: [PATCH 09/10] Extrapolate strings to constants Signed-off-by: Ryan Bogan --- .../processor/RRFProcessorTests.java | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorTests.java index 76d7b008d..01bbbfbec 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorTests.java @@ -60,24 +60,20 @@ public class RRFProcessorTests extends OpenSearchTestCase { private QueryPhaseResultConsumer mockQueryPhaseResultConsumer; private RRFProcessor rrfProcessor; + private static final String TAG = "tag"; + private static final String DESCRIPTION = "description"; @Before @SneakyThrows public void setUp() { super.setUp(); MockitoAnnotations.openMocks(this); - rrfProcessor = new RRFProcessor( - "tag", - "description", - mockNormalizationTechnique, - mockCombinationTechnique, - mockNormalizationWorkflow - ); + rrfProcessor = new RRFProcessor(TAG, DESCRIPTION, mockNormalizationTechnique, mockCombinationTechnique, mockNormalizationWorkflow); } @SneakyThrows public void testGetType() { - assertEquals("score-ranker-processor", rrfProcessor.getType()); + assertEquals(RRFProcessor.TYPE, rrfProcessor.getType()); } @SneakyThrows @@ -135,12 +131,12 @@ public void testProcess_whenValidNonHybridInput_thenSucceed() { @SneakyThrows public void testGetTag() { - assertEquals("tag", rrfProcessor.getTag()); + assertEquals(TAG, rrfProcessor.getTag()); } @SneakyThrows public void testGetDescription() { - assertEquals("description", rrfProcessor.getDescription()); + assertEquals(DESCRIPTION, rrfProcessor.getDescription()); } @SneakyThrows From f539741bda29f6c204d42566d107ae06337383f4 Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Tue, 26 Nov 2024 09:28:42 -0800 Subject: [PATCH 10/10] Address integration test feedback Signed-off-by: Ryan Bogan --- .../processor/RRFProcessorIT.java | 70 ++++++------------- 1 file changed, 20 insertions(+), 50 deletions(-) diff --git a/src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorIT.java index 69ad63a74..fccabab5c 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorIT.java @@ -5,60 +5,40 @@ package org.opensearch.neuralsearch.processor; import lombok.SneakyThrows; -import org.opensearch.common.settings.Settings; -import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.index.query.MatchQueryBuilder; +import org.opensearch.knn.index.query.KNNQueryBuilder; import org.opensearch.neuralsearch.BaseNeuralSearchIT; import org.opensearch.neuralsearch.query.HybridQueryBuilder; -import org.opensearch.neuralsearch.query.NeuralQueryBuilder; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.Map; import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_SCORE_ASSERTION; +import static org.opensearch.neuralsearch.util.TestUtils.TEST_SPACE_TYPE; public class RRFProcessorIT extends BaseNeuralSearchIT { private int currentDoc = 1; private static final String RRF_INDEX_NAME = "rrf-index"; + private static final String RRF_SEARCH_PIPELINE = "rrf-search-pipeline"; + private static final String RRF_INGEST_PIPELINE = "rrf-ingest-pipeline"; + + private static final int RRF_DIMENSION = 5; @SneakyThrows public void testRRF_whenValidInput_thenSucceed() { - String ingestPipelineName = "rrf-ingest-pipeline"; - String modelId = null; try { - modelId = prepareModel(); - createPipelineProcessor(modelId, ingestPipelineName, ProcessorType.TEXT_EMBEDDING); - Settings indexSettings = Settings.builder().put("index.knn", true).put("default_pipeline", ingestPipelineName).build(); - String indexMappings = XContentFactory.jsonBuilder() - .startObject() - .startObject("properties") - .startObject("id") - .field("type", "text") - .endObject() - .startObject("passage_embedding") - .field("type", "knn_vector") - .field("dimension", "768") - .startObject("method") - .field("engine", "lucene") - .field("space_type", "l2") - .field("name", "hnsw") - .endObject() - .endObject() - .startObject("text") - .field("type", "text") - .endObject() - .endObject() - .endObject() - .toString(); - // Removes the {} around the string, since they are already included with createIndex - indexMappings = indexMappings.substring(1, indexMappings.length() - 1); - createIndex(RRF_INDEX_NAME, indexSettings, indexMappings, null); + createPipelineProcessor(null, RRF_INGEST_PIPELINE, ProcessorType.TEXT_EMBEDDING); + prepareKnnIndex( + RRF_INDEX_NAME, + Collections.singletonList(new KNNFieldConfig("passage_embedding", RRF_DIMENSION, TEST_SPACE_TYPE)) + ); addDocuments(); createDefaultRRFSearchPipeline(); - HybridQueryBuilder hybridQueryBuilder = getHybridQueryBuilder(modelId); + HybridQueryBuilder hybridQueryBuilder = getHybridQueryBuilder(); Map results = search( RRF_INDEX_NAME, @@ -74,30 +54,20 @@ public void testRRF_whenValidInput_thenSucceed() { assertEquals(0.016129032, (Double) hitsList.get(1).get("_score"), DELTA_FOR_SCORE_ASSERTION); assertEquals(0.015873017, (Double) hitsList.getLast().get("_score"), DELTA_FOR_SCORE_ASSERTION); } finally { - wipeOfTestResources(RRF_INDEX_NAME, ingestPipelineName, modelId, RRF_SEARCH_PIPELINE); + wipeOfTestResources(RRF_INDEX_NAME, RRF_INGEST_PIPELINE, null, RRF_SEARCH_PIPELINE); } } - private HybridQueryBuilder getHybridQueryBuilder(String modelId) { + private HybridQueryBuilder getHybridQueryBuilder() { MatchQueryBuilder matchQueryBuilder = new MatchQueryBuilder("text", "cowboy rodeo bronco"); - - NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder( - "passage_embedding", - "wild_west", - "", - modelId, - 5, - null, - null, - null, - null, - null, - null - ); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder.Builder().fieldName("passage_embedding") + .k(5) + .vector(new float[] { 0.1f, 1.2f, 2.3f, 3.4f, 4.5f }) + .build(); HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); hybridQueryBuilder.add(matchQueryBuilder); - hybridQueryBuilder.add(neuralQueryBuilder); + hybridQueryBuilder.add(knnQueryBuilder); return hybridQueryBuilder; }