Skip to content

Commit

Permalink
Added null checks for fieldInfo in ExactSearcher to avoid NPE while r…
Browse files Browse the repository at this point in the history
…unning exact search for segments with no vector field

Signed-off-by: Navneet Verma <navneev@amazon.com>
  • Loading branch information
navneet1v committed Nov 20, 2024
1 parent 2d1a408 commit 33f05a5
Showing 7 changed files with 96 additions and 10 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -19,11 +19,12 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Enhancements
- Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241]
### Bug Fixes
* Fix NPE in ANN search when a segment doesn't contain vector field (#2278)[https://github.com/opensearch-project/k-NN/pull/2278]
* Fixing the bug when a segment has no vector field present for disk based vector search (#2282)[https://github.com/opensearch-project/k-NN/pull/2282]
### Infrastructure
* Updated C++ version in JNI from c++11 to c++17 [#2259](https://github.com/opensearch-project/k-NN/pull/2259)
* Upgrade bytebuddy and objenesis version to match OpenSearch core and, update github ci runner for macos [#2279](https://github.com/opensearch-project/k-NN/pull/2279)
### Documentation
### Maintenance
* Select index settings based on cluster version[2236](https://github.com/opensearch-project/k-NN/pull/2236)
* Added null checks for fieldInfo in ExactSearcher to avoid NPE while running exact search for segments with no vector field (#2278)[https://github.com/opensearch-project/k-NN/pull/2278]
### Refactoring
15 changes: 14 additions & 1 deletion src/main/java/org/opensearch/knn/common/FieldInfoExtractor.java
Original file line number Diff line number Diff line change
@@ -8,6 +8,8 @@
import lombok.experimental.UtilityClass;
import org.apache.commons.lang.StringUtils;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReader;
import org.opensearch.common.Nullable;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.KNNEngine;
@@ -27,7 +29,7 @@
import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE;

/**
* A utility class to extract information from FieldInfo.
* A utility class to extract information from FieldInfo and also provides utility functions to extract fieldInfo
*/
@UtilityClass
public class FieldInfoExtractor {
@@ -103,4 +105,15 @@ public static SpaceType getSpaceType(final ModelDao modelDao, final FieldInfo fi
}
return modelMetadata.getSpaceType();
}

/**
* Get the field info for the given field name, do a null check on the fieldInfo, as this function can return null,
* if the field is not found.
* @param leafReader {@link LeafReader}
* @param fieldName {@link String}
* @return {@link FieldInfo}
*/
public static @Nullable FieldInfo getFieldInfo(final LeafReader leafReader, final String fieldName) {
return leafReader.getFieldInfos().fieldInfo(fieldName);
}
}
Original file line number Diff line number Diff line change
@@ -12,6 +12,7 @@
import org.opensearch.index.fielddata.LeafFieldData;
import org.opensearch.index.fielddata.ScriptDocValues;
import org.opensearch.index.fielddata.SortedBinaryDocValues;
import org.opensearch.knn.common.FieldInfoExtractor;

import java.io.IOException;

@@ -40,7 +41,7 @@ public long ramBytesUsed() {
@Override
public ScriptDocValues<float[]> getScriptValues() {
try {
FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(fieldName);
FieldInfo fieldInfo = FieldInfoExtractor.getFieldInfo(reader, fieldName);
if (fieldInfo == null) {
return KNNVectorScriptDocValues.emptyValues(fieldName, vectorDataType);
}
22 changes: 17 additions & 5 deletions src/main/java/org/opensearch/knn/index/query/ExactSearcher.java
Original file line number Diff line number Diff line change
@@ -38,6 +38,7 @@
import org.opensearch.knn.indices.ModelDao;

import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Locale;
import java.util.Map;
@@ -59,7 +60,11 @@ public class ExactSearcher {
*/
public Map<Integer, Float> searchLeaf(final LeafReaderContext leafReaderContext, final ExactSearcherContext exactSearcherContext)
throws IOException {
KNNIterator iterator = getKNNIterator(leafReaderContext, exactSearcherContext);
final KNNIterator iterator = getKNNIterator(leafReaderContext, exactSearcherContext);
// because of any reason if we are not able to get KNNIterator, return an empty map
if (iterator == null) {
return Collections.emptyMap();
}
if (exactSearcherContext.getKnnQuery().getRadius() != null) {
return doRadialSearch(leafReaderContext, exactSearcherContext, iterator);
}
@@ -74,8 +79,8 @@ public Map<Integer, Float> searchLeaf(final LeafReaderContext leafReaderContext,
* Perform radial search by comparing scores with min score. Currently, FAISS from native engine supports radial search.
* Hence, we assume that Radius from knnQuery is always distance, and we convert it to score since we do exact search uses scores
* to filter out the documents that does not have given min score.
* @param leafReaderContext
* @param exactSearcherContext
* @param leafReaderContext {@link LeafReaderContext}
* @param exactSearcherContext {@link ExactSearcherContext}
* @param iterator {@link KNNIterator}
* @return Map of docId and score
* @throws IOException exception raised by iterator during traversal
@@ -87,7 +92,10 @@ private Map<Integer, Float> doRadialSearch(
) throws IOException {
final SegmentReader reader = Lucene.segmentReader(leafReaderContext.reader());
final KNNQuery knnQuery = exactSearcherContext.getKnnQuery();
final FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField());
final FieldInfo fieldInfo = FieldInfoExtractor.getFieldInfo(reader, knnQuery.getField());
if (fieldInfo == null) {
return Collections.emptyMap();
}
final KNNEngine engine = FieldInfoExtractor.extractKNNEngine(fieldInfo);
if (KNNEngine.FAISS != engine) {
throw new IllegalArgumentException(String.format(Locale.ROOT, "Engine [%s] does not support radial search", engine));
@@ -149,7 +157,11 @@ private KNNIterator getKNNIterator(LeafReaderContext leafReaderContext, ExactSea
final KNNQuery knnQuery = exactSearcherContext.getKnnQuery();
final BitSet matchedDocs = exactSearcherContext.getMatchedDocs();
final SegmentReader reader = Lucene.segmentReader(leafReaderContext.reader());
final FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField());
final FieldInfo fieldInfo = FieldInfoExtractor.getFieldInfo(reader, knnQuery.getField());
if (fieldInfo == null) {
log.debug("[KNN] Cannot get KNNIterator as Field info not found for {}:{}", knnQuery.getField(), reader.getSegmentName());
return null;
}
final SpaceType spaceType = FieldInfoExtractor.getSpaceType(modelDao, fieldInfo);

boolean isNestedRequired = exactSearcherContext.isParentHits() && knnQuery.getParentsFilter() != null;
4 changes: 2 additions & 2 deletions src/main/java/org/opensearch/knn/index/query/KNNWeight.java
Original file line number Diff line number Diff line change
@@ -227,7 +227,7 @@ private Map<Integer, Float> doANNSearch(
) throws IOException {
final SegmentReader reader = Lucene.segmentReader(context.reader());

FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField());
FieldInfo fieldInfo = FieldInfoExtractor.getFieldInfo(reader, knnQuery.getField());

if (fieldInfo == null) {
log.debug("[KNN] Field info not found for {}:{}", knnQuery.getField(), reader.getSegmentName());
@@ -479,7 +479,7 @@ private boolean isFilteredExactSearchRequireAfterANNSearch(final int filterIdsCo
*/
private boolean isMissingNativeEngineFiles(LeafReaderContext context) {
final SegmentReader reader = Lucene.segmentReader(context.reader());
final FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField());
final FieldInfo fieldInfo = FieldInfoExtractor.getFieldInfo(reader, knnQuery.getField());
// if segment has no documents with at least 1 vector field, field info will be null
if (fieldInfo == null) {
return false;
Original file line number Diff line number Diff line change
@@ -6,6 +6,8 @@
package org.opensearch.knn.common;

import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.LeafReader;
import org.junit.Assert;
import org.mockito.MockedStatic;
import org.mockito.Mockito;
@@ -63,4 +65,15 @@ public void testExtractVectorDataType() {
when(fieldInfo.getAttribute("model_id")).thenReturn(null);
assertEquals(VectorDataType.DEFAULT, FieldInfoExtractor.extractVectorDataType(fieldInfo));
}

public void testGetFieldInfo_whenDifferentInput_thenSuccess() {
LeafReader leafReader = Mockito.mock(LeafReader.class);
FieldInfos fieldInfos = Mockito.mock(FieldInfos.class);
FieldInfo fieldInfo = Mockito.mock(FieldInfo.class);
Mockito.when(leafReader.getFieldInfos()).thenReturn(fieldInfos);
Mockito.when(fieldInfos.fieldInfo("invalid")).thenReturn(null);
Mockito.when(fieldInfos.fieldInfo("valid")).thenReturn(fieldInfo);
Assert.assertNull(FieldInfoExtractor.getFieldInfo(leafReader, "invalid"));
Assert.assertEquals(fieldInfo, FieldInfoExtractor.getFieldInfo(leafReader, "valid"));
}
}
Original file line number Diff line number Diff line change
@@ -20,6 +20,7 @@
import org.mockito.Mockito;
import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.codec.KNNCodecVersion;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues;
@@ -50,6 +51,50 @@ public class ExactSearcherTests extends KNNTestCase {

private static final String SEGMENT_NAME = "0";

@SneakyThrows
public void testExactSearch_whenSegmentHasNoVectorField_thenNoDocsReturned() {
final float[] queryVector = new float[] { 0.1f, 2.0f, 3.0f };
final KNNQuery query = KNNQuery.builder().field(FIELD_NAME).queryVector(queryVector).k(10).indexName(INDEX_NAME).build();

final ExactSearcher.ExactSearcherContext.ExactSearcherContextBuilder exactSearcherContextBuilder =
ExactSearcher.ExactSearcherContext.builder().knnQuery(query);

ExactSearcher exactSearcher = new ExactSearcher(null);
final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class);
final SegmentReader reader = mock(SegmentReader.class);
when(leafReaderContext.reader()).thenReturn(reader);

final FSDirectory directory = mock(FSDirectory.class);
final SegmentInfo segmentInfo = new SegmentInfo(
directory,
Version.LATEST,
Version.LATEST,
SEGMENT_NAME,
100,
false,
false,
KNNCodecVersion.current().getDefaultCodecDelegate(),
Map.of(),
new byte[StringHelper.ID_LENGTH],
Map.of(),
Sort.RELEVANCE
);
segmentInfo.setFiles(Set.of());
final SegmentCommitInfo segmentCommitInfo = new SegmentCommitInfo(segmentInfo, 0, 0, 0, 0, 0, new byte[StringHelper.ID_LENGTH]);
when(reader.getSegmentInfo()).thenReturn(segmentCommitInfo);

final Path path = mock(Path.class);
when(directory.getDirectory()).thenReturn(path);
final FieldInfos fieldInfos = mock(FieldInfos.class);
final FieldInfo fieldInfo = mock(FieldInfo.class);
when(reader.getFieldInfos()).thenReturn(fieldInfos);
when(fieldInfos.fieldInfo(query.getField())).thenReturn(null);
when(fieldInfo.attributes()).thenReturn(Collections.emptyMap());
Map<Integer, Float> docIds = exactSearcher.searchLeaf(leafReaderContext, exactSearcherContextBuilder.build());
Mockito.verify(fieldInfos).fieldInfo(query.getField());
assertEquals(0, docIds.size());
}

@SneakyThrows
public void testRadialSearch_whenNoEngineFiles_thenSuccess() {
try (MockedStatic<KNNVectorValuesFactory> valuesFactoryMockedStatic = Mockito.mockStatic(KNNVectorValuesFactory.class)) {
@@ -75,6 +120,7 @@ public void testRadialSearch_whenNoEngineFiles_thenSuccess() {
.queryVector(queryVector)
.radius(radius)
.indexName(INDEX_NAME)
.vectorDataType(VectorDataType.FLOAT)
.context(context)
.build();

0 comments on commit 33f05a5

Please sign in to comment.