-
Notifications
You must be signed in to change notification settings - Fork 73
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
402 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
/* | ||
* Copyright (c) 2024, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include <cuvs/neighbors/common.hpp> | ||
#include <raft/core/device_mdarray.hpp> | ||
#include <raft/core/device_mdspan.hpp> | ||
#include <raft/core/handle.hpp> | ||
#include <raft/core/host_mdspan.hpp> | ||
|
||
namespace cuvs::neighbors::quantization { | ||
struct params { | ||
float quantile = 0.99; | ||
|
||
bool is_computed = false; | ||
double min; | ||
double max; | ||
double scalar; | ||
}; | ||
|
||
raft::device_matrix<int8_t, int64_t> scalar_quantize( | ||
raft::resources const& res, | ||
cuvs::neighbors::quantization::params& params, | ||
raft::device_matrix_view<const double, int64_t> dataset); | ||
|
||
raft::device_matrix<int8_t, int64_t> scalar_quantize( | ||
raft::resources const& res, | ||
cuvs::neighbors::quantization::params& params, | ||
raft::device_matrix_view<const float, int64_t> dataset); | ||
|
||
raft::device_matrix<int8_t, int64_t> scalar_quantize( | ||
raft::resources const& res, | ||
cuvs::neighbors::quantization::params& params, | ||
raft::device_matrix_view<const half, int64_t> dataset); | ||
|
||
/*raft::host_matrix<int8_t, int64_t> | ||
scalar_quantize(raft::resources const& res, | ||
const cuvs::neighbors::quantization::params& params, | ||
raft::host_matrix_view<const float, int64_t> dataset);*/ | ||
|
||
} // namespace cuvs::neighbors::quantization |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
/* | ||
* Copyright (c) 2024, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include <cuvs/neighbors/quantization.hpp> | ||
#include <raft/core/operators.hpp> | ||
#include <raft/linalg/unary_op.cuh> | ||
#include <raft/random/rng.cuh> | ||
#include <raft/random/sample_without_replacement.cuh> | ||
#include <thrust/execution_policy.h> | ||
#include <thrust/sort.h> | ||
|
||
namespace cuvs::neighbors::detail { | ||
|
||
template <typename T, typename QuantI = int8_t, typename TempT = double, typename TempI = int64_t> | ||
raft::device_matrix<QuantI, int64_t> scalar_quantize( | ||
raft::resources const& res, | ||
cuvs::neighbors::quantization::params& params, | ||
raft::device_matrix_view<const T, int64_t> dataset) | ||
{ | ||
cudaStream_t stream = raft::resource::get_cuda_stream(res); | ||
|
||
constexpr TempI q_type_min = static_cast<TempI>(std::numeric_limits<QuantI>::min()); | ||
constexpr TempI q_type_max = static_cast<TempI>(std::numeric_limits<QuantI>::max()); | ||
constexpr TempI range_q_type = q_type_max - q_type_min + TempI(1); | ||
|
||
size_t n_elements = dataset.extent(0) * dataset.extent(1); | ||
|
||
// conditional: search for quantiles / min / max | ||
if (!params.is_computed) { | ||
ASSERT(params.quantile > 0.5 && params.quantile <= 1.0, | ||
"quantile for scalar quantization needs to be within (.5, 1]"); | ||
|
||
double quantile_inv = 1.0 / params.quantile; | ||
|
||
// select subsample | ||
int seed = 137; | ||
constexpr size_t num_samples = 10000; | ||
raft::random::RngState rng(seed); | ||
size_t subset_size = std::min(num_samples, n_elements); | ||
auto subset = raft::make_device_vector<T>(res, subset_size); | ||
auto dataset_view = raft::make_device_vector_view<const T>(dataset.data_handle(), n_elements); | ||
raft::random::sample_without_replacement( | ||
res, rng, dataset_view, std::nullopt, subset.view(), std::nullopt); | ||
|
||
// quantile / sort and pick for now | ||
thrust::sort(raft::resource::get_thrust_policy(res), | ||
subset.data_handle(), | ||
subset.data_handle() + n_elements); | ||
|
||
int pos_max = raft::ceildiv((double)subset_size, quantile_inv) - 1; | ||
int pos_min = subset_size - pos_max - 1; | ||
|
||
T minmax_h[2]; | ||
raft::update_host(&(minmax_h[0]), subset.data_handle() + pos_min, 1, stream); | ||
raft::update_host(&(minmax_h[1]), subset.data_handle() + pos_max, 1, stream); | ||
raft::resource::sync_stream(res); | ||
|
||
// persist settings in params | ||
params.min = double(minmax_h[0]); | ||
params.max = double(minmax_h[1]); | ||
params.scalar = double(range_q_type) / (params.max - params.min + 1.0); | ||
params.is_computed = true; | ||
} | ||
|
||
// allocate target | ||
auto out = raft::make_device_matrix<QuantI, int64_t>(res, dataset.extent(0), dataset.extent(1)); | ||
|
||
// raft unary op or raft::linalg::map? | ||
// TempT / TempI as intermediate types | ||
raft::linalg::unaryOp(out.data_handle(), | ||
dataset.data_handle(), | ||
n_elements, | ||
raft::compose_op( | ||
raft::cast_op<QuantI>{}, | ||
raft::add_const_op<int>(q_type_min), | ||
[] __device__(TempI a) { | ||
return raft::max<TempI>(raft::min<TempI>(a, q_type_max - q_type_min), | ||
TempI(0)); | ||
}, | ||
raft::cast_op<TempI>{}, | ||
raft::add_const_op<TempT>(0.5), // for rounding | ||
raft::mul_const_op<TempT>(params.scalar), | ||
raft::sub_const_op<TempT>(params.min), | ||
raft::cast_op<TempT>{}), | ||
stream); | ||
|
||
return out; | ||
} | ||
|
||
/*template <typename T> | ||
raft::host_matrix<int8_t, int64_t> | ||
scalar_quantize(raft::resources const& res, | ||
const cuvs::neighbors::quantization::params& params, | ||
raft::host_matrix_view<const T, int64_t> dataset) | ||
{ | ||
}*/ | ||
|
||
} // namespace cuvs::neighbors::detail |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
/* | ||
* Copyright (c) 2024, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include "./detail/quantization.cuh" | ||
|
||
#include <cuvs/neighbors/quantization.hpp> | ||
|
||
namespace cuvs::neighbors::quantization { | ||
|
||
#define CUVS_INST_QUANTIZATION(T, QuantI) \ | ||
auto scalar_quantize(raft::resources const& res, \ | ||
params& params, \ | ||
raft::device_matrix_view<const T, int64_t> dataset) \ | ||
->raft::device_matrix<QuantI, int64_t> \ | ||
{ \ | ||
return detail::scalar_quantize<T, QuantI>(res, params, dataset); \ | ||
} | ||
|
||
CUVS_INST_QUANTIZATION(double, int8_t); | ||
CUVS_INST_QUANTIZATION(float, int8_t); | ||
CUVS_INST_QUANTIZATION(half, int8_t); | ||
|
||
#undef CUVS_INST_QUANTIZATION | ||
|
||
/* | ||
} \ | ||
auto scalar_quantize(raft::resources const& res, \ | ||
const params& params, \ | ||
raft::host_matrix_view<const T, int64_t> dataset) \ | ||
->raft::host_matrix<int8_t, int64_t> \ | ||
{ \ | ||
return detail::scalar_quantize<T>(res, params, dataset); \ | ||
*/ | ||
|
||
} // namespace cuvs::neighbors::quantization |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.