Skip to content

Commit

Permalink
reduce boilerplate
Browse files Browse the repository at this point in the history
  • Loading branch information
tfeher committed Mar 10, 2024
1 parent a414627 commit 1df5de7
Show file tree
Hide file tree
Showing 30 changed files with 109 additions and 1,597 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@

/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*
* NOTE: this file is to be used in source files generated by
* src/neighbors/detailivf_pq_compute_similarity_00_generate.py
*/

#pragma once

#include <raft/neighbors/detail/ivf_pq_compute_similarity-inl.cuh>
#include <raft/neighbors/detail/ivf_pq_fp_8bit.cuh>
#include <raft/neighbors/sample_filter.cuh>

#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \
OutT, LutT, IvfSampleFilterT) \
template auto \
raft::neighbors::ivf_pq::detail::compute_similarity_select<OutT, LutT, IvfSampleFilterT>( \
const cudaDeviceProp& dev_props, \
bool manage_local_topk, \
int locality_hint, \
double preferred_shmem_carveout, \
uint32_t pq_bits, \
uint32_t pq_dim, \
uint32_t precomp_data_count, \
uint32_t n_queries, \
uint32_t n_probes, \
uint32_t topk) \
->raft::neighbors::ivf_pq::detail::selected<OutT, LutT, IvfSampleFilterT>; \
\
template void \
raft::neighbors::ivf_pq::detail::compute_similarity_run<OutT, LutT, IvfSampleFilterT>( \
raft::neighbors::ivf_pq::detail::selected<OutT, LutT, IvfSampleFilterT> s, \
rmm::cuda_stream_view stream, \
uint32_t dim, \
uint32_t n_probes, \
uint32_t pq_dim, \
uint32_t n_queries, \
uint32_t queries_offset, \
raft::distance::DistanceType metric, \
raft::neighbors::ivf_pq::codebook_gen codebook_kind, \
uint32_t topk, \
uint32_t max_samples, \
const float* cluster_centers, \
const float* pq_centers, \
const uint8_t* const* pq_dataset, \
const uint32_t* cluster_labels, \
const uint32_t* _chunk_indices, \
const float* queries, \
const uint32_t* index_list, \
float* query_kths, \
IvfSampleFilterT sample_filter, \
LutT* lut_scores, \
OutT* _out_scores, \
uint32_t* _out_indices);

#define COMMA ,
60 changes: 3 additions & 57 deletions cpp/src/neighbors/detail/ivf_pq_compute_similarity_00_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

header = """
/*
header = """/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -31,63 +30,11 @@
/*
* NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py
*
* Make changes there and run in this directory:
*
* > python ivf_pq_compute_similarity_00_generate.py
*
*/
#include <raft/neighbors/detail/ivf_pq_compute_similarity-inl.cuh>
#include <raft/neighbors/detail/ivf_pq_fp_8bit.cuh>
#include <raft/neighbors/sample_filter.cuh>
#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(OutT, LutT, IvfSampleFilterT) \\
template auto raft::neighbors::ivf_pq::detail::compute_similarity_select<OutT, LutT, IvfSampleFilterT>( \\
const cudaDeviceProp& dev_props, \\
bool manage_local_topk, \\
int locality_hint, \\
double preferred_shmem_carveout, \\
uint32_t pq_bits, \\
uint32_t pq_dim, \\
uint32_t precomp_data_count, \\
uint32_t n_queries, \\
uint32_t n_probes, \\
uint32_t topk) -> raft::neighbors::ivf_pq::detail::selected<OutT, LutT, IvfSampleFilterT>; \\
\\
template void raft::neighbors::ivf_pq::detail::compute_similarity_run<OutT, LutT, IvfSampleFilterT>( \\
raft::neighbors::ivf_pq::detail::selected<OutT, LutT, IvfSampleFilterT> s, \\
rmm::cuda_stream_view stream, \\
uint32_t dim, \\
uint32_t n_probes, \\
uint32_t pq_dim, \\
uint32_t n_queries, \\
uint32_t queries_offset, \\
raft::distance::DistanceType metric, \\
raft::neighbors::ivf_pq::codebook_gen codebook_kind, \\
uint32_t topk, \\
uint32_t max_samples, \\
const float* cluster_centers, \\
const float* pq_centers, \\
const uint8_t* const* pq_dataset, \\
const uint32_t* cluster_labels, \\
const uint32_t* _chunk_indices, \\
const float* queries, \\
const uint32_t* index_list, \\
float* query_kths, \\
IvfSampleFilterT sample_filter, \\
LutT* lut_scores, \\
OutT* _out_scores, \\
uint32_t* _out_indices);
#define COMMA ,
"""

trailer = """
#undef COMMA
#undef instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select
#include <raft/neighbors/detail/ivf_pq_compute_similarity_template.cuh>
"""

none_filter_int64 = "raft::neighbors::filtering::ivf_to_sample_filter" \
Expand Down Expand Up @@ -135,5 +82,4 @@
with open(path, "w") as f:
f.write(header)
f.write(f"instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select({OutT}, {LutT}, {FilterT});\n")
f.write(trailer)
print(f"src/neighbors/detail/{path}")
57 changes: 2 additions & 55 deletions cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@

/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,65 +16,13 @@

/*
* NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py
*
* Make changes there and run in this directory:
*
* > python ivf_pq_compute_similarity_00_generate.py
*
*/

#include <raft/neighbors/detail/ivf_pq_compute_similarity-inl.cuh>
#include <raft/neighbors/detail/ivf_pq_fp_8bit.cuh>

#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \
OutT, LutT, IvfSampleFilterT) \
template auto \
raft::neighbors::ivf_pq::detail::compute_similarity_select<OutT, LutT, IvfSampleFilterT>( \
const cudaDeviceProp& dev_props, \
bool manage_local_topk, \
int locality_hint, \
double preferred_shmem_carveout, \
uint32_t pq_bits, \
uint32_t pq_dim, \
uint32_t precomp_data_count, \
uint32_t n_queries, \
uint32_t n_probes, \
uint32_t topk) \
->raft::neighbors::ivf_pq::detail::selected<OutT, LutT, IvfSampleFilterT>; \
\
template void \
raft::neighbors::ivf_pq::detail::compute_similarity_run<OutT, LutT, IvfSampleFilterT>( \
raft::neighbors::ivf_pq::detail::selected<OutT, LutT, IvfSampleFilterT> s, \
rmm::cuda_stream_view stream, \
uint32_t dim, \
uint32_t n_probes, \
uint32_t pq_dim, \
uint32_t n_queries, \
uint32_t queries_offset, \
raft::distance::DistanceType metric, \
raft::neighbors::ivf_pq::codebook_gen codebook_kind, \
uint32_t topk, \
uint32_t max_samples, \
const float* cluster_centers, \
const float* pq_centers, \
const uint8_t* const* pq_dataset, \
const uint32_t* cluster_labels, \
const uint32_t* _chunk_indices, \
const float* queries, \
const uint32_t* index_list, \
float* query_kths, \
IvfSampleFilterT sample_filter, \
LutT* lut_scores, \
OutT* _out_scores, \
uint32_t* _out_indices);

#define COMMA ,
#include <raft/neighbors/detail/ivf_pq_compute_similarity_template.cuh>
instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(
float,
float,
raft::neighbors::filtering::ivf_to_sample_filter<
int64_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>);

#undef COMMA

#undef instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@

/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,65 +16,13 @@

/*
* NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py
*
* Make changes there and run in this directory:
*
* > python ivf_pq_compute_similarity_00_generate.py
*
*/

#include <raft/neighbors/detail/ivf_pq_compute_similarity-inl.cuh>
#include <raft/neighbors/detail/ivf_pq_fp_8bit.cuh>

#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \
OutT, LutT, IvfSampleFilterT) \
template auto \
raft::neighbors::ivf_pq::detail::compute_similarity_select<OutT, LutT, IvfSampleFilterT>( \
const cudaDeviceProp& dev_props, \
bool manage_local_topk, \
int locality_hint, \
double preferred_shmem_carveout, \
uint32_t pq_bits, \
uint32_t pq_dim, \
uint32_t precomp_data_count, \
uint32_t n_queries, \
uint32_t n_probes, \
uint32_t topk) \
->raft::neighbors::ivf_pq::detail::selected<OutT, LutT, IvfSampleFilterT>; \
\
template void \
raft::neighbors::ivf_pq::detail::compute_similarity_run<OutT, LutT, IvfSampleFilterT>( \
raft::neighbors::ivf_pq::detail::selected<OutT, LutT, IvfSampleFilterT> s, \
rmm::cuda_stream_view stream, \
uint32_t dim, \
uint32_t n_probes, \
uint32_t pq_dim, \
uint32_t n_queries, \
uint32_t queries_offset, \
raft::distance::DistanceType metric, \
raft::neighbors::ivf_pq::codebook_gen codebook_kind, \
uint32_t topk, \
uint32_t max_samples, \
const float* cluster_centers, \
const float* pq_centers, \
const uint8_t* const* pq_dataset, \
const uint32_t* cluster_labels, \
const uint32_t* _chunk_indices, \
const float* queries, \
const uint32_t* index_list, \
float* query_kths, \
IvfSampleFilterT sample_filter, \
LutT* lut_scores, \
OutT* _out_scores, \
uint32_t* _out_indices);

#define COMMA ,
#include <raft/neighbors/detail/ivf_pq_compute_similarity_template.cuh>
instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(
float,
raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>,
raft::neighbors::filtering::ivf_to_sample_filter<
int64_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>);

#undef COMMA

#undef instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select
Loading

0 comments on commit 1df5de7

Please sign in to comment.