From 273f181ce3ff3c5d6a757ecefa63f571edeb5b22 Mon Sep 17 00:00:00 2001 From: Bansi Kasundra Date: Tue, 14 Jan 2025 11:45:10 -0800 Subject: [PATCH] Addressed comments Signed-off-by: Bansi Kasundra --- .../knn/plugin/script/KNNScoringSpace.java | 8 ++++++-- .../knn/plugin/script/KNNScoringSpaceUtil.java | 1 + .../opensearch/knn/integ/KNNScriptScoringIT.java | 13 ++----------- 3 files changed, 9 insertions(+), 13 deletions(-) diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java index 33907dd3a..b77b0f475 100644 --- a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java @@ -92,7 +92,7 @@ public ScoreScript getScoreScript( ctx, searcher ); - } else { + } else if (processedQuery instanceof byte[]) { return new KNNScoreScript.KNNByteVectorType( params, (byte[]) this.processedQuery, @@ -102,6 +102,10 @@ public ScoreScript getScoreScript( ctx, searcher ); + } else { + throw new IllegalStateException( + "Unexpected type for processedQuery. Expected float[] or byte[], but got: " + processedQuery.getClass().getName() + ); } } @@ -139,7 +143,7 @@ protected Object getProcessedQuery(final Object query, final KNNVectorFieldType VectorDataType vectorDataType = knnVectorFieldType.getVectorDataType() == null ? VectorDataType.FLOAT : knnVectorFieldType.getVectorDataType(); - if (vectorDataType.equals(VectorDataType.FLOAT)) { + if (vectorDataType == VectorDataType.FLOAT) { return parseToFloatArray( query, KNNVectorFieldMapperUtil.getExpectedVectorLength(knnVectorFieldType), diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtil.java b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtil.java index b4789e145..7699403b3 100644 --- a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtil.java +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtil.java @@ -155,6 +155,7 @@ public static byte[] parseToByteArray(Object object, int expectedVectorLength, V /** * Converts Object vector to byte[] * + * Expects all numbers in the Object vector to be in the byte range of [-128 to 127] * @param vector input vector * @return Byte array representing the vector */ diff --git a/src/test/java/org/opensearch/knn/integ/KNNScriptScoringIT.java b/src/test/java/org/opensearch/knn/integ/KNNScriptScoringIT.java index ef7afcd99..81bed80cf 100644 --- a/src/test/java/org/opensearch/knn/integ/KNNScriptScoringIT.java +++ b/src/test/java/org/opensearch/knn/integ/KNNScriptScoringIT.java @@ -779,15 +779,6 @@ private float[] randomVector(final int dimensions, final VectorDataType vectorDa return vector; } - private byte[] randomByteVector(final int dimensions, final VectorDataType vectorDataType) { - int size = VectorDataType.BINARY == vectorDataType ? dimensions / 8 : dimensions; - final byte[] vector = new byte[size]; - for (int i = 0; i < size; i++) { - vector[i] = randomByte(); - } - return vector; - } - private Map createDataset( Function scoreFunction, int dimensions, @@ -821,7 +812,7 @@ private Map createDataset( FIELD_NAME, Collections.emptyMap(), vectorDataType, - getMappingConfigForFlatMapping(vectorDataType.equals(VectorDataType.BINARY) ? queryVector.length * 8 : queryVector.length) + getMappingConfigForFlatMapping(vectorDataType == VectorDataType.BINARY ? queryVector.length * 8 : queryVector.length) ) ); switch (spaceType) { @@ -831,7 +822,7 @@ private Map createDataset( case COSINESIMIL: case INNER_PRODUCT: case HAMMING: - if (vectorDataType.equals(VectorDataType.FLOAT)) { + if (vectorDataType == VectorDataType.FLOAT) { return ((KNNScoringSpace.KNNFieldSpace) knnScoringSpace).getScoringMethod(queryVector); } return ((KNNScoringSpace.KNNFieldSpace) knnScoringSpace).getScoringMethod(toByte(queryVector));