Skip to content

Commit

Permalink
Validate Disjuction query in HybridQueryPhaseSearcher (#1127)
Browse files Browse the repository at this point in the history
* Validate Disjuction query in HybridQueryPhaseSearcher

Signed-off-by: Owais <[email protected]>
  • Loading branch information
owaiskazi19 authored Jan 24, 2025
1 parent c6b8ac4 commit 9799c6c
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Support empty string for fields in text embedding processor ([#1041](https://github.com/opensearch-project/neural-search/pull/1041))
- Optimize ML inference connection retry logic ([#1054](https://github.com/opensearch-project/neural-search/pull/1054))
- Support for builder constructor in Neural Query Builder ([#1047](https://github.com/opensearch-project/neural-search/pull/1047))
- Validate Disjunction query to avoid having nested hybrid query ([#1127](https://github.com/opensearch-project/neural-search/pull/1127))
### Bug Fixes
- Address inconsistent scoring in hybrid query results ([#998](https://github.com/opensearch-project/neural-search/pull/998))
- Fix bug where ingested document has list of nested objects ([#1040](https://github.com/opensearch-project/neural-search/pull/1040))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import lombok.NoArgsConstructor;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.DisjunctionMaxQuery;
import org.apache.lucene.search.Query;
import org.opensearch.common.settings.Settings;
import org.opensearch.index.mapper.MapperService;
Expand Down Expand Up @@ -104,7 +105,7 @@ protected Query extractHybridQuery(final SearchContext searchContext, final Quer
* }
* ]
* }
* TODO add similar validation for other compound type queries like dis_max, constant_score etc.
* TODO add similar validation for other compound type queries like constant_score, function_score etc.
*
* @param query query to validate
*/
Expand All @@ -114,6 +115,10 @@ private void validateQuery(final SearchContext searchContext, final Query query)
for (BooleanClause booleanClause : booleanClauses) {
validateNestedBooleanQuery(booleanClause.getQuery(), getMaxDepthLimit(searchContext));
}
} else if (query instanceof DisjunctionMaxQuery) {
for (Query disjunct : (DisjunctionMaxQuery) query) {
validateNestedDisJunctionQuery(disjunct, getMaxDepthLimit(searchContext));
}
}
}

Expand All @@ -135,6 +140,24 @@ private void validateNestedBooleanQuery(final Query query, final int level) {
}
}

private void validateNestedDisJunctionQuery(final Query query, final int level) {
if (query instanceof HybridQuery) {
throw new IllegalArgumentException("hybrid query must be a top level query and cannot be wrapped into other queries");
}
if (level <= 0) {
// ideally we should throw an error here but this code is on the main search workflow path and that might block
// execution of some queries. Instead, we're silently exit and allow such query to execute and potentially produce incorrect
// results in case hybrid query is wrapped into such dis_max query
log.error("reached max nested query limit, cannot process dis_max query with that many nested clauses");
return;
}
if (query instanceof DisjunctionMaxQuery) {
for (Query disjunct : (DisjunctionMaxQuery) query) {
validateNestedDisJunctionQuery(disjunct, level - 1);
}
}
}

private int getMaxDepthLimit(final SearchContext searchContext) {
Settings indexSettings = searchContext.getQueryShardContext().getIndexSettings().getSettings();
return MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(indexSettings).intValue();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
import org.opensearch.index.mapper.TextFieldMapper;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.DisMaxQueryBuilder;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.index.remote.RemoteStoreEnums;
Expand Down Expand Up @@ -516,6 +517,104 @@ public void testWrappedHybridQuery_whenHybridWrappedIntoBool_thenFail() {
releaseResources(directory, w, reader);
}

@SneakyThrows
public void testWrappedHybridQuery_whenHybridNestedInDisjunctionQuery_thenFail() {
HybridQueryPhaseSearcher hybridQueryPhaseSearcher = new HybridQueryPhaseSearcher();
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
when(mockQueryShardContext.index()).thenReturn(dummyIndex);
TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME);
when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType);
MapperService mapperService = mock(MapperService.class);
when(mapperService.hasNested()).thenReturn(false);

Directory directory = newDirectory();
IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random())));
FieldType ft = new FieldType(TextField.TYPE_NOT_STORED);
ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS);
ft.setOmitNorms(random().nextBoolean());
ft.freeze();
int docId1 = RandomizedTest.randomInt();
w.addDocument(getDocument(TEXT_FIELD_NAME, docId1, TEST_DOC_TEXT1, ft));
w.commit();

IndexReader reader = DirectoryReader.open(w);
SearchContext searchContext = mock(SearchContext.class);

ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher(
reader,
IndexSearcher.getDefaultSimilarity(),
IndexSearcher.getDefaultQueryCache(),
IndexSearcher.getDefaultQueryCachingPolicy(),
true,
null,
searchContext
);

ShardId shardId = new ShardId(dummyIndex, 1);
SearchShardTarget shardTarget = new SearchShardTarget(
randomAlphaOfLength(10),
shardId,
randomAlphaOfLength(10),
OriginalIndices.NONE
);
when(searchContext.shardTarget()).thenReturn(shardTarget);
when(searchContext.searcher()).thenReturn(contextIndexSearcher);
when(searchContext.size()).thenReturn(4);
QuerySearchResult querySearchResult = new QuerySearchResult();
when(searchContext.queryResult()).thenReturn(querySearchResult);
when(searchContext.numberOfShards()).thenReturn(1);
when(searchContext.searcher()).thenReturn(contextIndexSearcher);
IndexShard indexShard = mock(IndexShard.class);
when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0));
when(indexShard.getSearchOperationListener()).thenReturn(mock(SearchOperationListener.class));
when(searchContext.indexShard()).thenReturn(indexShard);
when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR);
when(searchContext.mapperService()).thenReturn(mapperService);
when(searchContext.getQueryShardContext()).thenReturn(mockQueryShardContext);
IndexMetadata indexMetadata = getIndexMetadata();
Settings settings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, Integer.toString(1)).build();
IndexSettings indexSettings = new IndexSettings(indexMetadata, settings);
when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings);

LinkedList<QueryCollectorContext> collectors = new LinkedList<>();
boolean hasFilterCollector = randomBoolean();
boolean hasTimeout = randomBoolean();

// Create a HybridQueryBuilder
HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder();
hybridQueryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1));
hybridQueryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2));
hybridQueryBuilder.paginationDepth(10);

// Create a regular term query
TermQueryBuilder termQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2);

// Create a disjunction query (OR) with the hybrid query and the term query
DisMaxQueryBuilder disjunctionMaxQueryBuilder = QueryBuilders.disMaxQuery().add(hybridQueryBuilder).add(termQuery);

Query query = disjunctionMaxQueryBuilder.toQuery(mockQueryShardContext);
when(searchContext.query()).thenReturn(query);

IllegalArgumentException exception = expectThrows(
IllegalArgumentException.class,
() -> hybridQueryPhaseSearcher.searchWith(
searchContext,
contextIndexSearcher,
query,
collectors,
hasFilterCollector,
hasTimeout
)
);

org.hamcrest.MatcherAssert.assertThat(
exception.getMessage(),
containsString("hybrid query must be a top level query and cannot be wrapped into other queries")
);

releaseResources(directory, w, reader);
}

@SneakyThrows
public void testWrappedHybridQuery_whenHybridWrappedIntoBoolAndIncorrectStructure_thenFail() {
HybridQueryPhaseSearcher hybridQueryPhaseSearcher = new HybridQueryPhaseSearcher();
Expand Down

0 comments on commit 9799c6c

Please sign in to comment.