Skip to content

Commit

Permalink
Removing redundant type conversions for script scoring for hamming sp…
Browse files Browse the repository at this point in the history
…ace with binary vectors
  • Loading branch information
kasundra07 committed Dec 19, 2024
1 parent d57fdea commit 07f1503
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,28 @@ public float[] getValue() {
}
}

public byte[] getByteValue() {
if (!docExists) {
String errorMessage = String.format(
"One of the document doesn't have a value for field '%s'. "
+ "This can be avoided by checking if a document has a value for the field or not "
+ "by doc['%s'].size() == 0 ? 0 : {your script}",
fieldName,
fieldName
);
throw new IllegalStateException(errorMessage);
}
try {
return doGetByteValue();
} catch (IOException e) {
throw ExceptionsHelper.convertToOpenSearchException(e);
}
}

protected byte[] doGetByteValue() throws IOException {
throw new UnsupportedOperationException();
}

protected abstract float[] doGetValue() throws IOException;

@Override
Expand Down Expand Up @@ -111,6 +133,15 @@ protected float[] doGetValue() throws IOException {
}
return value;
}

@Override
public byte[] doGetByteValue() {
try {
return values.vectorValue();
} catch (IOException e) {
throw ExceptionsHelper.convertToOpenSearchException(e);
}
}
}

private static final class KNNFloatVectorScriptDocValues extends KNNVectorScriptDocValues {
Expand Down Expand Up @@ -139,6 +170,15 @@ private static final class KNNNativeVectorScriptDocValues extends KNNVectorScrip
protected float[] doGetValue() throws IOException {
return getVectorDataType().getVectorFromBytesRef(values.binaryValue());
}

@Override
public byte[] doGetByteValue() {
try {
return values.binaryValue().bytes;
} catch (IOException e) {
throw ExceptionsHelper.convertToOpenSearchException(e);
}
}
}

/**
Expand Down
35 changes: 35 additions & 0 deletions src/main/java/org/opensearch/knn/plugin/script/KNNScoreScript.java
Original file line number Diff line number Diff line change
Expand Up @@ -144,4 +144,39 @@ public double execute(ScoreScript.ExplanationHolder explanationHolder) {
return this.scoringMethod.apply(this.queryValue, scriptDocValues.getValue());
}
}

/**
* KNNVectors with byte[] type. The query value passed in is expected to be byte[]. The fieldType of the docs
* being searched over are expected to be KNNVector type.
*/
public static class KNNByteVectorType extends KNNScoreScript<byte[]> {

public KNNByteVectorType(
Map<String, Object> params,
byte[] queryValue,
String field,
BiFunction<byte[], byte[], Float> scoringMethod,
SearchLookup lookup,
LeafReaderContext leafContext,
IndexSearcher searcher
) throws IOException {
super(params, queryValue, field, scoringMethod, lookup, leafContext, searcher);
}

/**
* This function called for each doc in the segment. We evaluate the score of the vector in the doc
*
* @param explanationHolder A helper to take in an explanation from a script and turn
* it into an {@link org.apache.lucene.search.Explanation}
* @return score of the vector to the query vector
*/
@Override
public double execute(ScoreScript.ExplanationHolder explanationHolder) {
KNNVectorScriptDocValues scriptDocValues = (KNNVectorScriptDocValues) getDoc().get(this.field);
if (scriptDocValues.isEmpty()) {
return 0.0;
}
return this.scoringMethod.apply(this.queryValue, scriptDocValues.getByteValue());
}
}
}
48 changes: 34 additions & 14 deletions src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,16 @@

import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.getVectorMagnitudeSquared;
import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.isBinaryFieldType;
import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.isBinaryVectorDataType;
import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.isKNNVectorFieldType;
import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.isLongFieldType;
import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.parseToBigInteger;
import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.parseToByteArray;
import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.parseToFloatArray;
import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.parseToLong;

public interface KNNScoringSpace {

/**
* Return the correct scoring script for a given query. The scoring script
*
Expand Down Expand Up @@ -181,25 +184,42 @@ protected BiFunction<float[], float[], Float> getScoringMethod(final float[] pro
}
}

class Hamming extends KNNFieldSpace {
private static final Set<VectorDataType> DATA_TYPES_HAMMING = Set.of(VectorDataType.BINARY);
class Hamming implements KNNScoringSpace {
private byte[] processedQuery;
BiFunction<byte[], byte[], Float> scoringMethod;

public Hamming(Object query, MappedFieldType fieldType) {
super(query, fieldType, "hamming", DATA_TYPES_HAMMING);
}
if (!isKNNVectorFieldType(fieldType)) {
throw new IllegalArgumentException("Incompatible field_type for hamming space. The field type must be knn_vector.");
}
KNNVectorFieldType knnVectorFieldType = (KNNVectorFieldType) fieldType;
if (!isBinaryVectorDataType(knnVectorFieldType)) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"Incompatible field_type for hamming space. The data type should be binary but got %s",
knnVectorFieldType.getVectorDataType()
)
);
}

@Override
protected BiFunction<float[], float[], Float> getScoringMethod(final float[] processedQuery) {
// TODO we want to avoid converting back and forth between byte and float
return (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.calculateHammingBit(toByte(q), toByte(v)));
this.processedQuery = parseToByteArray(
query,
KNNVectorFieldMapperUtil.getExpectedVectorLength(knnVectorFieldType),
knnVectorFieldType.getVectorDataType()
);
this.scoringMethod = (byte[] q, byte[] v) -> 1 / (1 + KNNScoringUtil.calculateHammingBit(q, v));
}

private byte[] toByte(final float[] vector) {
byte[] bytes = new byte[vector.length];
for (int i = 0; i < vector.length; i++) {
bytes[i] = (byte) vector[i];
}
return bytes;
@Override
public ScoreScript getScoreScript(
Map<String, Object> params,
String field,
SearchLookup lookup,
LeafReaderContext ctx,
IndexSearcher searcher
) throws IOException {
return new KNNScoreScript.KNNByteVectorType(params, this.processedQuery, field, this.scoringMethod, lookup, ctx, searcher);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,24 @@ public static float[] parseToFloatArray(Object object, int expectedVectorLength,
return floatArray;
}

/**
* Convert an Object to a byte array.
*
* @param object Object to be converted to a byte array
* @param expectedVectorLength int representing the expected vector length of this array.
* @return byte[] of the object
*/
public static byte[] parseToByteArray(Object object, int expectedVectorLength, VectorDataType vectorDataType) {
byte[] byteArray = convertVectorToByteArray(object, vectorDataType);
if (expectedVectorLength != byteArray.length) {
KNNCounter.SCRIPT_QUERY_ERRORS.increment();
throw new IllegalStateException(
"Object's length=" + byteArray.length + " does not match the " + "expected length=" + expectedVectorLength + "."
);
}
return byteArray;
}

/**
* Converts Object vector to primitive float[]
*
Expand All @@ -134,6 +152,29 @@ public static float[] convertVectorToPrimitive(Object vector, VectorDataType vec
return primitiveVector;
}

/**
* Converts Object vector to byte[]
*
* @param vector input vector
* @return Byte array representing the vector
*/
@SuppressWarnings("unchecked")
public static byte[] convertVectorToByteArray(Object vector, VectorDataType vectorDataType) {
byte[] byteVector = null;
if (vector != null) {
final List<Number> tmp = (List<Number>) vector;
byteVector = new byte[tmp.size()];
for (int i = 0; i < byteVector.length; i++) {
float value = tmp.get(i).floatValue();
if (VectorDataType.BYTE == vectorDataType || VectorDataType.BINARY == vectorDataType) {
validateByteVectorValue(value, vectorDataType);
}
byteVector[i] = tmp.get(i).byteValue();
}
}
return byteVector;
}

/**
* Calculates the magnitude of given vector
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,8 @@ public void testHamming_whenKNNFieldType_thenSucceed() {

KNNScoringSpace.Hamming hamming = new KNNScoringSpace.Hamming(arrayListQueryObject, fieldType);

float[] arrayFloat = new float[] { 1.0f, 2.0f, 3.0f };
assertEquals(1F, hamming.getScoringMethod().apply(arrayFloat, arrayFloat), 0.1F);
byte[] arrayByte = new byte[] { 1, 2, 3 };
assertEquals(1F, ((BiFunction<byte[], byte[], Float>) hamming.scoringMethod).apply(arrayByte, arrayByte), 0.1F);
}

public void testHamming_whenNonBinaryVectorDataType_thenException() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@ public void testParseKNNVectorQuery() {
expectThrows(ClassCastException.class, () -> KNNScoringSpaceUtil.parseToFloatArray(invalidObject, 3, VectorDataType.FLOAT));
}

public void testConvertVectorToByteArray() {
byte[] arrayByte = new byte[] { 1, 2, 3 };
List<Double> arrayListQueryObject = new ArrayList<>(Arrays.asList(1.0, 2.0, 3.0));
assertArrayEquals(arrayByte, KNNScoringSpaceUtil.parseToByteArray(arrayListQueryObject, 3, VectorDataType.BINARY));
}

public void testIsBinaryVectorDataType_whenBinary_thenReturnTrue() {
KNNVectorFieldType fieldType = mock(KNNVectorFieldType.class);
when(fieldType.getVectorDataType()).thenReturn(VectorDataType.BINARY);
Expand Down

0 comments on commit 07f1503

Please sign in to comment.