Skip to content

Commit

Permalink
Binary hamming fix (#2456)
Browse files Browse the repository at this point in the history
Removing redundant type conversions for script scoring for hamming space with binary vectors

Signed-off-by: Bansi Kasundra <[email protected]>
  • Loading branch information
kasundra07 authored Jan 28, 2025
1 parent 02fdd70 commit 69e2303
Show file tree
Hide file tree
Showing 16 changed files with 559 additions and 169 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Make the build work for M series MacOS without manual code changes and local JAVA_HOME config (#2397)[https://github.com/opensearch-project/k-NN/pull/2397]
- Remove DocsWithFieldSet reference from NativeEngineFieldVectorsWriter (#2408)[https://github.com/opensearch-project/k-NN/pull/2408]
- Remove skip building graph check for quantization use case (#2430)[https://github.com/opensearch-project/k-NN/2430]
- Removing redundant type conversions for script scoring for hamming space with binary vectors (#2351)[https://github.com/opensearch-project/k-NN/pull/2351]
### Bug Fixes
* Fixing the bug when a segment has no vector field present for disk based vector search (#2282)[https://github.com/opensearch-project/k-NN/pull/2282]
* Fixing the bug where search fails with "fields" parameter for an index with a knn_vector field (#2314)[https://github.com/opensearch-project/k-NN/pull/2314]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public long ramBytesUsed() {
}

@Override
public ScriptDocValues<float[]> getScriptValues() {
public ScriptDocValues<?> getScriptValues() {
try {
FieldInfo fieldInfo = FieldInfoExtractor.getFieldInfo(reader, fieldName);
if (fieldInfo == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import org.opensearch.index.fielddata.ScriptDocValues;

@RequiredArgsConstructor(access = AccessLevel.PRIVATE)
public abstract class KNNVectorScriptDocValues extends ScriptDocValues<float[]> {
public abstract class KNNVectorScriptDocValues<T> extends ScriptDocValues<T> {

private final DocIdSetIterator vectorValues;
private final String fieldName;
Expand All @@ -42,7 +42,7 @@ public void setNextDocId(int docId) throws IOException {
docExists = lastDocID == curDocID;
}

public float[] getValue() {
public T getValue() {
if (!docExists) {
String errorMessage = String.format(
"One of the document doesn't have a value for field '%s'. "
Expand All @@ -60,15 +60,15 @@ public float[] getValue() {
}
}

protected abstract float[] doGetValue() throws IOException;
protected abstract T doGetValue() throws IOException;

@Override
public int size() {
return docExists ? 1 : 0;
}

@Override
public float[] get(int i) {
public T get(int i) {
throw new UnsupportedOperationException("knn vector does not support this operation");
}

Expand All @@ -81,20 +81,20 @@ public float[] get(int i) {
* @return A KNNVectorScriptDocValues object based on the type of the values.
* @throws IllegalArgumentException If the type of values is unsupported.
*/
public static KNNVectorScriptDocValues create(DocIdSetIterator values, String fieldName, VectorDataType vectorDataType) {
public static KNNVectorScriptDocValues<?> create(DocIdSetIterator values, String fieldName, VectorDataType vectorDataType) {
Objects.requireNonNull(values, "values must not be null");
if (values instanceof ByteVectorValues) {
return new KNNByteVectorScriptDocValues((ByteVectorValues) values, fieldName, vectorDataType);
} else if (values instanceof FloatVectorValues) {
return new KNNFloatVectorScriptDocValues((FloatVectorValues) values, fieldName, vectorDataType);
} else if (values instanceof BinaryDocValues) {
return new KNNNativeVectorScriptDocValues((BinaryDocValues) values, fieldName, vectorDataType);
return new KNNNativeVectorScriptDocValues<>((BinaryDocValues) values, fieldName, vectorDataType);
} else {
throw new IllegalArgumentException("Unsupported values type: " + values.getClass());
}
}

private static final class KNNByteVectorScriptDocValues extends KNNVectorScriptDocValues {
private static final class KNNByteVectorScriptDocValues extends KNNVectorScriptDocValues<byte[]> {
private final ByteVectorValues values;

KNNByteVectorScriptDocValues(ByteVectorValues values, String field, VectorDataType type) {
Expand All @@ -103,17 +103,16 @@ private static final class KNNByteVectorScriptDocValues extends KNNVectorScriptD
}

@Override
protected float[] doGetValue() throws IOException {
byte[] bytes = values.vectorValue();
float[] value = new float[bytes.length];
for (int i = 0; i < bytes.length; i++) {
value[i] = (float) bytes[i];
protected byte[] doGetValue() throws IOException {
try {
return values.vectorValue();
} catch (IOException e) {
throw ExceptionsHelper.convertToOpenSearchException(e);
}
return value;
}
}

private static final class KNNFloatVectorScriptDocValues extends KNNVectorScriptDocValues {
private static final class KNNFloatVectorScriptDocValues extends KNNVectorScriptDocValues<float[]> {
private final FloatVectorValues values;

KNNFloatVectorScriptDocValues(FloatVectorValues values, String field, VectorDataType type) {
Expand All @@ -127,7 +126,7 @@ protected float[] doGetValue() throws IOException {
}
}

private static final class KNNNativeVectorScriptDocValues extends KNNVectorScriptDocValues {
private static final class KNNNativeVectorScriptDocValues<T> extends KNNVectorScriptDocValues<T> {
private final BinaryDocValues values;

KNNNativeVectorScriptDocValues(BinaryDocValues values, String field, VectorDataType type) {
Expand All @@ -136,7 +135,7 @@ private static final class KNNNativeVectorScriptDocValues extends KNNVectorScrip
}

@Override
protected float[] doGetValue() throws IOException {
protected T doGetValue() throws IOException {
return getVectorDataType().getVectorFromBytesRef(values.binaryValue());
}
}
Expand All @@ -148,10 +147,18 @@ protected float[] doGetValue() throws IOException {
* @param type The data type of the vector.
* @return An empty KNNVectorScriptDocValues object.
*/
public static KNNVectorScriptDocValues emptyValues(String fieldName, VectorDataType type) {
return new KNNVectorScriptDocValues(DocIdSetIterator.empty(), fieldName, type) {
public static KNNVectorScriptDocValues<?> emptyValues(String fieldName, VectorDataType type) {
if (type == VectorDataType.FLOAT) {
return new KNNVectorScriptDocValues<float[]>(DocIdSetIterator.empty(), fieldName, type) {
@Override
protected float[] doGetValue() throws IOException {
throw new UnsupportedOperationException("empty values");
}
};
}
return new KNNVectorScriptDocValues<byte[]>(DocIdSetIterator.empty(), fieldName, type) {
@Override
protected float[] doGetValue() throws IOException {
protected byte[] doGetValue() throws IOException {
throw new UnsupportedOperationException("empty values");
}
};
Expand Down
24 changes: 5 additions & 19 deletions src/main/java/org/opensearch/knn/index/VectorDataType.java
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,8 @@ public FieldType createKnnVectorFieldType(int dimension, KNNVectorSimilarityFunc
}

@Override
public float[] getVectorFromBytesRef(BytesRef binaryValue) {
float[] vector = new float[binaryValue.length];
int i = 0;
int j = binaryValue.offset;

while (i < binaryValue.length) {
vector[i++] = binaryValue.bytes[j++];
}
return vector;
public byte[] getVectorFromBytesRef(BytesRef binaryValue) {
return binaryValue.bytes;
}

@Override
Expand All @@ -75,15 +68,8 @@ public FieldType createKnnVectorFieldType(int dimension, KNNVectorSimilarityFunc
}

@Override
public float[] getVectorFromBytesRef(BytesRef binaryValue) {
float[] vector = new float[binaryValue.length];
int i = 0;
int j = binaryValue.offset;

while (i < binaryValue.length) {
vector[i++] = binaryValue.bytes[j++];
}
return vector;
public byte[] getVectorFromBytesRef(BytesRef binaryValue) {
return binaryValue.bytes;
}

@Override
Expand Down Expand Up @@ -143,7 +129,7 @@ public void freeNativeMemory(long memoryAddress) {
* @param binaryValue Binary Value
* @return float vector deserialized from binary value
*/
public abstract float[] getVectorFromBytesRef(BytesRef binaryValue);
public abstract <T> T getVectorFromBytesRef(BytesRef binaryValue);

/**
* @param trainingDataAllocation training data that has been allocated in native memory
Expand Down
43 changes: 40 additions & 3 deletions src/main/java/org/opensearch/knn/plugin/script/KNNScoreScript.java
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,9 @@ public double execute(ScoreScript.ExplanationHolder explanationHolder) {
* KNNVectors with float[] type. The query value passed in is expected to be float[]. The fieldType of the docs
* being searched over are expected to be KNNVector type.
*/
public static class KNNVectorType extends KNNScoreScript<float[]> {
public static class KNNFloatVectorType extends KNNScoreScript<float[]> {

public KNNVectorType(
public KNNFloatVectorType(
Map<String, Object> params,
float[] queryValue,
String field,
Expand All @@ -136,8 +136,45 @@ public KNNVectorType(
* @return score of the vector to the query vector
*/
@Override
@SuppressWarnings("unchecked")
public double execute(ScoreScript.ExplanationHolder explanationHolder) {
KNNVectorScriptDocValues scriptDocValues = (KNNVectorScriptDocValues) getDoc().get(this.field);
KNNVectorScriptDocValues<float[]> scriptDocValues = (KNNVectorScriptDocValues<float[]>) getDoc().get(this.field);
if (scriptDocValues.isEmpty()) {
return 0.0;
}
return this.scoringMethod.apply(this.queryValue, scriptDocValues.getValue());
}
}

/**
* KNNVectors with byte[] type. The query value passed in is expected to be byte[]. The fieldType of the docs
* being searched over are expected to be KNNVector type.
*/
public static class KNNByteVectorType extends KNNScoreScript<byte[]> {

public KNNByteVectorType(
Map<String, Object> params,
byte[] queryValue,
String field,
BiFunction<byte[], byte[], Float> scoringMethod,
SearchLookup lookup,
LeafReaderContext leafContext,
IndexSearcher searcher
) throws IOException {
super(params, queryValue, field, scoringMethod, lookup, leafContext, searcher);
}

/**
* This function called for each doc in the segment. We evaluate the score of the vector in the doc
*
* @param explanationHolder A helper to take in an explanation from a script and turn
* it into an {@link org.apache.lucene.search.Explanation}
* @return score of the vector to the query vector
*/
@Override
@SuppressWarnings("unchecked")
public double execute(ScoreScript.ExplanationHolder explanationHolder) {
KNNVectorScriptDocValues<byte[]> scriptDocValues = (KNNVectorScriptDocValues<byte[]>) getDoc().get(this.field);
if (scriptDocValues.isEmpty()) {
return 0.0;
}
Expand Down
Loading

0 comments on commit 69e2303

Please sign in to comment.