From 9799c6c2e88e30e19e250d1ac40c21fff03a4ebb Mon Sep 17 00:00:00 2001 From: Owais Kazi Date: Fri, 24 Jan 2025 10:55:22 -0800 Subject: [PATCH] Validate Disjuction query in HybridQueryPhaseSearcher (#1127) * Validate Disjuction query in HybridQueryPhaseSearcher Signed-off-by: Owais --- CHANGELOG.md | 1 + .../query/HybridQueryPhaseSearcher.java | 25 ++++- .../query/HybridQueryPhaseSearcherTests.java | 99 +++++++++++++++++++ 3 files changed, 124 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e9703b309..5e0bbdeaa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Support empty string for fields in text embedding processor ([#1041](https://github.com/opensearch-project/neural-search/pull/1041)) - Optimize ML inference connection retry logic ([#1054](https://github.com/opensearch-project/neural-search/pull/1054)) - Support for builder constructor in Neural Query Builder ([#1047](https://github.com/opensearch-project/neural-search/pull/1047)) +- Validate Disjunction query to avoid having nested hybrid query ([#1127](https://github.com/opensearch-project/neural-search/pull/1127)) ### Bug Fixes - Address inconsistent scoring in hybrid query results ([#998](https://github.com/opensearch-project/neural-search/pull/998)) - Fix bug where ingested document has list of nested objects ([#1040](https://github.com/opensearch-project/neural-search/pull/1040)) diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java index aca93b77b..0788f43a8 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java @@ -14,6 +14,7 @@ import lombok.NoArgsConstructor; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.DisjunctionMaxQuery; import org.apache.lucene.search.Query; import org.opensearch.common.settings.Settings; import org.opensearch.index.mapper.MapperService; @@ -104,7 +105,7 @@ protected Query extractHybridQuery(final SearchContext searchContext, final Quer * } * ] * } - * TODO add similar validation for other compound type queries like dis_max, constant_score etc. + * TODO add similar validation for other compound type queries like constant_score, function_score etc. * * @param query query to validate */ @@ -114,6 +115,10 @@ private void validateQuery(final SearchContext searchContext, final Query query) for (BooleanClause booleanClause : booleanClauses) { validateNestedBooleanQuery(booleanClause.getQuery(), getMaxDepthLimit(searchContext)); } + } else if (query instanceof DisjunctionMaxQuery) { + for (Query disjunct : (DisjunctionMaxQuery) query) { + validateNestedDisJunctionQuery(disjunct, getMaxDepthLimit(searchContext)); + } } } @@ -135,6 +140,24 @@ private void validateNestedBooleanQuery(final Query query, final int level) { } } + private void validateNestedDisJunctionQuery(final Query query, final int level) { + if (query instanceof HybridQuery) { + throw new IllegalArgumentException("hybrid query must be a top level query and cannot be wrapped into other queries"); + } + if (level <= 0) { + // ideally we should throw an error here but this code is on the main search workflow path and that might block + // execution of some queries. Instead, we're silently exit and allow such query to execute and potentially produce incorrect + // results in case hybrid query is wrapped into such dis_max query + log.error("reached max nested query limit, cannot process dis_max query with that many nested clauses"); + return; + } + if (query instanceof DisjunctionMaxQuery) { + for (Query disjunct : (DisjunctionMaxQuery) query) { + validateNestedDisJunctionQuery(disjunct, level - 1); + } + } + } + private int getMaxDepthLimit(final SearchContext searchContext) { Settings indexSettings = searchContext.getQueryShardContext().getIndexSettings().getSettings(); return MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(indexSettings).intValue(); diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java index 2aafa2ece..bd920f41a 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java @@ -58,6 +58,7 @@ import org.opensearch.index.mapper.TextFieldMapper; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.DisMaxQueryBuilder; import org.opensearch.index.query.QueryShardContext; import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.index.remote.RemoteStoreEnums; @@ -516,6 +517,104 @@ public void testWrappedHybridQuery_whenHybridWrappedIntoBool_thenFail() { releaseResources(directory, w, reader); } + @SneakyThrows + public void testWrappedHybridQuery_whenHybridNestedInDisjunctionQuery_thenFail() { + HybridQueryPhaseSearcher hybridQueryPhaseSearcher = new HybridQueryPhaseSearcher(); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + MapperService mapperService = mock(MapperService.class); + when(mapperService.hasNested()).thenReturn(false); + + Directory directory = newDirectory(); + IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + int docId1 = RandomizedTest.randomInt(); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId1, TEST_DOC_TEXT1, ft)); + w.commit(); + + IndexReader reader = DirectoryReader.open(w); + SearchContext searchContext = mock(SearchContext.class); + + ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher( + reader, + IndexSearcher.getDefaultSimilarity(), + IndexSearcher.getDefaultQueryCache(), + IndexSearcher.getDefaultQueryCachingPolicy(), + true, + null, + searchContext + ); + + ShardId shardId = new ShardId(dummyIndex, 1); + SearchShardTarget shardTarget = new SearchShardTarget( + randomAlphaOfLength(10), + shardId, + randomAlphaOfLength(10), + OriginalIndices.NONE + ); + when(searchContext.shardTarget()).thenReturn(shardTarget); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + when(searchContext.size()).thenReturn(4); + QuerySearchResult querySearchResult = new QuerySearchResult(); + when(searchContext.queryResult()).thenReturn(querySearchResult); + when(searchContext.numberOfShards()).thenReturn(1); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + IndexShard indexShard = mock(IndexShard.class); + when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0)); + when(indexShard.getSearchOperationListener()).thenReturn(mock(SearchOperationListener.class)); + when(searchContext.indexShard()).thenReturn(indexShard); + when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); + when(searchContext.mapperService()).thenReturn(mapperService); + when(searchContext.getQueryShardContext()).thenReturn(mockQueryShardContext); + IndexMetadata indexMetadata = getIndexMetadata(); + Settings settings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, Integer.toString(1)).build(); + IndexSettings indexSettings = new IndexSettings(indexMetadata, settings); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); + + LinkedList collectors = new LinkedList<>(); + boolean hasFilterCollector = randomBoolean(); + boolean hasTimeout = randomBoolean(); + + // Create a HybridQueryBuilder + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1)); + hybridQueryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2)); + hybridQueryBuilder.paginationDepth(10); + + // Create a regular term query + TermQueryBuilder termQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2); + + // Create a disjunction query (OR) with the hybrid query and the term query + DisMaxQueryBuilder disjunctionMaxQueryBuilder = QueryBuilders.disMaxQuery().add(hybridQueryBuilder).add(termQuery); + + Query query = disjunctionMaxQueryBuilder.toQuery(mockQueryShardContext); + when(searchContext.query()).thenReturn(query); + + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> hybridQueryPhaseSearcher.searchWith( + searchContext, + contextIndexSearcher, + query, + collectors, + hasFilterCollector, + hasTimeout + ) + ); + + org.hamcrest.MatcherAssert.assertThat( + exception.getMessage(), + containsString("hybrid query must be a top level query and cannot be wrapped into other queries") + ); + + releaseResources(directory, w, reader); + } + @SneakyThrows public void testWrappedHybridQuery_whenHybridWrappedIntoBoolAndIncorrectStructure_thenFail() { HybridQueryPhaseSearcher hybridQueryPhaseSearcher = new HybridQueryPhaseSearcher();