Skip to content

Commit

Permalink
Refactor spectral scale_obs to use existing normalization function (#…
Browse files Browse the repository at this point in the history
…2319)

The scale_obs function was calling a custom kernel to scale the elements of a matrix column by the l2-norm of the column.

There were two issues:
1. The kernel launch parameters would go out of bounds if the graph was too large.  The Y dimension is limited to 65535, but there was no logic in the function to ensure that we didn't set the Y value larger than that
3. A bug in the kernel, the column norm was not being calculated correctly... the outer loop was terminating, hence we were really only computing the column norm of the last column in the block.  Then we were normalizing all columns in the block by that value instead of by each value.

To simplify (there's going to be some optimization work done this summer), I replaced this with a simple thrust call that will scale the values correctly.

Authors:
  - Chuck Hastings (https://github.com/ChuckHastings)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #2319
  • Loading branch information
ChuckHastings authored May 17, 2024
1 parent 7e37451 commit 12f0096
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 82 deletions.
2 changes: 2 additions & 0 deletions cpp/include/raft/linalg/normalize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@

#include "detail/normalize.cuh"

#include <raft/core/device_mdspan.hpp>
#include <raft/core/operators.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/linalg/norm_types.hpp>
#include <raft/util/input_validation.hpp>

namespace raft {
namespace linalg {
Expand Down
6 changes: 4 additions & 2 deletions cpp/include/raft/spectral/detail/modularity_maximization.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <raft/core/resource/cublas_handle.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/linalg/detail/cublas_wrappers.hpp>
#include <raft/linalg/normalize.cuh>
#include <raft/spectral/cluster_solvers.cuh>
#include <raft/spectral/detail/spectral_util.cuh>
#include <raft/spectral/eigen_solvers.cuh>
Expand Down Expand Up @@ -101,8 +102,9 @@ std::tuple<vertex_t, weight_t, vertex_t> modularity_maximization(

// notice that at this point the matrix has already been transposed, so we are scaling
// columns
scale_obs(nEigVecs, n, eigVecs);
RAFT_CHECK_CUDA(stream);
auto dataset_view = raft::make_device_matrix_view(eigVecs, nEigVecs, n);
raft::linalg::row_normalize(
handle, raft::make_const_mdspan(dataset_view), dataset_view, raft::linalg::L2Norm);

// Find partition clustering
auto pair_cluster = cluster_solver.solve(handle, n, nEigVecs, eigVecs, clusters);
Expand Down
81 changes: 1 addition & 80 deletions cpp/include/raft/spectral/detail/spectral_util.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION.
* Copyright (c) 2020-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -39,85 +39,6 @@
namespace raft {
namespace spectral {

template <typename index_type_t, typename value_type_t>
RAFT_KERNEL scale_obs_kernel(index_type_t m, index_type_t n, value_type_t* obs)
{
index_type_t i, j, k, index, mm;
value_type_t alpha, v, last;
bool valid;
// ASSUMPTION: kernel is launched with either 2, 4, 8, 16 or 32 threads in x-dimension

// compute alpha
mm = (((m + blockDim.x - 1) / blockDim.x) * blockDim.x); // m in multiple of blockDim.x
alpha = 0.0;

for (j = threadIdx.y + blockIdx.y * blockDim.y; j < n; j += blockDim.y * gridDim.y) {
for (i = threadIdx.x; i < mm; i += blockDim.x) {
// check if the thread is valid
valid = i < m;

// get the value of the last thread
last = __shfl_sync(warp_full_mask(), alpha, blockDim.x - 1, blockDim.x);

// if you are valid read the value from memory, otherwise set your value to 0
alpha = (valid) ? obs[i + j * m] : 0.0;
alpha = alpha * alpha;

// do prefix sum (of size warpSize=blockDim.x =< 32)
for (k = 1; k < blockDim.x; k *= 2) {
v = __shfl_up_sync(warp_full_mask(), alpha, k, blockDim.x);
if (threadIdx.x >= k) alpha += v;
}
// shift by last
alpha += last;
}
}

// scale by alpha
alpha = __shfl_sync(warp_full_mask(), alpha, blockDim.x - 1, blockDim.x);
alpha = raft::sqrt(alpha);
for (j = threadIdx.y + blockIdx.y * blockDim.y; j < n; j += blockDim.y * gridDim.y) {
for (i = threadIdx.x; i < m; i += blockDim.x) { // blockDim.x=32
index = i + j * m;
obs[index] = obs[index] / alpha;
}
}
}

template <typename index_type_t>
index_type_t next_pow2(index_type_t n)
{
index_type_t v;
// Reference:
// http://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2Float
v = n - 1;
v |= v >> 1;
v |= v >> 2;
v |= v >> 4;
v |= v >> 8;
v |= v >> 16;
return v + 1;
}

template <typename index_type_t, typename value_type_t>
cudaError_t scale_obs(index_type_t m, index_type_t n, value_type_t* obs)
{
index_type_t p2m;

// find next power of 2
p2m = next_pow2<index_type_t>(m);
// setup launch configuration
unsigned int xsize = std::max(2, std::min(p2m, 32));
dim3 nthreads{xsize, 256 / xsize, 1};

dim3 nblocks{1, (n + nthreads.y - 1) / nthreads.y, 1};

// launch scaling kernel (scale each column of obs by its norm)
scale_obs_kernel<index_type_t, value_type_t><<<nblocks, nthreads>>>(m, n, obs);

return cudaSuccess;
}

template <typename vertex_t, typename edge_t, typename weight_t>
void transform_eigen_matrix(raft::resources const& handle,
edge_t n,
Expand Down

0 comments on commit 12f0096

Please sign in to comment.