From f6eca79d5d99fa9d9193a1dcb2001e946f7adbe7 Mon Sep 17 00:00:00 2001 From: Joel Knighton Date: Tue, 12 Nov 2024 16:36:10 -0600 Subject: [PATCH 1/8] WIP --- .../github/jbellis/jvector/pq/PQDecoder.java | 10 +---- .../jbellis/jvector/vector/VectorUtil.java | 4 ++ .../jvector/vector/VectorUtilSupport.java | 14 +++++++ .../vector/PanamaVectorUtilSupport.java | 6 +++ .../jbellis/jvector/vector/SimdOps.java | 39 +++++++++++++++++++ 5 files changed, 64 insertions(+), 9 deletions(-) diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/PQDecoder.java b/jvector-base/src/main/java/io/github/jbellis/jvector/pq/PQDecoder.java index 0be0e088..e417d50e 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/PQDecoder.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/pq/PQDecoder.java @@ -131,18 +131,10 @@ public float similarityTo(int node2) { } protected float decodedCosine(int node2) { - float sum = 0.0f; - float aMag = 0.0f; ByteSequence encoded = cv.get(node2); - for (int m = 0; m < encoded.length(); ++m) { - int centroidIndex = Byte.toUnsignedInt(encoded.get(m)); - sum += partialSums.get((m * cv.pq.getClusterCount()) + centroidIndex); - aMag += aMagnitude.get((m * cv.pq.getClusterCount()) + centroidIndex); - } - - return (float) (sum / Math.sqrt(aMag * bMagnitude)); + return VectorUtil.decodedCosineSimilarity(encoded, cv.pq.getClusterCount(), partialSums, aMagnitude, bMagnitude); } } } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java index a6d87807..d860cc1b 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java @@ -194,4 +194,8 @@ public static float max(VectorFloat v) { public static float min(VectorFloat v) { return impl.min(v); } + + public static float decodedCosineSimilarity(ByteSequence encoded, int clusterCount, VectorFloat partialSums, VectorFloat aMagnitude, float bMagnitude) { + return impl.decodedCosineSimilarity(encoded, clusterCount, partialSums, aMagnitude, bMagnitude); + } } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java index 46cb4f18..e3e2953e 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java @@ -199,4 +199,18 @@ default void bulkShuffleQuantizedSimilarityCosine(ByteSequence shuffles, int float max(VectorFloat v); float min(VectorFloat v); + default float decodedCosineSimilarity(ByteSequence encoded, int clusterCount, VectorFloat partialSums, VectorFloat aMagnitude, float bMagnitude) + { + float sum = 0.0f; + float aMag = 0.0f; + + for (int m = 0; m < encoded.length(); ++m) { + int centroidIndex = Byte.toUnsignedInt(encoded.get(m)); + var index = m * clusterCount + centroidIndex; + sum += partialSums.get(index); + aMag += aMagnitude.get(index); + } + + return (float) (sum / Math.sqrt(aMag * bMagnitude)); + } } diff --git a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java index e0c2be5e..e71d780a 100644 --- a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java +++ b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java @@ -159,5 +159,11 @@ public void calculatePartialSums(VectorFloat codebook, int codebookIndex, int public void quantizePartials(float delta, VectorFloat partials, VectorFloat partialBases, ByteSequence quantizedPartials) { SimdOps.quantizePartials(delta, (ArrayVectorFloat) partials, (ArrayVectorFloat) partialBases, (ArrayByteSequence) quantizedPartials); } + + @Override + public float decodedCosineSimilarity(ByteSequence encoded, int clusterCount, VectorFloat partialSums, VectorFloat aMagnitude, float bMagnitude) + { + return SimdOps.decodedCosineSimilarity((ArrayByteSequence) encoded, clusterCount, (ArrayVectorFloat) partialSums, (ArrayVectorFloat) aMagnitude, bMagnitude); + } } diff --git a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java index d5d132e8..b16d82b8 100644 --- a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java +++ b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java @@ -658,4 +658,43 @@ public static void quantizePartials(float delta, ArrayVectorFloat partials, Arra } } } + + public static float decodedCosineSimilarity(ArrayByteSequence encoded, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) { + var sum = FloatVector.zero(FloatVector.SPECIES_512); + var vaMagnitude = FloatVector.zero(FloatVector.SPECIES_512); + var baseOffsets = encoded.get(); + var partialSumsArray = partialSums.get(); + var aMagnitudeArray = aMagnitude.get(); + + int[] convOffsets = scratchInt512.get(); + int i = 0; + int limit = ByteVector.SPECIES_128.loopBound(baseOffsets.length); + + var scale = IntVector.zero(IntVector.SPECIES_512).addIndex(clusterCount); + + for (; i < limit; i += ByteVector.SPECIES_128.length()) { + + ByteVector.fromArray(ByteVector.SPECIES_128, baseOffsets, i) + .convertShape(VectorOperators.B2I, IntVector.SPECIES_512, 0) + .lanewise(VectorOperators.AND, BYTE_TO_INT_MASK_512) + .reinterpretAsInts() + .add(scale) + .intoArray(convOffsets,0); + + var offset = i * clusterCount; + sum = sum.add(FloatVector.fromArray(FloatVector.SPECIES_512, partialSumsArray, offset, convOffsets, 0)); + vaMagnitude = vaMagnitude.add(FloatVector.fromArray(FloatVector.SPECIES_512, aMagnitudeArray, offset, convOffsets, 0)); + } + + float sumResult = sum.reduceLanes(VectorOperators.ADD); + float aMagnitudeResult = vaMagnitude.reduceLanes(VectorOperators.ADD); + + for (; i < baseOffsets.length; i++) { + int offset = clusterCount * i + Byte.toUnsignedInt(baseOffsets[i]); + sumResult += partialSumsArray[offset]; + aMagnitudeResult += aMagnitudeArray[offset]; + } + + return (float) (sumResult / Math.sqrt(aMagnitudeResult * bMagnitude)); + } } From 20ed4ac99bb0ea242618cd5399912338c496a9e0 Mon Sep 17 00:00:00 2001 From: Michael Marshall Date: Thu, 14 Nov 2024 13:39:43 -0600 Subject: [PATCH 2/8] Break decodedCosineSimilarity out by HAS_AVX512 --- .../jbellis/jvector/vector/SimdOps.java | 45 +++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java index b16d82b8..55ef28f3 100644 --- a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java +++ b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java @@ -660,6 +660,12 @@ public static void quantizePartials(float delta, ArrayVectorFloat partials, Arra } public static float decodedCosineSimilarity(ArrayByteSequence encoded, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) { + return HAS_AVX512 + ? decodedCosineSimilarity512(encoded, clusterCount, partialSums, aMagnitude, bMagnitude) + : decodedCosineSimilarity256(encoded, clusterCount, partialSums, aMagnitude, bMagnitude); + } + + public static float decodedCosineSimilarity512(ArrayByteSequence encoded, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) { var sum = FloatVector.zero(FloatVector.SPECIES_512); var vaMagnitude = FloatVector.zero(FloatVector.SPECIES_512); var baseOffsets = encoded.get(); @@ -697,4 +703,43 @@ public static float decodedCosineSimilarity(ArrayByteSequence encoded, int clust return (float) (sumResult / Math.sqrt(aMagnitudeResult * bMagnitude)); } + + public static float decodedCosineSimilarity256(ArrayByteSequence encoded, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) { + var sum = FloatVector.zero(FloatVector.SPECIES_256); + var vaMagnitude = FloatVector.zero(FloatVector.SPECIES_256); + var baseOffsets = encoded.get(); + var partialSumsArray = partialSums.get(); + var aMagnitudeArray = aMagnitude.get(); + + int[] convOffsets = scratchInt256.get(); + int i = 0; + int limit = ByteVector.SPECIES_64.loopBound(baseOffsets.length); + + var scale = IntVector.zero(IntVector.SPECIES_256).addIndex(clusterCount); + + for (; i < limit; i += ByteVector.SPECIES_64.length()) { + + ByteVector.fromArray(ByteVector.SPECIES_64, baseOffsets, i) + .convertShape(VectorOperators.B2I, IntVector.SPECIES_256, 0) + .lanewise(VectorOperators.AND, BYTE_TO_INT_MASK_256) + .reinterpretAsInts() + .add(scale) + .intoArray(convOffsets,0); + + var offset = i * clusterCount; + sum = sum.add(FloatVector.fromArray(FloatVector.SPECIES_256, partialSumsArray, offset, convOffsets, 0)); + vaMagnitude = vaMagnitude.add(FloatVector.fromArray(FloatVector.SPECIES_256, aMagnitudeArray, offset, convOffsets, 0)); + } + + float sumResult = sum.reduceLanes(VectorOperators.ADD); + float aMagnitudeResult = vaMagnitude.reduceLanes(VectorOperators.ADD); + + for (; i < baseOffsets.length; i++) { + int offset = clusterCount * i + Byte.toUnsignedInt(baseOffsets[i]); + sumResult += partialSumsArray[offset]; + aMagnitudeResult += aMagnitudeArray[offset]; + } + + return (float) (sumResult / Math.sqrt(aMagnitudeResult * bMagnitude)); + } } From bebb4f6f60d7cac9a6073922652f7a6e96df41d6 Mon Sep 17 00:00:00 2001 From: Michael Marshall Date: Thu, 14 Nov 2024 15:08:26 -0600 Subject: [PATCH 3/8] Make PVUS#assembleAndSum use SimdOps; optimize SimdOps assembleAndSum --- .../jbellis/jvector/vector/PanamaVectorUtilSupport.java | 6 +----- .../java/io/github/jbellis/jvector/vector/SimdOps.java | 9 ++++----- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java index e71d780a..060ca9e1 100644 --- a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java +++ b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java @@ -92,11 +92,7 @@ public VectorFloat sub(VectorFloat a, int aOffset, VectorFloat b, int b @Override public float assembleAndSum(VectorFloat data, int dataBase, ByteSequence baseOffsets) { - float sum = 0f; - for (int i = 0; i < baseOffsets.length(); i++) { - sum += data.get(dataBase * i + Byte.toUnsignedInt(baseOffsets.get(i))); - } - return sum; + return SimdOps.assembleAndSum(((ArrayVectorFloat) data).get(), dataBase, ((ArrayByteSequence) baseOffsets).get()); } @Override diff --git a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java index 55ef28f3..60c5170b 100644 --- a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java +++ b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java @@ -525,11 +525,10 @@ static float assembleAndSum512(float[] data, int dataBase, byte[] baseOffsets) { FloatVector sum = FloatVector.zero(FloatVector.SPECIES_512); int i = 0; int limit = ByteVector.SPECIES_128.loopBound(baseOffsets.length); + var scale = IntVector.zero(IntVector.SPECIES_512).addIndex(dataBase); for (; i < limit; i += ByteVector.SPECIES_128.length()) { - var scale = IntVector.zero(IntVector.SPECIES_512).addIndex(1).add(i).mul(dataBase); - - ByteVector.fromArray(ByteVector.SPECIES_128, baseOffsets, i) + ByteVector.fromArray(ByteVector.SPECIES_128, baseOffsets, i * dataBase) .convertShape(VectorOperators.B2I, IntVector.SPECIES_512, 0) .lanewise(VectorOperators.AND, BYTE_TO_INT_MASK_512) .reinterpretAsInts() @@ -553,11 +552,11 @@ static float assembleAndSum256(float[] data, int dataBase, byte[] baseOffsets) { FloatVector sum = FloatVector.zero(FloatVector.SPECIES_256); int i = 0; int limit = ByteVector.SPECIES_64.loopBound(baseOffsets.length); + var scale = IntVector.zero(IntVector.SPECIES_256).addIndex(dataBase); for (; i < limit; i += ByteVector.SPECIES_64.length()) { - var scale = IntVector.zero(IntVector.SPECIES_256).addIndex(1).add(i).mul(dataBase); - ByteVector.fromArray(ByteVector.SPECIES_64, baseOffsets, i) + ByteVector.fromArray(ByteVector.SPECIES_64, baseOffsets, i * dataBase) .convertShape(VectorOperators.B2I, IntVector.SPECIES_256, 0) .lanewise(VectorOperators.AND, BYTE_TO_INT_MASK_256) .reinterpretAsInts() From feb1e02b3c068e5e294513903dfb9e3167cc24eb Mon Sep 17 00:00:00 2001 From: Michael Marshall Date: Thu, 14 Nov 2024 16:18:47 -0600 Subject: [PATCH 4/8] Attempt native implemenation --- jvector-native/src/main/c/jvector_simd.c | 45 ++++++++++++++++++++++++ jvector-native/src/main/c/jvector_simd.h | 1 + 2 files changed, 46 insertions(+) diff --git a/jvector-native/src/main/c/jvector_simd.c b/jvector-native/src/main/c/jvector_simd.c index 886186fa..22422502 100644 --- a/jvector-native/src/main/c/jvector_simd.c +++ b/jvector-native/src/main/c/jvector_simd.c @@ -318,6 +318,51 @@ float assemble_and_sum_f32_512(const float* data, int dataBase, const unsigned c return res; } +float decoded_cosine_similarity_f32_512(const unsigned char* baseOffsets, int baseOffsetsLength, int clusterCount, const float* partialSums, const float* aMagnitude, float bMagnitude) { + __m512 sum = _mm512_setzero_ps(); + __m512 vaMagnitude = _mm512_setzero_ps(); + int i = 0; + int limit = baseOffsetsLength - (baseOffsetsLength % 16); + __m512i indexRegister = initialIndexRegister; + __m512i scale = _mm512_set1_epi32(clusterCount); + + + for (; i < limit; i += 16) { + // Load and convert baseOffsets to integers + __m128i baseOffsetsRaw = _mm_loadu_si128((__m128i *)(baseOffsets + i)); + __m512i baseOffsetsInt = _mm512_cvtepu8_epi32(baseOffsetsRaw); + + indexRegister = _mm512_add_epi32(indexRegister, indexIncrement); + // Scale the baseOffsets by the cluster count + __m512i scaledOffsets = _mm512_mullo_epi32(indexRegister, scale); + + // Compute the offset base by multiplying 'i' with clusterCount and broadcasting to all lanes + __m512i offsetBase = _mm512_set1_epi32(i * clusterCount); + + // Calculate the final convOffsets by adding the scaled offsets and the offset base + __m512i convOffsets = _mm512_add_epi32(scaledOffsets, offsetBase); + + // Gather and sum values for partial sums and a magnitude + __m512 partialSumVals = _mm512_i32gather_ps(convOffsets, partialSums, 4); + sum = _mm512_add_ps(sum, partialSumVals); + + __m512 aMagnitudeVals = _mm512_i32gather_ps(convOffsets, aMagnitude, 4); + vaMagnitude = _mm512_add_ps(vaMagnitude, aMagnitudeVals); + } + + // Reduce sums + float sumResult = _mm512_reduce_add_ps(sum); + float aMagnitudeResult = _mm512_reduce_add_ps(vaMagnitude); + + // Handle the remaining elements + for (; i < baseOffsetsLength; i++) { + int offset = clusterCount * i + baseOffsets[i]; + sumResult += partialSums[offset]; + aMagnitudeResult += aMagnitude[offset]; + } + + return sumResult / sqrtf(aMagnitudeResult * bMagnitude); +} void calculate_partial_sums_dot_f32_512(const float* codebook, int codebookIndex, int size, int clusterCount, const float* query, int queryOffset, float* partialSums) { int codebookBase = codebookIndex * clusterCount; diff --git a/jvector-native/src/main/c/jvector_simd.h b/jvector-native/src/main/c/jvector_simd.h index a5410ef5..1b96a0a8 100644 --- a/jvector-native/src/main/c/jvector_simd.h +++ b/jvector-native/src/main/c/jvector_simd.h @@ -29,6 +29,7 @@ void bulk_quantized_shuffle_dot_f32_512(const unsigned char* shuffles, int codeb void bulk_quantized_shuffle_euclidean_f32_512(const unsigned char* shuffles, int codebookCount, const char* quantizedPartials, float delta, float minDistance, float* results); void bulk_quantized_shuffle_cosine_f32_512(const unsigned char* shuffles, int codebookCount, const char* quantizedPartialSums, float sumDelta, float minDistance, const char* quantizedPartialMagnitudes, float magnitudeDelta, float minMagnitude, float queryMagnitudeSquared, float* results); float assemble_and_sum_f32_512(const float* data, int dataBase, const unsigned char* baseOffsets, int baseOffsetsLength); +float decoded_cosine_similarity_f32_512(const unsigned char* baseOffsets, int baseOffsetsLength, int clusterCount, const float* partialSums, const float* aMagnitude, float bMagnitude); void calculate_partial_sums_dot_f32_512(const float* codebook, int codebookBase, int size, int clusterCount, const float* query, int queryOffset, float* partialSums); void calculate_partial_sums_euclidean_f32_512(const float* codebook, int codebookBase, int size, int clusterCount, const float* query, int queryOffset, float* partialSums); void calculate_partial_sums_best_dot_f32_512(const float* codebook, int codebookBase, int size, int clusterCount, const float* query, int queryOffset, float* partialSums, float* partialBestDistances); From aee061ea656b795d1f8a6d47f9dd3ac4a0129c59 Mon Sep 17 00:00:00 2001 From: Joel Knighton Date: Thu, 14 Nov 2024 16:38:19 -0600 Subject: [PATCH 5/8] Fix assembleAndSum --- .../java/io/github/jbellis/jvector/vector/SimdOps.java | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java index 60c5170b..a6a73f3d 100644 --- a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java +++ b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java @@ -528,14 +528,15 @@ static float assembleAndSum512(float[] data, int dataBase, byte[] baseOffsets) { var scale = IntVector.zero(IntVector.SPECIES_512).addIndex(dataBase); for (; i < limit; i += ByteVector.SPECIES_128.length()) { - ByteVector.fromArray(ByteVector.SPECIES_128, baseOffsets, i * dataBase) + ByteVector.fromArray(ByteVector.SPECIES_128, baseOffsets, i) .convertShape(VectorOperators.B2I, IntVector.SPECIES_512, 0) .lanewise(VectorOperators.AND, BYTE_TO_INT_MASK_512) .reinterpretAsInts() .add(scale) .intoArray(convOffsets,0); - sum = sum.add(FloatVector.fromArray(FloatVector.SPECIES_512, data, 0, convOffsets, 0)); + var offset = i * dataBase; + sum = sum.add(FloatVector.fromArray(FloatVector.SPECIES_512, data, offset, convOffsets, 0)); } float res = sum.reduceLanes(VectorOperators.ADD); @@ -556,14 +557,15 @@ static float assembleAndSum256(float[] data, int dataBase, byte[] baseOffsets) { for (; i < limit; i += ByteVector.SPECIES_64.length()) { - ByteVector.fromArray(ByteVector.SPECIES_64, baseOffsets, i * dataBase) + ByteVector.fromArray(ByteVector.SPECIES_64, baseOffsets, i) .convertShape(VectorOperators.B2I, IntVector.SPECIES_256, 0) .lanewise(VectorOperators.AND, BYTE_TO_INT_MASK_256) .reinterpretAsInts() .add(scale) .intoArray(convOffsets,0); - sum = sum.add(FloatVector.fromArray(FloatVector.SPECIES_256, data, 0, convOffsets, 0)); + var offset = i * dataBase; + sum = sum.add(FloatVector.fromArray(FloatVector.SPECIES_256, data, offset, convOffsets, 0)); } float res = sum.reduceLanes(VectorOperators.ADD); From 3d79217e01d462752bff8f67bf6843c6308a6cd2 Mon Sep 17 00:00:00 2001 From: Joel Knighton Date: Mon, 18 Nov 2024 17:47:18 -0600 Subject: [PATCH 6/8] Fix decoded_cosine_similarity_f32_512. Generate bindings with jextract. Call binding from NativeVectorUtilSupport --- jvector-native/src/main/c/jvector_simd.c | 8 ++- .../vector/NativeVectorUtilSupport.java | 6 +++ .../jvector/vector/cnative/NativeSimdOps.java | 52 +++++++++++++++++++ 3 files changed, 61 insertions(+), 5 deletions(-) diff --git a/jvector-native/src/main/c/jvector_simd.c b/jvector-native/src/main/c/jvector_simd.c index 22422502..c807f33c 100644 --- a/jvector-native/src/main/c/jvector_simd.c +++ b/jvector-native/src/main/c/jvector_simd.c @@ -318,6 +318,7 @@ float assemble_and_sum_f32_512(const float* data, int dataBase, const unsigned c return res; } + float decoded_cosine_similarity_f32_512(const unsigned char* baseOffsets, int baseOffsetsLength, int clusterCount, const float* partialSums, const float* aMagnitude, float bMagnitude) { __m512 sum = _mm512_setzero_ps(); __m512 vaMagnitude = _mm512_setzero_ps(); @@ -336,11 +337,8 @@ float decoded_cosine_similarity_f32_512(const unsigned char* baseOffsets, int ba // Scale the baseOffsets by the cluster count __m512i scaledOffsets = _mm512_mullo_epi32(indexRegister, scale); - // Compute the offset base by multiplying 'i' with clusterCount and broadcasting to all lanes - __m512i offsetBase = _mm512_set1_epi32(i * clusterCount); - - // Calculate the final convOffsets by adding the scaled offsets and the offset base - __m512i convOffsets = _mm512_add_epi32(scaledOffsets, offsetBase); + // Calculate the final convOffsets by adding the scaled indexes and the base offsets + __m512i convOffsets = _mm512_add_epi32(scaledOffsets, baseOffsetsInt); // Gather and sum values for partial sums and a magnitude __m512 partialSumVals = _mm512_i32gather_ps(convOffsets, partialSums, 4); diff --git a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorUtilSupport.java b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorUtilSupport.java index 3fe12e71..624ec27f 100644 --- a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorUtilSupport.java +++ b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorUtilSupport.java @@ -155,4 +155,10 @@ public void bulkShuffleQuantizedSimilarityCosine(ByteSequence shuffles, int c NativeSimdOps.bulk_quantized_shuffle_cosine_f32_512(((MemorySegmentByteSequence) shuffles).get(), codebookCount, ((MemorySegmentByteSequence) quantizedPartialSums).get(), sumDelta, minDistance, ((MemorySegmentByteSequence) quantizedPartialSquaredMagnitudes).get(), magnitudeDelta, minMagnitude, queryMagnitudeSquared, ((MemorySegmentVectorFloat) results).get()); } + + @Override + public float decodedCosineSimilarity(ByteSequence encoded, int clusterCount, VectorFloat partialSums, VectorFloat aMagnitude, float bMagnitude) + { + return NativeSimdOps.decoded_cosine_similarity_f32_512(((MemorySegmentByteSequence) encoded).get(), encoded.length(), clusterCount, ((MemorySegmentVectorFloat) partialSums).get(), ((MemorySegmentVectorFloat) aMagnitude).get(), bMagnitude); + } } diff --git a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/cnative/NativeSimdOps.java b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/cnative/NativeSimdOps.java index e148b1be..47d616ed 100644 --- a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/cnative/NativeSimdOps.java +++ b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/cnative/NativeSimdOps.java @@ -452,6 +452,58 @@ public static float assemble_and_sum_f32_512(MemorySegment data, int dataBase, M } } + private static class decoded_cosine_similarity_f32_512 { + public static final FunctionDescriptor DESC = FunctionDescriptor.of( + NativeSimdOps.C_FLOAT, + NativeSimdOps.C_POINTER, + NativeSimdOps.C_INT, + NativeSimdOps.C_INT, + NativeSimdOps.C_POINTER, + NativeSimdOps.C_POINTER, + NativeSimdOps.C_FLOAT + ); + + public static final MethodHandle HANDLE = Linker.nativeLinker().downcallHandle( + NativeSimdOps.findOrThrow("decoded_cosine_similarity_f32_512"), + DESC, Linker.Option.critical(true)); + } + + /** + * Function descriptor for: + * {@snippet lang=c : + * float decoded_cosine_similarity_f32_512(const unsigned char *baseOffsets, int baseOffsetsLength, int clusterCount, const float *partialSums, const float *aMagnitude, float bMagnitude) + * } + */ + public static FunctionDescriptor decoded_cosine_similarity_f32_512$descriptor() { + return decoded_cosine_similarity_f32_512.DESC; + } + + /** + * Downcall method handle for: + * {@snippet lang=c : + * float decoded_cosine_similarity_f32_512(const unsigned char *baseOffsets, int baseOffsetsLength, int clusterCount, const float *partialSums, const float *aMagnitude, float bMagnitude) + * } + */ + public static MethodHandle decoded_cosine_similarity_f32_512$handle() { + return decoded_cosine_similarity_f32_512.HANDLE; + } + /** + * {@snippet lang=c : + * float decoded_cosine_similarity_f32_512(const unsigned char *baseOffsets, int baseOffsetsLength, int clusterCount, const float *partialSums, const float *aMagnitude, float bMagnitude) + * } + */ + public static float decoded_cosine_similarity_f32_512(MemorySegment baseOffsets, int baseOffsetsLength, int clusterCount, MemorySegment partialSums, MemorySegment aMagnitude, float bMagnitude) { + var mh$ = decoded_cosine_similarity_f32_512.HANDLE; + try { + if (TRACE_DOWNCALLS) { + traceDowncall("decoded_cosine_similarity_f32_512", baseOffsets, baseOffsetsLength, clusterCount, partialSums, aMagnitude, bMagnitude); + } + return (float)mh$.invokeExact(baseOffsets, baseOffsetsLength, clusterCount, partialSums, aMagnitude, bMagnitude); + } catch (Throwable ex$) { + throw new AssertionError("should not reach here", ex$); + } + } + private static class calculate_partial_sums_dot_f32_512 { public static final FunctionDescriptor DESC = FunctionDescriptor.ofVoid( NativeSimdOps.C_POINTER, From f99c7462f333024efc6377adca879a45cbff8c66 Mon Sep 17 00:00:00 2001 From: Michael Marshall Date: Wed, 20 Nov 2024 21:07:01 -0600 Subject: [PATCH 7/8] Add 'pq' prefix to all forms of decodedCosineSimilarity methods --- .../github/jbellis/jvector/pq/PQDecoder.java | 2 +- .../jbellis/jvector/vector/VectorUtil.java | 4 ++-- .../jvector/vector/VectorUtilSupport.java | 2 +- jvector-native/src/main/c/jvector_simd.c | 2 +- jvector-native/src/main/c/jvector_simd.h | 2 +- .../vector/NativeVectorUtilSupport.java | 4 ++-- .../jvector/vector/cnative/NativeSimdOps.java | 24 +++++++++---------- .../vector/PanamaVectorUtilSupport.java | 4 ++-- .../jbellis/jvector/vector/SimdOps.java | 10 ++++---- 9 files changed, 27 insertions(+), 27 deletions(-) diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/PQDecoder.java b/jvector-base/src/main/java/io/github/jbellis/jvector/pq/PQDecoder.java index e417d50e..bd678395 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/PQDecoder.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/pq/PQDecoder.java @@ -134,7 +134,7 @@ protected float decodedCosine(int node2) { ByteSequence encoded = cv.get(node2); - return VectorUtil.decodedCosineSimilarity(encoded, cv.pq.getClusterCount(), partialSums, aMagnitude, bMagnitude); + return VectorUtil.pqDecodedCosineSimilarity(encoded, cv.pq.getClusterCount(), partialSums, aMagnitude, bMagnitude); } } } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java index d860cc1b..595cb915 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java @@ -195,7 +195,7 @@ public static float min(VectorFloat v) { return impl.min(v); } - public static float decodedCosineSimilarity(ByteSequence encoded, int clusterCount, VectorFloat partialSums, VectorFloat aMagnitude, float bMagnitude) { - return impl.decodedCosineSimilarity(encoded, clusterCount, partialSums, aMagnitude, bMagnitude); + public static float pqDecodedCosineSimilarity(ByteSequence encoded, int clusterCount, VectorFloat partialSums, VectorFloat aMagnitude, float bMagnitude) { + return impl.pqDecodedCosineSimilarity(encoded, clusterCount, partialSums, aMagnitude, bMagnitude); } } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java index e3e2953e..320f71a1 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java @@ -199,7 +199,7 @@ default void bulkShuffleQuantizedSimilarityCosine(ByteSequence shuffles, int float max(VectorFloat v); float min(VectorFloat v); - default float decodedCosineSimilarity(ByteSequence encoded, int clusterCount, VectorFloat partialSums, VectorFloat aMagnitude, float bMagnitude) + default float pqDecodedCosineSimilarity(ByteSequence encoded, int clusterCount, VectorFloat partialSums, VectorFloat aMagnitude, float bMagnitude) { float sum = 0.0f; float aMag = 0.0f; diff --git a/jvector-native/src/main/c/jvector_simd.c b/jvector-native/src/main/c/jvector_simd.c index c807f33c..63dddf6c 100644 --- a/jvector-native/src/main/c/jvector_simd.c +++ b/jvector-native/src/main/c/jvector_simd.c @@ -319,7 +319,7 @@ float assemble_and_sum_f32_512(const float* data, int dataBase, const unsigned c return res; } -float decoded_cosine_similarity_f32_512(const unsigned char* baseOffsets, int baseOffsetsLength, int clusterCount, const float* partialSums, const float* aMagnitude, float bMagnitude) { +float pq_decoded_cosine_similarity_f32_512(const unsigned char* baseOffsets, int baseOffsetsLength, int clusterCount, const float* partialSums, const float* aMagnitude, float bMagnitude) { __m512 sum = _mm512_setzero_ps(); __m512 vaMagnitude = _mm512_setzero_ps(); int i = 0; diff --git a/jvector-native/src/main/c/jvector_simd.h b/jvector-native/src/main/c/jvector_simd.h index 1b96a0a8..76bbb928 100644 --- a/jvector-native/src/main/c/jvector_simd.h +++ b/jvector-native/src/main/c/jvector_simd.h @@ -29,7 +29,7 @@ void bulk_quantized_shuffle_dot_f32_512(const unsigned char* shuffles, int codeb void bulk_quantized_shuffle_euclidean_f32_512(const unsigned char* shuffles, int codebookCount, const char* quantizedPartials, float delta, float minDistance, float* results); void bulk_quantized_shuffle_cosine_f32_512(const unsigned char* shuffles, int codebookCount, const char* quantizedPartialSums, float sumDelta, float minDistance, const char* quantizedPartialMagnitudes, float magnitudeDelta, float minMagnitude, float queryMagnitudeSquared, float* results); float assemble_and_sum_f32_512(const float* data, int dataBase, const unsigned char* baseOffsets, int baseOffsetsLength); -float decoded_cosine_similarity_f32_512(const unsigned char* baseOffsets, int baseOffsetsLength, int clusterCount, const float* partialSums, const float* aMagnitude, float bMagnitude); +float pq_decoded_cosine_similarity_f32_512(const unsigned char* baseOffsets, int baseOffsetsLength, int clusterCount, const float* partialSums, const float* aMagnitude, float bMagnitude); void calculate_partial_sums_dot_f32_512(const float* codebook, int codebookBase, int size, int clusterCount, const float* query, int queryOffset, float* partialSums); void calculate_partial_sums_euclidean_f32_512(const float* codebook, int codebookBase, int size, int clusterCount, const float* query, int queryOffset, float* partialSums); void calculate_partial_sums_best_dot_f32_512(const float* codebook, int codebookBase, int size, int clusterCount, const float* query, int queryOffset, float* partialSums, float* partialBestDistances); diff --git a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorUtilSupport.java b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorUtilSupport.java index 624ec27f..0af7ebcb 100644 --- a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorUtilSupport.java +++ b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorUtilSupport.java @@ -157,8 +157,8 @@ public void bulkShuffleQuantizedSimilarityCosine(ByteSequence shuffles, int c } @Override - public float decodedCosineSimilarity(ByteSequence encoded, int clusterCount, VectorFloat partialSums, VectorFloat aMagnitude, float bMagnitude) + public float pqDecodedCosineSimilarity(ByteSequence encoded, int clusterCount, VectorFloat partialSums, VectorFloat aMagnitude, float bMagnitude) { - return NativeSimdOps.decoded_cosine_similarity_f32_512(((MemorySegmentByteSequence) encoded).get(), encoded.length(), clusterCount, ((MemorySegmentVectorFloat) partialSums).get(), ((MemorySegmentVectorFloat) aMagnitude).get(), bMagnitude); + return NativeSimdOps.pq_decoded_cosine_similarity_f32_512(((MemorySegmentByteSequence) encoded).get(), encoded.length(), clusterCount, ((MemorySegmentVectorFloat) partialSums).get(), ((MemorySegmentVectorFloat) aMagnitude).get(), bMagnitude); } } diff --git a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/cnative/NativeSimdOps.java b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/cnative/NativeSimdOps.java index 47d616ed..baa28739 100644 --- a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/cnative/NativeSimdOps.java +++ b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/cnative/NativeSimdOps.java @@ -452,7 +452,7 @@ public static float assemble_and_sum_f32_512(MemorySegment data, int dataBase, M } } - private static class decoded_cosine_similarity_f32_512 { + private static class pq_decoded_cosine_similarity_f32_512 { public static final FunctionDescriptor DESC = FunctionDescriptor.of( NativeSimdOps.C_FLOAT, NativeSimdOps.C_POINTER, @@ -464,39 +464,39 @@ private static class decoded_cosine_similarity_f32_512 { ); public static final MethodHandle HANDLE = Linker.nativeLinker().downcallHandle( - NativeSimdOps.findOrThrow("decoded_cosine_similarity_f32_512"), + NativeSimdOps.findOrThrow("pq_decoded_cosine_similarity_f32_512"), DESC, Linker.Option.critical(true)); } /** * Function descriptor for: * {@snippet lang=c : - * float decoded_cosine_similarity_f32_512(const unsigned char *baseOffsets, int baseOffsetsLength, int clusterCount, const float *partialSums, const float *aMagnitude, float bMagnitude) + * float pq_decoded_cosine_similarity_f32_512(const unsigned char *baseOffsets, int baseOffsetsLength, int clusterCount, const float *partialSums, const float *aMagnitude, float bMagnitude) * } */ - public static FunctionDescriptor decoded_cosine_similarity_f32_512$descriptor() { - return decoded_cosine_similarity_f32_512.DESC; + public static FunctionDescriptor pq_decoded_cosine_similarity_f32_512$descriptor() { + return pq_decoded_cosine_similarity_f32_512.DESC; } /** * Downcall method handle for: * {@snippet lang=c : - * float decoded_cosine_similarity_f32_512(const unsigned char *baseOffsets, int baseOffsetsLength, int clusterCount, const float *partialSums, const float *aMagnitude, float bMagnitude) + * float pq_decoded_cosine_similarity_f32_512(const unsigned char *baseOffsets, int baseOffsetsLength, int clusterCount, const float *partialSums, const float *aMagnitude, float bMagnitude) * } */ - public static MethodHandle decoded_cosine_similarity_f32_512$handle() { - return decoded_cosine_similarity_f32_512.HANDLE; + public static MethodHandle pq_decoded_cosine_similarity_f32_512$handle() { + return pq_decoded_cosine_similarity_f32_512.HANDLE; } /** * {@snippet lang=c : - * float decoded_cosine_similarity_f32_512(const unsigned char *baseOffsets, int baseOffsetsLength, int clusterCount, const float *partialSums, const float *aMagnitude, float bMagnitude) + * float pq_decoded_cosine_similarity_f32_512(const unsigned char *baseOffsets, int baseOffsetsLength, int clusterCount, const float *partialSums, const float *aMagnitude, float bMagnitude) * } */ - public static float decoded_cosine_similarity_f32_512(MemorySegment baseOffsets, int baseOffsetsLength, int clusterCount, MemorySegment partialSums, MemorySegment aMagnitude, float bMagnitude) { - var mh$ = decoded_cosine_similarity_f32_512.HANDLE; + public static float pq_decoded_cosine_similarity_f32_512(MemorySegment baseOffsets, int baseOffsetsLength, int clusterCount, MemorySegment partialSums, MemorySegment aMagnitude, float bMagnitude) { + var mh$ = pq_decoded_cosine_similarity_f32_512.HANDLE; try { if (TRACE_DOWNCALLS) { - traceDowncall("decoded_cosine_similarity_f32_512", baseOffsets, baseOffsetsLength, clusterCount, partialSums, aMagnitude, bMagnitude); + traceDowncall("pq_decoded_cosine_similarity_f32_512", baseOffsets, baseOffsetsLength, clusterCount, partialSums, aMagnitude, bMagnitude); } return (float)mh$.invokeExact(baseOffsets, baseOffsetsLength, clusterCount, partialSums, aMagnitude, bMagnitude); } catch (Throwable ex$) { diff --git a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java index 060ca9e1..4b2602cb 100644 --- a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java +++ b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java @@ -157,9 +157,9 @@ public void quantizePartials(float delta, VectorFloat partials, VectorFloat encoded, int clusterCount, VectorFloat partialSums, VectorFloat aMagnitude, float bMagnitude) + public float pqDecodedCosineSimilarity(ByteSequence encoded, int clusterCount, VectorFloat partialSums, VectorFloat aMagnitude, float bMagnitude) { - return SimdOps.decodedCosineSimilarity((ArrayByteSequence) encoded, clusterCount, (ArrayVectorFloat) partialSums, (ArrayVectorFloat) aMagnitude, bMagnitude); + return SimdOps.pqDecodedCosineSimilarity((ArrayByteSequence) encoded, clusterCount, (ArrayVectorFloat) partialSums, (ArrayVectorFloat) aMagnitude, bMagnitude); } } diff --git a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java index a6a73f3d..034aa987 100644 --- a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java +++ b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java @@ -660,13 +660,13 @@ public static void quantizePartials(float delta, ArrayVectorFloat partials, Arra } } - public static float decodedCosineSimilarity(ArrayByteSequence encoded, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) { + public static float pqDecodedCosineSimilarity(ArrayByteSequence encoded, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) { return HAS_AVX512 - ? decodedCosineSimilarity512(encoded, clusterCount, partialSums, aMagnitude, bMagnitude) - : decodedCosineSimilarity256(encoded, clusterCount, partialSums, aMagnitude, bMagnitude); + ? pqDecodedCosineSimilarity512(encoded, clusterCount, partialSums, aMagnitude, bMagnitude) + : pqDecodedCosineSimilarity256(encoded, clusterCount, partialSums, aMagnitude, bMagnitude); } - public static float decodedCosineSimilarity512(ArrayByteSequence encoded, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) { + public static float pqDecodedCosineSimilarity512(ArrayByteSequence encoded, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) { var sum = FloatVector.zero(FloatVector.SPECIES_512); var vaMagnitude = FloatVector.zero(FloatVector.SPECIES_512); var baseOffsets = encoded.get(); @@ -705,7 +705,7 @@ public static float decodedCosineSimilarity512(ArrayByteSequence encoded, int cl return (float) (sumResult / Math.sqrt(aMagnitudeResult * bMagnitude)); } - public static float decodedCosineSimilarity256(ArrayByteSequence encoded, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) { + public static float pqDecodedCosineSimilarity256(ArrayByteSequence encoded, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) { var sum = FloatVector.zero(FloatVector.SPECIES_256); var vaMagnitude = FloatVector.zero(FloatVector.SPECIES_256); var baseOffsets = encoded.get(); From bc91bd96ea471b64553abe623274fd742eeffe0d Mon Sep 17 00:00:00 2001 From: Michael Marshall Date: Wed, 20 Nov 2024 23:21:02 -0600 Subject: [PATCH 8/8] WIP: store raw representation of ByteSeqeunce in PQVectors --- .../github/jbellis/jvector/pq/PQDecoder.java | 5 +- .../github/jbellis/jvector/pq/PQVectors.java | 50 +++++++++++++------ .../jvector/vector/ArrayVectorProvider.java | 16 +++++- .../jbellis/jvector/vector/VectorUtil.java | 5 +- .../jvector/vector/VectorUtilSupport.java | 4 +- .../vector/types/VectorTypeSupport.java | 4 ++ .../jbellis/jvector/example/SiftSmall.java | 4 +- .../vector/MemorySegmentByteSequence.java | 4 ++ .../vector/MemorySegmentVectorProvider.java | 18 +++++++ .../vector/NativeVectorUtilSupport.java | 7 ++- .../vector/PanamaVectorUtilSupport.java | 7 ++- .../jbellis/jvector/vector/SimdOps.java | 8 ++- 12 files changed, 96 insertions(+), 36 deletions(-) diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/PQDecoder.java b/jvector-base/src/main/java/io/github/jbellis/jvector/pq/PQDecoder.java index bd678395..d2b3822d 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/PQDecoder.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/pq/PQDecoder.java @@ -131,10 +131,7 @@ public float similarityTo(int node2) { } protected float decodedCosine(int node2) { - - ByteSequence encoded = cv.get(node2); - - return VectorUtil.pqDecodedCosineSimilarity(encoded, cv.pq.getClusterCount(), partialSums, aMagnitude, bMagnitude); + return VectorUtil.pqDecodedCosineSimilarity(cv, node2, cv.pq.getClusterCount(), partialSums, aMagnitude, bMagnitude); } } } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/PQVectors.java b/jvector-base/src/main/java/io/github/jbellis/jvector/pq/PQVectors.java index b58880ab..2bffb4a9 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/PQVectors.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/pq/PQVectors.java @@ -29,33 +29,38 @@ import java.io.DataOutput; import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.Objects; import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; public class PQVectors implements CompressedVectors { private static final VectorTypeSupport vectorTypeSupport = VectorizationProvider.getInstance().getVectorTypeSupport(); final ProductQuantization pq; - private final List> compressedVectors; + // In order to decrease the memory footprint, we store the compressed vectors in their raw representation + // and then use the vectorTypeSupport to convert them to ByteSequence when needed and to interpret + // the raw Object. + private final List rawCompressedVectors; /** * Initialize the PQVectors with an initial List of vectors. This list may be * mutated, but caller is responsible for thread safety issues when doing so. */ - public PQVectors(ProductQuantization pq, List> compressedVectors) + public PQVectors(ProductQuantization pq, List compressedVectors) { this.pq = pq; - this.compressedVectors = compressedVectors; + this.rawCompressedVectors = compressedVectors; } public PQVectors(ProductQuantization pq, ByteSequence[] compressedVectors) { - this(pq, List.of(compressedVectors)); + this(pq, Arrays.stream(compressedVectors).map(ByteSequence::get).collect(Collectors.toList())); } @Override public int count() { - return compressedVectors.size(); + return rawCompressedVectors.size(); } @Override @@ -65,10 +70,10 @@ public void write(DataOutput out, int version) throws IOException pq.write(out, version); // compressed vectors - out.writeInt(compressedVectors.size()); + out.writeInt(rawCompressedVectors.size()); out.writeInt(pq.getSubspaceCount()); - for (var v : compressedVectors) { - vectorTypeSupport.writeByteSequence(out, v); + for (var v : rawCompressedVectors) { + vectorTypeSupport.writeBytes(out, v); } } @@ -81,7 +86,7 @@ public static PQVectors load(RandomAccessReader in) throws IOException { if (size < 0) { throw new IOException("Invalid compressed vector count " + size); } - List> compressedVectors = new ArrayList<>(size); + List compressedVectors = new ArrayList<>(size); int compressedDimension = in.readInt(); if (compressedDimension < 0) { @@ -90,7 +95,7 @@ public static PQVectors load(RandomAccessReader in) throws IOException { for (int i = 0; i < size; i++) { - ByteSequence vector = vectorTypeSupport.readByteSequence(in, compressedDimension); + Object vector = vectorTypeSupport.readBytes(in, compressedDimension); compressedVectors.add(vector); } @@ -109,12 +114,21 @@ public boolean equals(Object o) { PQVectors that = (PQVectors) o; if (!Objects.equals(pq, that.pq)) return false; - return Objects.equals(compressedVectors, that.compressedVectors); + if (rawCompressedVectors.size() != that.rawCompressedVectors.size()) return false; + for (int i = 0; i < rawCompressedVectors.size(); i++) { + // Because this method is not called on any hot path, we accept the overhead of creating ByteSequence + var a = vectorTypeSupport.createByteSequence(rawCompressedVectors.get(i)); + var b = vectorTypeSupport.createByteSequence(that.rawCompressedVectors.get(i)); + if (!a.equals(b)) { + return false; + } + } + return true; } @Override public int hashCode() { - return Objects.hash(pq, compressedVectors); + return Objects.hash(pq, rawCompressedVectors); } @Override @@ -188,7 +202,11 @@ public ScoreFunction.ApproximateScoreFunction scoreFunctionFor(VectorFloat q, } public ByteSequence get(int ordinal) { - return compressedVectors.get(ordinal); + return vectorTypeSupport.createByteSequence(rawCompressedVectors.get(ordinal)); + } + + public Object getRaw(int ordinal) { + return rawCompressedVectors.get(ordinal); } public ProductQuantization getProductQuantization() { @@ -225,8 +243,8 @@ public long ramBytesUsed() { int AH_BYTES = RamUsageEstimator.NUM_BYTES_ARRAY_HEADER; long codebooksSize = pq.ramBytesUsed(); - long listSize = (long) REF_BYTES * (1 + compressedVectors.size()); - long dataSize = (long) (OH_BYTES + AH_BYTES + pq.compressedVectorSize()) * compressedVectors.size(); + long listSize = (long) REF_BYTES * (1 + rawCompressedVectors.size()); + long dataSize = (long) (OH_BYTES + AH_BYTES + pq.compressedVectorSize()) * rawCompressedVectors.size(); return codebooksSize + listSize + dataSize; } @@ -234,7 +252,7 @@ public long ramBytesUsed() { public String toString() { return "PQVectors{" + "pq=" + pq + - ", count=" + compressedVectors.size() + + ", count=" + rawCompressedVectors.size() + '}'; } } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/ArrayVectorProvider.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/ArrayVectorProvider.java index 8f3e05db..64be9fb3 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/ArrayVectorProvider.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/ArrayVectorProvider.java @@ -83,11 +83,17 @@ public ByteSequence createByteSequence(int length) } @Override - public ByteSequence readByteSequence(RandomAccessReader r, int size) throws IOException + public byte[] readBytes(RandomAccessReader r, int size) throws IOException { byte[] vector = new byte[size]; r.readFully(vector); - return new ArrayByteSequence(vector); + return vector; + } + + @Override + public ByteSequence readByteSequence(RandomAccessReader r, int size) throws IOException + { + return new ArrayByteSequence(readBytes(r, size)); } @Override @@ -102,4 +108,10 @@ public void writeByteSequence(DataOutput out, ByteSequence sequence) throws I ArrayByteSequence v = (ArrayByteSequence) sequence; out.write(v.get()); } + + @Override + public void writeBytes(DataOutput out, Object sequence) throws IOException + { + out.write((byte[]) sequence); + } } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java index 595cb915..5b81d838 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java @@ -24,6 +24,7 @@ package io.github.jbellis.jvector.vector; +import io.github.jbellis.jvector.pq.PQVectors; import io.github.jbellis.jvector.vector.types.ByteSequence; import io.github.jbellis.jvector.vector.types.VectorFloat; @@ -195,7 +196,7 @@ public static float min(VectorFloat v) { return impl.min(v); } - public static float pqDecodedCosineSimilarity(ByteSequence encoded, int clusterCount, VectorFloat partialSums, VectorFloat aMagnitude, float bMagnitude) { - return impl.pqDecodedCosineSimilarity(encoded, clusterCount, partialSums, aMagnitude, bMagnitude); + public static float pqDecodedCosineSimilarity(PQVectors cv, int node, int clusterCount, VectorFloat partialSums, VectorFloat aMagnitude, float bMagnitude) { + return impl.pqDecodedCosineSimilarity(cv, node, clusterCount, partialSums, aMagnitude, bMagnitude); } } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java index 320f71a1..23a80c74 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java @@ -24,6 +24,7 @@ package io.github.jbellis.jvector.vector; +import io.github.jbellis.jvector.pq.PQVectors; import io.github.jbellis.jvector.vector.types.ByteSequence; import io.github.jbellis.jvector.vector.types.VectorFloat; @@ -199,10 +200,11 @@ default void bulkShuffleQuantizedSimilarityCosine(ByteSequence shuffles, int float max(VectorFloat v); float min(VectorFloat v); - default float pqDecodedCosineSimilarity(ByteSequence encoded, int clusterCount, VectorFloat partialSums, VectorFloat aMagnitude, float bMagnitude) + default float pqDecodedCosineSimilarity(PQVectors cv, int node, int clusterCount, VectorFloat partialSums, VectorFloat aMagnitude, float bMagnitude) { float sum = 0.0f; float aMag = 0.0f; + ByteSequence encoded = cv.get(node); for (int m = 0; m < encoded.length(); ++m) { int centroidIndex = Byte.toUnsignedInt(encoded.get(m)); diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/types/VectorTypeSupport.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/types/VectorTypeSupport.java index 40938937..52d116b4 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/types/VectorTypeSupport.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/types/VectorTypeSupport.java @@ -81,7 +81,11 @@ public interface VectorTypeSupport { ByteSequence readByteSequence(RandomAccessReader r, int size) throws IOException; + Object readBytes(RandomAccessReader r, int size) throws IOException; + void readByteSequence(RandomAccessReader r, ByteSequence sequence) throws IOException; void writeByteSequence(DataOutput out, ByteSequence sequence) throws IOException; + + void writeBytes(DataOutput out, Object bytes) throws IOException; } diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/SiftSmall.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/SiftSmall.java index f6053775..550c23d9 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/SiftSmall.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/SiftSmall.java @@ -215,7 +215,7 @@ public static void siftDiskAnnLTM(List> baseVectors, List> incrementallyCompressedVectors = new ArrayList<>(); + List incrementallyCompressedVectors = new ArrayList<>(); PQVectors pqv = new PQVectors(pq, incrementallyCompressedVectors); BuildScoreProvider bsp = BuildScoreProvider.pqBuildScoreProvider(VectorSimilarityFunction.EUCLIDEAN, pqv); @@ -235,7 +235,7 @@ public static void siftDiskAnnLTM(List> baseVectors, List v : baseVectors) { // compress the new vector and add it to the PQVectors (via incrementallyCompressedVectors) int ordinal = incrementallyCompressedVectors.size(); - incrementallyCompressedVectors.add(pq.encode(v)); + incrementallyCompressedVectors.add(pq.encode(v).get()); // write the full vector to disk writer.writeInline(ordinal, Feature.singleState(FeatureId.INLINE_VECTORS, new InlineVectors.State(v))); // now add it to the graph -- the previous steps must be completed first since the PQVectors diff --git a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/MemorySegmentByteSequence.java b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/MemorySegmentByteSequence.java index 4c541856..c73bfd41 100644 --- a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/MemorySegmentByteSequence.java +++ b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/MemorySegmentByteSequence.java @@ -72,6 +72,10 @@ public byte get(int n) { return segment.get(ValueLayout.JAVA_BYTE, n); } + public static byte get(MemorySegment ms, int n) { + return ms.get(ValueLayout.JAVA_BYTE, n); + } + @Override public void set(int n, byte value) { segment.set(ValueLayout.JAVA_BYTE, n, value); diff --git a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/MemorySegmentVectorProvider.java b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/MemorySegmentVectorProvider.java index a098ced7..e47fce03 100644 --- a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/MemorySegmentVectorProvider.java +++ b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/MemorySegmentVectorProvider.java @@ -23,6 +23,7 @@ import java.io.DataOutput; import java.io.IOException; +import java.lang.foreign.MemorySegment; import java.nio.Buffer; /** @@ -89,6 +90,14 @@ public ByteSequence readByteSequence(RandomAccessReader r, int size) throws I return vector; } + @Override + public MemorySegment readBytes(RandomAccessReader r, int size) throws IOException + { + var array = new byte[size]; + r.readFully(array); + return MemorySegment.ofArray(array); + } + @Override public void readByteSequence(RandomAccessReader r, ByteSequence sequence) throws IOException { r.readFully(((MemorySegmentByteSequence) sequence).get().asByteBuffer()); @@ -101,4 +110,13 @@ public void writeByteSequence(DataOutput out, ByteSequence sequence) throws I for (int i = 0; i < sequence.length(); i++) out.writeByte(sequence.get(i)); } + + @Override + public void writeBytes(DataOutput out, Object bytes) throws IOException + { + MemorySegment sequence = (MemorySegment) bytes; + int size = Math.toIntExact(sequence.byteSize()); + for (int i = 0; i < size; i++) + out.writeByte(MemorySegmentByteSequence.get(sequence, i)); + } } diff --git a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorUtilSupport.java b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorUtilSupport.java index 0af7ebcb..5ce0d7f7 100644 --- a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorUtilSupport.java +++ b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorUtilSupport.java @@ -16,10 +16,12 @@ package io.github.jbellis.jvector.vector; +import io.github.jbellis.jvector.pq.PQVectors; import io.github.jbellis.jvector.vector.cnative.NativeSimdOps; import io.github.jbellis.jvector.vector.types.ByteSequence; import io.github.jbellis.jvector.vector.types.VectorFloat; +import java.lang.foreign.MemorySegment; import java.util.List; /** @@ -157,8 +159,9 @@ public void bulkShuffleQuantizedSimilarityCosine(ByteSequence shuffles, int c } @Override - public float pqDecodedCosineSimilarity(ByteSequence encoded, int clusterCount, VectorFloat partialSums, VectorFloat aMagnitude, float bMagnitude) + public float pqDecodedCosineSimilarity(PQVectors cv, int node, int clusterCount, VectorFloat partialSums, VectorFloat aMagnitude, float bMagnitude) { - return NativeSimdOps.pq_decoded_cosine_similarity_f32_512(((MemorySegmentByteSequence) encoded).get(), encoded.length(), clusterCount, ((MemorySegmentVectorFloat) partialSums).get(), ((MemorySegmentVectorFloat) aMagnitude).get(), bMagnitude); + MemorySegment encoded = (MemorySegment) cv.getRaw(node); + return NativeSimdOps.pq_decoded_cosine_similarity_f32_512(encoded, (int) encoded.byteSize(), clusterCount, ((MemorySegmentVectorFloat) partialSums).get(), ((MemorySegmentVectorFloat) aMagnitude).get(), bMagnitude); } } diff --git a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java index 4b2602cb..525f4821 100644 --- a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java +++ b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java @@ -16,6 +16,8 @@ package io.github.jbellis.jvector.vector; +import io.github.jbellis.jvector.pq.CompressedVectors; +import io.github.jbellis.jvector.pq.PQVectors; import io.github.jbellis.jvector.vector.types.ByteSequence; import io.github.jbellis.jvector.vector.types.VectorFloat; @@ -157,9 +159,10 @@ public void quantizePartials(float delta, VectorFloat partials, VectorFloat encoded, int clusterCount, VectorFloat partialSums, VectorFloat aMagnitude, float bMagnitude) + public float pqDecodedCosineSimilarity(PQVectors cv, int node, int clusterCount, VectorFloat partialSums, VectorFloat aMagnitude, float bMagnitude) { - return SimdOps.pqDecodedCosineSimilarity((ArrayByteSequence) encoded, clusterCount, (ArrayVectorFloat) partialSums, (ArrayVectorFloat) aMagnitude, bMagnitude); + byte[] encoded = (byte[]) cv.getRaw(node); + return SimdOps.pqDecodedCosineSimilarity(encoded, clusterCount, (ArrayVectorFloat) partialSums, (ArrayVectorFloat) aMagnitude, bMagnitude); } } diff --git a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java index 034aa987..8ac55a60 100644 --- a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java +++ b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java @@ -660,16 +660,15 @@ public static void quantizePartials(float delta, ArrayVectorFloat partials, Arra } } - public static float pqDecodedCosineSimilarity(ArrayByteSequence encoded, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) { + public static float pqDecodedCosineSimilarity(byte[] encoded, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) { return HAS_AVX512 ? pqDecodedCosineSimilarity512(encoded, clusterCount, partialSums, aMagnitude, bMagnitude) : pqDecodedCosineSimilarity256(encoded, clusterCount, partialSums, aMagnitude, bMagnitude); } - public static float pqDecodedCosineSimilarity512(ArrayByteSequence encoded, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) { + public static float pqDecodedCosineSimilarity512(byte[] baseOffsets, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) { var sum = FloatVector.zero(FloatVector.SPECIES_512); var vaMagnitude = FloatVector.zero(FloatVector.SPECIES_512); - var baseOffsets = encoded.get(); var partialSumsArray = partialSums.get(); var aMagnitudeArray = aMagnitude.get(); @@ -705,10 +704,9 @@ public static float pqDecodedCosineSimilarity512(ArrayByteSequence encoded, int return (float) (sumResult / Math.sqrt(aMagnitudeResult * bMagnitude)); } - public static float pqDecodedCosineSimilarity256(ArrayByteSequence encoded, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) { + public static float pqDecodedCosineSimilarity256(byte[] baseOffsets, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) { var sum = FloatVector.zero(FloatVector.SPECIES_256); var vaMagnitude = FloatVector.zero(FloatVector.SPECIES_256); - var baseOffsets = encoded.get(); var partialSumsArray = partialSums.get(); var aMagnitudeArray = aMagnitude.get();