Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add on disk 4x compression with Faiss #2425

Open
wants to merge 9 commits into
base: 2.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Add binary index support for Lucene engine. (#2292)[https://github.com/opensearch-project/k-NN/pull/2292]
- Add expand_nested_docs Parameter support to NMSLIB engine (#2331)[https://github.com/opensearch-project/k-NN/pull/2331]
- Add cosine similarity support for faiss engine (#2376)[https://github.com/opensearch-project/k-NN/pull/2376]
- Add support for Faiss onDisk 4x compression (#2425)[https://github.com/opensearch-project/k-NN/pull/2425]
### Enhancements
- Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241]
- Allow method parameter override for training based indices (#2290) https://github.com/opensearch-project/k-NN/pull/2290]
Expand Down
43 changes: 43 additions & 0 deletions jni/include/faiss_index_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,25 @@ class IndexService {
*/
virtual void writeIndex(faiss::IOWriter* writer, jlong idMapAddress);

/**
* Initialize index from template
*
* @param jniUtil jni util
* @param env jni environment
* @param dim dimension of vectors
* @param numVectors number of vectors
* @param threadCount number of thread count to be used while adding data
* @param templateIndexJ template index
* @return memory address of the native index object
*/
virtual jlong initIndexFromTemplate(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, int dim, int numVectors, int threadCount, jbyteArray templateIndexJ);


virtual ~IndexService() = default;

protected:
virtual void allocIndex(faiss::Index * index, size_t dim, size_t numVectors);
virtual jlong initAndAllocateIndex(std::unique_ptr<faiss::Index> &index, size_t threadCount, size_t dim, size_t numVectors);

std::unique_ptr<FaissMethods> faissMethods;
}; // class IndexService
Expand Down Expand Up @@ -132,8 +147,22 @@ class BinaryIndexService final : public IndexService {
*/
void writeIndex(faiss::IOWriter* writer, jlong idMapAddress) final;

/**
* Initialize index from template
*
* @param jniUtil jni util
* @param env jni environment
* @param dim dimension of vectors
* @param numVectors number of vectors
* @param threadCount number of thread count to be used while adding data
* @param templateIndexJ template index
* @return memory address of the native index object
*/
jlong initIndexFromTemplate(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, int dim, int numVectors, int threadCount, jbyteArray templateIndexJ) final;

protected:
void allocIndex(faiss::Index * index, size_t dim, size_t numVectors) final;
jlong initAndAllocateIndex(std::unique_ptr<faiss::IndexBinary> &index, size_t threadCount, size_t dim, size_t numVectors);
}; // class BinaryIndexService

/**
Expand Down Expand Up @@ -191,8 +220,22 @@ class ByteIndexService final : public IndexService {
*/
void writeIndex(faiss::IOWriter* writer, jlong idMapAddress) final;

/**
* Initialize index from template
*
* @param jniUtil jni util
* @param env jni environment
* @param dim dimension of vectors
* @param numVectors number of vectors
* @param threadCount number of thread count to be used while adding data
* @param templateIndexJ template index
* @return memory address of the native index object
*/
jlong initIndexFromTemplate(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, int dim, int numVectors, int threadCount, jbyteArray templateIndexJ) final;

protected:
void allocIndex(faiss::Index * index, size_t dim, size_t numVectors) final;
jlong initAndAllocateIndex(std::unique_ptr<faiss::Index> &index, size_t threadCount, size_t dim, size_t numVectors) final;
}; // class ByteIndexService

}
Expand Down
3 changes: 3 additions & 0 deletions jni/include/faiss_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ namespace knn_jni {

void WriteIndex(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jobject output, jlong indexAddr, IndexService *indexService);

jlong InitIndexFromTemplate(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong numDocs, jint dimJ, jobject parametersJ, jbyteArray templateIndexJ, IndexService *indexService);


// Create an index with ids and vectors. Instead of creating a new index, this function creates the index
// based off of the template index passed in. The index is serialized to indexPathJ.
void CreateIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ,
Expand Down
26 changes: 26 additions & 0 deletions jni/include/org_opensearch_knn_jni_FaissService.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,32 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeBinaryIndex
*/
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeByteIndex(JNIEnv *, jclass, jlong, jobject);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: initIndexFromTemplate
* Signature: ([IJILjava/lang/String;Ljava/util/Map;)V
*/
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initIndexFromTemplate(JNIEnv * env, jclass cls,
jlong numDocs, jint dimJ,
jobject parametersJ, jbyteArray templateIndexJ);
/*
* Class: org_opensearch_knn_jni_FaissService
* Method: initBinaryIndexFromTemplate
* Signature: ([IJILjava/lang/String;Ljava/util/Map;)V
*/
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initBinaryIndexFromTemplate(JNIEnv * env, jclass cls,
jlong numDocs, jint dimJ,
jobject parametersJ, jbyteArray templateIndexJ);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: initByteIndexFromTemplate
* Signature: ([IJILjava/lang/String;Ljava/util/Map;)V
*/
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initByteIndexFromTemplate(JNIEnv * env, jclass cls,
jlong numDocs, jint dimJ,
jobject parametersJ, jbyteArray templateIndexJ);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: createIndexFromTemplate
Expand Down
185 changes: 141 additions & 44 deletions jni/src/faiss_index_service.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,25 @@ void IndexService::allocIndex(faiss::Index * index, size_t dim, size_t numVector
}
}

jlong IndexService::initAndAllocateIndex(std::unique_ptr<faiss::Index> &index, size_t threadCount, size_t dim, size_t numVectors) {
// Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread
if (threadCount != 0) {
omp_set_num_threads(threadCount);
}

std::unique_ptr<faiss::IndexIDMap> idMap (faissMethods->indexIdMap(index.get()));
//Makes sure the index is deleted when the destructor is called, this cannot be passed in the constructor
idMap->own_fields = true;

// TODO: allocIndex for IVF
allocIndex(dynamic_cast<faiss::Index *>(idMap->index), dim, numVectors);

//Release the ownership so as to make sure not delete the underlying index that is created. The index is needed later
//in insert and write operations
index.release();
return reinterpret_cast<jlong>(idMap.release());
}

jlong IndexService::initIndex(
knn_jni::JNIUtilInterface * jniUtil,
JNIEnv * env,
Expand All @@ -81,11 +100,6 @@ jlong IndexService::initIndex(
// Create index using Faiss factory method
std::unique_ptr<faiss::Index> index(faissMethods->indexFactory(dim, indexDescription.c_str(), metric));

// Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread
if (threadCount != 0) {
omp_set_num_threads(threadCount);
}

// Add extra parameters that cant be configured with the index factory
SetExtraParameters<faiss::Index, faiss::IndexIVF, faiss::IndexHNSW>(jniUtil, env, parameters, index.get());

Expand All @@ -94,16 +108,7 @@ jlong IndexService::initIndex(
throw std::runtime_error("Index is not trained");
}

std::unique_ptr<faiss::IndexIDMap> idMap (faissMethods->indexIdMap(index.get()));
//Makes sure the index is deleted when the destructor is called, this cannot be passed in the constructor
idMap->own_fields = true;

allocIndex(dynamic_cast<faiss::Index *>(idMap->index), dim, numVectors);

//Release the ownership so as to make sure not delete the underlying index that is created. The index is needed later
//in insert and write operations
index.release();
return reinterpret_cast<jlong>(idMap.release());
return initAndAllocateIndex(index, threadCount, dim, numVectors);
}

void IndexService::insertToIndex(
Expand Down Expand Up @@ -155,6 +160,32 @@ void IndexService::writeIndex(
}
}

jlong IndexService::initIndexFromTemplate(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only difference between this and initIndex is the index creation call to faiss, can we abstract out the logic and reuse the rest please. you can pass in the pointer returned by faiss to reuse the logic and set the index uniq pointer maybe.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

refactored as discussed and also validated that the index is getting deleted

knn_jni::JNIUtilInterface * jniUtil,
JNIEnv * env,
int dim,
int numVectors,
int threadCount,
jbyteArray templateIndexJ
) {

// Get vector of bytes from jbytearray
int indexBytesCount = jniUtil->GetJavaBytesArrayLength(env, templateIndexJ);
jbyte * indexBytesJ = jniUtil->GetByteArrayElements(env, templateIndexJ, nullptr);

faiss::VectorIOReader vectorIoReader;
for (int i = 0; i < indexBytesCount; i++) {
vectorIoReader.data.push_back((uint8_t) indexBytesJ[i]);
}
jniUtil->ReleaseByteArrayElements(env, templateIndexJ, indexBytesJ, JNI_ABORT);

// Create faiss index
std::unique_ptr<faiss::Index> index;
index.reset(faiss::read_index(&vectorIoReader, 0));

return initAndAllocateIndex(index, threadCount, dim, numVectors);
}

BinaryIndexService::BinaryIndexService(std::unique_ptr<FaissMethods> _faissMethods)
: IndexService(std::move(_faissMethods)) {
}
Expand All @@ -166,6 +197,25 @@ void BinaryIndexService::allocIndex(faiss::Index * index, size_t dim, size_t num
}
}

jlong BinaryIndexService::initAndAllocateIndex(std::unique_ptr<faiss::IndexBinary> &index, size_t threadCount, size_t dim, size_t numVectors) {
// Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread
if (threadCount != 0) {
omp_set_num_threads(threadCount);
}

std::unique_ptr<faiss::IndexBinaryIDMap> idMap(faissMethods->indexBinaryIdMap(index.get()));
//Makes sure the index is deleted when the destructor is called, this cannot be passed in the constructor
idMap->own_fields = true;

// TODO: allocIndex for IVF
allocIndex(dynamic_cast<faiss::Index *>(idMap->index), dim, numVectors);

//Release the ownership so as to make sure not delete the underlying index that is created. The index is needed later
//in insert and write operations
index.release();
return reinterpret_cast<jlong>(idMap.release());
}

jlong BinaryIndexService::initIndex(
knn_jni::JNIUtilInterface * jniUtil,
JNIEnv * env,
Expand All @@ -178,10 +228,6 @@ jlong BinaryIndexService::initIndex(
) {
// Create index using Faiss factory method
std::unique_ptr<faiss::IndexBinary> index(faissMethods->indexBinaryFactory(dim, indexDescription.c_str()));
// Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread
if (threadCount != 0) {
omp_set_num_threads(threadCount);
}

// Add extra parameters that cant be configured with the index factory
SetExtraParameters<faiss::IndexBinary, faiss::IndexBinaryIVF, faiss::IndexBinaryHNSW>(jniUtil, env, parameters, index.get());
Expand All @@ -191,16 +237,7 @@ jlong BinaryIndexService::initIndex(
throw std::runtime_error("Index is not trained");
}

std::unique_ptr<faiss::IndexBinaryIDMap> idMap(faissMethods->indexBinaryIdMap(index.get()));
//Makes sure the index is deleted when the destructor is called
idMap->own_fields = true;

allocIndex(dynamic_cast<faiss::Index *>(idMap->index), dim, numVectors);

//Release the ownership so as to make sure not delete the underlying index that is created. The index is needed later
//in insert and write operations
index.release();
return reinterpret_cast<jlong>(idMap.release());
return initAndAllocateIndex(index, threadCount, dim, numVectors);
}

void BinaryIndexService::insertToIndex(
Expand Down Expand Up @@ -252,6 +289,35 @@ void BinaryIndexService::writeIndex(
}
}

jlong BinaryIndexService::initIndexFromTemplate(
naveentatikonda marked this conversation as resolved.
Show resolved Hide resolved
knn_jni::JNIUtilInterface * jniUtil,
JNIEnv * env,
int dim,
int numVectors,
int threadCount,
jbyteArray templateIndexJ
) {
if (dim % 8 != 0) {
throw std::runtime_error("Dimensions should be multiple of 8");
}

// Get vector of bytes from jbytearray
int indexBytesCount = jniUtil->GetJavaBytesArrayLength(env, templateIndexJ);
jbyte * indexBytesJ = jniUtil->GetByteArrayElements(env, templateIndexJ, nullptr);

faiss::VectorIOReader vectorIoReader;
for (int i = 0; i < indexBytesCount; i++) {
vectorIoReader.data.push_back((uint8_t) indexBytesJ[i]);
}
jniUtil->ReleaseByteArrayElements(env, templateIndexJ, indexBytesJ, JNI_ABORT);

// Create faiss index
std::unique_ptr<faiss::IndexBinary> index;
index.reset(faiss::read_index_binary(&vectorIoReader, 0));

return initAndAllocateIndex(index, threadCount, dim, numVectors);
}

ByteIndexService::ByteIndexService(std::unique_ptr<FaissMethods> _faissMethods)
: IndexService(std::move(_faissMethods)) {
}
Expand All @@ -264,6 +330,25 @@ void ByteIndexService::allocIndex(faiss::Index * index, size_t dim, size_t numVe
}
}

jlong ByteIndexService::initAndAllocateIndex(std::unique_ptr<faiss::Index> &index, size_t threadCount, size_t dim, size_t numVectors) {
// Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread
if (threadCount != 0) {
omp_set_num_threads(threadCount);
}

std::unique_ptr<faiss::IndexIDMap> idMap (faissMethods->indexIdMap(index.get()));
//Makes sure the index is deleted when the destructor is called, this cannot be passed in the constructor
idMap->own_fields = true;

// TODO: allocIndex for IVF
allocIndex(dynamic_cast<faiss::Index *>(idMap->index), dim, numVectors);

//Release the ownership so as to make sure not delete the underlying index that is created. The index is needed later
//in insert and write operations
index.release();
return reinterpret_cast<jlong>(idMap.release());
}

jlong ByteIndexService::initIndex(
knn_jni::JNIUtilInterface * jniUtil,
JNIEnv * env,
Expand All @@ -277,11 +362,6 @@ jlong ByteIndexService::initIndex(
// Create index using Faiss factory method
std::unique_ptr<faiss::Index> index(faissMethods->indexFactory(dim, indexDescription.c_str(), metric));

// Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread
if (threadCount != 0) {
omp_set_num_threads(threadCount);
}

// Add extra parameters that cant be configured with the index factory
SetExtraParameters<faiss::Index, faiss::IndexIVF, faiss::IndexHNSW>(jniUtil, env, parameters, index.get());

Expand All @@ -290,16 +370,7 @@ jlong ByteIndexService::initIndex(
throw std::runtime_error("Index is not trained");
}

std::unique_ptr<faiss::IndexIDMap> idMap (faissMethods->indexIdMap(index.get()));
//Makes sure the index is deleted when the destructor is called, this cannot be passed in the constructor
idMap->own_fields = true;

allocIndex(dynamic_cast<faiss::Index *>(idMap->index), dim, numVectors);

//Release the ownership so as to make sure not delete the underlying index that is created. The index is needed later
//in insert and write operations
index.release();
return reinterpret_cast<jlong>(idMap.release());
return initAndAllocateIndex(index, threadCount, dim, numVectors);
}

void ByteIndexService::insertToIndex(
Expand Down Expand Up @@ -368,5 +439,31 @@ void ByteIndexService::writeIndex(
throw std::runtime_error("Failed to write index to disk");
}
}

jlong ByteIndexService::initIndexFromTemplate(
knn_jni::JNIUtilInterface * jniUtil,
JNIEnv * env,
int dim,
int numVectors,
int threadCount,
jbyteArray templateIndexJ
) {

// Get vector of bytes from jbytearray
int indexBytesCount = jniUtil->GetJavaBytesArrayLength(env, templateIndexJ);
jbyte * indexBytesJ = jniUtil->GetByteArrayElements(env, templateIndexJ, nullptr);

faiss::VectorIOReader vectorIoReader;
for (int i = 0; i < indexBytesCount; i++) {
vectorIoReader.data.push_back((uint8_t) indexBytesJ[i]);
}
jniUtil->ReleaseByteArrayElements(env, templateIndexJ, indexBytesJ, JNI_ABORT);

// Create faiss index
std::unique_ptr<faiss::Index> index;
index.reset(faiss::read_index(&vectorIoReader, 0));

return initAndAllocateIndex(index, threadCount, dim, numVectors);
}
} // namespace faiss_wrapper
} // namesapce knn_jni
Loading
Loading