Skip to content

Commit

Permalink
unify fusedl2nn with fuseddistanceNN, add deprecation warning for fus…
Browse files Browse the repository at this point in the history
…ed_l2_nn_min_arg, support only float for fused_distance_nn
  • Loading branch information
mdoijade committed Feb 22, 2024
1 parent 80e1358 commit e9090d6
Show file tree
Hide file tree
Showing 8 changed files with 199 additions and 159 deletions.
11 changes: 9 additions & 2 deletions cpp/include/raft/distance/detail/fused_distance_nn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <raft/distance/detail/distance_ops/l2_exp.cuh> // ops::l2_exp_distance_op
#include <raft/distance/detail/fused_distance_nn/cutlass_base.cuh>
#include <raft/distance/detail/fused_distance_nn/fused_cosine_nn.cuh>
#include <raft/distance/detail/fused_distance_nn/fused_l2_nn.cuh>
#include <raft/distance/detail/fused_distance_nn/helper_structs.cuh>
#include <raft/distance/detail/fused_distance_nn/simt_kernel.cuh>
#include <raft/distance/detail/pairwise_distance_base.cuh> // PairwiseDistances
Expand Down Expand Up @@ -76,11 +77,17 @@ void fusedDistanceNNImpl(OutT* min,
}

switch (metric) {
case DistanceType::CosineExpanded:
case raft::distance::DistanceType::CosineExpanded:
fusedCosineNN<DataT, OutT, IdxT, P, ReduceOpT, KVPReduceOpT>(
min, x, y, xn, yn, m, n, k, workspace, redOp, pairRedOp, sqrt, stream);
break;
default: assert("only cosine metric is supported with fusedDistanceNN\n"); break;
case raft::distance::DistanceType::L2SqrtExpanded:
case raft::distance::DistanceType::L2Expanded:
// initOutBuffer is take care by fusedDistanceNNImpl() so we set it false to fusedL2NNImpl.
fusedL2NNImpl<DataT, OutT, IdxT, P, ReduceOpT, KVPReduceOpT>(
min, x, y, xn, yn, m, n, k, workspace, redOp, pairRedOp, sqrt, false, stream);
break;
default: assert("only cosine/l2 metric is supported with fusedDistanceNN\n"); break;
}
}

Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/distance/fused_l2_nn-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
#include <cub/cub.cuh>
#include <limits>
#include <raft/core/resources.hpp>
#include <raft/distance/detail/fused_l2_nn.cuh>
#include <raft/distance/detail/fused_distance_nn/fused_l2_nn.cuh>
#include <raft/distance/fused_distance_nn_helpers.cuh>
#include <raft/linalg/contractions.cuh>
#include <raft/util/cuda_utils.cuh>
Expand Down
36 changes: 19 additions & 17 deletions cpp/include/raft_runtime/distance/fused_l2_nn.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -42,23 +42,25 @@ namespace raft::runtime::distance {
* @param[in] k gemm k
* @param[in] sqrt Whether the output `minDist` should contain L2-sqrt
*/
void fused_l2_nn_min_arg(raft::resources const& handle,
int* min,
const float* x,
const float* y,
int m,
int n,
int k,
bool sqrt);
[[deprecated("use fused_distance_nn_min_arg instead")]] void fused_l2_nn_min_arg(
raft::resources const& handle,
int* min,
const float* x,
const float* y,
int m,
int n,
int k,
bool sqrt);

void fused_l2_nn_min_arg(raft::resources const& handle,
int* min,
const double* x,
const double* y,
int m,
int n,
int k,
bool sqrt);
[[deprecated("use fused_distance_nn_min_arg instead")]] void fused_l2_nn_min_arg(
raft::resources const& handle,
int* min,
const double* x,
const double* y,
int m,
int n,
int k,
bool sqrt);

/** @} */ // end group fused_l2_nn_min_arg_runtime

Expand Down
75 changes: 6 additions & 69 deletions cpp/src/raft_runtime/distance/fused_distance_min_arg.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,86 +14,19 @@
* limitations under the License.
*/

#include "fused_distance_min_arg.hpp"
#include <raft/core/device_mdarray.hpp>
#include <raft/core/kvp.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resource/thrust_policy.hpp>
#include <raft/core/resources.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/distance/fused_distance_nn.cuh>
#include <raft/linalg/norm.cuh>
#include <thrust/for_each.h>
#include <thrust/tuple.h>

namespace raft::runtime::distance {

template <typename IndexT, typename DataT>
struct KeyValueIndexOp {
__host__ __device__ __forceinline__ IndexT
operator()(const raft::KeyValuePair<IndexT, DataT>& a) const
{
return a.key;
}
};

template <typename value_t, typename idx_t>
void compute_fused_cosine_nn_min_arg(raft::resources const& handle,
idx_t* min,
const value_t* x,
const value_t* y,
idx_t m,
idx_t n,
idx_t k,
bool sqrt)
{
rmm::device_uvector<int> workspace(m, resource::get_cuda_stream(handle));
auto kvp = raft::make_device_vector<raft::KeyValuePair<idx_t, value_t>>(handle, m);

rmm::device_uvector<value_t> x_norms(m, resource::get_cuda_stream(handle));
rmm::device_uvector<value_t> y_norms(n, resource::get_cuda_stream(handle));
constexpr bool is_row_major = true;
raft::linalg::rowNorm(x_norms.data(),
x,
k,
m,
raft::linalg::L2Norm,
is_row_major,
resource::get_cuda_stream(handle),
raft::sqrt_op{});
raft::linalg::rowNorm(y_norms.data(),
y,
k,
n,
raft::linalg::L2Norm,
is_row_major,
resource::get_cuda_stream(handle),
raft::sqrt_op{});

raft::distance::fusedDistanceNNMinReduce(kvp.data_handle(),
x,
y,
x_norms.data(),
y_norms.data(),
m,
n,
k,
(void*)workspace.data(),
sqrt,
true,
is_row_major,
raft::distance::DistanceType::CosineExpanded,
0.0f,
resource::get_cuda_stream(handle));

KeyValueIndexOp<idx_t, value_t> conversion_op;
thrust::transform(resource::get_thrust_policy(handle),
kvp.data_handle(),
kvp.data_handle() + m,
min,
conversion_op);
resource::sync_stream(handle);
}

void fused_distance_nn_min_arg(raft::resources const& handle,
int* min,
const float* x,
Expand All @@ -110,7 +43,11 @@ void fused_distance_nn_min_arg(raft::resources const& handle,
case raft::distance::DistanceType::CosineExpanded:
compute_fused_cosine_nn_min_arg<float, int>(handle, min, x, y, m, n, k, sqrt);
break;
default: assert("only cosine metric is supported with fusedDistanceNN\n"); break;
case raft::distance::DistanceType::L2Expanded:
case raft::distance::DistanceType::L2SqrtExpanded:
compute_fused_l2_nn_min_arg<float, int>(handle, min, x, y, m, n, k, sqrt);
break;
default: assert("only Cosine/L2 metric is supported with fusedDistanceNN\n"); break;
}
}

Expand Down
143 changes: 143 additions & 0 deletions cpp/src/raft_runtime/distance/fused_distance_min_arg.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/core/device_mdarray.hpp>
#include <raft/core/kvp.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resource/thrust_policy.hpp>
#include <raft/core/resources.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/distance/fused_distance_nn.cuh>
#include <raft/distance/fused_l2_nn.cuh>
#include <raft/linalg/norm.cuh>
#include <thrust/for_each.h>
#include <thrust/tuple.h>

namespace raft::runtime::distance {

template <typename IndexT, typename DataT>
struct KeyValueIndexOp {
__host__ __device__ __forceinline__ IndexT
operator()(const raft::KeyValuePair<IndexT, DataT>& a) const
{
return a.key;
}
};

template <typename value_t, typename idx_t>
void compute_fused_l2_nn_min_arg(raft::resources const& handle,
idx_t* min,
const value_t* x,
const value_t* y,
idx_t m,
idx_t n,
idx_t k,
bool sqrt)
{
rmm::device_uvector<int> workspace(m, resource::get_cuda_stream(handle));
auto kvp = raft::make_device_vector<raft::KeyValuePair<idx_t, value_t>>(handle, m);
constexpr bool is_row_major = true;

rmm::device_uvector<value_t> x_norms(m, resource::get_cuda_stream(handle));
rmm::device_uvector<value_t> y_norms(n, resource::get_cuda_stream(handle));
raft::linalg::rowNorm(
x_norms.data(), x, k, m, raft::linalg::L2Norm, is_row_major, resource::get_cuda_stream(handle));
raft::linalg::rowNorm(
y_norms.data(), y, k, n, raft::linalg::L2Norm, is_row_major, resource::get_cuda_stream(handle));

raft::distance::fusedL2NNMinReduce(kvp.data_handle(),
x,
y,
x_norms.data(),
y_norms.data(),
m,
n,
k,
(void*)workspace.data(),
sqrt,
true,
resource::get_cuda_stream(handle));

KeyValueIndexOp<idx_t, value_t> conversion_op;
thrust::transform(resource::get_thrust_policy(handle),
kvp.data_handle(),
kvp.data_handle() + m,
min,
conversion_op);
resource::sync_stream(handle);
}

template <typename value_t, typename idx_t>
void compute_fused_cosine_nn_min_arg(raft::resources const& handle,
idx_t* min,
const value_t* x,
const value_t* y,
idx_t m,
idx_t n,
idx_t k,
bool sqrt)
{
rmm::device_uvector<int> workspace(m, resource::get_cuda_stream(handle));
auto kvp = raft::make_device_vector<raft::KeyValuePair<idx_t, value_t>>(handle, m);

rmm::device_uvector<value_t> x_norms(m, resource::get_cuda_stream(handle));
rmm::device_uvector<value_t> y_norms(n, resource::get_cuda_stream(handle));
constexpr bool is_row_major = true;
raft::linalg::rowNorm(x_norms.data(),
x,
k,
m,
raft::linalg::L2Norm,
is_row_major,
resource::get_cuda_stream(handle),
raft::sqrt_op{});
raft::linalg::rowNorm(y_norms.data(),
y,
k,
n,
raft::linalg::L2Norm,
is_row_major,
resource::get_cuda_stream(handle),
raft::sqrt_op{});

raft::distance::fusedDistanceNNMinReduce(kvp.data_handle(),
x,
y,
x_norms.data(),
y_norms.data(),
m,
n,
k,
(void*)workspace.data(),
sqrt,
true,
is_row_major,
raft::distance::DistanceType::CosineExpanded,
0.0f,
resource::get_cuda_stream(handle));

KeyValueIndexOp<idx_t, value_t> conversion_op;
thrust::transform(resource::get_thrust_policy(handle),
kvp.data_handle(),
kvp.data_handle() + m,
min,
conversion_op);
resource::sync_stream(handle);
}

} // end namespace raft::runtime::distance
Loading

0 comments on commit e9090d6

Please sign in to comment.