Skip to content

Commit

Permalink
Fix code review comments
Browse files Browse the repository at this point in the history
Signed-off-by: Vijayan Balasubramanian <[email protected]>
  • Loading branch information
VijayanB committed Jan 14, 2025
1 parent 67c8eac commit e993c4d
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,14 @@ public class Faiss extends NativeLibrary {
// https://opensearch.org/docs/latest/search-plugins/knn/approximate-knn/#spaces
private final static Map<SpaceType, Function<Float, Float>> SCORE_TO_DISTANCE_TRANSFORMATIONS = ImmutableMap.<
SpaceType,
Function<Float, Float>>builder().put(SpaceType.INNER_PRODUCT, score -> score > 1 ? 1 - score : 1 / score - 1).build();
Function<Float, Float>>builder()
.put(SpaceType.INNER_PRODUCT, score -> score > 1 ? 1 - score : (1 / score) - 1)
.put(SpaceType.COSINESIMIL, score -> 2 - 2 * score)
.build();

private final static Map<SpaceType, Function<Float, Float>> DISTANCE_TRANSLATIONS = ImmutableMap.<
SpaceType,
Function<Float, Float>>builder().put(SpaceType.COSINESIMIL, distance -> (2 - distance) / 2).build();
Function<Float, Float>>builder().put(SpaceType.COSINESIMIL, distance -> 1 - distance).build();

// Package private so that the method resolving logic can access the methods
final static Map<String, KNNMethod> METHODS = ImmutableMap.of(METHOD_HNSW, new FaissHNSWMethod(), METHOD_IVF, new FaissIVFMethod());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,7 @@ protected void validatePreparse() {
protected abstract VectorValidator getVectorValidator();

/**
* Getter for per dimension validator during vector parsing
* Getter for per dimension validator during vector parsing, and before any transformation
*
* @return PerDimensionValidator
*/
Expand All @@ -681,6 +681,11 @@ protected void validatePreparse() {
*/
protected abstract PerDimensionProcessor getPerDimensionProcessor();

/**
* Getter for vector transformer after vector parsing and validation
*
* @return VectorTransformer
*/
protected abstract VectorTransformer getVectorTransformer();

protected void parseCreateField(ParseContext context, int dimension, VectorDataType vectorDataType) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.index.query.QueryShardException;
import org.opensearch.knn.index.KNNVectorIndexFieldData;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.query.rescore.RescoreContext;
import org.opensearch.search.aggregations.support.CoreValuesSourceType;
import org.opensearch.search.lookup.SearchLookup;
Expand Down Expand Up @@ -99,4 +101,14 @@ public RescoreContext resolveRescoreContext(RescoreContext userProvidedContext)
Mode mode = knnMappingConfig.getMode();
return compressionLevel.getDefaultRescoreContext(mode, dimension);
}

public float[] transformQueryVector(float[] vector, KNNEngine knnEngine, SpaceType spaceType) {
if (vector == null) {
throw new IllegalArgumentException("Vector cannot be null");
}
if (knnEngine != KNNEngine.FAISS || VectorDataType.FLOAT != vectorDataType) {
return vector;
}
return VectorTransformerFactory.getVectorTransformer(knnEngine, spaceType).transform(vector);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import org.apache.lucene.util.VectorUtil;

import java.util.Arrays;

/**
* Normalizes vectors using L2 (Euclidean) normalization. This transformation ensures
* that the vector's magnitude becomes 1 while preserving its directional properties.
Expand All @@ -26,6 +28,9 @@ public float[] transform(float[] vector) {
if (vector == null || vector.length == 0) {
throw new IllegalArgumentException("Vector cannot be null or empty");
}
return VectorUtil.l2normalize(vector);
// l2normalize method will update input vector in place, hence, to avoid side effects,
// copy input vector and normalize it
float[] transformedVector = Arrays.copyOf(vector, vector.length);
return VectorUtil.l2normalize(transformedVector);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import org.apache.commons.lang.StringUtils;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.util.VectorUtil;
import org.opensearch.common.ValidationException;
import org.opensearch.core.ParseField;
import org.opensearch.core.common.Strings;
Expand Down Expand Up @@ -528,6 +527,7 @@ protected Query doToQuery(QueryShardContext context) {
default:
spaceType.validateVector(vector);
}
float[] transformedVector = knnVectorFieldType.transformQueryVector(vector, knnEngine, spaceType);

if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine)
&& filter != null
Expand Down Expand Up @@ -613,12 +613,6 @@ private void updateQueryStats(VectorQueryType vectorQueryType) {
}

private float[] getVectorForCreatingQueryRequest(VectorDataType vectorDataType, KNNEngine knnEngine, SpaceType spaceType) {

// Cosine similarity is supported as Inner product by FAISS by normalizing input vector, hence, we have to normalize
// query vector before applying search
if (knnEngine == KNNEngine.FAISS && spaceType == SpaceType.COSINESIMIL && VectorDataType.FLOAT == vectorDataType) {
return VectorUtil.l2normalize(this.vector);
}
if ((VectorDataType.FLOAT == vectorDataType) || (VectorDataType.BYTE == vectorDataType && KNNEngine.FAISS == knnEngine)) {
return this.vector;
}
Expand Down

0 comments on commit e993c4d

Please sign in to comment.