diff --git a/src/main/java/org/opensearch/knn/index/SpaceType.java b/src/main/java/org/opensearch/knn/index/SpaceType.java index 147e260b9..5d90071e8 100644 --- a/src/main/java/org/opensearch/knn/index/SpaceType.java +++ b/src/main/java/org/opensearch/knn/index/SpaceType.java @@ -77,11 +77,6 @@ public float scoreTranslation(float rawScore) { return Math.max((2.0F - rawScore) / 2.0F, 0.0F); } - @Override - public float scoreToDistanceTranslation(float score) { - return score; - } - @Override public KNNVectorSimilarityFunction getKnnVectorSimilarityFunction() { return KNNVectorSimilarityFunction.COSINE; diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java b/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java index a9d8e5323..1f1208527 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java @@ -17,7 +17,6 @@ import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.index.mapper.PerDimensionProcessor; import org.opensearch.knn.index.mapper.PerDimensionValidator; -import org.opensearch.knn.index.mapper.VectorTransformer; import java.util.Objects; import java.util.Set; @@ -90,11 +89,6 @@ protected PerDimensionProcessor doGetPerDimensionProcessor( throw new IllegalStateException("Unsupported vector data type " + vectorDataType); } - @Override - protected VectorTransformer getVectorTransformer(KNNMethodContext knnMethodContext) { - return super.getVectorTransformer(knnMethodContext); - } - static KNNLibraryIndexingContext adjustIndexDescription( MethodAsMapBuilder methodAsMapBuilder, MethodComponentContext methodComponentContext, diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/Faiss.java b/src/main/java/org/opensearch/knn/index/engine/faiss/Faiss.java index d4222dc8d..5a0258279 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/Faiss.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/Faiss.java @@ -47,11 +47,14 @@ public class Faiss extends NativeLibrary { // https://opensearch.org/docs/latest/search-plugins/knn/approximate-knn/#spaces private final static Map> SCORE_TO_DISTANCE_TRANSFORMATIONS = ImmutableMap.< SpaceType, - Function>builder().put(SpaceType.INNER_PRODUCT, score -> score > 1 ? 1 - score : 1 / score - 1).build(); + Function>builder() + .put(SpaceType.INNER_PRODUCT, score -> score > 1 ? 1 - score : (1 / score) - 1) + .put(SpaceType.COSINESIMIL, score -> 2 * score - 1) + .build(); private final static Map> DISTANCE_TRANSLATIONS = ImmutableMap.< SpaceType, - Function>builder().put(SpaceType.COSINESIMIL, distance -> (2 - distance) / 2).build(); + Function>builder().put(SpaceType.COSINESIMIL, distance -> 1 - distance).build(); // Package private so that the method resolving logic can access the methods final static Map METHODS = ImmutableMap.of(METHOD_HNSW, new FaissHNSWMethod(), METHOD_IVF, new FaissIVFMethod()); @@ -99,6 +102,7 @@ public Float distanceToRadialThreshold(Float distance, SpaceType spaceType) { @Override public Float scoreToRadialThreshold(Float score, SpaceType spaceType) { + // Faiss engine uses distance as is and need transformation if (this.scoreTransform.containsKey(spaceType)) { return this.scoreTransform.get(spaceType).apply(score); } diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java index e1a616433..8604d5506 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java @@ -675,7 +675,7 @@ protected void validatePreparse() { protected abstract VectorValidator getVectorValidator(); /** - * Getter for per dimension validator during vector parsing + * Getter for per dimension validator during vector parsing, and before any transformation * * @return PerDimensionValidator */ @@ -688,6 +688,11 @@ protected void validatePreparse() { */ protected abstract PerDimensionProcessor getPerDimensionProcessor(); + /** + * Getter for vector transformer after vector parsing and validation + * + * @return VectorTransformer + */ protected abstract VectorTransformer getVectorTransformer(); protected void parseCreateField(ParseContext context, int dimension, VectorDataType vectorDataType) throws IOException { @@ -700,8 +705,8 @@ protected void parseCreateField(ParseContext context, int dimension, VectorDataT } final byte[] array = bytesArrayOptional.get(); getVectorValidator().validateVector(array); - final byte[] transformedArray = getVectorTransformer().transform(array); - context.doc().addAll(getFieldsForByteVector(transformedArray)); + getVectorTransformer().transform(array); + context.doc().addAll(getFieldsForByteVector(array)); } else if (VectorDataType.FLOAT == vectorDataType) { Optional floatsArrayOptional = getFloatsFromContext(context, dimension); @@ -710,8 +715,8 @@ protected void parseCreateField(ParseContext context, int dimension, VectorDataT } final float[] array = floatsArrayOptional.get(); getVectorValidator().validateVector(array); - final float[] transformedArray = getVectorTransformer().transform(array); - context.doc().addAll(getFieldsForFloatVector(transformedArray)); + getVectorTransformer().transform(array); + context.doc().addAll(getFieldsForFloatVector(array)); } else { throw new IllegalArgumentException( String.format(Locale.ROOT, "Cannot parse context for unsupported values provided for field [%s]", VECTOR_DATA_TYPE_FIELD) diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java index a0832c1d0..24b2989b4 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java @@ -17,12 +17,16 @@ import org.opensearch.index.query.QueryShardException; import org.opensearch.knn.index.KNNVectorIndexFieldData; import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.query.rescore.RescoreContext; +import org.opensearch.knn.indices.ModelDao; +import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.search.aggregations.support.CoreValuesSourceType; import org.opensearch.search.lookup.SearchLookup; import java.util.Locale; import java.util.Map; +import java.util.Optional; import java.util.function.Supplier; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.deserializeStoredVector; @@ -99,4 +103,37 @@ public RescoreContext resolveRescoreContext(RescoreContext userProvidedContext) Mode mode = knnMappingConfig.getMode(); return compressionLevel.getDefaultRescoreContext(mode, dimension); } + + /** + * Transforms a query vector based on the field's configuration. The transformation is performed + * in-place on the input vector according to either the KNN method context or the model ID. + * + * @param vector The float array to be transformed in-place. Must not be null. + * @throws IllegalStateException if neither KNN method context nor Model ID is configured + * + * The transformation process follows this order: + * 1. If vector is not FLOAT type, no transformation is performed + * 2. Attempts to use KNN method context if present + * 3. Falls back to model ID if KNN method context is not available + * 4. Throws exception if neither configuration is present + */ + public void transformQueryVector(float[] vector) { + if (VectorDataType.FLOAT != vectorDataType) { + return; + } + final Optional knnMethodContext = knnMappingConfig.getKnnMethodContext(); + if (knnMethodContext.isPresent()) { + VectorTransformerFactory.getVectorTransformer(knnMethodContext.get()).transform(vector); + return; + } + final Optional modelId = knnMappingConfig.getModelId(); + if (modelId.isPresent()) { + ModelDao modelDao = ModelDao.OpenSearchKNNModelDao.getInstance(); + final ModelMetadata metadata = modelDao.getMetadata(modelId.get()); + VectorTransformerFactory.getVectorTransformer(metadata).transform(vector); + return; + } + throw new IllegalStateException("Either KNN method context or Model Id should be configured"); + + } } diff --git a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java index 63a72637a..fc7638fd0 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java @@ -217,10 +217,8 @@ private void initVectorTransformer() { KNNMethodConfigContext knnMethodConfigContext = getKNNMethodConfigContextFromModelMetadata(modelMetadata); // Need to handle BWC case if (knnMethodContext == null || knnMethodConfigContext == null) { - log.debug( - "Method Context not available - falling back to Model Metadata for Engine and Space type to determine VectorTransformer instance" - ); - vectorTransformer = VectorTransformerFactory.getVectorTransformer(modelMetadata.getKnnEngine(), modelMetadata.getSpaceType()); + log.debug("Method Context not available - falling back to Model Metadata to determine VectorTransformer instance"); + vectorTransformer = VectorTransformerFactory.getVectorTransformer(modelMetadata); return; } diff --git a/src/main/java/org/opensearch/knn/index/mapper/NormalizeVectorTransformer.java b/src/main/java/org/opensearch/knn/index/mapper/NormalizeVectorTransformer.java index 6a9642435..a348cd1b8 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/NormalizeVectorTransformer.java +++ b/src/main/java/org/opensearch/knn/index/mapper/NormalizeVectorTransformer.java @@ -2,30 +2,25 @@ * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.knn.index.mapper; import org.apache.lucene.util.VectorUtil; /** - * Normalizes vectors using L2 (Euclidean) normalization. This transformation ensures - * that the vector's magnitude becomes 1 while preserving its directional properties. + * Normalizes vectors using L2 (Euclidean) normalization, ensuring the vector's + * magnitude becomes 1 while preserving its directional properties. */ public class NormalizeVectorTransformer implements VectorTransformer { - /** - * Transforms the input vector into unit vector by applying L2 normalization. - * - * @param vector The input vector to be normalized. Must not be null. - * @return A new float array containing the L2-normalized version of the input vector. - * Each component is divided by the Euclidean norm of the vector. - * @throws IllegalArgumentException if the input vector is null, empty, or a zero vector - */ @Override - public float[] transform(float[] vector) { + public void transform(float[] vector) { + validateVector(vector); + VectorUtil.l2normalize(vector); + } + + private void validateVector(float[] vector) { if (vector == null || vector.length == 0) { throw new IllegalArgumentException("Vector cannot be null or empty"); } - return VectorUtil.l2normalize(vector); } } diff --git a/src/main/java/org/opensearch/knn/index/mapper/VectorTransformer.java b/src/main/java/org/opensearch/knn/index/mapper/VectorTransformer.java index f02df13ef..ac6a9b1ac 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/VectorTransformer.java +++ b/src/main/java/org/opensearch/knn/index/mapper/VectorTransformer.java @@ -2,11 +2,8 @@ * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.knn.index.mapper; -import java.util.Arrays; - /** * Defines operations for transforming vectors in the k-NN search context. * Implementations can modify vectors while preserving their dimensional properties @@ -15,50 +12,31 @@ public interface VectorTransformer { /** - * Transforms a float vector into a new vector of the same type. - * - * Example: - *
{@code
-     * float[] input = {1.0f, 2.0f, 3.0f};
-     * float[] transformed = transformer.transform(input);
-     * }
+ * Transforms a float vector in place. * * @param vector The input vector to transform (must not be null) - * @return The transformed vector * @throws IllegalArgumentException if the input vector is null */ - default float[] transform(final float[] vector) { + default void transform(final float[] vector) { if (vector == null) { throw new IllegalArgumentException("Input vector cannot be null"); } - return Arrays.copyOf(vector, vector.length); } /** - * Transforms a byte vector into a new vector of the same type. - * - * Example: - *
{@code
-     * byte[] input = {1, 2, 3};
-     * byte[] transformed = transformer.transform(input);
-     * }
+ * Transforms a byte vector in place. * * @param vector The input vector to transform (must not be null) - * @return The transformed vector * @throws IllegalArgumentException if the input vector is null */ - default byte[] transform(final byte[] vector) { + default void transform(final byte[] vector) { if (vector == null) { throw new IllegalArgumentException("Input vector cannot be null"); } - // return copy of vector to avoid side effects - return Arrays.copyOf(vector, vector.length); - } /** * A no-operation transformer that returns vector values unchanged. - * This constant can be used when no transformation is needed. */ VectorTransformer NOOP_VECTOR_TRANSFORMER = new VectorTransformer() { }; diff --git a/src/main/java/org/opensearch/knn/index/mapper/VectorTransformerFactory.java b/src/main/java/org/opensearch/knn/index/mapper/VectorTransformerFactory.java index 8726c48b0..94463533c 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/VectorTransformerFactory.java +++ b/src/main/java/org/opensearch/knn/index/mapper/VectorTransformerFactory.java @@ -10,6 +10,7 @@ import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.index.engine.KNNMethodContext; +import org.opensearch.knn.indices.ModelMetadata; /** * Factory class responsible for creating appropriate vector transformers based on the KNN method context. @@ -35,6 +36,28 @@ public static VectorTransformer getVectorTransformer(final KNNMethodContext cont return getVectorTransformer(context.getKnnEngine(), context.getSpaceType()); } + /** + * Creates a VectorTransformer based on the provided model metadata. + * + * @param metadata The model metadata containing KNN engine and space type configuration. + * This parameter must not be null. + * @return A VectorTransformer instance configured according to the model metadata + * @throws IllegalArgumentException if metadata is null + * + * The factory determines the appropriate transformer implementation based on: + * - The KNN engine (e.g., FAISS, NMSLIB) + * - The space type (e.g., L2, COSINE) + * + * The returned transformer can be used to modify vectors in-place according to + * the specified engine and space type requirements. + */ + public static VectorTransformer getVectorTransformer(final ModelMetadata metadata) { + if (metadata == null) { + throw new IllegalArgumentException("ModelMetadata cannot be null"); + } + return getVectorTransformer(metadata.getKnnEngine(), metadata.getSpaceType()); + } + /** * Returns a vector transformer based on the provided KNN engine and space type. * For FAISS engine with cosine similarity space type, returns a NormalizeVectorTransformer @@ -45,7 +68,7 @@ public static VectorTransformer getVectorTransformer(final KNNMethodContext cont * @param spaceType The space type * @return VectorTransformer An appropriate vector transformer instance */ - public static VectorTransformer getVectorTransformer(final KNNEngine knnEngine, final SpaceType spaceType) { + private static VectorTransformer getVectorTransformer(final KNNEngine knnEngine, final SpaceType spaceType) { return shouldNormalizeVector(knnEngine, spaceType) ? new NormalizeVectorTransformer() : VectorTransformer.NOOP_VECTOR_TRANSFORMER; } diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index c2998df6c..3ba273e71 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -12,7 +12,6 @@ import org.apache.commons.lang.StringUtils; import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.Query; -import org.apache.lucene.util.VectorUtil; import org.opensearch.common.ValidationException; import org.opensearch.core.ParseField; import org.opensearch.core.common.Strings; @@ -429,6 +428,7 @@ protected Query doToQuery(QueryShardContext context) { SpaceType spaceType = queryConfigFromMapping.get().getSpaceType(); VectorDataType vectorDataType = queryConfigFromMapping.get().getVectorDataType(); RescoreContext processedRescoreContext = knnVectorFieldType.resolveRescoreContext(rescoreContext); + knnVectorFieldType.transformQueryVector(vector); VectorQueryType vectorQueryType = getVectorQueryType(k, maxDistance, minScore); updateQueryStats(vectorQueryType); @@ -542,7 +542,7 @@ protected Query doToQuery(QueryShardContext context) { .knnEngine(knnEngine) .indexName(indexName) .fieldName(this.fieldName) - .vector(getVectorForCreatingQueryRequest(vectorDataType, knnEngine, spaceType)) + .vector(getVectorForCreatingQueryRequest(vectorDataType, knnEngine)) .byteVector(getVectorForCreatingQueryRequest(vectorDataType, knnEngine, byteVector)) .vectorDataType(vectorDataType) .k(this.k) @@ -559,8 +559,8 @@ protected Query doToQuery(QueryShardContext context) { .knnEngine(knnEngine) .indexName(indexName) .fieldName(this.fieldName) - .vector(getVectorForCreatingQueryRequest(vectorDataType, knnEngine, spaceType)) - .byteVector(getVectorForCreatingQueryRequest(vectorDataType, knnEngine, byteVector)) + .vector(VectorDataType.FLOAT == vectorDataType ? this.vector : null) + .byteVector(VectorDataType.BYTE == vectorDataType ? byteVector : null) .vectorDataType(vectorDataType) .radius(radius) .methodParameters(this.methodParameters) @@ -612,13 +612,7 @@ private void updateQueryStats(VectorQueryType vectorQueryType) { } } - private float[] getVectorForCreatingQueryRequest(VectorDataType vectorDataType, KNNEngine knnEngine, SpaceType spaceType) { - - // Cosine similarity is supported as Inner product by FAISS by normalizing input vector, hence, we have to normalize - // query vector before applying search - if (knnEngine == KNNEngine.FAISS && spaceType == SpaceType.COSINESIMIL && VectorDataType.FLOAT == vectorDataType) { - return VectorUtil.l2normalize(this.vector); - } + private float[] getVectorForCreatingQueryRequest(VectorDataType vectorDataType, KNNEngine knnEngine) { if ((VectorDataType.FLOAT == vectorDataType) || (VectorDataType.BYTE == vectorDataType && KNNEngine.FAISS == knnEngine)) { return this.vector; } diff --git a/src/test/java/org/opensearch/knn/index/FaissIT.java b/src/test/java/org/opensearch/knn/index/FaissIT.java index e4fe782f6..4113579bf 100644 --- a/src/test/java/org/opensearch/knn/index/FaissIT.java +++ b/src/test/java/org/opensearch/knn/index/FaissIT.java @@ -2053,6 +2053,61 @@ public void testCosineSimilarity_withNoGraphs_withRadialSearch_withScore_thenSuc validateGraphEviction(); } + public void testEndToEnd_withApproxAndExactSearch_inSameIndex_ForCosineSpaceType() throws Exception { + String indexName = randomLowerCaseString(); + String fieldName = randomLowerCaseString(); + SpaceType spaceType = SpaceType.COSINESIMIL; + Integer dimension = testData.indexData.vectors[0].length; + + // Create an index + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(fieldName) + .field("type", "knn_vector") + .field("dimension", dimension) + .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, KNNConstants.METHOD_HNSW) + .field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()) + .endObject() + .endObject() + .endObject() + .endObject(); + + Map mappingMap = xContentBuilderToMap(builder); + String mapping = builder.toString(); + + createKnnIndex(indexName, buildKNNIndexSettings(0), mapping); + + // Index one document + addKnnDoc(indexName, randomAlphaOfLength(5), fieldName, Floats.asList(testData.indexData.vectors[0]).toArray()); + + // Assert we have the right number of documents in the index + refreshAllIndices(); + assertEquals(1, getDocCount(indexName)); + // update threshold setting to skip building graph + updateIndexSettings(indexName, Settings.builder().put(KNNSettings.INDEX_KNN_ADVANCED_APPROXIMATE_THRESHOLD, -1)); + // add duplicate document with different id + addKnnDoc(indexName, randomAlphaOfLength(5), fieldName, Floats.asList(testData.indexData.vectors[0]).toArray()); + assertEquals(2, getDocCount(indexName)); + final int k = 2; + // search index + Response response = searchKNNIndex( + indexName, + KNNQueryBuilder.builder().fieldName(fieldName).vector(testData.queries[0]).k(k).build(), + k + ); + String responseBody = EntityUtils.toString(response.getEntity()); + List knnResults = parseSearchResponse(responseBody, fieldName); + assertEquals(k, knnResults.size()); + + List actualScores = parseSearchResponseScore(responseBody, fieldName); + + // both document should have identical score + assertEquals(actualScores.get(0), actualScores.get(1), 0.001); + } + protected void setupKNNIndexForFilterQuery() throws Exception { setupKNNIndexForFilterQuery(getKNNDefaultIndexSettings()); } diff --git a/src/test/java/org/opensearch/knn/index/mapper/NormalizeVectorTransformerTests.java b/src/test/java/org/opensearch/knn/index/mapper/NormalizeVectorTransformerTests.java index 532985232..1c4237d7b 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/NormalizeVectorTransformerTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/NormalizeVectorTransformerTests.java @@ -21,13 +21,13 @@ public void testNormalizeTransformer_withEmptyVector_thenThrowsException() { public void testNormalizeTransformer_withValidVector_thenSuccess() { float[] input = { -3.0f, 4.0f }; - float[] normalized = transformer.transform(input); + transformer.transform(input); - assertEquals(-0.6f, normalized[0], DELTA); - assertEquals(0.8f, normalized[1], DELTA); + assertEquals(-0.6f, input[0], DELTA); + assertEquals(0.8f, input[1], DELTA); // Verify the magnitude is 1 - assertEquals(1.0f, calculateMagnitude(normalized), DELTA); + assertEquals(1.0f, calculateMagnitude(input), DELTA); } private float calculateMagnitude(float[] vector) { diff --git a/src/test/java/org/opensearch/knn/index/mapper/VectorTransformerFactoryTests.java b/src/test/java/org/opensearch/knn/index/mapper/VectorTransformerFactoryTests.java index 6148f83d6..3e50dd546 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/VectorTransformerFactoryTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/VectorTransformerFactoryTests.java @@ -9,29 +9,41 @@ import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.index.engine.KNNMethodContext; +import org.opensearch.knn.indices.ModelMetadata; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; public class VectorTransformerFactoryTests extends KNNTestCase { - public void testAllSpaceTypes_withFaiss() { + + public void testGetVectorTransformer_withNullModelMetadata() { + // Test case for null context + assertThrows(IllegalArgumentException.class, () -> VectorTransformerFactory.getVectorTransformer((ModelMetadata) null)); + } + + public void testAllSpaceTypes_usingModelMetadata_withFaiss() { for (SpaceType spaceType : SpaceType.values()) { - VectorTransformer transformer = VectorTransformerFactory.getVectorTransformer(KNNEngine.FAISS, spaceType); + ModelMetadata metaData = mock(ModelMetadata.class); + when(metaData.getKnnEngine()).thenReturn(KNNEngine.FAISS); + when(metaData.getSpaceType()).thenReturn(spaceType); + VectorTransformer transformer = VectorTransformerFactory.getVectorTransformer(metaData); validateTransformer(spaceType, KNNEngine.FAISS, transformer); } } - public void testAllEngines_withCosine() { - // Test all engines with COSINESIMIL space type + public void testAllEngines_usingModelMetadata_withCosine() { for (KNNEngine engine : KNNEngine.values()) { - VectorTransformer transformer = VectorTransformerFactory.getVectorTransformer(engine, SpaceType.COSINESIMIL); + ModelMetadata metaData = mock(ModelMetadata.class); + when(metaData.getKnnEngine()).thenReturn(engine); + when(metaData.getSpaceType()).thenReturn(SpaceType.COSINESIMIL); + VectorTransformer transformer = VectorTransformerFactory.getVectorTransformer(metaData); validateTransformer(SpaceType.COSINESIMIL, engine, transformer); } } public void testGetVectorTransformer_withNullContext() { // Test case for null context - assertThrows(IllegalArgumentException.class, () -> VectorTransformerFactory.getVectorTransformer(null)); + assertThrows(IllegalArgumentException.class, () -> VectorTransformerFactory.getVectorTransformer((KNNMethodContext) null)); } public void testAllSpaceTypes_usingContext_withFaiss() { diff --git a/src/test/java/org/opensearch/knn/recall/RecallTestsIT.java b/src/test/java/org/opensearch/knn/recall/RecallTestsIT.java index be2dd7b82..ae162401b 100644 --- a/src/test/java/org/opensearch/knn/recall/RecallTestsIT.java +++ b/src/test/java/org/opensearch/knn/recall/RecallTestsIT.java @@ -231,7 +231,7 @@ public void testRecall_whenLuceneHnswFP32_thenRecallAbove75percent() { */ @SneakyThrows public void testRecall_whenFaissHnswFP32_thenRecallAbove75percent() { - List spaceTypes = List.of(SpaceType.L2, SpaceType.INNER_PRODUCT); + List spaceTypes = List.of(SpaceType.L2, SpaceType.INNER_PRODUCT, SpaceType.COSINESIMIL); for (SpaceType spaceType : spaceTypes) { String indexName = createIndexName(KNNEngine.FAISS, spaceType); XContentBuilder builder = XContentFactory.jsonBuilder()