Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update API for node2vec and biased random walks #4841

Open
wants to merge 49 commits into
base: branch-25.02
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 48 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
86b2038
add support for rng state
jnke2016 Dec 28, 2024
38690a6
update test to take rng state parameter
jnke2016 Dec 28, 2024
fd5b387
add support for rng state
jnke2016 Dec 28, 2024
9d56b5f
deprecate old API
jnke2016 Dec 28, 2024
615837e
add new API for node2vec random walks
jnke2016 Dec 28, 2024
21a76bb
add mg node2vec random walks to the python API
jnke2016 Dec 28, 2024
86a13d3
update docstrings
jnke2016 Dec 28, 2024
4c8744f
enable mg node2vec_random walks
jnke2016 Dec 28, 2024
4da1c7e
update argument list in function call
jnke2016 Dec 28, 2024
6984645
support optional weights
jnke2016 Dec 28, 2024
d04588a
update docstring and deprecate arguments
jnke2016 Dec 31, 2024
cb6a294
add new API for uniform_random_walks
jnke2016 Dec 31, 2024
7936026
deprecate method
jnke2016 Dec 31, 2024
7a5056f
update copyrights
jnke2016 Dec 31, 2024
e2e4694
add uniform random walks
jnke2016 Dec 31, 2024
877265b
add new API for node2vec random walks
jnke2016 Dec 31, 2024
bb77237
deprecate legacy implementation
jnke2016 Dec 31, 2024
618fe76
add random state argumment and update copyright
jnke2016 Dec 31, 2024
a1d004c
update header file to take as input a random state
jnke2016 Dec 31, 2024
bea2a2f
add support for rng state as input
jnke2016 Dec 31, 2024
755acc7
update tests to support rng state as input
jnke2016 Dec 31, 2024
ef00fa5
add biased random walks to the PLC API
jnke2016 Jan 1, 2025
ae4833c
add biased random walks to the python API
jnke2016 Jan 1, 2025
8314291
update docstrings and init file
jnke2016 Jan 1, 2025
1603bcd
fix typo
jnke2016 Jan 1, 2025
0a03b29
update copyright
jnke2016 Jan 1, 2025
88e405d
add mg implementation of biased and uniform random walks
jnke2016 Jan 1, 2025
c8265e7
update docstrings
jnke2016 Jan 2, 2025
0a4d29b
deprecate legacy implementation
jnke2016 Jan 2, 2025
4e0eff9
remove unused import
jnke2016 Jan 2, 2025
7c85269
update MG C tests
jnke2016 Jan 10, 2025
067d53b
remove unused variable and update the number of arrays passed at the …
jnke2016 Jan 10, 2025
b94a6ea
update copyright and remove debug print
jnke2016 Jan 11, 2025
e83722e
fix renumbering bug
jnke2016 Jan 11, 2025
2c1a034
enable MG tests and fix bugs
jnke2016 Jan 11, 2025
d521dfc
fix style
jnke2016 Jan 11, 2025
6418b96
remove unsued import
jnke2016 Jan 11, 2025
af3b31f
add type annotations
jnke2016 Jan 13, 2025
f0e3b0f
fix style
jnke2016 Jan 13, 2025
74648d4
deprecated old test suite
jnke2016 Jan 13, 2025
70b8d8b
add sg tests for uniform random walks
jnke2016 Jan 13, 2025
9b270c7
update copyright
jnke2016 Jan 13, 2025
9e850d2
update tests
jnke2016 Jan 14, 2025
10c471a
add support of multi column seeds
jnke2016 Jan 14, 2025
b94bad2
add support of multi column seeds for 'select_random_vertices'
jnke2016 Jan 14, 2025
af68c3f
add multi column tests
jnke2016 Jan 14, 2025
0c8f85e
add support of multi column seeds
jnke2016 Jan 14, 2025
37d3a47
add test for biased random walks
jnke2016 Jan 14, 2025
e7a5952
add mg ECG
jnke2016 Jan 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion cpp/include/cugraph_c/sampling_algorithms.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2024, NVIDIA CORPORATION.
* Copyright (c) 2021-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -41,6 +41,7 @@ typedef struct {
* @brief Compute uniform random walks
*
* @param [in] handle Handle for accessing resources
* @param [in,out] rng_state State of the random number generator, updated with each call
* @param [in] graph Pointer to graph. NOTE: Graph might be modified if the storage
* needs to be transposed
* @param [in] start_vertices Array of source vertices
Expand All @@ -52,6 +53,7 @@ typedef struct {
*/
cugraph_error_code_t cugraph_uniform_random_walks(
const cugraph_resource_handle_t* handle,
cugraph_rng_state_t* rng_state,
cugraph_graph_t* graph,
const cugraph_type_erased_device_array_view_t* start_vertices,
size_t max_length,
Expand All @@ -62,6 +64,7 @@ cugraph_error_code_t cugraph_uniform_random_walks(
* @brief Compute biased random walks
*
* @param [in] handle Handle for accessing resources
* @param [in,out] rng_state State of the random number generator, updated with each call
* @param [in] graph Pointer to graph. NOTE: Graph might be modified if the storage
* needs to be transposed
* @param [in] start_vertices Array of source vertices
Expand All @@ -73,6 +76,7 @@ cugraph_error_code_t cugraph_uniform_random_walks(
*/
cugraph_error_code_t cugraph_biased_random_walks(
const cugraph_resource_handle_t* handle,
cugraph_rng_state_t* rng_state,
cugraph_graph_t* graph,
const cugraph_type_erased_device_array_view_t* start_vertices,
size_t max_length,
Expand All @@ -83,6 +87,7 @@ cugraph_error_code_t cugraph_biased_random_walks(
* @brief Compute random walks using the node2vec framework.
*
* @param [in] handle Handle for accessing resources
* @param [in,out] rng_state State of the random number generator, updated with each call
* @param [in] graph Pointer to graph. NOTE: Graph might be modified if the storage
* needs to be transposed
* @param [in] start_vertices Array of source vertices
Expand All @@ -98,6 +103,7 @@ cugraph_error_code_t cugraph_biased_random_walks(
*/
cugraph_error_code_t cugraph_node2vec_random_walks(
const cugraph_resource_handle_t* handle,
cugraph_rng_state_t* rng_state,
cugraph_graph_t* graph,
const cugraph_type_erased_device_array_view_t* start_vertices,
size_t max_length,
Expand Down
54 changes: 27 additions & 27 deletions cpp/src/c_api/random_walks.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
* Copyright (c) 2022-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -154,19 +154,20 @@ namespace {

struct uniform_random_walks_functor : public cugraph::c_api::abstract_functor {
raft::handle_t const& handle_;
// FIXME: rng_state_ should be passed as a parameter
cugraph::c_api::cugraph_rng_state_t* rng_state_{nullptr};
cugraph::c_api::cugraph_graph_t* graph_{nullptr};
cugraph::c_api::cugraph_type_erased_device_array_view_t const* start_vertices_{nullptr};
size_t max_length_{0};
cugraph::c_api::cugraph_random_walk_result_t* result_{nullptr};

uniform_random_walks_functor(cugraph_resource_handle_t const* handle,
cugraph_rng_state_t* rng_state,
cugraph_graph_t* graph,
cugraph_type_erased_device_array_view_t const* start_vertices,
size_t max_length)
: abstract_functor(),
handle_(*reinterpret_cast<cugraph::c_api::cugraph_resource_handle_t const*>(handle)->handle_),
rng_state_(reinterpret_cast<cugraph::c_api::cugraph_rng_state_t*>(rng_state)),
graph_(reinterpret_cast<cugraph::c_api::cugraph_graph_t*>(graph)),
start_vertices_(
reinterpret_cast<cugraph::c_api::cugraph_type_erased_device_array_view_t const*>(
Expand Down Expand Up @@ -224,10 +225,6 @@ struct uniform_random_walks_functor : public cugraph::c_api::abstract_functor {
graph_view.local_vertex_partition_range_last(),
false);

// FIXME: remove once rng_state passed as parameter
rng_state_ = reinterpret_cast<cugraph::c_api::cugraph_rng_state_t*>(
new cugraph::c_api::cugraph_rng_state_t{raft::random::RngState{0}});

auto [paths, weights] = cugraph::uniform_random_walks(
handle_,
rng_state_->rng_state_,
Expand Down Expand Up @@ -261,19 +258,20 @@ struct uniform_random_walks_functor : public cugraph::c_api::abstract_functor {

struct biased_random_walks_functor : public cugraph::c_api::abstract_functor {
raft::handle_t const& handle_;
// FIXME: rng_state_ should be passed as a parameter
cugraph::c_api::cugraph_rng_state_t* rng_state_{nullptr};
cugraph::c_api::cugraph_graph_t* graph_{nullptr};
cugraph::c_api::cugraph_type_erased_device_array_view_t const* start_vertices_{nullptr};
size_t max_length_{0};
cugraph::c_api::cugraph_random_walk_result_t* result_{nullptr};

biased_random_walks_functor(cugraph_resource_handle_t const* handle,
cugraph_rng_state_t* rng_state,
cugraph_graph_t* graph,
cugraph_type_erased_device_array_view_t const* start_vertices,
size_t max_length)
: abstract_functor(),
handle_(*reinterpret_cast<cugraph::c_api::cugraph_resource_handle_t const*>(handle)->handle_),
rng_state_(reinterpret_cast<cugraph::c_api::cugraph_rng_state_t*>(rng_state)),
graph_(reinterpret_cast<cugraph::c_api::cugraph_graph_t*>(graph)),
start_vertices_(
reinterpret_cast<cugraph::c_api::cugraph_type_erased_device_array_view_t const*>(
Expand All @@ -293,8 +291,6 @@ struct biased_random_walks_functor : public cugraph::c_api::abstract_functor {
// FIXME: Think about how to handle SG vice MG
if constexpr (!cugraph::is_candidate<vertex_t, edge_t, weight_t>::value) {
unsupported();
} else if constexpr (multi_gpu) {
unsupported();
} else {
// random walks expects store_transposed == false
if constexpr (store_transposed) {
Expand Down Expand Up @@ -333,10 +329,6 @@ struct biased_random_walks_functor : public cugraph::c_api::abstract_functor {
graph_view.local_vertex_partition_range_last(),
false);

// FIXME: remove once rng_state passed as parameter
rng_state_ = reinterpret_cast<cugraph::c_api::cugraph_rng_state_t*>(
new cugraph::c_api::cugraph_rng_state_t{raft::random::RngState{0}});

auto [paths, weights] = cugraph::biased_random_walks(
handle_,
rng_state_->rng_state_,
Expand All @@ -348,8 +340,13 @@ struct biased_random_walks_functor : public cugraph::c_api::abstract_functor {
//
// Need to unrenumber the vertices in the resulting paths
//
cugraph::unrenumber_local_int_vertices<vertex_t>(
handle_, paths.data(), paths.size(), number_map->data(), 0, paths.size() - 1, false);
cugraph::unrenumber_int_vertices<vertex_t, multi_gpu>(
handle_,
paths.data(),
paths.size(),
number_map->data(),
graph_view.vertex_partition_range_lasts(),
false);

result_ = new cugraph::c_api::cugraph_random_walk_result_t{
false,
Expand All @@ -365,7 +362,6 @@ struct biased_random_walks_functor : public cugraph::c_api::abstract_functor {

struct node2vec_random_walks_functor : public cugraph::c_api::abstract_functor {
raft::handle_t const& handle_;
// FIXME: rng_state_ should be passed as a parameter
cugraph::c_api::cugraph_rng_state_t* rng_state_{nullptr};
cugraph::c_api::cugraph_graph_t* graph_{nullptr};
cugraph::c_api::cugraph_type_erased_device_array_view_t const* start_vertices_{nullptr};
Expand All @@ -375,13 +371,15 @@ struct node2vec_random_walks_functor : public cugraph::c_api::abstract_functor {
cugraph::c_api::cugraph_random_walk_result_t* result_{nullptr};

node2vec_random_walks_functor(cugraph_resource_handle_t const* handle,
cugraph_rng_state_t* rng_state,
cugraph_graph_t* graph,
cugraph_type_erased_device_array_view_t const* start_vertices,
size_t max_length,
double p,
double q)
: abstract_functor(),
handle_(*reinterpret_cast<cugraph::c_api::cugraph_resource_handle_t const*>(handle)->handle_),
rng_state_(reinterpret_cast<cugraph::c_api::cugraph_rng_state_t*>(rng_state)),
graph_(reinterpret_cast<cugraph::c_api::cugraph_graph_t*>(graph)),
start_vertices_(
reinterpret_cast<cugraph::c_api::cugraph_type_erased_device_array_view_t const*>(
Expand All @@ -403,8 +401,6 @@ struct node2vec_random_walks_functor : public cugraph::c_api::abstract_functor {
// FIXME: Think about how to handle SG vice MG
if constexpr (!cugraph::is_candidate<vertex_t, edge_t, weight_t>::value) {
unsupported();
} else if constexpr (multi_gpu) {
unsupported();
} else {
// random walks expects store_transposed == false
if constexpr (store_transposed) {
Expand Down Expand Up @@ -443,10 +439,6 @@ struct node2vec_random_walks_functor : public cugraph::c_api::abstract_functor {
graph_view.local_vertex_partition_range_last(),
false);

// FIXME: remove once rng_state passed as parameter
rng_state_ = reinterpret_cast<cugraph::c_api::cugraph_rng_state_t*>(
new cugraph::c_api::cugraph_rng_state_t{raft::random::RngState{0}});

auto [paths, weights] = cugraph::node2vec_random_walks(
handle_,
rng_state_->rng_state_,
Expand All @@ -464,8 +456,13 @@ struct node2vec_random_walks_functor : public cugraph::c_api::abstract_functor {
//
// Need to unrenumber the vertices in the resulting paths
//
cugraph::unrenumber_local_int_vertices<vertex_t>(
handle_, paths.data(), paths.size(), number_map->data(), 0, paths.size(), false);
cugraph::unrenumber_int_vertices<vertex_t, multi_gpu>(
handle_,
paths.data(),
paths.size(),
number_map->data(),
graph_view.vertex_partition_range_lasts(),
false);

result_ = new cugraph::c_api::cugraph_random_walk_result_t{
false,
Expand Down Expand Up @@ -546,6 +543,7 @@ void cugraph_random_walk_result_free(cugraph_random_walk_result_t* result)

cugraph_error_code_t cugraph_uniform_random_walks(
const cugraph_resource_handle_t* handle,
cugraph_rng_state_t* rng_state,
cugraph_graph_t* graph,
const cugraph_type_erased_device_array_view_t* start_vertices,
size_t max_length,
Expand All @@ -560,13 +558,14 @@ cugraph_error_code_t cugraph_uniform_random_walks(
"vertex type of graph and start_vertices must match",
*error);

uniform_random_walks_functor functor(handle, graph, start_vertices, max_length);
uniform_random_walks_functor functor(handle, rng_state, graph, start_vertices, max_length);

return cugraph::c_api::run_algorithm(graph, functor, result, error);
}

cugraph_error_code_t cugraph_biased_random_walks(
const cugraph_resource_handle_t* handle,
cugraph_rng_state_t* rng_state,
cugraph_graph_t* graph,
const cugraph_type_erased_device_array_view_t* start_vertices,
size_t max_length,
Expand All @@ -581,13 +580,14 @@ cugraph_error_code_t cugraph_biased_random_walks(
"vertex type of graph and start_vertices must match",
*error);

biased_random_walks_functor functor(handle, graph, start_vertices, max_length);
biased_random_walks_functor functor(handle, rng_state, graph, start_vertices, max_length);

return cugraph::c_api::run_algorithm(graph, functor, result, error);
}

cugraph_error_code_t cugraph_node2vec_random_walks(
const cugraph_resource_handle_t* handle,
cugraph_rng_state_t* rng_state,
cugraph_graph_t* graph,
const cugraph_type_erased_device_array_view_t* start_vertices,
size_t max_length,
Expand All @@ -604,7 +604,7 @@ cugraph_error_code_t cugraph_node2vec_random_walks(
"vertex type of graph and start_vertices must match",
*error);

node2vec_random_walks_functor functor(handle, graph, start_vertices, max_length, p, q);
node2vec_random_walks_functor functor(handle, rng_state, graph, start_vertices, max_length, p, q);

return cugraph::c_api::run_algorithm(graph, functor, result, error);
}
Loading
Loading