Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing Googletests and re-enabling in CI #1904

Merged
merged 20 commits into from
Oct 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ci/test_cpp.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ trap "EXITCODE=1" ERR
set +e

# Run libraft gtests from libraft-tests package
cd "$CONDA_PREFIX"/bin/gtests/libraft
ctest -j8 --output-on-failure

rapids-logger "Test script exiting with value: $EXITCODE"
Expand Down
42 changes: 33 additions & 9 deletions cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,22 @@

namespace raft::distance::detail::ops {

/**
* Reserve 1 digit of precision from each floating-point type
* for round-off error tolerance.
* @tparam DataT
*/
template <typename DataT>
__device__ constexpr DataT get_clamp_precision()
{
switch (sizeof(DataT)) {
case 2: return 1e-3;
case 4: return 1e-6;
case 8: return 1e-15;
default: return 0;
}
}

// Epilogue operator for CUTLASS based kernel
template <typename DataT, typename AccT>
struct l2_exp_cutlass_op {
Expand All @@ -31,11 +47,13 @@ struct l2_exp_cutlass_op {
__device__ AccT operator()(DataT& aNorm, const DataT& bNorm, DataT& accVal) const noexcept
{
AccT outVal = aNorm + bNorm - DataT(2.0) * accVal;
// outVal could be negative due to numerical instability, especially when
// calculating self distance.
// clamp to 0 to avoid potential NaN in sqrt
outVal = outVal * (raft::abs(outVal) >= DataT(0.0001));
return sqrt ? raft::sqrt(outVal) : outVal;

/**
* Self-neighboring points should have (aNorm == bNorm) == accVal and the dot product (accVal)
* can sometimes have round-off errors, which will cause (aNorm == bNorm) ~ accVal instead.
*/
outVal = outVal * !((outVal * outVal < get_clamp_precision<DataT>()) * (aNorm == bNorm));
return sqrt ? raft::sqrt(outVal * (outVal > 0)) : outVal;
}

__device__ AccT operator()(DataT aData) const noexcept { return aData; }
Expand Down Expand Up @@ -86,10 +104,16 @@ struct l2_exp_distance_op {
for (int i = 0; i < Policy::AccRowsPerTh; ++i) {
#pragma unroll
for (int j = 0; j < Policy::AccColsPerTh; ++j) {
DataT val = regxn[i] + regyn[j] - (DataT)2.0 * acc[i][j];
// val could be negative due to numerical instability, especially when
// calculating self distance. Clamp to 0 to avoid potential NaN in sqrt
acc[i][j] = val * (raft::abs(val) >= DataT(0.0001));
DataT accVal = acc[i][j];
DataT val = regxn[i] + regyn[j] - (DataT)2.0 * accVal;

/**
* Self-neighboring points should have (aNorm == bNorm) == accVal and the dot product
* (accVal) can sometimes have round-off errors, which will cause (aNorm == bNorm) ~ accVal
* instead.
*/
acc[i][j] =
val * (val > 0) * !((val * val < get_clamp_precision<DataT>()) * (regxn[i] == regyn[j]));
cjnolet marked this conversation as resolved.
Show resolved Hide resolved
}
}
if (sqrt) {
Expand Down
14 changes: 5 additions & 9 deletions cpp/include/raft/neighbors/detail/knn_brute_force.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <cstdint>
#include <iostream>
#include <raft/core/resources.hpp>
#include <raft/distance/detail/distance_ops/l2_exp.cuh>
#include <raft/distance/distance.cuh>
#include <raft/distance/distance_types.hpp>
#include <raft/linalg/map.cuh>
Expand Down Expand Up @@ -186,6 +187,7 @@ void tiled_brute_force_knn(const raft::resources& handle,
auto row_norms = search_norms.data();
auto col_norms = precomputed_index_norms ? precomputed_index_norms : index_norms.data();
auto dist = temp_distances.data();
bool sqrt = metric == raft::distance::DistanceType::L2SqrtExpanded;

raft::linalg::map_offset(
handle,
Expand All @@ -194,15 +196,9 @@ void tiled_brute_force_knn(const raft::resources& handle,
IndexType row = i + (idx / current_centroid_size);
IndexType col = j + (idx % current_centroid_size);

auto val = row_norms[row] + col_norms[col] - 2.0 * dist[idx];

// due to numerical instability (especially around self-distance)
// the distances here could be slightly negative, which will
// cause NaN values in the subsequent sqrt. Clamp to 0
val = val * (val >= 0.0001);
if (metric == raft::distance::DistanceType::L2SqrtExpanded) { val = sqrt(val); }
val = distance_epilogue(val, row, col);
return val;
raft::distance::detail::ops::l2_exp_cutlass_op<ElementType, ElementType> l2_op(sqrt);
auto val = l2_op(row_norms[row], col_norms[col], dist[idx]);
return distance_epilogue(val, row, col);
});
} else if (metric == raft::distance::DistanceType::CosineExpanded) {
auto row_norms = search_norms.data();
Expand Down
3 changes: 2 additions & 1 deletion cpp/test/distance/fused_l2_nn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ RAFT_KERNEL naiveKernel(raft::KeyValuePair<int, DataT>* min,
auto diff = midx >= m || nidx >= n ? DataT(0) : x[xidx] - y[yidx];
acc += diff * diff;
}

if (Sqrt) { acc = raft::sqrt(acc); }
ReduceOpT redOp;
typedef cub::WarpReduce<raft::KeyValuePair<int, DataT>> WarpReduce;
Expand Down Expand Up @@ -343,7 +344,7 @@ const std::vector<Inputs<double>> inputsd = {
{0.00001, 128, 32, 33, 1234ULL}, {0.00001, 128, 64, 33, 1234ULL},
{0.00001, 128, 128, 65, 1234ULL}, {0.00001, 64, 128, 129, 1234ULL},

{0.00001, 1805, 134, 2, 1234ULL}, {0.00001, 8192, 1024, 25, 1234ULL},
{0.00001, 1805, 134, 2, 1234ULL}, //{0.00001, 8192, 1024, 25, 1234ULL},
};
typedef FusedL2NNTest<double, false> FusedL2NNTestD_Sq;
TEST_P(FusedL2NNTestD_Sq, Result)
Expand Down
6 changes: 3 additions & 3 deletions cpp/test/neighbors/ann_cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ class AnnCagraTest : public ::testing::TestWithParam<AnnCagraInputs> {
distances_Cagra,
ps.n_queries,
ps.k,
0.001,
0.003,
min_recall));
EXPECT_TRUE(eval_distances(handle_,
database.data(),
Expand Down Expand Up @@ -515,7 +515,7 @@ class AnnCagraFilterTest : public ::testing::TestWithParam<AnnCagraInputs> {
distances_Cagra,
ps.n_queries,
ps.k,
0.001,
0.003,
min_recall));
EXPECT_TRUE(eval_distances(handle_,
database.data(),
Expand Down Expand Up @@ -628,7 +628,7 @@ class AnnCagraFilterTest : public ::testing::TestWithParam<AnnCagraInputs> {
distances_Cagra,
ps.n_queries,
ps.k,
0.001,
0.003,
min_recall));
EXPECT_TRUE(eval_distances(handle_,
database.data(),
Expand Down
2 changes: 1 addition & 1 deletion cpp/test/neighbors/ann_ivf_pq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ class ivf_pq_test : public ::testing::TestWithParam<ivf_pq_inputs> {
// Hence, encoding-decoding chain often leads to altering both the PQ codes and the
// reconstructed data.
compare_vectors_l2(
handle_, vectors_1.view(), vectors_2.view(), label, compression_ratio, 0.025);
handle_, vectors_1.view(), vectors_2.view(), label, compression_ratio, 0.04); // 0.025);
}

void check_packing(index<IdxT>* index, uint32_t label)
Expand Down
33 changes: 12 additions & 21 deletions docs/source/raft_ann_benchmarks.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,6 @@ You can see the exact versions as well in the dockerhub site:

[//]: # (```)



## How to run the benchmarks

We provide a collection of lightweight Python scripts to run the benchmarks. There are 4 general steps to running the benchmarks and visualizing the results.
Expand Down Expand Up @@ -118,17 +116,6 @@ will be written at location `datasets/glove-100-inner/`.
### Step 2: Build and Search Index
The script `raft-ann-bench.run` will build and search indices for a given dataset and its
specified configuration.
To confirgure which algorithms are available, we use `algos.yaml`.
To configure building/searching indices for a dataset, look at [index configuration](#json-index-config).
An entry in `algos.yaml` looks like:
```yaml
raft_ivf_pq:
executable: RAFT_IVF_PQ_ANN_BENCH
requires_gpu: true
```
`executable` : specifies the name of the binary that will build/search the index. It is assumed to be
available in `raft/cpp/build/`.
`requires_gpu` : denotes whether an algorithm requires GPU to run.

The usage of the script `raft-ann-bench.run` is:
```bash
Expand Down Expand Up @@ -294,8 +281,6 @@ options:
Path to billion-scale dataset groundtruth file (default: None)
```



### Running with Docker containers

Two methods are provided for running the benchmarks with the Docker containers.
Expand Down Expand Up @@ -410,14 +395,8 @@ The table below contains the possible settings for the `algo` field. Each unique
| HNSWlib | `hnswlib` |
| RAFT | `raft_brute_force`, `raft_cagra`, `raft_ivf_flat`, `raft_ivf_pq` |




By default, the index will be placed in `bench/ann/data/<dataset_name>/index/<name>`. Using `sift-128-euclidean` for the dataset with the `algo` example above, the indexes would be placed in `bench/ann/data/sift-128-euclidean/index/algo_name/param1_val1-param2_val2`.




## Adding a new ANN algorithm

### Implementation and Configuration
Expand Down Expand Up @@ -490,6 +469,7 @@ How to interpret these JSON objects is totally left to the implementation and sh
}
```


### Adding a CMake Target
In `raft/cpp/bench/ann/CMakeLists.txt`, we provide a `CMake` function to configure a new Benchmark target with the following signature:
```
Expand All @@ -511,3 +491,14 @@ ConfigureAnnBench(
```

This will create an executable called `HNSWLIB_ANN_BENCH`, which can then be used to run `HNSWLIB` benchmarks.

Add a new entry to `algos.yaml` to map the name of the algorithm to its binary executable and specify whether the algorithm requires GPU support.
```yaml
raft_ivf_pq:
executable: RAFT_IVF_PQ_ANN_BENCH
requires_gpu: true
```

`executable` : specifies the name of the binary that will build/search the index. It is assumed to be
available in `raft/cpp/build/`.
`requires_gpu` : denotes whether an algorithm requires GPU to run.
8 changes: 2 additions & 6 deletions python/pylibraft/pylibraft/test/test_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
from pylibraft.distance import pairwise_distance


@pytest.mark.parametrize("n_rows", [32, 100])
@pytest.mark.parametrize("n_cols", [40, 100])
@pytest.mark.parametrize("n_rows", [50, 100])
@pytest.mark.parametrize("n_cols", [10, 50])
@pytest.mark.parametrize(
"metric",
[
Expand Down Expand Up @@ -63,8 +63,6 @@ def test_distance(n_rows, n_cols, inplace, metric, order, dtype):
else:
expected = cdist(input1, input1, metric)

expected[expected <= 1e-5] = 0.0

input1_device = device_ndarray(input1)
output_device = device_ndarray(output) if inplace else None

Expand All @@ -79,6 +77,4 @@ def test_distance(n_rows, n_cols, inplace, metric, order, dtype):

actual = output_device.copy_to_host()

actual[actual <= 1e-5] = 0.0

assert np.allclose(expected, actual, atol=1e-3, rtol=1e-3)