diff --git a/src/main/java/org/opensearch/neuralsearch/processor/RRFProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/RRFProcessor.java index 207af156c..ad1f73c29 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/RRFProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/RRFProcessor.java @@ -13,6 +13,7 @@ import java.util.Optional; import lombok.Getter; +import org.opensearch.ml.repackage.com.google.common.annotations.VisibleForTesting; import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; import org.opensearch.search.fetch.FetchSearchResult; @@ -98,7 +99,8 @@ public boolean isIgnoreFailure() { return false; } - private boolean shouldSkipProcessor(SearchPhaseResults searchPhaseResult) { + @VisibleForTesting + boolean shouldSkipProcessor(SearchPhaseResults searchPhaseResult) { if (Objects.isNull(searchPhaseResult) || !(searchPhaseResult instanceof QueryPhaseResultConsumer queryPhaseResultConsumer)) { return true; } @@ -111,7 +113,8 @@ private boolean shouldSkipProcessor(SearchPha * @param searchPhaseResult * @return true if results are from hybrid query */ - private boolean isHybridQuery(final SearchPhaseResult searchPhaseResult) { + @VisibleForTesting + boolean isHybridQuery(final SearchPhaseResult searchPhaseResult) { // check for delimiter at the end of the score docs. return Objects.nonNull(searchPhaseResult.queryResult()) && Objects.nonNull(searchPhaseResult.queryResult().topDocs()) @@ -120,9 +123,7 @@ private boolean isHybridQuery(final SearchPhaseResult searchPhaseResult) { && isHybridQueryStartStopElement(searchPhaseResult.queryResult().topDocs().topDocs.scoreDocs[0]); } - private List getQueryPhaseSearchResults( - final SearchPhaseResults results - ) { + List getQueryPhaseSearchResults(final SearchPhaseResults results) { return results.getAtomicArray() .asList() .stream() @@ -130,7 +131,8 @@ private List getQueryPhase .collect(Collectors.toList()); } - private Optional getFetchSearchResults( + @VisibleForTesting + Optional getFetchSearchResults( final SearchPhaseResults searchPhaseResults ) { Optional optionalFirstSearchPhaseResult = searchPhaseResults.getAtomicArray().asList().stream().findFirst(); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorIT.java new file mode 100644 index 000000000..fccabab5c --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorIT.java @@ -0,0 +1,93 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor; + +import lombok.SneakyThrows; +import org.opensearch.index.query.MatchQueryBuilder; +import org.opensearch.knn.index.query.KNNQueryBuilder; +import org.opensearch.neuralsearch.BaseNeuralSearchIT; +import org.opensearch.neuralsearch.query.HybridQueryBuilder; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_SCORE_ASSERTION; +import static org.opensearch.neuralsearch.util.TestUtils.TEST_SPACE_TYPE; + +public class RRFProcessorIT extends BaseNeuralSearchIT { + + private int currentDoc = 1; + private static final String RRF_INDEX_NAME = "rrf-index"; + private static final String RRF_SEARCH_PIPELINE = "rrf-search-pipeline"; + private static final String RRF_INGEST_PIPELINE = "rrf-ingest-pipeline"; + + private static final int RRF_DIMENSION = 5; + + @SneakyThrows + public void testRRF_whenValidInput_thenSucceed() { + try { + createPipelineProcessor(null, RRF_INGEST_PIPELINE, ProcessorType.TEXT_EMBEDDING); + prepareKnnIndex( + RRF_INDEX_NAME, + Collections.singletonList(new KNNFieldConfig("passage_embedding", RRF_DIMENSION, TEST_SPACE_TYPE)) + ); + addDocuments(); + createDefaultRRFSearchPipeline(); + + HybridQueryBuilder hybridQueryBuilder = getHybridQueryBuilder(); + + Map results = search( + RRF_INDEX_NAME, + hybridQueryBuilder, + null, + 5, + Map.of("search_pipeline", RRF_SEARCH_PIPELINE) + ); + Map hits = (Map) results.get("hits"); + ArrayList> hitsList = (ArrayList>) hits.get("hits"); + assertEquals(3, hitsList.size()); + assertEquals(0.016393442, (Double) hitsList.getFirst().get("_score"), DELTA_FOR_SCORE_ASSERTION); + assertEquals(0.016129032, (Double) hitsList.get(1).get("_score"), DELTA_FOR_SCORE_ASSERTION); + assertEquals(0.015873017, (Double) hitsList.getLast().get("_score"), DELTA_FOR_SCORE_ASSERTION); + } finally { + wipeOfTestResources(RRF_INDEX_NAME, RRF_INGEST_PIPELINE, null, RRF_SEARCH_PIPELINE); + } + } + + private HybridQueryBuilder getHybridQueryBuilder() { + MatchQueryBuilder matchQueryBuilder = new MatchQueryBuilder("text", "cowboy rodeo bronco"); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder.Builder().fieldName("passage_embedding") + .k(5) + .vector(new float[] { 0.1f, 1.2f, 2.3f, 3.4f, 4.5f }) + .build(); + + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(matchQueryBuilder); + hybridQueryBuilder.add(knnQueryBuilder); + return hybridQueryBuilder; + } + + @SneakyThrows + private void addDocuments() { + addDocument( + "A West Virginia university women 's basketball team , officials , and a small gathering of fans are in a West Virginia arena .", + "4319130149.jpg" + ); + addDocument("A wild animal races across an uncut field with a minimal amount of trees .", "1775029934.jpg"); + addDocument( + "People line the stands which advertise Freemont 's orthopedics , a cowboy rides a light brown bucking bronco .", + "2664027527.jpg" + ); + addDocument("A man who is riding a wild horse in the rodeo is very near to falling off .", "4427058951.jpg"); + addDocument("A rodeo cowboy , wearing a cowboy hat , is being thrown off of a wild white horse .", "2691147709.jpg"); + } + + @SneakyThrows + private void addDocument(String description, String imageText) { + addDocument(RRF_INDEX_NAME, String.valueOf(currentDoc++), "text", description, "image_text", imageText); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorTests.java new file mode 100644 index 000000000..01bbbfbec --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorTests.java @@ -0,0 +1,226 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor; + +import lombok.SneakyThrows; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.junit.Before; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.OriginalIndices; +import org.opensearch.action.search.QueryPhaseResultConsumer; +import org.opensearch.action.search.SearchPhaseContext; +import org.opensearch.action.search.SearchPhaseName; +import org.opensearch.action.search.SearchPhaseResults; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.support.IndicesOptions; +import org.opensearch.common.lucene.search.TopDocsAndMaxScore; +import org.opensearch.common.util.concurrent.AtomicArray; +import org.opensearch.core.common.Strings; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; +import org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil; +import org.opensearch.search.DocValueFormat; +import org.opensearch.search.SearchPhaseResult; +import org.opensearch.search.SearchShardTarget; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.fetch.FetchSearchResult; +import org.opensearch.search.internal.AliasFilter; +import org.opensearch.search.internal.ShardSearchContextId; +import org.opensearch.search.internal.ShardSearchRequest; +import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.List; +import java.util.Optional; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class RRFProcessorTests extends OpenSearchTestCase { + + @Mock + private ScoreNormalizationTechnique mockNormalizationTechnique; + @Mock + private ScoreCombinationTechnique mockCombinationTechnique; + @Mock + private NormalizationProcessorWorkflow mockNormalizationWorkflow; + @Mock + private SearchPhaseResults mockSearchPhaseResults; + @Mock + private SearchPhaseContext mockSearchPhaseContext; + @Mock + private QueryPhaseResultConsumer mockQueryPhaseResultConsumer; + + private RRFProcessor rrfProcessor; + private static final String TAG = "tag"; + private static final String DESCRIPTION = "description"; + + @Before + @SneakyThrows + public void setUp() { + super.setUp(); + MockitoAnnotations.openMocks(this); + rrfProcessor = new RRFProcessor(TAG, DESCRIPTION, mockNormalizationTechnique, mockCombinationTechnique, mockNormalizationWorkflow); + } + + @SneakyThrows + public void testGetType() { + assertEquals(RRFProcessor.TYPE, rrfProcessor.getType()); + } + + @SneakyThrows + public void testGetBeforePhase() { + assertEquals(SearchPhaseName.QUERY, rrfProcessor.getBeforePhase()); + } + + @SneakyThrows + public void testGetAfterPhase() { + assertEquals(SearchPhaseName.FETCH, rrfProcessor.getAfterPhase()); + } + + @SneakyThrows + public void testIsIgnoreFailure() { + assertFalse(rrfProcessor.isIgnoreFailure()); + } + + @SneakyThrows + public void testProcess_whenNullSearchPhaseResult_thenSkipWorkflow() { + rrfProcessor.process(null, mockSearchPhaseContext); + verify(mockNormalizationWorkflow, never()).execute(any()); + } + + @SneakyThrows + public void testProcess_whenNonQueryPhaseResultConsumer_thenSkipWorkflow() { + rrfProcessor.process(mockSearchPhaseResults, mockSearchPhaseContext); + verify(mockNormalizationWorkflow, never()).execute(any()); + } + + @SneakyThrows + public void testProcess_whenValidHybridInput_thenSucceed() { + QuerySearchResult result = createQuerySearchResult(true); + AtomicArray atomicArray = new AtomicArray<>(1); + atomicArray.set(0, result); + + when(mockQueryPhaseResultConsumer.getAtomicArray()).thenReturn(atomicArray); + + rrfProcessor.process(mockQueryPhaseResultConsumer, mockSearchPhaseContext); + + verify(mockNormalizationWorkflow).execute(any(NormalizationExecuteDTO.class)); + } + + @SneakyThrows + public void testProcess_whenValidNonHybridInput_thenSucceed() { + QuerySearchResult result = createQuerySearchResult(false); + AtomicArray atomicArray = new AtomicArray<>(1); + atomicArray.set(0, result); + + when(mockQueryPhaseResultConsumer.getAtomicArray()).thenReturn(atomicArray); + + rrfProcessor.process(mockQueryPhaseResultConsumer, mockSearchPhaseContext); + + verify(mockNormalizationWorkflow, never()).execute(any(NormalizationExecuteDTO.class)); + } + + @SneakyThrows + public void testGetTag() { + assertEquals(TAG, rrfProcessor.getTag()); + } + + @SneakyThrows + public void testGetDescription() { + assertEquals(DESCRIPTION, rrfProcessor.getDescription()); + } + + @SneakyThrows + public void testShouldSkipProcessor() { + assertTrue(rrfProcessor.shouldSkipProcessor(null)); + assertTrue(rrfProcessor.shouldSkipProcessor(mockSearchPhaseResults)); + + AtomicArray atomicArray = new AtomicArray<>(1); + atomicArray.set(0, createQuerySearchResult(false)); + when(mockQueryPhaseResultConsumer.getAtomicArray()).thenReturn(atomicArray); + + assertTrue(rrfProcessor.shouldSkipProcessor(mockQueryPhaseResultConsumer)); + + atomicArray.set(0, createQuerySearchResult(true)); + assertFalse(rrfProcessor.shouldSkipProcessor(mockQueryPhaseResultConsumer)); + } + + @SneakyThrows + public void testGetQueryPhaseSearchResults() { + AtomicArray atomicArray = new AtomicArray<>(2); + atomicArray.set(0, createQuerySearchResult(true)); + atomicArray.set(1, createQuerySearchResult(false)); + when(mockQueryPhaseResultConsumer.getAtomicArray()).thenReturn(atomicArray); + + List results = rrfProcessor.getQueryPhaseSearchResults(mockQueryPhaseResultConsumer); + assertEquals(2, results.size()); + assertNotNull(results.get(0)); + assertNotNull(results.get(1)); + } + + @SneakyThrows + public void testGetFetchSearchResults() { + AtomicArray atomicArray = new AtomicArray<>(1); + atomicArray.set(0, createQuerySearchResult(true)); + when(mockQueryPhaseResultConsumer.getAtomicArray()).thenReturn(atomicArray); + + Optional result = rrfProcessor.getFetchSearchResults(mockQueryPhaseResultConsumer); + assertFalse(result.isPresent()); + } + + private QuerySearchResult createQuerySearchResult(boolean isHybrid) { + ShardId shardId = new ShardId("index", "uuid", 0); + OriginalIndices originalIndices = new OriginalIndices(new String[] { "index" }, IndicesOptions.strictExpandOpenAndForbidClosed()); + SearchRequest searchRequest = new SearchRequest("index"); + searchRequest.source(new SearchSourceBuilder()); + searchRequest.allowPartialSearchResults(true); + + int numberOfShards = 1; + AliasFilter aliasFilter = new AliasFilter(null, Strings.EMPTY_ARRAY); + float indexBoost = 1.0f; + long nowInMillis = System.currentTimeMillis(); + String clusterAlias = null; + String[] indexRoutings = Strings.EMPTY_ARRAY; + + ShardSearchRequest shardSearchRequest = new ShardSearchRequest( + originalIndices, + searchRequest, + shardId, + numberOfShards, + aliasFilter, + indexBoost, + nowInMillis, + clusterAlias, + indexRoutings + ); + + QuerySearchResult result = new QuerySearchResult( + new ShardSearchContextId("test", 1), + new SearchShardTarget("node1", shardId, clusterAlias, originalIndices), + shardSearchRequest + ); + result.from(0).size(10); + + ScoreDoc[] scoreDocs; + if (isHybrid) { + scoreDocs = new ScoreDoc[] { HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults(0) }; + } else { + scoreDocs = new ScoreDoc[] { new ScoreDoc(0, 1.0f) }; + } + + TopDocs topDocs = new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), scoreDocs); + TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(topDocs, 1.0f); + result.topDocs(topDocsAndMaxScore, new DocValueFormat[0]); + + return result; + } +} diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java index afc545447..f8021b08e 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java @@ -89,6 +89,7 @@ public abstract class BaseNeuralSearchIT extends OpenSearchSecureRestTestCase { ); private static final Set SUCCESS_STATUSES = Set.of(RestStatus.CREATED, RestStatus.OK); protected static final String CONCURRENT_SEGMENT_SEARCH_ENABLED = "search.concurrent_segment_search.enabled"; + protected static final String RRF_SEARCH_PIPELINE = "rrf-search-pipeline"; protected final ClassLoader classLoader = this.getClass().getClassLoader(); @@ -1441,4 +1442,31 @@ protected enum ProcessorType { TEXT_IMAGE_EMBEDDING, SPARSE_ENCODING } + + @SneakyThrows + protected void createDefaultRRFSearchPipeline() { + String requestBody = XContentFactory.jsonBuilder() + .startObject() + .field("description", "Post processor for hybrid search") + .startArray("phase_results_processors") + .startObject() + .startObject("score-ranker-processor") + .startObject("combination") + .field("technique", "rrf") + .endObject() + .endObject() + .endObject() + .endArray() + .endObject() + .toString(); + + makeRequest( + client(), + "PUT", + String.format(LOCALE, "/_search/pipeline/%s", RRF_SEARCH_PIPELINE), + null, + toHttpEntity(String.format(LOCALE, requestBody)), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + } }