From 5c63fdef2cc6f2bbcf114b9884fa2ef40d0b464a Mon Sep 17 00:00:00 2001 From: Ethan Emoto Date: Tue, 14 Jan 2025 14:09:46 -0800 Subject: [PATCH] Add support for search using the "fields" parameter with knn_vector field (#2314) Signed-off-by: Ethan Emoto --- CHANGELOG.md | 1 + .../knn/index/mapper/KNNVectorFieldType.java | 18 +- .../opensearch/knn/index/OpenSearchIT.java | 286 +++++++++++++++++- .../index/mapper/KNNVectorFieldTypeTests.java | 34 +++ .../org/opensearch/knn/KNNRestTestCase.java | 21 ++ 5 files changed, 358 insertions(+), 2 deletions(-) create mode 100644 src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldTypeTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index a98cdfdf4..0cdca353a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Use one formula to calculate cosine similarity (#2357)[https://github.com/opensearch-project/k-NN/pull/2357] ### Bug Fixes * Fixing the bug when a segment has no vector field present for disk based vector search (#2282)[https://github.com/opensearch-project/k-NN/pull/2282] +* Fixing the bug where search fails with "fields" parameter for an index with a knn_vector field (#2314)[https://github.com/opensearch-project/k-NN/pull/2314] * Fix for NPE while merging segments after all the vector fields docs are deleted (#2365)[https://github.com/opensearch-project/k-NN/pull/2365] * Allow validation for non knn index only after 2.17.0 (#2315)[https://github.com/opensearch-project/k-NN/pull/2315] * Fixing the bug to prevent updating the index.knn setting after index creation(#2348)[https://github.com/opensearch-project/k-NN/pull/2348] diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java index a0832c1d0..d12247ad7 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java @@ -6,11 +6,14 @@ package org.opensearch.knn.index.mapper; import lombok.Getter; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.apache.lucene.search.DocValuesFieldExistsQuery; import org.apache.lucene.search.Query; import org.apache.lucene.util.BytesRef; import org.opensearch.index.fielddata.IndexFieldData; import org.opensearch.index.mapper.MappedFieldType; +import org.opensearch.index.mapper.ArraySourceValueFetcher; import org.opensearch.index.mapper.TextSearchInfo; import org.opensearch.index.mapper.ValueFetcher; import org.opensearch.index.query.QueryShardContext; @@ -21,6 +24,8 @@ import org.opensearch.search.aggregations.support.CoreValuesSourceType; import org.opensearch.search.lookup.SearchLookup; +import java.util.ArrayList; +import java.util.Collections; import java.util.Locale; import java.util.Map; import java.util.function.Supplier; @@ -32,6 +37,7 @@ */ @Getter public class KNNVectorFieldType extends MappedFieldType { + private static final Logger logger = LogManager.getLogger(KNNVectorFieldType.class); KNNMappingConfig knnMappingConfig; VectorDataType vectorDataType; @@ -51,7 +57,17 @@ public KNNVectorFieldType(String name, Map metadata, VectorDataT @Override public ValueFetcher valueFetcher(QueryShardContext context, SearchLookup searchLookup, String format) { - throw new UnsupportedOperationException("KNN Vector do not support fields search"); + return new ArraySourceValueFetcher(name(), context) { + @Override + protected Object parseSourceValue(Object value) { + if (value instanceof ArrayList) { + return value; + } else { + logger.warn("Expected type ArrayList for value, but got {} ", value.getClass()); + return Collections.emptyList(); + } + } + }; } @Override diff --git a/src/test/java/org/opensearch/knn/index/OpenSearchIT.java b/src/test/java/org/opensearch/knn/index/OpenSearchIT.java index 88dd908d3..ca07cd0f3 100644 --- a/src/test/java/org/opensearch/knn/index/OpenSearchIT.java +++ b/src/test/java/org/opensearch/knn/index/OpenSearchIT.java @@ -13,7 +13,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.primitives.Floats; -import java.util.Locale; + import lombok.SneakyThrows; import org.apache.hc.core5.http.ParseException; import org.junit.BeforeClass; @@ -39,6 +39,7 @@ import java.net.URL; import java.util.Arrays; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.TreeMap; @@ -924,6 +925,289 @@ public void testKNNIndex_whenBuildVectorGraphThresholdIsProvidedEndToEnd_thenBui deleteKNNIndex(indexName); } + public void testKNNIndexSearchFieldsParameter() throws Exception { + createKnnIndex(INDEX_NAME, createKnnIndexMapping(Arrays.asList("vector1", "vector2", "vector3"), Arrays.asList(2, 3, 5))); + // Add docs with knn_vector fields + for (int i = 1; i <= 20; i++) { + Float[] vector1 = { (float) i, (float) (i + 1) }; + Float[] vector2 = { (float) i, (float) (i + 1), (float) (i + 2) }; + Float[] vector3 = { (float) i, (float) (i + 1), (float) (i + 2), (float) (i + 3), (float) (i + 4) }; + addKnnDoc( + INDEX_NAME, + Integer.toString(i), + Arrays.asList("vector1", "vector2", "vector3"), + Arrays.asList(vector1, vector2, vector3) + ); + } + int k = 10; // nearest 10 neighbors + + // Create match_all search body, all fields + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .field("fields", new String[] { "*" }) + .startObject("query") + .startObject("match_all") + .endObject() + .endObject() + .endObject(); + Response response = searchKNNIndex(INDEX_NAME, builder, k); + assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response.getEntity()), "vector1")); + assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response.getEntity()), "vector2")); + assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response.getEntity()), "vector3")); + + // Create match_all search body, some fields + builder = XContentFactory.jsonBuilder() + .startObject() + .field("fields", new String[] { "vector1", "vector2" }) + .startObject("query") + .startObject("match_all") + .endObject() + .endObject() + .endObject(); + Response response2 = searchKNNIndex(INDEX_NAME, builder, k); + assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response2.getEntity()), "vector1")); + assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response2.getEntity()), "vector2")); + assertEquals(0, parseSearchResponseFieldsCount(EntityUtils.toString(response2.getEntity()), "vector3")); + + // Create knn search body, all fields + builder = XContentFactory.jsonBuilder() + .startObject() + .field("fields", new String[] { "*" }) + .startObject("query") + .startObject("knn") + .startObject("vector2") + .field("vector", new float[] { 2.0f, 2.0f, 2.0f }) + .field("k", k) + .endObject() + .endObject() + .endObject() + .endObject(); + Response response3 = searchKNNIndex(INDEX_NAME, builder, k); + assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response3.getEntity()), "vector1")); + assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response3.getEntity()), "vector2")); + assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response3.getEntity()), "vector3")); + + // Create knn search body, some fields + builder = XContentFactory.jsonBuilder() + .startObject() + .field("fields", new String[] { "vector1", "vector2" }) + .startObject("query") + .startObject("knn") + .startObject("vector2") + .field("vector", new float[] { 2.0f, 2.0f, 2.0f }) + .field("k", k) + .endObject() + .endObject() + .endObject() + .endObject(); + Response response4 = searchKNNIndex(INDEX_NAME, builder, k); + assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response4.getEntity()), "vector1")); + assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response4.getEntity()), "vector2")); + assertEquals(0, parseSearchResponseFieldsCount(EntityUtils.toString(response4.getEntity()), "vector3")); + } + + public void testKNNIndexSearchFieldsParameterWithOtherFields() throws Exception { + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject("vector1") + .field("type", "knn_vector") + .field("dimension", "2") + .endObject() + .startObject("vector2") + .field("type", "knn_vector") + .field("dimension", "3") + .endObject() + .startObject("float1") + .field("type", "float") + .endObject() + .startObject("float2") + .field("type", "float") + .endObject() + .endObject() + .endObject(); + createKnnIndex(INDEX_NAME, xContentBuilder.toString()); + // Add docs with knn_vector and other fields + for (int i = 1; i <= 20; i++) { + Float[] vector1 = { (float) i, (float) (i + 1) }; + Float[] vector2 = { (float) i, (float) (i + 1), (float) (i + 2) }; + Float[] float1 = { (float) i }; + Float[] float2 = { (float) (i + 1) }; + addKnnDoc( + INDEX_NAME, + Integer.toString(i), + Arrays.asList("vector1", "vector2", "float1", "float2"), + Arrays.asList(vector1, vector2, float1, float2) + ); + } + int k = 10; // nearest 10 neighbors + + // Create match_all search body, all fields + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .field("fields", new String[] { "*" }) + .startObject("query") + .startObject("match_all") + .endObject() + .endObject() + .endObject(); + Response response = searchKNNIndex(INDEX_NAME, builder, k); + assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response.getEntity()), "vector1")); + assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response.getEntity()), "vector2")); + assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response.getEntity()), "float1")); + assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response.getEntity()), "float2")); + + // Create match_all search body, some fields + builder = XContentFactory.jsonBuilder() + .startObject() + .field("fields", new String[] { "vector1", "float2" }) + .startObject("query") + .startObject("match_all") + .endObject() + .endObject() + .endObject(); + Response response2 = searchKNNIndex(INDEX_NAME, builder, k); + assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response2.getEntity()), "vector1")); + assertEquals(0, parseSearchResponseFieldsCount(EntityUtils.toString(response2.getEntity()), "vector2")); + assertEquals(0, parseSearchResponseFieldsCount(EntityUtils.toString(response2.getEntity()), "float1")); + assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response2.getEntity()), "float2")); + + // Create knn search body, all fields + builder = XContentFactory.jsonBuilder() + .startObject() + .field("fields", new String[] { "*" }) + .startObject("query") + .startObject("knn") + .startObject("vector2") + .field("vector", new float[] { 2.0f, 2.0f, 2.0f }) + .field("k", k) + .endObject() + .endObject() + .endObject() + .endObject(); + Response response3 = searchKNNIndex(INDEX_NAME, builder, k); + assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response3.getEntity()), "vector1")); + assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response3.getEntity()), "vector2")); + assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response3.getEntity()), "float1")); + assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response3.getEntity()), "float2")); + + // Create knn search body, some fields + builder = XContentFactory.jsonBuilder() + .startObject() + .field("fields", new String[] { "vector1", "float2" }) + .startObject("query") + .startObject("knn") + .startObject("vector2") + .field("vector", new float[] { 2.0f, 2.0f, 2.0f }) + .field("k", k) + .endObject() + .endObject() + .endObject() + .endObject(); + Response response4 = searchKNNIndex(INDEX_NAME, builder, k); + assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response4.getEntity()), "vector1")); + assertEquals(0, parseSearchResponseFieldsCount(EntityUtils.toString(response4.getEntity()), "vector2")); + assertEquals(0, parseSearchResponseFieldsCount(EntityUtils.toString(response4.getEntity()), "float1")); + assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response4.getEntity()), "float2")); + } + + public void testKNNIndexSearchFieldsParameterDocsWithOnlyOtherFields() throws Exception { + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject("vector1") + .field("type", "knn_vector") + .field("dimension", "2") + .endObject() + .startObject("vector2") + .field("type", "knn_vector") + .field("dimension", "3") + .endObject() + .startObject("text1") + .field("type", "text") + .endObject() + .endObject() + .endObject(); + createKnnIndex(INDEX_NAME, xContentBuilder.toString()); + // Add knn_vector docs + for (int i = 1; i <= 20; i++) { + Float[] vector1 = { (float) i, (float) (i + 1) }; + Float[] vector2 = { (float) i, (float) (i + 1), (float) (i + 2) }; + addKnnDoc(INDEX_NAME, Integer.toString(i), Arrays.asList("vector1", "vector2"), Arrays.asList(vector1, vector2)); + } + // Add non knn_vector docs + for (int i = 21; i <= 40; i++) { + addNonKNNDoc(INDEX_NAME, Integer.toString(i), "text1", "text " + i); + } + int k = 10; // nearest 10 neighbors + + // Create match search body, all non vector fields + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .field("fields", new String[] { "text1" }) + .startObject("query") + .startObject("match") + .field("text1", "text") + .endObject() + .endObject() + .endObject(); + Response response = searchKNNIndex(INDEX_NAME, builder, k); + assertEquals(0, parseSearchResponseFieldsCount(EntityUtils.toString(response.getEntity()), "vector1")); + assertEquals(0, parseSearchResponseFieldsCount(EntityUtils.toString(response.getEntity()), "vector2")); + assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response.getEntity()), "text1")); + + // Create match search body, all vector fields + builder = XContentFactory.jsonBuilder() + .startObject() + .field("fields", new String[] { "vector1", "vector2" }) + .startObject("query") + .startObject("match") + .field("text1", "text") + .endObject() + .endObject() + .endObject(); + Response response2 = searchKNNIndex(INDEX_NAME, builder, k); + assertEquals(0, parseSearchResponseFieldsCount(EntityUtils.toString(response2.getEntity()), "vector1")); + assertEquals(0, parseSearchResponseFieldsCount(EntityUtils.toString(response2.getEntity()), "vector2")); + assertEquals(0, parseSearchResponseFieldsCount(EntityUtils.toString(response2.getEntity()), "text1")); + + // Create knn search body, all vector fields + builder = XContentFactory.jsonBuilder() + .startObject() + .field("fields", new String[] { "vector1", "vector2" }) + .startObject("query") + .startObject("knn") + .startObject("vector2") + .field("vector", new float[] { 2.0f, 2.0f, 2.0f }) + .field("k", k) + .endObject() + .endObject() + .endObject() + .endObject(); + Response response3 = searchKNNIndex(INDEX_NAME, builder, k); + assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response3.getEntity()), "vector1")); + assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response3.getEntity()), "vector2")); + assertEquals(0, parseSearchResponseFieldsCount(EntityUtils.toString(response3.getEntity()), "text1")); + + // Create knn search body, all non vector fields + builder = XContentFactory.jsonBuilder() + .startObject() + .field("fields", new String[] { "text1" }) + .startObject("query") + .startObject("knn") + .startObject("vector2") + .field("vector", new float[] { 2.0f, 2.0f, 2.0f }) + .field("k", k) + .endObject() + .endObject() + .endObject() + .endObject(); + Response response4 = searchKNNIndex(INDEX_NAME, builder, k); + assertEquals(0, parseSearchResponseFieldsCount(EntityUtils.toString(response4.getEntity()), "vector1")); + assertEquals(0, parseSearchResponseFieldsCount(EntityUtils.toString(response4.getEntity()), "vector2")); + assertEquals(0, parseSearchResponseFieldsCount(EntityUtils.toString(response4.getEntity()), "text1")); + } + private List getResults(final String indexName, final String fieldName, final float[] vector, final int k) throws IOException, ParseException { final Response searchResponseField = searchKNNIndex(indexName, new KNNQueryBuilder(fieldName, vector, k), k); diff --git a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldTypeTests.java b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldTypeTests.java new file mode 100644 index 000000000..56727a865 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldTypeTests.java @@ -0,0 +1,34 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.mapper; + +import org.opensearch.index.mapper.ArraySourceValueFetcher; +import org.opensearch.index.mapper.ValueFetcher; +import org.opensearch.index.query.QueryShardContext; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.engine.KNNMethodContext; + +import java.util.Collections; + +import static org.mockito.Mockito.mock; + +public class KNNVectorFieldTypeTests extends KNNTestCase { + private static final String FIELD_NAME = "test-field"; + + public void testValueFetcher() { + KNNMethodContext knnMethodContext = getDefaultKNNMethodContext(); + KNNVectorFieldType knnVectorFieldType = new KNNVectorFieldType( + FIELD_NAME, + Collections.emptyMap(), + VectorDataType.FLOAT, + getMappingConfigForMethodMapping(knnMethodContext, 3) + ); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + ValueFetcher valueFetcher = knnVectorFieldType.valueFetcher(mockQueryShardContext, null, null); + assertTrue(valueFetcher instanceof ArraySourceValueFetcher); + } +} diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index 896674a18..477816d91 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -337,6 +337,27 @@ protected List parseSearchResponseScriptFields(final String responseB return knnSearchResponses; } + protected int parseSearchResponseFieldsCount(String responseBody, String fieldName) throws IOException { + @SuppressWarnings("unchecked") + List hits = (List) ((Map) createParser( + MediaTypeRegistry.getDefaultMediaType().xContent(), + responseBody + ).map().get("hits")).get("hits"); + + @SuppressWarnings("unchecked") + List fieldFound = hits.stream().map(hit -> { + if (((Map) hit).get("fields") == null) { + return 0; + } + if (((Map) ((Map) hit).get("fields")).get(fieldName) != null) { + return 1; + } else { + return 0; + } + }).collect(Collectors.toList()); + return fieldFound.stream().mapToInt(Integer::intValue).sum(); + } + /** * Parse the response of Aggregation to extract the value */