Skip to content

Commit

Permalink
fix copy bug
Browse files Browse the repository at this point in the history
  • Loading branch information
alexbarghi-nv committed Apr 18, 2024
1 parent ba6ee3b commit c01b548
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
7 changes: 4 additions & 3 deletions cpp/src/c_api/uniform_neighbor_sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@

#include <raft/core/handle.hpp>

#include <iostream>

namespace cugraph {
namespace c_api {

Expand Down Expand Up @@ -157,6 +155,10 @@ struct uniform_neighbor_sampling_functor : public cugraph::c_api::abstract_funct
auto number_map = reinterpret_cast<rmm::device_uvector<vertex_t>*>(graph_->number_map_);

rmm::device_uvector<vertex_t> start_vertices(start_vertices_->size_, handle_.get_stream());
raft::copy(start_vertices.data(),
start_vertices_->as_type<vertex_t>(),
start_vertices.size(),
handle_.get_stream());

std::optional<rmm::device_uvector<label_t>> start_vertex_labels{std::nullopt};

Expand Down Expand Up @@ -252,7 +254,6 @@ struct uniform_neighbor_sampling_functor : public cugraph::c_api::abstract_funct
if (options_.renumber_results_) {
if (options_.compression_type_ == cugraph_compression_type_t::COO) {
// COO
std::cout << "retain seeds? " << options_.retain_seeds_ << std::endl;

rmm::device_uvector<vertex_t> output_majors(0, handle_.get_stream());
rmm::device_uvector<vertex_t> output_renumber_map(0, handle_.get_stream());
Expand Down
2 changes: 2 additions & 0 deletions python/pylibcugraph/pylibcugraph/uniform_neighbor_sample.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,8 @@ def uniform_neighbor_sample(ResourceHandle resource_handle,
cugraph_type_erased_host_array_view_free(fan_out_ptr)
if batch_id_list is not None:
cugraph_type_erased_device_array_view_free(batch_id_ptr)
if label_offsets is not None:
cugraph_type_erased_device_array_view_free(label_offsets_ptr)

# Have the SamplingResult instance assume ownership of the result data.
result = SamplingResult()
Expand Down

0 comments on commit c01b548

Please sign in to comment.