Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for search using the "fields" parameter with knn_vector field #2314

Merged
merged 11 commits into from
Jan 14, 2025
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Use one formula to calculate cosine similarity (#2357)[https://github.com/opensearch-project/k-NN/pull/2357]
### Bug Fixes
* Fixing the bug when a segment has no vector field present for disk based vector search (#2282)[https://github.com/opensearch-project/k-NN/pull/2282]
* Fixing the bug where search fails with "fields" parameter for an index with a knn_vector field (#2314)[https://github.com/opensearch-project/k-NN/pull/2314]
* Fix for NPE while merging segments after all the vector fields docs are deleted (#2365)[https://github.com/opensearch-project/k-NN/pull/2365]
* Allow validation for non knn index only after 2.17.0 (#2315)[https://github.com/opensearch-project/k-NN/pull/2315]
* Release query vector memory after execution (#2346)[https://github.com/opensearch-project/k-NN/pull/2346]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@
package org.opensearch.knn.index.mapper;

import lombok.Getter;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.search.DocValuesFieldExistsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.util.BytesRef;
import org.opensearch.index.fielddata.IndexFieldData;
import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.index.mapper.ArraySourceValueFetcher;
import org.opensearch.index.mapper.TextSearchInfo;
import org.opensearch.index.mapper.ValueFetcher;
import org.opensearch.index.query.QueryShardContext;
Expand All @@ -21,6 +24,8 @@
import org.opensearch.search.aggregations.support.CoreValuesSourceType;
import org.opensearch.search.lookup.SearchLookup;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Locale;
import java.util.Map;
import java.util.function.Supplier;
Expand All @@ -32,6 +37,7 @@
*/
@Getter
public class KNNVectorFieldType extends MappedFieldType {
private static final Logger logger = LogManager.getLogger(KNNVectorFieldType.class);
KNNMappingConfig knnMappingConfig;
VectorDataType vectorDataType;

Expand All @@ -51,7 +57,17 @@ public KNNVectorFieldType(String name, Map<String, String> metadata, VectorDataT

@Override
public ValueFetcher valueFetcher(QueryShardContext context, SearchLookup searchLookup, String format) {
throw new UnsupportedOperationException("KNN Vector do not support fields search");
e-emoto marked this conversation as resolved.
Show resolved Hide resolved
return new ArraySourceValueFetcher(name(), context) {
martin-gaievski marked this conversation as resolved.
Show resolved Hide resolved
@Override
e-emoto marked this conversation as resolved.
Show resolved Hide resolved
protected Object parseSourceValue(Object value) {
if (value instanceof ArrayList) {
shatejas marked this conversation as resolved.
Show resolved Hide resolved
return value;
} else {
logger.warn("Expected type ArrayList for value, but got {} ", value.getClass());
return Collections.emptyList();
}
}
};
}

@Override
Expand Down
286 changes: 285 additions & 1 deletion src/test/java/org/opensearch/knn/index/OpenSearchIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import com.google.common.collect.ImmutableList;
import com.google.common.primitives.Floats;
import java.util.Locale;

import lombok.SneakyThrows;
import org.apache.hc.core5.http.ParseException;
import org.junit.BeforeClass;
Expand All @@ -39,6 +39,7 @@
import java.net.URL;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.TreeMap;

Expand Down Expand Up @@ -924,6 +925,289 @@ public void testKNNIndex_whenBuildVectorGraphThresholdIsProvidedEndToEnd_thenBui
deleteKNNIndex(indexName);
}

public void testKNNIndexSearchFieldsParameter() throws Exception {
e-emoto marked this conversation as resolved.
Show resolved Hide resolved
createKnnIndex(INDEX_NAME, createKnnIndexMapping(Arrays.asList("vector1", "vector2", "vector3"), Arrays.asList(2, 3, 5)));
// Add docs with knn_vector fields
for (int i = 1; i <= 20; i++) {
Float[] vector1 = { (float) i, (float) (i + 1) };
Float[] vector2 = { (float) i, (float) (i + 1), (float) (i + 2) };
Float[] vector3 = { (float) i, (float) (i + 1), (float) (i + 2), (float) (i + 3), (float) (i + 4) };
addKnnDoc(
INDEX_NAME,
Integer.toString(i),
Arrays.asList("vector1", "vector2", "vector3"),
Arrays.asList(vector1, vector2, vector3)
);
}
int k = 10; // nearest 10 neighbors

// Create match_all search body, all fields
XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
.field("fields", new String[] { "*" })
.startObject("query")
.startObject("match_all")
e-emoto marked this conversation as resolved.
Show resolved Hide resolved
.endObject()
.endObject()
.endObject();
Response response = searchKNNIndex(INDEX_NAME, builder, k);
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response.getEntity()), "vector1"));
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response.getEntity()), "vector2"));
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response.getEntity()), "vector3"));

// Create match_all search body, some fields
builder = XContentFactory.jsonBuilder()
.startObject()
.field("fields", new String[] { "vector1", "vector2" })
.startObject("query")
.startObject("match_all")
.endObject()
.endObject()
.endObject();
Response response2 = searchKNNIndex(INDEX_NAME, builder, k);
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response2.getEntity()), "vector1"));
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response2.getEntity()), "vector2"));
assertEquals(0, parseSearchResponseFieldsCount(EntityUtils.toString(response2.getEntity()), "vector3"));

// Create knn search body, all fields
builder = XContentFactory.jsonBuilder()
.startObject()
.field("fields", new String[] { "*" })
.startObject("query")
.startObject("knn")
.startObject("vector2")
.field("vector", new float[] { 2.0f, 2.0f, 2.0f })
.field("k", k)
.endObject()
.endObject()
.endObject()
.endObject();
Response response3 = searchKNNIndex(INDEX_NAME, builder, k);
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response3.getEntity()), "vector1"));
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response3.getEntity()), "vector2"));
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response3.getEntity()), "vector3"));

// Create knn search body, some fields
builder = XContentFactory.jsonBuilder()
.startObject()
.field("fields", new String[] { "vector1", "vector2" })
.startObject("query")
.startObject("knn")
.startObject("vector2")
.field("vector", new float[] { 2.0f, 2.0f, 2.0f })
.field("k", k)
.endObject()
.endObject()
.endObject()
.endObject();
Response response4 = searchKNNIndex(INDEX_NAME, builder, k);
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response4.getEntity()), "vector1"));
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response4.getEntity()), "vector2"));
assertEquals(0, parseSearchResponseFieldsCount(EntityUtils.toString(response4.getEntity()), "vector3"));
}

public void testKNNIndexSearchFieldsParameterWithOtherFields() throws Exception {
XContentBuilder xContentBuilder = XContentFactory.jsonBuilder()
.startObject()
.startObject("properties")
.startObject("vector1")
.field("type", "knn_vector")
.field("dimension", "2")
.endObject()
.startObject("vector2")
.field("type", "knn_vector")
.field("dimension", "3")
.endObject()
.startObject("float1")
.field("type", "float")
.endObject()
.startObject("float2")
.field("type", "float")
.endObject()
.endObject()
.endObject();
createKnnIndex(INDEX_NAME, xContentBuilder.toString());
// Add docs with knn_vector and other fields
for (int i = 1; i <= 20; i++) {
Float[] vector1 = { (float) i, (float) (i + 1) };
Float[] vector2 = { (float) i, (float) (i + 1), (float) (i + 2) };
Float[] float1 = { (float) i };
Float[] float2 = { (float) (i + 1) };
addKnnDoc(
INDEX_NAME,
Integer.toString(i),
Arrays.asList("vector1", "vector2", "float1", "float2"),
Arrays.asList(vector1, vector2, float1, float2)
);
}
int k = 10; // nearest 10 neighbors

// Create match_all search body, all fields
XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
.field("fields", new String[] { "*" })
.startObject("query")
.startObject("match_all")
.endObject()
.endObject()
.endObject();
Response response = searchKNNIndex(INDEX_NAME, builder, k);
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response.getEntity()), "vector1"));
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response.getEntity()), "vector2"));
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response.getEntity()), "float1"));
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response.getEntity()), "float2"));

// Create match_all search body, some fields
builder = XContentFactory.jsonBuilder()
.startObject()
.field("fields", new String[] { "vector1", "float2" })
.startObject("query")
.startObject("match_all")
.endObject()
.endObject()
.endObject();
Response response2 = searchKNNIndex(INDEX_NAME, builder, k);
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response2.getEntity()), "vector1"));
assertEquals(0, parseSearchResponseFieldsCount(EntityUtils.toString(response2.getEntity()), "vector2"));
assertEquals(0, parseSearchResponseFieldsCount(EntityUtils.toString(response2.getEntity()), "float1"));
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response2.getEntity()), "float2"));

// Create knn search body, all fields
builder = XContentFactory.jsonBuilder()
.startObject()
.field("fields", new String[] { "*" })
.startObject("query")
.startObject("knn")
.startObject("vector2")
.field("vector", new float[] { 2.0f, 2.0f, 2.0f })
.field("k", k)
.endObject()
.endObject()
.endObject()
.endObject();
Response response3 = searchKNNIndex(INDEX_NAME, builder, k);
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response3.getEntity()), "vector1"));
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response3.getEntity()), "vector2"));
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response3.getEntity()), "float1"));
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response3.getEntity()), "float2"));

// Create knn search body, some fields
builder = XContentFactory.jsonBuilder()
.startObject()
.field("fields", new String[] { "vector1", "float2" })
.startObject("query")
.startObject("knn")
.startObject("vector2")
.field("vector", new float[] { 2.0f, 2.0f, 2.0f })
.field("k", k)
.endObject()
.endObject()
.endObject()
.endObject();
Response response4 = searchKNNIndex(INDEX_NAME, builder, k);
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response4.getEntity()), "vector1"));
assertEquals(0, parseSearchResponseFieldsCount(EntityUtils.toString(response4.getEntity()), "vector2"));
assertEquals(0, parseSearchResponseFieldsCount(EntityUtils.toString(response4.getEntity()), "float1"));
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response4.getEntity()), "float2"));
}

public void testKNNIndexSearchFieldsParameterDocsWithOnlyOtherFields() throws Exception {
XContentBuilder xContentBuilder = XContentFactory.jsonBuilder()
.startObject()
.startObject("properties")
.startObject("vector1")
.field("type", "knn_vector")
.field("dimension", "2")
.endObject()
.startObject("vector2")
.field("type", "knn_vector")
.field("dimension", "3")
.endObject()
.startObject("text1")
.field("type", "text")
.endObject()
.endObject()
.endObject();
createKnnIndex(INDEX_NAME, xContentBuilder.toString());
// Add knn_vector docs
for (int i = 1; i <= 20; i++) {
Float[] vector1 = { (float) i, (float) (i + 1) };
Float[] vector2 = { (float) i, (float) (i + 1), (float) (i + 2) };
addKnnDoc(INDEX_NAME, Integer.toString(i), Arrays.asList("vector1", "vector2"), Arrays.asList(vector1, vector2));
}
// Add non knn_vector docs
for (int i = 21; i <= 40; i++) {
addNonKNNDoc(INDEX_NAME, Integer.toString(i), "text1", "text " + i);
}
int k = 10; // nearest 10 neighbors

// Create match search body, all non vector fields
XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
.field("fields", new String[] { "text1" })
.startObject("query")
.startObject("match")
.field("text1", "text")
.endObject()
.endObject()
.endObject();
Response response = searchKNNIndex(INDEX_NAME, builder, k);
assertEquals(0, parseSearchResponseFieldsCount(EntityUtils.toString(response.getEntity()), "vector1"));
assertEquals(0, parseSearchResponseFieldsCount(EntityUtils.toString(response.getEntity()), "vector2"));
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response.getEntity()), "text1"));

// Create match search body, all vector fields
builder = XContentFactory.jsonBuilder()
.startObject()
.field("fields", new String[] { "vector1", "vector2" })
.startObject("query")
.startObject("match")
.field("text1", "text")
.endObject()
.endObject()
.endObject();
Response response2 = searchKNNIndex(INDEX_NAME, builder, k);
assertEquals(0, parseSearchResponseFieldsCount(EntityUtils.toString(response2.getEntity()), "vector1"));
assertEquals(0, parseSearchResponseFieldsCount(EntityUtils.toString(response2.getEntity()), "vector2"));
assertEquals(0, parseSearchResponseFieldsCount(EntityUtils.toString(response2.getEntity()), "text1"));

// Create knn search body, all vector fields
builder = XContentFactory.jsonBuilder()
.startObject()
.field("fields", new String[] { "vector1", "vector2" })
.startObject("query")
.startObject("knn")
.startObject("vector2")
.field("vector", new float[] { 2.0f, 2.0f, 2.0f })
.field("k", k)
.endObject()
.endObject()
.endObject()
.endObject();
Response response3 = searchKNNIndex(INDEX_NAME, builder, k);
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response3.getEntity()), "vector1"));
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response3.getEntity()), "vector2"));
assertEquals(0, parseSearchResponseFieldsCount(EntityUtils.toString(response3.getEntity()), "text1"));

// Create knn search body, all non vector fields
builder = XContentFactory.jsonBuilder()
.startObject()
.field("fields", new String[] { "text1" })
.startObject("query")
.startObject("knn")
.startObject("vector2")
.field("vector", new float[] { 2.0f, 2.0f, 2.0f })
.field("k", k)
.endObject()
.endObject()
.endObject()
.endObject();
Response response4 = searchKNNIndex(INDEX_NAME, builder, k);
assertEquals(0, parseSearchResponseFieldsCount(EntityUtils.toString(response4.getEntity()), "vector1"));
assertEquals(0, parseSearchResponseFieldsCount(EntityUtils.toString(response4.getEntity()), "vector2"));
assertEquals(0, parseSearchResponseFieldsCount(EntityUtils.toString(response4.getEntity()), "text1"));
}

private List<KNNResult> getResults(final String indexName, final String fieldName, final float[] vector, final int k)
throws IOException, ParseException {
final Response searchResponseField = searchKNNIndex(indexName, new KNNQueryBuilder(fieldName, vector, k), k);
Expand Down
Loading
Loading