From 1dc2ec9f17b83866258dab918b861b3c43f1289c Mon Sep 17 00:00:00 2001 From: Balasubramanian Date: Thu, 16 Jan 2025 21:24:48 -0800 Subject: [PATCH] Refactor Signed-off-by: Balasubramanian --- .../knn/index/engine/AbstractKNNMethod.java | 6 +++--- .../engine/faiss/AbstractFaissMethod.java | 15 ++++++++------- .../knn/index/mapper/KNNVectorFieldMapper.java | 2 +- .../knn/index/mapper/LuceneFieldMapper.java | 7 ------- .../knn/index/mapper/ModelFieldMapper.java | 15 ++++++++++----- .../mapper/NormalizeVectorTransformer.java | 12 ++++++++++++ .../index/mapper/VectorTransformerFactory.java | 18 ++---------------- .../NormalizeVectorTransformerTests.java | 4 ++++ .../mapper/VectorTransformerFactoryTests.java | 2 +- 9 files changed, 41 insertions(+), 40 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/engine/AbstractKNNMethod.java b/src/main/java/org/opensearch/knn/index/engine/AbstractKNNMethod.java index 9768e56f7..bfd908a09 100644 --- a/src/main/java/org/opensearch/knn/index/engine/AbstractKNNMethod.java +++ b/src/main/java/org/opensearch/knn/index/engine/AbstractKNNMethod.java @@ -108,8 +108,8 @@ protected PerDimensionProcessor doGetPerDimensionProcessor( return PerDimensionProcessor.NOOP_PROCESSOR; } - protected VectorTransformer getVectorTransformer(KNNMethodContext knnMethodContext) { - return VectorTransformerFactory.getVectorTransformer(knnMethodContext.getKnnEngine(), knnMethodContext.getSpaceType()); + protected VectorTransformer getVectorTransformer(SpaceType spaceType) { + return VectorTransformerFactory.NOOP_VECTOR_TRANSFORMER; } @Override @@ -130,7 +130,7 @@ public KNNLibraryIndexingContext getKNNLibraryIndexingContext( .vectorValidator(doGetVectorValidator(knnMethodContext, knnMethodConfigContext)) .perDimensionValidator(doGetPerDimensionValidator(knnMethodContext, knnMethodConfigContext)) .perDimensionProcessor(doGetPerDimensionProcessor(knnMethodContext, knnMethodConfigContext)) - .vectorTransformer(getVectorTransformer(knnMethodContext)) + .vectorTransformer(getVectorTransformer(knnMethodContext.getSpaceType())) .build(); } diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java b/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java index 1f1208527..5e7b72b69 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java @@ -8,15 +8,11 @@ import org.apache.commons.lang.StringUtils; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; -import org.opensearch.knn.index.engine.AbstractKNNMethod; -import org.opensearch.knn.index.engine.KNNLibraryIndexingContext; -import org.opensearch.knn.index.engine.KNNLibrarySearchContext; -import org.opensearch.knn.index.engine.KNNMethodConfigContext; -import org.opensearch.knn.index.engine.KNNMethodContext; -import org.opensearch.knn.index.engine.MethodComponent; -import org.opensearch.knn.index.engine.MethodComponentContext; +import org.opensearch.knn.index.engine.*; import org.opensearch.knn.index.mapper.PerDimensionProcessor; import org.opensearch.knn.index.mapper.PerDimensionValidator; +import org.opensearch.knn.index.mapper.VectorTransformer; +import org.opensearch.knn.index.mapper.VectorTransformerFactory; import java.util.Objects; import java.util.Set; @@ -143,4 +139,9 @@ protected SpaceType convertUserToMethodSpaceType(SpaceType spaceType) { } return super.convertUserToMethodSpaceType(spaceType); } + + @Override + protected VectorTransformer getVectorTransformer(SpaceType spaceType) { + return VectorTransformerFactory.getVectorTransformer(KNNEngine.FAISS, spaceType); + } } diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java index 4e552a9e0..99c6ebe2a 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java @@ -702,7 +702,7 @@ protected void validatePreparse() { * */ protected VectorTransformer getVectorTransformer() { - return VectorTransformerFactory.getVectorTransformer(); + return VectorTransformerFactory.NOOP_VECTOR_TRANSFORMER; } protected void parseCreateField(ParseContext context, int dimension, VectorDataType vectorDataType) throws IOException { diff --git a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java index 83f3ce4c5..4ceb9b4b2 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java @@ -42,7 +42,6 @@ public class LuceneFieldMapper extends KNNVectorFieldMapper { private final PerDimensionProcessor perDimensionProcessor; private final PerDimensionValidator perDimensionValidator; private final VectorValidator vectorValidator; - private final VectorTransformer vectorTransformer; static LuceneFieldMapper createFieldMapper( String fullname, @@ -123,7 +122,6 @@ private LuceneFieldMapper( this.perDimensionProcessor = knnLibraryIndexingContext.getPerDimensionProcessor(); this.perDimensionValidator = knnLibraryIndexingContext.getPerDimensionValidator(); this.vectorValidator = knnLibraryIndexingContext.getVectorValidator(); - this.vectorTransformer = knnLibraryIndexingContext.getVectorTransformer(); } @Override @@ -171,11 +169,6 @@ protected PerDimensionProcessor getPerDimensionProcessor() { return perDimensionProcessor; } - @Override - protected VectorTransformer getVectorTransformer() { - return vectorTransformer; - } - @Override void updateEngineStats() { KNNEngine.LUCENE.setInitialized(true); diff --git a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java index 879706aa8..d472090fc 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java @@ -197,8 +197,11 @@ protected PerDimensionProcessor getPerDimensionProcessor() { @Override protected VectorTransformer getVectorTransformer() { + // we don't want to call model metadata to get space type and engine for every vector, + // since getVectorTransformer() will be called once per vector. Hence, + // we initialize it once, and use it every other time initVectorTransformer(); - return vectorTransformer; + return this.vectorTransformer; } /** @@ -207,6 +210,7 @@ protected VectorTransformer getVectorTransformer() { * and KNN method context. * @throws IllegalStateException if model metadata cannot be retrieved */ + private void initVectorTransformer() { if (vectorTransformer != null) { return; @@ -215,13 +219,14 @@ private void initVectorTransformer() { KNNMethodContext knnMethodContext = getKNNMethodContextFromModelMetadata(modelMetadata); KNNMethodConfigContext knnMethodConfigContext = getKNNMethodConfigContextFromModelMetadata(modelMetadata); - // Need to handle BWC case + // Need to handle BWC case where method context is not available if (knnMethodContext == null || knnMethodConfigContext == null) { - log.debug("Method Context not available - falling back to Model Metadata to determine VectorTransformer instance"); - vectorTransformer = VectorTransformerFactory.getVectorTransformer(modelMetadata.getKnnEngine(), modelMetadata.getSpaceType()); + vectorTransformer = VectorTransformerFactory.NOOP_VECTOR_TRANSFORMER; return; } - + // get vector transformer from Indexing Context. We want Engine/Library to provide necessary + // input rather than creating Transformer from the engine and space type. This design + // decision is taken to make sure that Engine will drive the implementation than Field Mapper. KNNLibraryIndexingContext knnLibraryIndexingContext = knnMethodContext.getKnnEngine() .getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext); vectorTransformer = knnLibraryIndexingContext.getVectorTransformer(); diff --git a/src/main/java/org/opensearch/knn/index/mapper/NormalizeVectorTransformer.java b/src/main/java/org/opensearch/knn/index/mapper/NormalizeVectorTransformer.java index a348cd1b8..cd1331d3d 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/NormalizeVectorTransformer.java +++ b/src/main/java/org/opensearch/knn/index/mapper/NormalizeVectorTransformer.java @@ -18,6 +18,18 @@ public void transform(float[] vector) { VectorUtil.l2normalize(vector); } + /** + * Transforms a byte array vector by normalizing it. + * This operation is currently not supported for byte arrays. + * + * @param vector the byte array to be normalized + * @throws UnsupportedOperationException when this method is called, as byte array normalization is not supported + */ + @Override + public void transform(byte[] vector) { + throw new UnsupportedOperationException("Byte array normalization is not supported"); + } + private void validateVector(float[] vector) { if (vector == null || vector.length == 0) { throw new IllegalArgumentException("Vector cannot be null or empty"); diff --git a/src/main/java/org/opensearch/knn/index/mapper/VectorTransformerFactory.java b/src/main/java/org/opensearch/knn/index/mapper/VectorTransformerFactory.java index 4db901a4f..f87e496df 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/VectorTransformerFactory.java +++ b/src/main/java/org/opensearch/knn/index/mapper/VectorTransformerFactory.java @@ -20,23 +20,9 @@ public final class VectorTransformerFactory { /** * A no-operation transformer that returns vector values unchanged. */ - private final static VectorTransformer NOOP_VECTOR_TRANSFORMER = new VectorTransformer() { + public final static VectorTransformer NOOP_VECTOR_TRANSFORMER = new VectorTransformer() { }; - /** - * Returns a vector transformer instance for vector transformations. - * This method provides access to the default no-operation vector transformer - * that performs identity transformation on vectors. The transformer does not - * modify the input vectors and returns them as-is.This implementation returns a stateless, thread-safe transformer - * instance that can be safely shared across multiple calls - * - * @return VectorTransformer A singleton instance of the no-operation vector - * transformer (NOOP_VECTOR_TRANSFORMER) - */ - public static VectorTransformer getVectorTransformer() { - return NOOP_VECTOR_TRANSFORMER; - } - /** * Returns a vector transformer based on the provided KNN engine and space type. * For FAISS engine with cosine similarity space type, returns a NormalizeVectorTransformer @@ -48,7 +34,7 @@ public static VectorTransformer getVectorTransformer() { * @return VectorTransformer An appropriate vector transformer instance */ public static VectorTransformer getVectorTransformer(final KNNEngine knnEngine, final SpaceType spaceType) { - return shouldNormalizeVector(knnEngine, spaceType) ? new NormalizeVectorTransformer() : getVectorTransformer(); + return shouldNormalizeVector(knnEngine, spaceType) ? new NormalizeVectorTransformer() : NOOP_VECTOR_TRANSFORMER; } private static boolean shouldNormalizeVector(final KNNEngine knnEngine, final SpaceType spaceType) { diff --git a/src/test/java/org/opensearch/knn/index/mapper/NormalizeVectorTransformerTests.java b/src/test/java/org/opensearch/knn/index/mapper/NormalizeVectorTransformerTests.java index 1c4237d7b..4b17b9a12 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/NormalizeVectorTransformerTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/NormalizeVectorTransformerTests.java @@ -19,6 +19,10 @@ public void testNormalizeTransformer_withEmptyVector_thenThrowsException() { assertThrows(IllegalArgumentException.class, () -> transformer.transform(new float[0])); } + public void testNormalizeTransformer_withByteVector_thenThrowsException() { + assertThrows(UnsupportedOperationException.class, () -> transformer.transform(new byte[0])); + } + public void testNormalizeTransformer_withValidVector_thenSuccess() { float[] input = { -3.0f, 4.0f }; transformer.transform(input); diff --git a/src/test/java/org/opensearch/knn/index/mapper/VectorTransformerFactoryTests.java b/src/test/java/org/opensearch/knn/index/mapper/VectorTransformerFactoryTests.java index d93a836a1..6e213c151 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/VectorTransformerFactoryTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/VectorTransformerFactoryTests.java @@ -34,7 +34,7 @@ private static void validateTransformer(SpaceType spaceType, KNNEngine engine, V } else { assertSame( "Should return NOOP transformer for " + engine + " with COSINESIMIL", - VectorTransformerFactory.getVectorTransformer(), + VectorTransformerFactory.NOOP_VECTOR_TRANSFORMER, transformer ); }