Skip to content

Commit

Permalink
Fix include errors, header, and unsafe locks in iface.hpp (#467)
Browse files Browse the repository at this point in the history
Fix a few issues with the internal header `neighbors/iface/iface.hpp`  leading to compile time errors and dangerous runtime behavior:

  - Add missing includes
  - Use `std::lock_guard` to avoid a deadlock on exception
  - Add NVIDIA header
  - Avoid an extra stream sync during search.

Authors:
  - Artem M. Chirkin (https://github.com/achirkin)

Approvers:
  - Victor Lafargue (https://github.com/viclafargue)
  - Corey J. Nolet (https://github.com/cjnolet)
  - Ben Frederickson (https://github.com/benfred)

URL: #467
  • Loading branch information
achirkin authored Nov 14, 2024
1 parent fdb1180 commit bb9c669
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 24 deletions.
2 changes: 2 additions & 0 deletions cpp/src/neighbors/cagra_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
#include <cuvs/neighbors/cagra.h>
#include <cuvs/neighbors/cagra.hpp>

#include <fstream>

namespace {

template <typename T>
Expand Down
53 changes: 29 additions & 24 deletions cpp/src/neighbors/iface/iface.hpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,30 @@
#include <mutex>
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <cuvs/neighbors/cagra.hpp>
#include <cuvs/neighbors/common.hpp>
#include <cuvs/neighbors/ivf_flat.hpp>
#include <cuvs/neighbors/ivf_pq.hpp>
#include <raft/core/device_resources.hpp>

#include <fstream>
#include <mutex>

namespace cuvs::neighbors {

using namespace raft;
Expand All @@ -16,7 +35,7 @@ void build(const raft::device_resources& handle,
const cuvs::neighbors::index_params* index_params,
raft::mdspan<const T, matrix_extent<int64_t>, row_major, Accessor> index_dataset)
{
interface.mutex_->lock();
std::lock_guard(*interface.mutex_);

if constexpr (std::is_same<AnnIndexType, ivf_flat::index<T, IdxT>>::value) {
auto idx = cuvs::neighbors::ivf_flat::build(
Expand All @@ -32,8 +51,6 @@ void build(const raft::device_resources& handle,
interface.index_.emplace(std::move(idx));
}
resource::sync_stream(handle);

interface.mutex_->unlock();
}

template <typename AnnIndexType, typename T, typename IdxT, typename Accessor1, typename Accessor2>
Expand All @@ -44,7 +61,7 @@ void extend(
std::optional<raft::mdspan<const IdxT, vector_extent<int64_t>, layout_c_contiguous, Accessor2>>
new_indices)
{
interface.mutex_->lock();
std::lock_guard(*interface.mutex_);

if constexpr (std::is_same<AnnIndexType, ivf_flat::index<T, IdxT>>::value) {
auto idx =
Expand All @@ -58,8 +75,6 @@ void extend(
RAFT_FAIL("CAGRA does not implement the extend method");
}
resource::sync_stream(handle);

interface.mutex_->unlock();
}

template <typename AnnIndexType, typename T, typename IdxT>
Expand All @@ -70,7 +85,7 @@ void search(const raft::device_resources& handle,
raft::device_matrix_view<IdxT, int64_t, row_major> neighbors,
raft::device_matrix_view<float, int64_t, row_major> distances)
{
// interface.mutex_->lock();
// std::lock_guard(*interface.mutex_);
if constexpr (std::is_same<AnnIndexType, ivf_flat::index<T, int64_t>>::value) {
cuvs::neighbors::ivf_flat::search(
handle,
Expand All @@ -94,9 +109,7 @@ void search(const raft::device_resources& handle,
neighbors,
distances);
}
resource::sync_stream(handle);

// interface.mutex_->unlock();
// resource::sync_stream(handle);
}

// for MG ANN only
Expand All @@ -108,7 +121,7 @@ void search(const raft::device_resources& handle,
raft::device_matrix_view<IdxT, int64_t, row_major> d_neighbors,
raft::device_matrix_view<float, int64_t, row_major> d_distances)
{
// interface.mutex_->lock();
// std::lock_guard(*interface.mutex_);

int64_t n_rows = h_queries.extent(0);
int64_t n_dims = h_queries.extent(1);
Expand All @@ -120,16 +133,14 @@ void search(const raft::device_resources& handle,
auto d_query_view = raft::make_const_mdspan(d_queries.view());

search(handle, interface, search_params, d_query_view, d_neighbors, d_distances);

// interface.mutex_->unlock();
}

template <typename AnnIndexType, typename T, typename IdxT>
void serialize(const raft::device_resources& handle,
const cuvs::neighbors::iface<AnnIndexType, T, IdxT>& interface,
std::ostream& os)
{
interface.mutex_->lock();
std::lock_guard(*interface.mutex_);

if constexpr (std::is_same<AnnIndexType, ivf_flat::index<T, IdxT>>::value) {
ivf_flat::serialize(handle, os, interface.index_.value());
Expand All @@ -138,16 +149,14 @@ void serialize(const raft::device_resources& handle,
} else if constexpr (std::is_same<AnnIndexType, cagra::index<T, IdxT>>::value) {
cagra::serialize(handle, os, interface.index_.value(), true);
}

interface.mutex_->unlock();
}

template <typename AnnIndexType, typename T, typename IdxT>
void deserialize(const raft::device_resources& handle,
cuvs::neighbors::iface<AnnIndexType, T, IdxT>& interface,
std::istream& is)
{
interface.mutex_->lock();
std::lock_guard(*interface.mutex_);

if constexpr (std::is_same<AnnIndexType, ivf_flat::index<T, IdxT>>::value) {
ivf_flat::index<T, IdxT> idx(handle);
Expand All @@ -162,16 +171,14 @@ void deserialize(const raft::device_resources& handle,
cagra::deserialize(handle, is, &idx);
interface.index_.emplace(std::move(idx));
}

interface.mutex_->unlock();
}

template <typename AnnIndexType, typename T, typename IdxT>
void deserialize(const raft::device_resources& handle,
cuvs::neighbors::iface<AnnIndexType, T, IdxT>& interface,
const std::string& filename)
{
interface.mutex_->lock();
std::lock_guard(*interface.mutex_);

std::ifstream is(filename, std::ios::in | std::ios::binary);
if (!is) { RAFT_FAIL("Cannot open file %s", filename.c_str()); }
Expand All @@ -191,8 +198,6 @@ void deserialize(const raft::device_resources& handle,
}

is.close();

interface.mutex_->unlock();
}

}; // namespace cuvs::neighbors
}; // namespace cuvs::neighbors
2 changes: 2 additions & 0 deletions cpp/src/neighbors/ivf_flat_c.cpp
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
#include <cuvs/neighbors/ivf_flat.h>
#include <cuvs/neighbors/ivf_flat.hpp>

#include <fstream>

namespace {

template <typename T, typename IdxT>
Expand Down
2 changes: 2 additions & 0 deletions cpp/src/neighbors/mg/mg.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
#include <cuvs/neighbors/common.hpp>
#include <cuvs/neighbors/mg.hpp>

#include <fstream>

namespace cuvs::neighbors {
using namespace raft;

Expand Down
4 changes: 4 additions & 0 deletions examples/cpp/src/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
* limitations under the License.
*/

#pragma once

#include <cstdint>
#include <raft/core/device_mdarray.hpp>
#include <raft/core/device_resources.hpp>
Expand All @@ -28,6 +30,8 @@
#include <thrust/device_ptr.h>
#include <thrust/iterator/counting_iterator.h>

#include <fstream>

// Fill dataset and queries with synthetic data.
void generate_dataset(raft::device_resources const &dev_resources,
raft::device_matrix_view<float, int64_t> dataset,
Expand Down

0 comments on commit bb9c669

Please sign in to comment.