diff --git a/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java b/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java index 55ff65516..146177ba9 100644 --- a/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java +++ b/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java @@ -60,6 +60,28 @@ public float[] getValue() { } } + public byte[] getByteValue() { + if (!docExists) { + String errorMessage = String.format( + "One of the document doesn't have a value for field '%s'. " + + "This can be avoided by checking if a document has a value for the field or not " + + "by doc['%s'].size() == 0 ? 0 : {your script}", + fieldName, + fieldName + ); + throw new IllegalStateException(errorMessage); + } + try { + return doGetByteValue(); + } catch (IOException e) { + throw ExceptionsHelper.convertToOpenSearchException(e); + } + } + + protected byte[] doGetByteValue() throws IOException { + throw new UnsupportedOperationException(); + } + protected abstract float[] doGetValue() throws IOException; @Override @@ -111,6 +133,15 @@ protected float[] doGetValue() throws IOException { } return value; } + + @Override + public byte[] doGetByteValue() { + try { + return values.vectorValue(); + } catch (IOException e) { + throw ExceptionsHelper.convertToOpenSearchException(e); + } + } } private static final class KNNFloatVectorScriptDocValues extends KNNVectorScriptDocValues { @@ -139,6 +170,15 @@ private static final class KNNNativeVectorScriptDocValues extends KNNVectorScrip protected float[] doGetValue() throws IOException { return getVectorDataType().getVectorFromBytesRef(values.binaryValue()); } + + @Override + public byte[] doGetByteValue() { + try { + return values.binaryValue().bytes; + } catch (IOException e) { + throw ExceptionsHelper.convertToOpenSearchException(e); + } + } } /** diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNScoreScript.java b/src/main/java/org/opensearch/knn/plugin/script/KNNScoreScript.java index d7a84817b..4b2a2b598 100644 --- a/src/main/java/org/opensearch/knn/plugin/script/KNNScoreScript.java +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNScoreScript.java @@ -144,4 +144,39 @@ public double execute(ScoreScript.ExplanationHolder explanationHolder) { return this.scoringMethod.apply(this.queryValue, scriptDocValues.getValue()); } } + + /** + * KNNVectors with byte[] type. The query value passed in is expected to be byte[]. The fieldType of the docs + * being searched over are expected to be KNNVector type. + */ + public static class KNNByteVectorType extends KNNScoreScript { + + public KNNByteVectorType( + Map params, + byte[] queryValue, + String field, + BiFunction scoringMethod, + SearchLookup lookup, + LeafReaderContext leafContext, + IndexSearcher searcher + ) throws IOException { + super(params, queryValue, field, scoringMethod, lookup, leafContext, searcher); + } + + /** + * This function called for each doc in the segment. We evaluate the score of the vector in the doc + * + * @param explanationHolder A helper to take in an explanation from a script and turn + * it into an {@link org.apache.lucene.search.Explanation} + * @return score of the vector to the query vector + */ + @Override + public double execute(ScoreScript.ExplanationHolder explanationHolder) { + KNNVectorScriptDocValues scriptDocValues = (KNNVectorScriptDocValues) getDoc().get(this.field); + if (scriptDocValues.isEmpty()) { + return 0.0; + } + return this.scoringMethod.apply(this.queryValue, scriptDocValues.getByteValue()); + } + } } diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java index 71616c9fd..619359453 100644 --- a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java @@ -26,13 +26,16 @@ import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.getVectorMagnitudeSquared; import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.isBinaryFieldType; +import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.isBinaryVectorDataType; import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.isKNNVectorFieldType; import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.isLongFieldType; import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.parseToBigInteger; +import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.parseToByteArray; import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.parseToFloatArray; import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.parseToLong; public interface KNNScoringSpace { + /** * Return the correct scoring script for a given query. The scoring script * @@ -181,25 +184,42 @@ protected BiFunction getScoringMethod(final float[] pro } } - class Hamming extends KNNFieldSpace { - private static final Set DATA_TYPES_HAMMING = Set.of(VectorDataType.BINARY); + class Hamming implements KNNScoringSpace { + private byte[] processedQuery; + BiFunction scoringMethod; public Hamming(Object query, MappedFieldType fieldType) { - super(query, fieldType, "hamming", DATA_TYPES_HAMMING); - } + if (!isKNNVectorFieldType(fieldType)) { + throw new IllegalArgumentException("Incompatible field_type for hamming space. The field type must be knn_vector."); + } + KNNVectorFieldType knnVectorFieldType = (KNNVectorFieldType) fieldType; + if (!isBinaryVectorDataType(knnVectorFieldType)) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "Incompatible field_type for hamming space. The data type should be binary but got %s", + knnVectorFieldType.getVectorDataType() + ) + ); + } - @Override - protected BiFunction getScoringMethod(final float[] processedQuery) { - // TODO we want to avoid converting back and forth between byte and float - return (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.calculateHammingBit(toByte(q), toByte(v))); + this.processedQuery = parseToByteArray( + query, + KNNVectorFieldMapperUtil.getExpectedVectorLength(knnVectorFieldType), + knnVectorFieldType.getVectorDataType() + ); + this.scoringMethod = (byte[] q, byte[] v) -> 1 / (1 + KNNScoringUtil.calculateHammingBit(q, v)); } - private byte[] toByte(final float[] vector) { - byte[] bytes = new byte[vector.length]; - for (int i = 0; i < vector.length; i++) { - bytes[i] = (byte) vector[i]; - } - return bytes; + @Override + public ScoreScript getScoreScript( + Map params, + String field, + SearchLookup lookup, + LeafReaderContext ctx, + IndexSearcher searcher + ) throws IOException { + return new KNNScoreScript.KNNByteVectorType(params, this.processedQuery, field, this.scoringMethod, lookup, ctx, searcher); } } diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtil.java b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtil.java index 7a97fdb05..17e47ba8e 100644 --- a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtil.java +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtil.java @@ -111,6 +111,24 @@ public static float[] parseToFloatArray(Object object, int expectedVectorLength, return floatArray; } + /** + * Convert an Object to a byte array. + * + * @param object Object to be converted to a byte array + * @param expectedVectorLength int representing the expected vector length of this array. + * @return byte[] of the object + */ + public static byte[] parseToByteArray(Object object, int expectedVectorLength, VectorDataType vectorDataType) { + byte[] byteArray = convertVectorToByteArray(object, vectorDataType); + if (expectedVectorLength != byteArray.length) { + KNNCounter.SCRIPT_QUERY_ERRORS.increment(); + throw new IllegalStateException( + "Object's length=" + byteArray.length + " does not match the " + "expected length=" + expectedVectorLength + "." + ); + } + return byteArray; + } + /** * Converts Object vector to primitive float[] * @@ -134,6 +152,29 @@ public static float[] convertVectorToPrimitive(Object vector, VectorDataType vec return primitiveVector; } + /** + * Converts Object vector to byte[] + * + * @param vector input vector + * @return Byte array representing the vector + */ + @SuppressWarnings("unchecked") + public static byte[] convertVectorToByteArray(Object vector, VectorDataType vectorDataType) { + byte[] byteVector = null; + if (vector != null) { + final List tmp = (List) vector; + byteVector = new byte[tmp.size()]; + for (int i = 0; i < byteVector.length; i++) { + float value = tmp.get(i).floatValue(); + if (VectorDataType.BYTE == vectorDataType || VectorDataType.BINARY == vectorDataType) { + validateByteVectorValue(value, vectorDataType); + } + byteVector[i] = tmp.get(i).byteValue(); + } + } + return byteVector; + } + /** * Calculates the magnitude of given vector * diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceTests.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceTests.java index 4fc549d6b..cc709fa4f 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceTests.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceTests.java @@ -212,8 +212,8 @@ public void testHamming_whenKNNFieldType_thenSucceed() { KNNScoringSpace.Hamming hamming = new KNNScoringSpace.Hamming(arrayListQueryObject, fieldType); - float[] arrayFloat = new float[] { 1.0f, 2.0f, 3.0f }; - assertEquals(1F, hamming.getScoringMethod().apply(arrayFloat, arrayFloat), 0.1F); + byte[] arrayByte = new byte[] { 1, 2, 3 }; + assertEquals(1F, ((BiFunction) hamming.scoringMethod).apply(arrayByte, arrayByte), 0.1F); } public void testHamming_whenNonBinaryVectorDataType_thenException() { diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtilTests.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtilTests.java index 2374e4f7b..575e40145 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtilTests.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtilTests.java @@ -76,6 +76,12 @@ public void testParseKNNVectorQuery() { expectThrows(ClassCastException.class, () -> KNNScoringSpaceUtil.parseToFloatArray(invalidObject, 3, VectorDataType.FLOAT)); } + public void testConvertVectorToByteArray() { + byte[] arrayByte = new byte[] { 1, 2, 3 }; + List arrayListQueryObject = new ArrayList<>(Arrays.asList(1.0, 2.0, 3.0)); + assertArrayEquals(arrayByte, KNNScoringSpaceUtil.parseToByteArray(arrayListQueryObject, 3, VectorDataType.BINARY)); + } + public void testIsBinaryVectorDataType_whenBinary_thenReturnTrue() { KNNVectorFieldType fieldType = mock(KNNVectorFieldType.class); when(fieldType.getVectorDataType()).thenReturn(VectorDataType.BINARY);