Skip to content

Commit

Permalink
inner product in index build; debug statements
Browse files Browse the repository at this point in the history
  • Loading branch information
tarang-jain committed Apr 10, 2024
1 parent c506216 commit cb7fbba
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 34 deletions.
5 changes: 4 additions & 1 deletion cpp/include/raft/neighbors/cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <raft/core/host_device_accessor.hpp>
#include <raft/core/mdspan.hpp>
#include <raft/core/resources.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/neighbors/cagra_types.hpp>
#include <raft/neighbors/dataset.hpp>

Expand All @@ -48,6 +49,7 @@ namespace raft::neighbors::cagra {
*
* The following distance metrics are supported:
* - L2Expanded
* - InnerProduct
*
* Usage example:
* @code{.cpp}
Expand Down Expand Up @@ -79,6 +81,7 @@ template <typename DataT, typename IdxT, typename accessor>
void build_knn_graph(raft::resources const& res,
mdspan<const DataT, matrix_extent<int64_t>, row_major, accessor> dataset,
raft::host_matrix_view<IdxT, int64_t, row_major> knn_graph,
raft::distance::DistanceType metric,
std::optional<float> refine_rate = std::nullopt,
std::optional<ivf_pq::index_params> build_params = std::nullopt,
std::optional<ivf_pq::search_params> search_params = std::nullopt)
Expand All @@ -93,7 +96,7 @@ void build_knn_graph(raft::resources const& res,
dataset.data_handle(), dataset.extent(0), dataset.extent(1));

cagra::detail::build_knn_graph(
res, dataset_internal, knn_graph_internal, refine_rate, build_params, search_params);
res, dataset_internal, knn_graph_internal, metric, refine_rate, build_params, search_params);
}

/**
Expand Down
13 changes: 11 additions & 2 deletions cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "../../cagra_types.hpp"
#include "../../vpq_dataset.cuh"
#include "graph_core.cuh"
#include "raft/util/cudart_utils.hpp"

#include <raft/core/device_mdarray.hpp>
#include <raft/core/device_mdspan.hpp>
Expand All @@ -44,11 +45,14 @@ template <typename DataT, typename IdxT, typename accessor>
void build_knn_graph(raft::resources const& res,
mdspan<const DataT, matrix_extent<int64_t>, row_major, accessor> dataset,
raft::host_matrix_view<IdxT, int64_t, row_major> knn_graph,
raft::distance::DistanceType metric,
std::optional<float> refine_rate = std::nullopt,
std::optional<ivf_pq::index_params> build_params = std::nullopt,
std::optional<ivf_pq::search_params> search_params = std::nullopt)
{
RAFT_EXPECTS(!build_params || build_params->metric == distance::DistanceType::L2Expanded,
std::cout << "metric from build_knn_graph" << metric<< std::endl;
RAFT_EXPECTS(!build_params || build_params->metric == metric, "Mismatch between index metric and IVF-PQ metric");
RAFT_EXPECTS(metric == distance::DistanceType::L2Expanded || metric == distance::DistanceType::InnerProduct,
"Currently only L2Expanded metric is supported");

uint32_t node_degree = knn_graph.extent(1);
Expand All @@ -65,6 +69,7 @@ void build_knn_graph(raft::resources const& res,
build_params->kmeans_trainset_fraction = dataset.extent(0) < 10000 ? 1 : 10;
build_params->kmeans_n_iters = 25;
build_params->add_data_on_build = true;
build_params->metric = metric;
}

// Make model name
Expand Down Expand Up @@ -148,6 +153,8 @@ void build_knn_graph(raft::resources const& res,
distances.data_handle(), batch.size(), distances.extent(1));

ivf_pq::search(res, *search_params, index, queries_view, neighbors_view, distances_view);
raft::resource::sync_stream(res);
raft::print_device_vector("distances vector", distances.data_handle(), distances.extent(1), std::cout);
if constexpr (is_host_mdspan_v<decltype(dataset)>) {
raft::copy(neighbors_host.data_handle(),
neighbors.data_handle(),
Expand All @@ -174,6 +181,7 @@ void build_knn_graph(raft::resources const& res,
refined_neighbors_host_view,
refined_distances_host_view,
build_params->metric);
raft::print_host_vector("host_distances", refined_distances_host.data_handle(), top_k, std::cout);
} else {
auto neighbor_candidates_view = make_device_matrix_view<const int64_t, uint64_t>(
neighbors.data_handle(), batch.size(), gpu_top_k);
Expand All @@ -197,6 +205,7 @@ void build_knn_graph(raft::resources const& res,
refined_neighbors_view.size(),
resource::get_cuda_stream(res));
resource::sync_stream(res);
raft::print_device_vector("device_distances", refined_distances.data_handle(), top_k, std::cout);
}
// omit itself & write out
// TODO(tfeher): do this in parallel with GPU processing of next batch
Expand Down Expand Up @@ -321,7 +330,7 @@ index<T, IdxT> build(
raft::make_host_matrix<IdxT, int64_t>(dataset.extent(0), intermediate_degree));

if (params.build_algo == graph_build_algo::IVF_PQ) {
build_knn_graph(res, dataset, knn_graph->view(), refine_rate, pq_build_params, search_params);
build_knn_graph(res, dataset, knn_graph->view(), params.metric, refine_rate, pq_build_params, search_params);

} else {
// Use nn-descent to build CAGRA knn graph
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ void search_main(raft::resources const& res,
strided_dset->n_rows(),
strided_dset->dim(),
strided_dset->stride());

std::cout << "index.metric from search_main" << index.metric() << std::endl;
search_main_core<dataset_desc_t, CagraSampleFilterT>(res,
params,
dataset_desc,
Expand Down
38 changes: 11 additions & 27 deletions cpp/include/raft/neighbors/detail/cagra/compute_distance_vpq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include "compute_distance.hpp"

#include <raft/distance/distance_types.hpp>
#include <raft/util/integer_utils.hpp>

namespace raft::neighbors::cagra::detail {
Expand Down Expand Up @@ -168,22 +169,13 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t<half, DIS
// Loading query vector in smem
half2 diff2 = (reinterpret_cast<const half2*>(
query_ptr))[device::swizzling<std::uint32_t, DATASET_BLOCK_DIM / 2>(d / 2)];
if (metric == raft::distance::L2Expanded) {
// Loading PQ code book in smem
diff2 -= *(reinterpret_cast<half2*>(smem_pq_code_book_ptr +
(1 << PQ_BITS) * 2 * (m / 2) +
(2 * (pq_code & 0xff))));
diff2 -= vq_vals[d1 / vlen].val.data[(d1 % vlen) / 2];
norm2 += diff2 * diff2;
} else {
half2 multiplier2 = *(reinterpret_cast<half2*>(smem_pq_code_book_ptr +
(1 << PQ_BITS) * 2 * (m / 2) +
(2 * (pq_code & 0xff)))) +
vq_vals[d1 / vlen].val.data[(d1 % vlen) / 2];
norm2 -= diff2 * multiplier2;
}
pq_code >>= 8;
// Loading PQ code book in smem
diff2 -= *(reinterpret_cast<half2*>(
smem_pq_code_book_ptr + (1 << PQ_BITS) * 2 * (m / 2) + (2 * (pq_code & 0xff))));
diff2 -= vq_vals[d1 / vlen].val.data[(d1 % vlen) / 2];
norm2 += diff2 * diff2;
}
pq_code >>= 8;
}
}
norm += static_cast<float>(norm2.x + norm2.y);
Expand Down Expand Up @@ -219,17 +211,9 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t<half, DIS
const std::uint32_t d = d1 + (PQ_LEN * k);
// if (d >= dataset_dim) break;
DISTANCE_T diff = query_ptr[d]; // (from smem)
if (metric == raft::distance::L2Expanded) {
// Loading PQ code book in smem
diff -= pq_scale * static_cast<float>(pq_vals.data[m]);
diff -= vq_scale * static_cast<float>(vq_vals[d1 / vlen].val.data[d1 % vlen]);
norm += diff * diff;
} else {
DISTANCE_T multiplier =
pq_scale * static_cast<float>(pq_vals.data[m]) +
vq_scale * static_cast<float>(vq_vals[d1 / vlen].val.data[d1 % vlen]);
norm -= diff * multiplier;
}
diff -= pq_scale * static_cast<float>(pq_vals.data[m]);
diff -= vq_scale * static_cast<float>(vq_vals[d1 / vlen].val.data[d1 % vlen]);
norm += diff * diff;
}
pq_code >>= 8;
}
Expand All @@ -245,4 +229,4 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t<half, DIS
}
};

} // namespace raft::neighbors::cagra::detail
} // namespace raft::neighbors::cagra::detail
Original file line number Diff line number Diff line change
Expand Up @@ -973,7 +973,7 @@ struct search : search_plan_impl<DATASET_DESCRIPTOR_T, SAMPLE_FILTER_T> {
num_queries,
stream);
if (topk_distances_ptr) {
bool invert = this->metric == distance::InnerProduct;
bool invert = this->metric == distance::DistanceType::InnerProduct;
batched_memcpy(topk_distances_ptr,
topk,
result_distances_ptr,
Expand All @@ -982,6 +982,7 @@ struct search : search_plan_impl<DATASET_DESCRIPTOR_T, SAMPLE_FILTER_T> {
num_queries,
stream,
invert);
raft::print_device_vector("result_distances_ptr", result_distances_ptr, topk, std::cout);
}

if (num_executed_iterations) {
Expand Down
4 changes: 2 additions & 2 deletions cpp/test/neighbors/ann_cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -395,9 +395,9 @@ class AnnCagraSortTest : public ::testing::TestWithParam<AnnCagraInputs> {

if (ps.build_algo == graph_build_algo::IVF_PQ) {
if (ps.host_dataset) {
cagra::build_knn_graph<DataT, IdxT>(handle_, database_host_view, knn_graph.view());
cagra::build_knn_graph<DataT, IdxT>(handle_, database_host_view, knn_graph.view(), ps.metric);
} else {
cagra::build_knn_graph<DataT, IdxT>(handle_, database_view, knn_graph.view());
cagra::build_knn_graph<DataT, IdxT>(handle_, database_view, knn_graph.view(), ps.metric);
}
} else {
auto nn_descent_idx_params = experimental::nn_descent::index_params{};
Expand Down

0 comments on commit cb7fbba

Please sign in to comment.