Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: store raw representation of ByteSeqeunce in PQVectors #369

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -131,18 +131,7 @@ public float similarityTo(int node2) {
}

protected float decodedCosine(int node2) {
float sum = 0.0f;
float aMag = 0.0f;

ByteSequence<?> encoded = cv.get(node2);

for (int m = 0; m < encoded.length(); ++m) {
int centroidIndex = Byte.toUnsignedInt(encoded.get(m));
sum += partialSums.get((m * cv.pq.getClusterCount()) + centroidIndex);
aMag += aMagnitude.get((m * cv.pq.getClusterCount()) + centroidIndex);
}

return (float) (sum / Math.sqrt(aMag * bMagnitude));
return VectorUtil.pqDecodedCosineSimilarity(cv, node2, cv.pq.getClusterCount(), partialSums, aMagnitude, bMagnitude);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,33 +29,38 @@
import java.io.DataOutput;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;

public class PQVectors implements CompressedVectors {
private static final VectorTypeSupport vectorTypeSupport = VectorizationProvider.getInstance().getVectorTypeSupport();
final ProductQuantization pq;
private final List<ByteSequence<?>> compressedVectors;
// In order to decrease the memory footprint, we store the compressed vectors in their raw representation
// and then use the vectorTypeSupport to convert them to ByteSequence when needed and to interpret
// the raw Object.
private final List<Object> rawCompressedVectors;

/**
* Initialize the PQVectors with an initial List of vectors. This list may be
* mutated, but caller is responsible for thread safety issues when doing so.
*/
public PQVectors(ProductQuantization pq, List<ByteSequence<?>> compressedVectors)
public PQVectors(ProductQuantization pq, List<Object> compressedVectors)
{
this.pq = pq;
this.compressedVectors = compressedVectors;
this.rawCompressedVectors = compressedVectors;
}

public PQVectors(ProductQuantization pq, ByteSequence<?>[] compressedVectors)
{
this(pq, List.of(compressedVectors));
this(pq, Arrays.stream(compressedVectors).map(ByteSequence::get).collect(Collectors.toList()));
}

@Override
public int count() {
return compressedVectors.size();
return rawCompressedVectors.size();
}

@Override
Expand All @@ -65,10 +70,10 @@ public void write(DataOutput out, int version) throws IOException
pq.write(out, version);

// compressed vectors
out.writeInt(compressedVectors.size());
out.writeInt(rawCompressedVectors.size());
out.writeInt(pq.getSubspaceCount());
for (var v : compressedVectors) {
vectorTypeSupport.writeByteSequence(out, v);
for (var v : rawCompressedVectors) {
vectorTypeSupport.writeBytes(out, v);
}
}

Expand All @@ -81,7 +86,7 @@ public static PQVectors load(RandomAccessReader in) throws IOException {
if (size < 0) {
throw new IOException("Invalid compressed vector count " + size);
}
List<ByteSequence<?>> compressedVectors = new ArrayList<>(size);
List<Object> compressedVectors = new ArrayList<>(size);

int compressedDimension = in.readInt();
if (compressedDimension < 0) {
Expand All @@ -90,7 +95,7 @@ public static PQVectors load(RandomAccessReader in) throws IOException {

for (int i = 0; i < size; i++)
{
ByteSequence<?> vector = vectorTypeSupport.readByteSequence(in, compressedDimension);
Object vector = vectorTypeSupport.readBytes(in, compressedDimension);
compressedVectors.add(vector);
}

Expand All @@ -109,12 +114,21 @@ public boolean equals(Object o) {

PQVectors that = (PQVectors) o;
if (!Objects.equals(pq, that.pq)) return false;
return Objects.equals(compressedVectors, that.compressedVectors);
if (rawCompressedVectors.size() != that.rawCompressedVectors.size()) return false;
for (int i = 0; i < rawCompressedVectors.size(); i++) {
// Because this method is not called on any hot path, we accept the overhead of creating ByteSequence
var a = vectorTypeSupport.createByteSequence(rawCompressedVectors.get(i));
var b = vectorTypeSupport.createByteSequence(that.rawCompressedVectors.get(i));
if (!a.equals(b)) {
return false;
}
}
return true;
}

@Override
public int hashCode() {
return Objects.hash(pq, compressedVectors);
return Objects.hash(pq, rawCompressedVectors);
}

@Override
Expand Down Expand Up @@ -188,7 +202,11 @@ public ScoreFunction.ApproximateScoreFunction scoreFunctionFor(VectorFloat<?> q,
}

public ByteSequence<?> get(int ordinal) {
return compressedVectors.get(ordinal);
return vectorTypeSupport.createByteSequence(rawCompressedVectors.get(ordinal));
}

public Object getRaw(int ordinal) {
return rawCompressedVectors.get(ordinal);
}

public ProductQuantization getProductQuantization() {
Expand Down Expand Up @@ -225,16 +243,16 @@ public long ramBytesUsed() {
int AH_BYTES = RamUsageEstimator.NUM_BYTES_ARRAY_HEADER;

long codebooksSize = pq.ramBytesUsed();
long listSize = (long) REF_BYTES * (1 + compressedVectors.size());
long dataSize = (long) (OH_BYTES + AH_BYTES + pq.compressedVectorSize()) * compressedVectors.size();
long listSize = (long) REF_BYTES * (1 + rawCompressedVectors.size());
long dataSize = (long) (OH_BYTES + AH_BYTES + pq.compressedVectorSize()) * rawCompressedVectors.size();
return codebooksSize + listSize + dataSize;
}

@Override
public String toString() {
return "PQVectors{" +
"pq=" + pq +
", count=" + compressedVectors.size() +
", count=" + rawCompressedVectors.size() +
'}';
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,17 @@ public ByteSequence<?> createByteSequence(int length)
}

@Override
public ByteSequence<?> readByteSequence(RandomAccessReader r, int size) throws IOException
public byte[] readBytes(RandomAccessReader r, int size) throws IOException
{
byte[] vector = new byte[size];
r.readFully(vector);
return new ArrayByteSequence(vector);
return vector;
}

@Override
public ByteSequence<?> readByteSequence(RandomAccessReader r, int size) throws IOException
{
return new ArrayByteSequence(readBytes(r, size));
}

@Override
Expand All @@ -102,4 +108,10 @@ public void writeByteSequence(DataOutput out, ByteSequence<?> sequence) throws I
ArrayByteSequence v = (ArrayByteSequence) sequence;
out.write(v.get());
}

@Override
public void writeBytes(DataOutput out, Object sequence) throws IOException
{
out.write((byte[]) sequence);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

package io.github.jbellis.jvector.vector;

import io.github.jbellis.jvector.pq.PQVectors;
import io.github.jbellis.jvector.vector.types.ByteSequence;
import io.github.jbellis.jvector.vector.types.VectorFloat;

Expand Down Expand Up @@ -194,4 +195,8 @@ public static float max(VectorFloat<?> v) {
public static float min(VectorFloat<?> v) {
return impl.min(v);
}

public static float pqDecodedCosineSimilarity(PQVectors cv, int node, int clusterCount, VectorFloat<?> partialSums, VectorFloat<?> aMagnitude, float bMagnitude) {
return impl.pqDecodedCosineSimilarity(cv, node, clusterCount, partialSums, aMagnitude, bMagnitude);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

package io.github.jbellis.jvector.vector;

import io.github.jbellis.jvector.pq.PQVectors;
import io.github.jbellis.jvector.vector.types.ByteSequence;
import io.github.jbellis.jvector.vector.types.VectorFloat;

Expand Down Expand Up @@ -199,4 +200,19 @@ default void bulkShuffleQuantizedSimilarityCosine(ByteSequence<?> shuffles, int
float max(VectorFloat<?> v);
float min(VectorFloat<?> v);

default float pqDecodedCosineSimilarity(PQVectors cv, int node, int clusterCount, VectorFloat<?> partialSums, VectorFloat<?> aMagnitude, float bMagnitude)
{
float sum = 0.0f;
float aMag = 0.0f;
ByteSequence<?> encoded = cv.get(node);

for (int m = 0; m < encoded.length(); ++m) {
int centroidIndex = Byte.toUnsignedInt(encoded.get(m));
var index = m * clusterCount + centroidIndex;
sum += partialSums.get(index);
aMag += aMagnitude.get(index);
}

return (float) (sum / Math.sqrt(aMag * bMagnitude));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,11 @@ public interface VectorTypeSupport {

ByteSequence<?> readByteSequence(RandomAccessReader r, int size) throws IOException;

Object readBytes(RandomAccessReader r, int size) throws IOException;

void readByteSequence(RandomAccessReader r, ByteSequence<?> sequence) throws IOException;

void writeByteSequence(DataOutput out, ByteSequence<?> sequence) throws IOException;

void writeBytes(DataOutput out, Object bytes) throws IOException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ public static void siftDiskAnnLTM(List<VectorFloat<?>> baseVectors, List<VectorF

// as we build the index we'll compress the new vectors and add them to this List backing a PQVectors;
// this is used to score the construction searches
List<ByteSequence<?>> incrementallyCompressedVectors = new ArrayList<>();
List<Object> incrementallyCompressedVectors = new ArrayList<>();
PQVectors pqv = new PQVectors(pq, incrementallyCompressedVectors);
BuildScoreProvider bsp = BuildScoreProvider.pqBuildScoreProvider(VectorSimilarityFunction.EUCLIDEAN, pqv);

Expand All @@ -235,7 +235,7 @@ public static void siftDiskAnnLTM(List<VectorFloat<?>> baseVectors, List<VectorF
for (VectorFloat<?> v : baseVectors) {
// compress the new vector and add it to the PQVectors (via incrementallyCompressedVectors)
int ordinal = incrementallyCompressedVectors.size();
incrementallyCompressedVectors.add(pq.encode(v));
incrementallyCompressedVectors.add(pq.encode(v).get());
// write the full vector to disk
writer.writeInline(ordinal, Feature.singleState(FeatureId.INLINE_VECTORS, new InlineVectors.State(v)));
// now add it to the graph -- the previous steps must be completed first since the PQVectors
Expand Down
43 changes: 43 additions & 0 deletions jvector-native/src/main/c/jvector_simd.c
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,49 @@ float assemble_and_sum_f32_512(const float* data, int dataBase, const unsigned c
return res;
}

float pq_decoded_cosine_similarity_f32_512(const unsigned char* baseOffsets, int baseOffsetsLength, int clusterCount, const float* partialSums, const float* aMagnitude, float bMagnitude) {
__m512 sum = _mm512_setzero_ps();
__m512 vaMagnitude = _mm512_setzero_ps();
int i = 0;
int limit = baseOffsetsLength - (baseOffsetsLength % 16);
__m512i indexRegister = initialIndexRegister;
__m512i scale = _mm512_set1_epi32(clusterCount);


for (; i < limit; i += 16) {
// Load and convert baseOffsets to integers
__m128i baseOffsetsRaw = _mm_loadu_si128((__m128i *)(baseOffsets + i));
__m512i baseOffsetsInt = _mm512_cvtepu8_epi32(baseOffsetsRaw);

indexRegister = _mm512_add_epi32(indexRegister, indexIncrement);
// Scale the baseOffsets by the cluster count
__m512i scaledOffsets = _mm512_mullo_epi32(indexRegister, scale);

// Calculate the final convOffsets by adding the scaled indexes and the base offsets
__m512i convOffsets = _mm512_add_epi32(scaledOffsets, baseOffsetsInt);

// Gather and sum values for partial sums and a magnitude
__m512 partialSumVals = _mm512_i32gather_ps(convOffsets, partialSums, 4);
sum = _mm512_add_ps(sum, partialSumVals);

__m512 aMagnitudeVals = _mm512_i32gather_ps(convOffsets, aMagnitude, 4);
vaMagnitude = _mm512_add_ps(vaMagnitude, aMagnitudeVals);
}

// Reduce sums
float sumResult = _mm512_reduce_add_ps(sum);
float aMagnitudeResult = _mm512_reduce_add_ps(vaMagnitude);

// Handle the remaining elements
for (; i < baseOffsetsLength; i++) {
int offset = clusterCount * i + baseOffsets[i];
sumResult += partialSums[offset];
aMagnitudeResult += aMagnitude[offset];
}

return sumResult / sqrtf(aMagnitudeResult * bMagnitude);
}

void calculate_partial_sums_dot_f32_512(const float* codebook, int codebookIndex, int size, int clusterCount, const float* query, int queryOffset, float* partialSums) {
int codebookBase = codebookIndex * clusterCount;
for (int i = 0; i < clusterCount; i++) {
Expand Down
1 change: 1 addition & 0 deletions jvector-native/src/main/c/jvector_simd.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ void bulk_quantized_shuffle_dot_f32_512(const unsigned char* shuffles, int codeb
void bulk_quantized_shuffle_euclidean_f32_512(const unsigned char* shuffles, int codebookCount, const char* quantizedPartials, float delta, float minDistance, float* results);
void bulk_quantized_shuffle_cosine_f32_512(const unsigned char* shuffles, int codebookCount, const char* quantizedPartialSums, float sumDelta, float minDistance, const char* quantizedPartialMagnitudes, float magnitudeDelta, float minMagnitude, float queryMagnitudeSquared, float* results);
float assemble_and_sum_f32_512(const float* data, int dataBase, const unsigned char* baseOffsets, int baseOffsetsLength);
float pq_decoded_cosine_similarity_f32_512(const unsigned char* baseOffsets, int baseOffsetsLength, int clusterCount, const float* partialSums, const float* aMagnitude, float bMagnitude);
void calculate_partial_sums_dot_f32_512(const float* codebook, int codebookBase, int size, int clusterCount, const float* query, int queryOffset, float* partialSums);
void calculate_partial_sums_euclidean_f32_512(const float* codebook, int codebookBase, int size, int clusterCount, const float* query, int queryOffset, float* partialSums);
void calculate_partial_sums_best_dot_f32_512(const float* codebook, int codebookBase, int size, int clusterCount, const float* query, int queryOffset, float* partialSums, float* partialBestDistances);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ public byte get(int n) {
return segment.get(ValueLayout.JAVA_BYTE, n);
}

public static byte get(MemorySegment ms, int n) {
return ms.get(ValueLayout.JAVA_BYTE, n);
}

@Override
public void set(int n, byte value) {
segment.set(ValueLayout.JAVA_BYTE, n, value);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import java.io.DataOutput;
import java.io.IOException;
import java.lang.foreign.MemorySegment;
import java.nio.Buffer;

/**
Expand Down Expand Up @@ -89,6 +90,14 @@ public ByteSequence<?> readByteSequence(RandomAccessReader r, int size) throws I
return vector;
}

@Override
public MemorySegment readBytes(RandomAccessReader r, int size) throws IOException
{
var array = new byte[size];
r.readFully(array);
return MemorySegment.ofArray(array);
}

@Override
public void readByteSequence(RandomAccessReader r, ByteSequence<?> sequence) throws IOException {
r.readFully(((MemorySegmentByteSequence) sequence).get().asByteBuffer());
Expand All @@ -101,4 +110,13 @@ public void writeByteSequence(DataOutput out, ByteSequence<?> sequence) throws I
for (int i = 0; i < sequence.length(); i++)
out.writeByte(sequence.get(i));
}

@Override
public void writeBytes(DataOutput out, Object bytes) throws IOException
{
MemorySegment sequence = (MemorySegment) bytes;
int size = Math.toIntExact(sequence.byteSize());
for (int i = 0; i < size; i++)
out.writeByte(MemorySegmentByteSequence.get(sequence, i));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@

package io.github.jbellis.jvector.vector;

import io.github.jbellis.jvector.pq.PQVectors;
import io.github.jbellis.jvector.vector.cnative.NativeSimdOps;
import io.github.jbellis.jvector.vector.types.ByteSequence;
import io.github.jbellis.jvector.vector.types.VectorFloat;

import java.lang.foreign.MemorySegment;
import java.util.List;

/**
Expand Down Expand Up @@ -155,4 +157,11 @@ public void bulkShuffleQuantizedSimilarityCosine(ByteSequence<?> shuffles, int c
NativeSimdOps.bulk_quantized_shuffle_cosine_f32_512(((MemorySegmentByteSequence) shuffles).get(), codebookCount, ((MemorySegmentByteSequence) quantizedPartialSums).get(), sumDelta, minDistance,
((MemorySegmentByteSequence) quantizedPartialSquaredMagnitudes).get(), magnitudeDelta, minMagnitude, queryMagnitudeSquared, ((MemorySegmentVectorFloat) results).get());
}

@Override
public float pqDecodedCosineSimilarity(PQVectors cv, int node, int clusterCount, VectorFloat<?> partialSums, VectorFloat<?> aMagnitude, float bMagnitude)
{
MemorySegment encoded = (MemorySegment) cv.getRaw(node);
return NativeSimdOps.pq_decoded_cosine_similarity_f32_512(encoded, (int) encoded.byteSize(), clusterCount, ((MemorySegmentVectorFloat) partialSums).get(), ((MemorySegmentVectorFloat) aMagnitude).get(), bMagnitude);
}
}
Loading
Loading