Skip to content

Commit

Permalink
Add binary index support for Lucene engine
Browse files Browse the repository at this point in the history
Signed-off-by: Jay Deng <[email protected]>
  • Loading branch information
jed326 authored and Jay Deng committed Nov 27, 2024
1 parent 7523cc3 commit 625fd1c
Show file tree
Hide file tree
Showing 13 changed files with 259 additions and 20 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.18...2.x)
### Features
- Add binary index support for Lucene engine. (#2292)[https://github.com/opensearch-project/k-NN/pull/2292]
### Enhancements
- Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241]
### Bug Fixes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ public float compare(byte[] v1, byte[] v2) {

@Override
public VectorSimilarityFunction getVectorSimilarityFunction() {
throw new IllegalStateException("VectorSimilarityFunction is not available for Hamming space");
// This is not used in binary case
return VectorSimilarityFunction.EUCLIDEAN;
}
};

Expand Down
2 changes: 1 addition & 1 deletion src/main/java/org/opensearch/knn/index/VectorDataType.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public enum VectorDataType {

@Override
public FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunction vectorSimilarityFunction) {
throw new IllegalStateException("Unsupported method");
return KnnByteVectorField.createFieldType(dimension / 8, vectorSimilarityFunction);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,12 @@ public KnnVectorsFormat getKnnVectorsFormatForField(final String field) {
}
}

KNNVectorsFormatParams knnVectorsFormatParams = new KNNVectorsFormatParams(params, defaultMaxConnections, defaultBeamWidth);
KNNVectorsFormatParams knnVectorsFormatParams = new KNNVectorsFormatParams(
params,
defaultMaxConnections,
defaultBeamWidth,
knnMethodContext.getSpaceType()
);
log.debug(
"Initialize KNN vector format for field [{}] with params [{}] = \"{}\" and [{}] = \"{}\"",
field,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.KNN990Codec;

import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;

import java.io.IOException;

public class KNN990BinaryVectorScorer implements FlatVectorsScorer {
@Override
public RandomVectorScorerSupplier getRandomVectorScorerSupplier(
VectorSimilarityFunction vectorSimilarityFunction,
RandomAccessVectorValues randomAccessVectorValues
) throws IOException {
assert randomAccessVectorValues instanceof RandomAccessVectorValues.Bytes;
if (randomAccessVectorValues instanceof RandomAccessVectorValues.Bytes) {
return new BinaryRandomVectorScorerSupplier((RandomAccessVectorValues.Bytes) randomAccessVectorValues);
}
throw new IllegalArgumentException("vectorValues must be an instance of RandomAccessVectorValues.Bytes");
}

@Override
public RandomVectorScorer getRandomVectorScorer(
VectorSimilarityFunction vectorSimilarityFunction,
RandomAccessVectorValues randomAccessVectorValues,
float[] queryVector
) throws IOException {
throw new IllegalArgumentException("binary vectors do not support float[] targets");
}

@Override
public RandomVectorScorer getRandomVectorScorer(
VectorSimilarityFunction vectorSimilarityFunction,
RandomAccessVectorValues randomAccessVectorValues,
byte[] queryVector
) throws IOException {
assert randomAccessVectorValues instanceof RandomAccessVectorValues.Bytes;
if (randomAccessVectorValues instanceof RandomAccessVectorValues.Bytes) {
return new BinaryRandomVectorScorer((RandomAccessVectorValues.Bytes) randomAccessVectorValues, queryVector);
}
throw new IllegalArgumentException("vectorValues must be an instance of RandomAccessVectorValues.Bytes");
}

static class BinaryRandomVectorScorer implements RandomVectorScorer {
private final RandomAccessVectorValues.Bytes vectorValues;
private final int bitDimensions;
private final byte[] queryVector;

BinaryRandomVectorScorer(RandomAccessVectorValues.Bytes vectorValues, byte[] query) {
this.queryVector = query;
this.bitDimensions = vectorValues.dimension() * Byte.SIZE;
this.vectorValues = vectorValues;
}

@Override
public float score(int node) throws IOException {
return (bitDimensions - VectorUtil.xorBitCount(queryVector, vectorValues.vectorValue(node))) / (float) bitDimensions;
}

@Override
public int maxOrd() {
return vectorValues.size();
}

@Override
public int ordToDoc(int ord) {
return vectorValues.ordToDoc(ord);
}

@Override
public Bits getAcceptOrds(Bits acceptDocs) {
return vectorValues.getAcceptOrds(acceptDocs);
}
}

static class BinaryRandomVectorScorerSupplier implements RandomVectorScorerSupplier {
protected final RandomAccessVectorValues.Bytes vectorValues;
protected final RandomAccessVectorValues.Bytes vectorValues1;
protected final RandomAccessVectorValues.Bytes vectorValues2;

public BinaryRandomVectorScorerSupplier(RandomAccessVectorValues.Bytes vectorValues) throws IOException {
this.vectorValues = vectorValues;
this.vectorValues1 = vectorValues.copy();
this.vectorValues2 = vectorValues.copy();
}

@Override
public RandomVectorScorer scorer(int ord) throws IOException {
byte[] queryVector = vectorValues1.vectorValue(ord);
return new BinaryRandomVectorScorer(vectorValues2, queryVector);
}

@Override
public RandomVectorScorerSupplier copy() throws IOException {
return new BinaryRandomVectorScorerSupplier(vectorValues.copy());
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.KNN990Codec;

import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.search.TaskExecutor;

import java.io.IOException;
import java.util.concurrent.ExecutorService;

public class KNN990HnswBinaryVectorsFormat extends KnnVectorsFormat {

private final int maxConn;
private final int beamWidth;
private static final FlatVectorsFormat flatVectorsFormat = new Lucene99FlatVectorsFormat(new KNN990BinaryVectorScorer());
private final int numMergeWorkers;
private final TaskExecutor mergeExec;

private static final String NAME = "KNN990HnswBinaryVectorsFormat";

public KNN990HnswBinaryVectorsFormat() {
this(16, 100, 1, (ExecutorService) null);
}

public KNN990HnswBinaryVectorsFormat(int maxConn, int beamWidth) {
this(maxConn, beamWidth, 1, (ExecutorService) null);
}

public KNN990HnswBinaryVectorsFormat(int maxConn, int beamWidth, int numMergeWorkers, ExecutorService mergeExec) {
super(NAME);
if (maxConn > 0 && maxConn <= 512) {
if (beamWidth > 0 && beamWidth <= 3200) {
this.maxConn = maxConn;
this.beamWidth = beamWidth;
if (numMergeWorkers == 1 && mergeExec != null) {
throw new IllegalArgumentException("No executor service is needed as we'll use single thread to merge");
} else {
this.numMergeWorkers = numMergeWorkers;
if (mergeExec != null) {
this.mergeExec = new TaskExecutor(mergeExec);
} else {
this.mergeExec = null;
}

}
} else {
throw new IllegalArgumentException("beamWidth must be positive and less than or equal to 3200; beamWidth=" + beamWidth);
}
} else {
throw new IllegalArgumentException("maxConn must be positive and less than or equal to 512; maxConn=" + maxConn);
}
}

@Override
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
return new Lucene99HnswVectorsWriter(
state,
this.maxConn,
this.beamWidth,
flatVectorsFormat.fieldsWriter(state),
this.numMergeWorkers,
this.mergeExec
);
}

@Override
public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
return new Lucene99HnswVectorsReader(state, flatVectorsFormat.fieldsReader(state));
}

@Override
public int getMaxDimensions(String fieldName) {
return 1024;
}

@Override
public String toString() {
return "KNN990HnswBinaryVectorsFormat(name=KNN990HnswBinaryVectorsFormat, maxConn="
+ this.maxConn
+ ", beamWidth="
+ this.beamWidth
+ ", flatVectorFormat="
+ flatVectorsFormat
+ ")";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.codec.BasePerFieldKnnVectorsFormat;
import org.opensearch.knn.index.engine.KNNEngine;

Expand All @@ -24,11 +25,17 @@ public KNN990PerFieldKnnVectorsFormat(final Optional<MapperService> mapperServic
mapperService,
Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN,
Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH,
() -> new Lucene99HnswVectorsFormat(),
knnVectorsFormatParams -> new Lucene99HnswVectorsFormat(
knnVectorsFormatParams.getMaxConnections(),
knnVectorsFormatParams.getBeamWidth()
),
Lucene99HnswVectorsFormat::new,
knnVectorsFormatParams -> {
if (knnVectorsFormatParams.getSpaceType() == SpaceType.HAMMING) {
return new KNN990HnswBinaryVectorsFormat(
knnVectorsFormatParams.getMaxConnections(),
knnVectorsFormatParams.getBeamWidth()
);
} else {
return new Lucene99HnswVectorsFormat(knnVectorsFormatParams.getMaxConnections(), knnVectorsFormatParams.getBeamWidth());
}
},
knnScalarQuantizedVectorsFormatParams -> new Lucene99HnswScalarQuantizedVectorsFormat(
knnScalarQuantizedVectorsFormatParams.getMaxConnections(),
knnScalarQuantizedVectorsFormatParams.getBeamWidth(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import lombok.Getter;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.SpaceType;

import java.util.Map;

Expand All @@ -17,10 +18,16 @@
public class KNNVectorsFormatParams {
private int maxConnections;
private int beamWidth;
private final SpaceType spaceType;

public KNNVectorsFormatParams(final Map<String, Object> params, int defaultMaxConnections, int defaultBeamWidth) {
this(params, defaultMaxConnections, defaultBeamWidth, SpaceType.UNDEFINED);
}

public KNNVectorsFormatParams(final Map<String, Object> params, int defaultMaxConnections, int defaultBeamWidth, SpaceType spaceType) {
initMaxConnections(params, defaultMaxConnections);
initBeamWidth(params, defaultBeamWidth);
this.spaceType = spaceType;
}

public boolean validate(final Map<String, Object> params) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,18 @@
*/
public class LuceneHNSWMethod extends AbstractKNNMethod {

private static final Set<VectorDataType> SUPPORTED_DATA_TYPES = ImmutableSet.of(VectorDataType.FLOAT, VectorDataType.BYTE);
private static final Set<VectorDataType> SUPPORTED_DATA_TYPES = ImmutableSet.of(
VectorDataType.FLOAT,
VectorDataType.BYTE,
VectorDataType.BINARY
);

public final static List<SpaceType> SUPPORTED_SPACES = Arrays.asList(
SpaceType.UNDEFINED,
SpaceType.L2,
SpaceType.COSINESIMIL,
SpaceType.INNER_PRODUCT
SpaceType.INNER_PRODUCT,
SpaceType.HAMMING
);

final static Encoder SQ_ENCODER = new LuceneSQEncoder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ public static Query create(CreateQueryRequest createQueryRequest) {
log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k));
switch (vectorDataType) {
case BYTE:
case BINARY:
return getKnnByteVectorQuery(fieldName, byteVector, luceneK, filterQuery, parentFilter);
case FLOAT:
return getKnnFloatVectorQuery(fieldName, vector, luceneK, filterQuery, parentFilter);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@
#

org.opensearch.knn.index.codec.KNN990Codec.NativeEngines990KnnVectorsFormat
org.opensearch.knn.index.codec.KNN990Codec.KNN990HnswBinaryVectorsFormat
Loading

0 comments on commit 625fd1c

Please sign in to comment.