Skip to content

Commit

Permalink
Merge branch 'branch-23.10' into 23.10-cagra-remove
Browse files Browse the repository at this point in the history
  • Loading branch information
cjnolet authored Oct 2, 2023
2 parents 7ebf87f + 120cff4 commit 378fc52
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 17 deletions.
2 changes: 1 addition & 1 deletion ci/test_wheel_raft_dask.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ RAPIDS_PY_WHEEL_NAME="pylibraft_${RAPIDS_PY_CUDA_SUFFIX}" rapids-download-wheels
python -m pip install --no-deps ./local-pylibraft-dep/pylibraft*.whl

# Always install latest dask for testing
python -m pip install git+https://github.com/dask/dask.git@main git+https://github.com/dask/distributed.git@main git+https://github.com/rapidsai/[email protected]
python -m pip install git+https://github.com/dask/dask.git@2023.9.2 git+https://github.com/dask/distributed.git@2023.9.2 git+https://github.com/rapidsai/[email protected]

# echo to expand wildcard before adding `[extra]` requires for pip
python -m pip install $(echo ./dist/raft_dask*.whl)[test]
Expand Down
6 changes: 3 additions & 3 deletions conda/environments/all_cuda-118_arch-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ dependencies:
- cupy>=12.0.0
- cxx-compiler
- cython>=3.0.0
- dask-core>=2023.7.1
- dask-core==2023.9.2
- dask-cuda==23.10.*
- dask>=2023.7.1
- distributed>=2023.7.1
- dask==2023.9.2
- distributed==2023.9.2
- doxygen>=1.8.20
- gcc_linux-64=11.*
- gmock>=1.13.0
Expand Down
6 changes: 3 additions & 3 deletions conda/environments/all_cuda-120_arch-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ dependencies:
- cupy>=12.0.0
- cxx-compiler
- cython>=3.0.0
- dask-core>=2023.7.1
- dask-core==2023.9.2
- dask-cuda==23.10.*
- dask>=2023.7.1
- distributed>=2023.7.1
- dask==2023.9.2
- distributed==2023.9.2
- doxygen>=1.8.20
- gcc_linux-64=11.*
- gmock>=1.13.0
Expand Down
6 changes: 3 additions & 3 deletions conda/recipes/raft-dask/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@ requirements:
- cudatoolkit
{% endif %}
- {{ pin_compatible('cuda-version', max_pin='x', min_pin='x') }}
- dask >=2023.7.1
- dask-core >=2023.7.1
- dask ==2023.9.2
- dask-core ==2023.9.2
- dask-cuda ={{ minor_version }}
- distributed >=2023.7.1
- distributed ==2023.9.2
- joblib >=0.11
- nccl >=2.9.9
- pylibraft {{ version }}
Expand Down
25 changes: 23 additions & 2 deletions cpp/include/raft/neighbors/brute_force_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,11 @@ struct index : ann::index {
/** Dataset norms */
[[nodiscard]] inline auto norms() const -> device_vector_view<const T, int64_t, row_major>
{
return make_const_mdspan(norms_.value().view());
return norms_view_.value();
}

/** Whether ot not this index has dataset norms */
[[nodiscard]] inline bool has_norms() const noexcept { return norms_.has_value(); }
[[nodiscard]] inline bool has_norms() const noexcept { return norms_view_.has_value(); }

[[nodiscard]] inline T metric_arg() const noexcept { return metric_arg_; }

Expand Down Expand Up @@ -102,10 +102,30 @@ struct index : ann::index {
norms_(std::move(norms)),
metric_arg_(metric_arg)
{
if (norms_) { norms_view_ = make_const_mdspan(norms_.value().view()); }
update_dataset(res, dataset);
resource::sync_stream(res);
}

/** Construct a brute force index from dataset
*
* This class stores a non-owning reference to the dataset and norms here.
* Having precomputed norms gives us a performance advantage at query time.
*/
index(raft::resources const& res,
raft::device_matrix_view<const T, int64_t, row_major> dataset_view,
std::optional<raft::device_vector_view<const T, int64_t>> norms_view,
raft::distance::DistanceType metric,
T metric_arg = 0.0)
: ann::index(),
metric_(metric),
dataset_(make_device_matrix<T, int64_t>(res, 0, 0)),
dataset_view_(dataset_view),
norms_view_(norms_view),
metric_arg_(metric_arg)
{
}

private:
/**
* Replace the dataset with a new dataset.
Expand Down Expand Up @@ -135,6 +155,7 @@ struct index : ann::index {
raft::distance::DistanceType metric_;
raft::device_matrix<T, int64_t, row_major> dataset_;
std::optional<raft::device_vector<T, int64_t>> norms_;
std::optional<raft::device_vector_view<const T, int64_t>> norms_view_;
raft::device_matrix_view<const T, int64_t, row_major> dataset_view_;
T metric_arg_;
};
Expand Down
6 changes: 3 additions & 3 deletions dependencies.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -430,15 +430,15 @@ dependencies:
common:
- output_types: [conda, pyproject]
packages:
- dask>=2023.7.1
- dask==2023.9.2
- dask-cuda==23.10.*
- distributed>=2023.7.1
- distributed==2023.9.2
- joblib>=0.11
- numba>=0.57
- *numpy
- output_types: conda
packages:
- dask-core>=2023.7.1
- dask-core==2023.9.2
- ucx>=1.13.0
- ucx-proc=*=gpu
- &ucx_py_conda ucx-py==0.34.*
Expand Down
4 changes: 2 additions & 2 deletions python/raft-dask/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ license = { text = "Apache 2.0" }
requires-python = ">=3.9"
dependencies = [
"dask-cuda==23.10.*",
"dask>=2023.7.1",
"distributed>=2023.7.1",
"dask==2023.9.2",
"distributed==2023.9.2",
"joblib>=0.11",
"numba>=0.57",
"numpy>=1.21",
Expand Down

0 comments on commit 378fc52

Please sign in to comment.