-
Notifications
You must be signed in to change notification settings - Fork 310
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
support negative sampling, coo in PLC
- Loading branch information
1 parent
e413c8a
commit 61fb2d3
Showing
4 changed files
with
293 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# Have cython use python 3 syntax | ||
# cython: language_level = 3 | ||
|
||
|
||
from pylibcugraph._cugraph_c.coo cimport ( | ||
cugraph_coo_t, | ||
) | ||
|
||
cdef class COO: | ||
cdef cugraph_coo_t* c_coo_ptr | ||
cdef set_ptr(self, cugraph_coo_t* ptr) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# Have cython use python 3 syntax | ||
# cython: language_level = 3 | ||
|
||
from pylibcugraph._cugraph_c.coo cimport ( | ||
cugraph_coo_t, | ||
cugraph_coo_free, | ||
cugraph_coo_get_sources, | ||
cugraph_coo_get_destinations, | ||
cugraph_coo_get_edge_weights, | ||
cugraph_coo_get_edge_id, | ||
cugraph_coo_get_edge_type, | ||
) | ||
|
||
cdef class COO: | ||
""" | ||
Cython interface to a cugraph_coo_t pointer. Instances of this | ||
call will take ownership of the pointer and free it under standard python | ||
GC rules (ie. when all references to it are no longer present). | ||
This class provides methods to return non-owning cupy ndarrays for the | ||
corresponding array members. Returning these cupy arrays increments the ref | ||
count on the COO instances from which the cupy arrays are | ||
referencing. | ||
""" | ||
def __cinit__(self): | ||
# This COO instance owns sample_result_ptr now. It will be | ||
# freed when this instance is deleted (see __dealloc__()) | ||
self.c_coo_ptr = NULL | ||
|
||
def __dealloc__(self): | ||
if self.c_coo_ptr is not NULL: | ||
cugraph_coo_free(self.c_coo_ptr) | ||
|
||
cdef set_ptr(self, cugraph_coo_t* ptr): | ||
self.c_coo_ptr = ptr | ||
|
||
def get_array(self, cugraph_type_erased_device_array_view_t* ptr): | ||
if ptr is NULL: | ||
return None | ||
|
||
return create_cupy_array_view_for_device_ptr( | ||
ptr, | ||
self, | ||
) | ||
|
||
def get_sources(self): | ||
if self.c_coo_ptr is NULL: | ||
raise ValueError("pointer not set, must call set_ptr() with a " | ||
"non-NULL value first.") | ||
return get_array( | ||
<cugraph_type_erased_device_array_view_t*>cugraph_coo_get_sources(self.c_sample_result_ptr) | ||
) | ||
|
||
def get_destinations(self): | ||
if self.c_coo_ptr is NULL: | ||
raise ValueError("pointer not set, must call set_ptr() with a " | ||
"non-NULL value first.") | ||
return get_array( | ||
<cugraph_type_erased_device_array_view_t*>cugraph_coo_get_destinations(self.c_sample_result_ptr) | ||
) | ||
|
||
def get_edge_ids(self): | ||
if self.c_coo_ptr is NULL: | ||
raise ValueError("pointer not set, must call set_ptr() with a " | ||
"non-NULL value first.") | ||
return get_array( | ||
<cugraph_type_erased_device_array_view_t*>cugraph_coo_get_edge_id(self.c_sample_result_ptr) | ||
) | ||
|
||
def get_edge_types(self): | ||
if self.c_coo_ptr is NULL: | ||
raise ValueError("pointer not set, must call set_ptr() with a " | ||
"non-NULL value first.") | ||
return get_array( | ||
<cugraph_type_erased_device_array_view_t*>cugraph_coo_get_edge_type(self.c_sample_result_ptr) | ||
) | ||
|
||
def get_edge_weights(self): | ||
if self.c_coo_ptr is NULL: | ||
raise ValueError("pointer not set, must call set_ptr() with a " | ||
"non-NULL value first.") | ||
return get_array( | ||
<cugraph_type_erased_device_array_view_t*>cugraph_coo_get_edge_weights(self.c_sample_result_ptr) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# Have cython use python 3 syntax | ||
# cython: language_level = 3 | ||
|
||
from libc.stdint cimport uintptr_t | ||
|
||
from pylibcugraph._cugraph_c.resource_handle cimport ( | ||
cugraph_resource_handle_t, | ||
bool_t, | ||
) | ||
from pylibcugraph._cugraph_c.error cimport ( | ||
cugraph_error_code_t, | ||
cugraph_error_t, | ||
) | ||
from pylibcugraph._cugraph_c.array cimport ( | ||
cugraph_type_erased_device_array_view_t, | ||
cugraph_type_erased_device_array_view_create, | ||
cugraph_type_erased_device_array_view_free, | ||
cugraph_type_erased_host_array_view_t, | ||
cugraph_type_erased_host_array_view_create, | ||
cugraph_type_erased_host_array_view_free, | ||
) | ||
from pylibcugraph._cugraph_c.graph cimport ( | ||
cugraph_graph_t, | ||
) | ||
from pylibcugraph._cugraph_c.sampling_algorithms cimport ( | ||
cugraph_negative_sampling, | ||
) | ||
from pylibcugraph._cugraph_c.coo cimport ( | ||
cugraph_coo_t, | ||
) | ||
from pylibcugraph.internal_types.coo cimport ( | ||
COO, | ||
) | ||
|
||
def negative_sampling(ResourceHandle resource_handle, | ||
_GPUGraph graph, | ||
size_t num_samples, | ||
random_state=None, | ||
vertices=None, | ||
src_bias=None, | ||
dst_bias=None, | ||
remove_duplicates=False, | ||
remove_false_negatives=False, | ||
exact_number_of_samples=False, | ||
do_expensive_check=False): | ||
""" | ||
Performs negative sampling, which is essentially a form of graph generation. | ||
By setting vertices, src_bias, and dst_bias, this function can perform | ||
biased negative sampling. | ||
Parameters | ||
---------- | ||
resource_handle: ResourceHandle | ||
Handle to the underlying device and host resources needed for | ||
referencing data and running algorithms. | ||
input_graph: SGGraph or MGGraph | ||
The stored cuGraph graph to create negative samples for. | ||
num_samples: int | ||
The number of negative edges to generate for each positive edge. | ||
random_state: int (Optional) | ||
Random state to use when generating samples. Optional argument, | ||
defaults to a hash of process id, time, and hostname. | ||
(See pylibcugraph.random.CuGraphRandomState) | ||
vertices: device array type (Optional) | ||
Vertex ids corresponding to the src/dst biases, if provided. | ||
Ignored if src/dst biases are not provided. | ||
src_bias: device array type (Optional) | ||
Probability per edge that a vertex is selected as a source vertex. | ||
Does not have to be normalized. Uses a uniform distribution if | ||
not provided. | ||
dst_bias: device array type (Optional) | ||
Probability per edge that a vertex is selected as a destination vertex. | ||
Does not have to be normalized. Uses a uniform distribution if | ||
not provided. | ||
remove_duplicates: bool (Optional) | ||
Whether to remove duplicate edges from the generated edgelist. | ||
Defaults to False (does not remove duplicates). | ||
remove_false_negatives: bool (Optional) | ||
Whether to remove false negatives from the generated edgelist. | ||
Defaults to False (does not check for and remove false negatives). | ||
exact_number_of_samples: bool (Optional) | ||
Whether to manually regenerate samples until the desired number | ||
as specified by num_samples has been generated. | ||
Defaults to False (does not regenerate if enough samples are not | ||
produced in the initial round). | ||
do_expensive_check: bool (Optional) | ||
Whether to perform an expensive error check at the C++ level. | ||
Defaults to False (no error check). | ||
Returns | ||
------- | ||
dict[str, cupy.ndarray] | ||
Generated edges in COO format. | ||
""" | ||
|
||
assert_CAI_type(vertices, "vertices", True) | ||
assert_CAI_type(src_bias, "src_bias", True) | ||
assert_CAI_type(dst_bias, "dst_bias", True) | ||
|
||
cdef cugraph_resource_handle_t* c_resource_handle_ptr = ( | ||
resource_handle.c_resource_handle_ptr | ||
) | ||
|
||
cdef cugraph_graph_t* c_graph_ptr = input_graph.c_graph_ptr | ||
|
||
cdef bool_t c_remove_duplicates = remove_duplicates | ||
cdef bool_t c_remove_false_negatives = remove_false_negatives | ||
cdef bool_t c_exact_number_of_samples = exact_number_of_samples | ||
cdef bool_t c_do_expensive_check = do_expensive_check | ||
|
||
cg_rng_state = CuGraphRandomState(resource_handle, random_state) | ||
|
||
cdef cugraph_rng_state_t* rng_state_ptr = \ | ||
cg_rng_state.rng_state_ptr | ||
|
||
cdef cugraph_type_erased_device_array_view_t* vertices_ptr = \ | ||
create_cugraph_type_erased_device_array_view_from_py_obj(vertices) | ||
cdef cugraph_type_erased_device_array_view_t* src_bias_ptr = \ | ||
create_cugraph_type_erased_device_array_view_from_py_obj(src_bias) | ||
cdef cugraph_type_erased_device_array_view_t* dst_bias_ptr = \ | ||
create_cugraph_type_erased_device_array_view_from_py_obj(dst_bias) | ||
|
||
cdef cugraph_coo_t* result_ptr | ||
cdef cugraph_error_code_t* err_ptr | ||
|
||
error_code = cugraph_negative_sampling( | ||
c_resource_handle_ptr, | ||
cg_rng_state.rng_state_ptr, | ||
c_graph_ptr, | ||
num_samples, | ||
vertices_ptr, | ||
src_bias_ptr, | ||
dst_bias_ptr, | ||
c_remove_duplicates, | ||
c_remove_false_negatives | ||
c_exact_number_of_samples, | ||
c_do_expensive_check, | ||
&result_ptr, | ||
&err_ptr, | ||
) | ||
assert_success(error_code, error_ptr, "cugraph_negative_sampling") | ||
|
||
coo = COO() | ||
coo.set_ptr(result_ptr) | ||
|
||
return { | ||
'sources': coo.get_sources(), | ||
'destinations': coo.get_destinations(), | ||
'edge_id': coo.get_edge_ids(), | ||
'edge_type': coo.get_edge_types(), | ||
'weight': coo.get_edge_weights(), | ||
} |