From e8ffc55abce9f393251684593cadb89ff42b9cac Mon Sep 17 00:00:00 2001 From: Ishan Chattopadhyaya Date: Fri, 19 Apr 2024 05:14:54 +0530 Subject: [PATCH] Reducing extra data copy --- cuda/src/CudaIndexJni.cu | 20 ++++---------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/cuda/src/CudaIndexJni.cu b/cuda/src/CudaIndexJni.cu index 369e300..2bac870 100644 --- a/cuda/src/CudaIndexJni.cu +++ b/cuda/src/CudaIndexJni.cu @@ -28,33 +28,21 @@ JNIEXPORT jint JNICALL Java_com_searchscale_lucene_vectorsearch_jni_CuVSIndexJni std::cout<<"CUDA devices: "<array->hostmatrix->devicematrix), - // TODO: it might possible to do it once (JNI -> Device) for better efficiency. + // Copy the arrays from JNI to local variables. long startTime = ms(); jsize numDocs = env->GetArrayLength(docIds); std::vector docs (numDocs); env->GetIntArrayRegion( docIds, 0, numDocs, &docs[0] ); // TODO: This docid to index mapping should be persisted and used during search std::vector data(numVectors * dimension); env->GetFloatArrayRegion( dataVectors, 0, numVectors * dimension, &data[0] ); - auto datasetHost = raft::make_host_matrix(dev_resources, numVectors, dimension); - auto dataset = raft::make_device_matrix(dev_resources, numVectors, dimension); - int p = 0; - for(size_t i = 0; i < numDocs ; i ++) { - for(size_t j = 0; j < dimension; ++j) { - datasetHost(i, j) = data[p++]; // TODO: Is there a better SIMD friendly way to copy? - } - } - cudaStream_t stream = raft::resource::get_cuda_stream(dev_resources); - raft::copy(dataset.data_handle(), datasetHost.data_handle(), datasetHost.size(), stream); - raft::resource::sync_stream(dev_resources, stream); + auto extents = raft::make_extents(numVectors, dimension); + auto dataset = raft::make_mdspan(&data[0], extents); std::cout<<"Data copying time (CPU to GPU): "<<(ms()-startTime)<(dev_resources, index_params, raft::make_const_mdspan(dataset.view())); + auto ind = raft::neighbors::cagra::build(dev_resources, index_params, raft::make_const_mdspan(dataset)); std::cout << "Cagra Index building time: " << (ms()-startTime) << std::endl; // Serialize the index into a file