Skip to content

Commit

Permalink
Addressed comments
Browse files Browse the repository at this point in the history
Signed-off-by: Bansi Kasundra <[email protected]>
  • Loading branch information
kasundra07 committed Jan 14, 2025
1 parent bdfc4f8 commit 273f181
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ public ScoreScript getScoreScript(
ctx,
searcher
);
} else {
} else if (processedQuery instanceof byte[]) {
return new KNNScoreScript.KNNByteVectorType(
params,
(byte[]) this.processedQuery,
Expand All @@ -102,6 +102,10 @@ public ScoreScript getScoreScript(
ctx,
searcher
);
} else {
throw new IllegalStateException(
"Unexpected type for processedQuery. Expected float[] or byte[], but got: " + processedQuery.getClass().getName()
);
}
}

Expand Down Expand Up @@ -139,7 +143,7 @@ protected Object getProcessedQuery(final Object query, final KNNVectorFieldType
VectorDataType vectorDataType = knnVectorFieldType.getVectorDataType() == null
? VectorDataType.FLOAT
: knnVectorFieldType.getVectorDataType();
if (vectorDataType.equals(VectorDataType.FLOAT)) {
if (vectorDataType == VectorDataType.FLOAT) {
return parseToFloatArray(
query,
KNNVectorFieldMapperUtil.getExpectedVectorLength(knnVectorFieldType),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ public static byte[] parseToByteArray(Object object, int expectedVectorLength, V
/**
* Converts Object vector to byte[]
*
* Expects all numbers in the Object vector to be in the byte range of [-128 to 127]
* @param vector input vector
* @return Byte array representing the vector
*/
Expand Down
13 changes: 2 additions & 11 deletions src/test/java/org/opensearch/knn/integ/KNNScriptScoringIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -779,15 +779,6 @@ private float[] randomVector(final int dimensions, final VectorDataType vectorDa
return vector;
}

private byte[] randomByteVector(final int dimensions, final VectorDataType vectorDataType) {
int size = VectorDataType.BINARY == vectorDataType ? dimensions / 8 : dimensions;
final byte[] vector = new byte[size];
for (int i = 0; i < size; i++) {
vector[i] = randomByte();
}
return vector;
}

private Map<String, KNNResult> createDataset(
Function<float[], Float> scoreFunction,
int dimensions,
Expand Down Expand Up @@ -821,7 +812,7 @@ private Map<String, KNNResult> createDataset(
FIELD_NAME,
Collections.emptyMap(),
vectorDataType,
getMappingConfigForFlatMapping(vectorDataType.equals(VectorDataType.BINARY) ? queryVector.length * 8 : queryVector.length)
getMappingConfigForFlatMapping(vectorDataType == VectorDataType.BINARY ? queryVector.length * 8 : queryVector.length)
)
);
switch (spaceType) {
Expand All @@ -831,7 +822,7 @@ private Map<String, KNNResult> createDataset(
case COSINESIMIL:
case INNER_PRODUCT:
case HAMMING:
if (vectorDataType.equals(VectorDataType.FLOAT)) {
if (vectorDataType == VectorDataType.FLOAT) {
return ((KNNScoringSpace.KNNFieldSpace) knnScoringSpace).getScoringMethod(queryVector);
}
return ((KNNScoringSpace.KNNFieldSpace) knnScoringSpace).getScoringMethod(toByte(queryVector));
Expand Down

0 comments on commit 273f181

Please sign in to comment.