Skip to content

Commit

Permalink
Add support for search using the "fields" parameter with knn_vector f…
Browse files Browse the repository at this point in the history
…ield (#2314)

Signed-off-by: Ethan Emoto <[email protected]>
  • Loading branch information
e-emoto authored Jan 14, 2025
1 parent 60824c8 commit 5c63fde
Show file tree
Hide file tree
Showing 5 changed files with 358 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Use one formula to calculate cosine similarity (#2357)[https://github.com/opensearch-project/k-NN/pull/2357]
### Bug Fixes
* Fixing the bug when a segment has no vector field present for disk based vector search (#2282)[https://github.com/opensearch-project/k-NN/pull/2282]
* Fixing the bug where search fails with "fields" parameter for an index with a knn_vector field (#2314)[https://github.com/opensearch-project/k-NN/pull/2314]
* Fix for NPE while merging segments after all the vector fields docs are deleted (#2365)[https://github.com/opensearch-project/k-NN/pull/2365]
* Allow validation for non knn index only after 2.17.0 (#2315)[https://github.com/opensearch-project/k-NN/pull/2315]
* Fixing the bug to prevent updating the index.knn setting after index creation(#2348)[https://github.com/opensearch-project/k-NN/pull/2348]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@
package org.opensearch.knn.index.mapper;

import lombok.Getter;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.search.DocValuesFieldExistsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.util.BytesRef;
import org.opensearch.index.fielddata.IndexFieldData;
import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.index.mapper.ArraySourceValueFetcher;
import org.opensearch.index.mapper.TextSearchInfo;
import org.opensearch.index.mapper.ValueFetcher;
import org.opensearch.index.query.QueryShardContext;
Expand All @@ -21,6 +24,8 @@
import org.opensearch.search.aggregations.support.CoreValuesSourceType;
import org.opensearch.search.lookup.SearchLookup;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Locale;
import java.util.Map;
import java.util.function.Supplier;
Expand All @@ -32,6 +37,7 @@
*/
@Getter
public class KNNVectorFieldType extends MappedFieldType {
private static final Logger logger = LogManager.getLogger(KNNVectorFieldType.class);
KNNMappingConfig knnMappingConfig;
VectorDataType vectorDataType;

Expand All @@ -51,7 +57,17 @@ public KNNVectorFieldType(String name, Map<String, String> metadata, VectorDataT

@Override
public ValueFetcher valueFetcher(QueryShardContext context, SearchLookup searchLookup, String format) {
throw new UnsupportedOperationException("KNN Vector do not support fields search");
return new ArraySourceValueFetcher(name(), context) {
@Override
protected Object parseSourceValue(Object value) {
if (value instanceof ArrayList) {
return value;
} else {
logger.warn("Expected type ArrayList for value, but got {} ", value.getClass());
return Collections.emptyList();
}
}
};
}

@Override
Expand Down
286 changes: 285 additions & 1 deletion src/test/java/org/opensearch/knn/index/OpenSearchIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import com.google.common.collect.ImmutableList;
import com.google.common.primitives.Floats;
import java.util.Locale;

import lombok.SneakyThrows;
import org.apache.hc.core5.http.ParseException;
import org.junit.BeforeClass;
Expand All @@ -39,6 +39,7 @@
import java.net.URL;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.TreeMap;

Expand Down Expand Up @@ -924,6 +925,289 @@ public void testKNNIndex_whenBuildVectorGraphThresholdIsProvidedEndToEnd_thenBui
deleteKNNIndex(indexName);
}

public void testKNNIndexSearchFieldsParameter() throws Exception {
createKnnIndex(INDEX_NAME, createKnnIndexMapping(Arrays.asList("vector1", "vector2", "vector3"), Arrays.asList(2, 3, 5)));
// Add docs with knn_vector fields
for (int i = 1; i <= 20; i++) {
Float[] vector1 = { (float) i, (float) (i + 1) };
Float[] vector2 = { (float) i, (float) (i + 1), (float) (i + 2) };
Float[] vector3 = { (float) i, (float) (i + 1), (float) (i + 2), (float) (i + 3), (float) (i + 4) };
addKnnDoc(
INDEX_NAME,
Integer.toString(i),
Arrays.asList("vector1", "vector2", "vector3"),
Arrays.asList(vector1, vector2, vector3)
);
}
int k = 10; // nearest 10 neighbors

// Create match_all search body, all fields
XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
.field("fields", new String[] { "*" })
.startObject("query")
.startObject("match_all")
.endObject()
.endObject()
.endObject();
Response response = searchKNNIndex(INDEX_NAME, builder, k);
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response.getEntity()), "vector1"));
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response.getEntity()), "vector2"));
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response.getEntity()), "vector3"));

// Create match_all search body, some fields
builder = XContentFactory.jsonBuilder()
.startObject()
.field("fields", new String[] { "vector1", "vector2" })
.startObject("query")
.startObject("match_all")
.endObject()
.endObject()
.endObject();
Response response2 = searchKNNIndex(INDEX_NAME, builder, k);
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response2.getEntity()), "vector1"));
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response2.getEntity()), "vector2"));
assertEquals(0, parseSearchResponseFieldsCount(EntityUtils.toString(response2.getEntity()), "vector3"));

// Create knn search body, all fields
builder = XContentFactory.jsonBuilder()
.startObject()
.field("fields", new String[] { "*" })
.startObject("query")
.startObject("knn")
.startObject("vector2")
.field("vector", new float[] { 2.0f, 2.0f, 2.0f })
.field("k", k)
.endObject()
.endObject()
.endObject()
.endObject();
Response response3 = searchKNNIndex(INDEX_NAME, builder, k);
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response3.getEntity()), "vector1"));
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response3.getEntity()), "vector2"));
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response3.getEntity()), "vector3"));

// Create knn search body, some fields
builder = XContentFactory.jsonBuilder()
.startObject()
.field("fields", new String[] { "vector1", "vector2" })
.startObject("query")
.startObject("knn")
.startObject("vector2")
.field("vector", new float[] { 2.0f, 2.0f, 2.0f })
.field("k", k)
.endObject()
.endObject()
.endObject()
.endObject();
Response response4 = searchKNNIndex(INDEX_NAME, builder, k);
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response4.getEntity()), "vector1"));
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response4.getEntity()), "vector2"));
assertEquals(0, parseSearchResponseFieldsCount(EntityUtils.toString(response4.getEntity()), "vector3"));
}

public void testKNNIndexSearchFieldsParameterWithOtherFields() throws Exception {
XContentBuilder xContentBuilder = XContentFactory.jsonBuilder()
.startObject()
.startObject("properties")
.startObject("vector1")
.field("type", "knn_vector")
.field("dimension", "2")
.endObject()
.startObject("vector2")
.field("type", "knn_vector")
.field("dimension", "3")
.endObject()
.startObject("float1")
.field("type", "float")
.endObject()
.startObject("float2")
.field("type", "float")
.endObject()
.endObject()
.endObject();
createKnnIndex(INDEX_NAME, xContentBuilder.toString());
// Add docs with knn_vector and other fields
for (int i = 1; i <= 20; i++) {
Float[] vector1 = { (float) i, (float) (i + 1) };
Float[] vector2 = { (float) i, (float) (i + 1), (float) (i + 2) };
Float[] float1 = { (float) i };
Float[] float2 = { (float) (i + 1) };
addKnnDoc(
INDEX_NAME,
Integer.toString(i),
Arrays.asList("vector1", "vector2", "float1", "float2"),
Arrays.asList(vector1, vector2, float1, float2)
);
}
int k = 10; // nearest 10 neighbors

// Create match_all search body, all fields
XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
.field("fields", new String[] { "*" })
.startObject("query")
.startObject("match_all")
.endObject()
.endObject()
.endObject();
Response response = searchKNNIndex(INDEX_NAME, builder, k);
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response.getEntity()), "vector1"));
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response.getEntity()), "vector2"));
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response.getEntity()), "float1"));
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response.getEntity()), "float2"));

// Create match_all search body, some fields
builder = XContentFactory.jsonBuilder()
.startObject()
.field("fields", new String[] { "vector1", "float2" })
.startObject("query")
.startObject("match_all")
.endObject()
.endObject()
.endObject();
Response response2 = searchKNNIndex(INDEX_NAME, builder, k);
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response2.getEntity()), "vector1"));
assertEquals(0, parseSearchResponseFieldsCount(EntityUtils.toString(response2.getEntity()), "vector2"));
assertEquals(0, parseSearchResponseFieldsCount(EntityUtils.toString(response2.getEntity()), "float1"));
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response2.getEntity()), "float2"));

// Create knn search body, all fields
builder = XContentFactory.jsonBuilder()
.startObject()
.field("fields", new String[] { "*" })
.startObject("query")
.startObject("knn")
.startObject("vector2")
.field("vector", new float[] { 2.0f, 2.0f, 2.0f })
.field("k", k)
.endObject()
.endObject()
.endObject()
.endObject();
Response response3 = searchKNNIndex(INDEX_NAME, builder, k);
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response3.getEntity()), "vector1"));
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response3.getEntity()), "vector2"));
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response3.getEntity()), "float1"));
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response3.getEntity()), "float2"));

// Create knn search body, some fields
builder = XContentFactory.jsonBuilder()
.startObject()
.field("fields", new String[] { "vector1", "float2" })
.startObject("query")
.startObject("knn")
.startObject("vector2")
.field("vector", new float[] { 2.0f, 2.0f, 2.0f })
.field("k", k)
.endObject()
.endObject()
.endObject()
.endObject();
Response response4 = searchKNNIndex(INDEX_NAME, builder, k);
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response4.getEntity()), "vector1"));
assertEquals(0, parseSearchResponseFieldsCount(EntityUtils.toString(response4.getEntity()), "vector2"));
assertEquals(0, parseSearchResponseFieldsCount(EntityUtils.toString(response4.getEntity()), "float1"));
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response4.getEntity()), "float2"));
}

public void testKNNIndexSearchFieldsParameterDocsWithOnlyOtherFields() throws Exception {
XContentBuilder xContentBuilder = XContentFactory.jsonBuilder()
.startObject()
.startObject("properties")
.startObject("vector1")
.field("type", "knn_vector")
.field("dimension", "2")
.endObject()
.startObject("vector2")
.field("type", "knn_vector")
.field("dimension", "3")
.endObject()
.startObject("text1")
.field("type", "text")
.endObject()
.endObject()
.endObject();
createKnnIndex(INDEX_NAME, xContentBuilder.toString());
// Add knn_vector docs
for (int i = 1; i <= 20; i++) {
Float[] vector1 = { (float) i, (float) (i + 1) };
Float[] vector2 = { (float) i, (float) (i + 1), (float) (i + 2) };
addKnnDoc(INDEX_NAME, Integer.toString(i), Arrays.asList("vector1", "vector2"), Arrays.asList(vector1, vector2));
}
// Add non knn_vector docs
for (int i = 21; i <= 40; i++) {
addNonKNNDoc(INDEX_NAME, Integer.toString(i), "text1", "text " + i);
}
int k = 10; // nearest 10 neighbors

// Create match search body, all non vector fields
XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
.field("fields", new String[] { "text1" })
.startObject("query")
.startObject("match")
.field("text1", "text")
.endObject()
.endObject()
.endObject();
Response response = searchKNNIndex(INDEX_NAME, builder, k);
assertEquals(0, parseSearchResponseFieldsCount(EntityUtils.toString(response.getEntity()), "vector1"));
assertEquals(0, parseSearchResponseFieldsCount(EntityUtils.toString(response.getEntity()), "vector2"));
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response.getEntity()), "text1"));

// Create match search body, all vector fields
builder = XContentFactory.jsonBuilder()
.startObject()
.field("fields", new String[] { "vector1", "vector2" })
.startObject("query")
.startObject("match")
.field("text1", "text")
.endObject()
.endObject()
.endObject();
Response response2 = searchKNNIndex(INDEX_NAME, builder, k);
assertEquals(0, parseSearchResponseFieldsCount(EntityUtils.toString(response2.getEntity()), "vector1"));
assertEquals(0, parseSearchResponseFieldsCount(EntityUtils.toString(response2.getEntity()), "vector2"));
assertEquals(0, parseSearchResponseFieldsCount(EntityUtils.toString(response2.getEntity()), "text1"));

// Create knn search body, all vector fields
builder = XContentFactory.jsonBuilder()
.startObject()
.field("fields", new String[] { "vector1", "vector2" })
.startObject("query")
.startObject("knn")
.startObject("vector2")
.field("vector", new float[] { 2.0f, 2.0f, 2.0f })
.field("k", k)
.endObject()
.endObject()
.endObject()
.endObject();
Response response3 = searchKNNIndex(INDEX_NAME, builder, k);
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response3.getEntity()), "vector1"));
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response3.getEntity()), "vector2"));
assertEquals(0, parseSearchResponseFieldsCount(EntityUtils.toString(response3.getEntity()), "text1"));

// Create knn search body, all non vector fields
builder = XContentFactory.jsonBuilder()
.startObject()
.field("fields", new String[] { "text1" })
.startObject("query")
.startObject("knn")
.startObject("vector2")
.field("vector", new float[] { 2.0f, 2.0f, 2.0f })
.field("k", k)
.endObject()
.endObject()
.endObject()
.endObject();
Response response4 = searchKNNIndex(INDEX_NAME, builder, k);
assertEquals(0, parseSearchResponseFieldsCount(EntityUtils.toString(response4.getEntity()), "vector1"));
assertEquals(0, parseSearchResponseFieldsCount(EntityUtils.toString(response4.getEntity()), "vector2"));
assertEquals(0, parseSearchResponseFieldsCount(EntityUtils.toString(response4.getEntity()), "text1"));
}

private List<KNNResult> getResults(final String indexName, final String fieldName, final float[] vector, final int k)
throws IOException, ParseException {
final Response searchResponseField = searchKNNIndex(indexName, new KNNQueryBuilder(fieldName, vector, k), k);
Expand Down
Loading

0 comments on commit 5c63fde

Please sign in to comment.