Skip to content

Commit

Permalink
fix bug, write test
Browse files Browse the repository at this point in the history
  • Loading branch information
alexbarghi-nv committed Aug 26, 2024
1 parent a22ce90 commit 55d8ad1
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 5 deletions.
6 changes: 4 additions & 2 deletions python/cugraph/cugraph/gnn/data_loading/bulk_sampler_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,12 @@ def create_df_from_disjoint_series(series_list: List[cudf.Series]):


def create_df_from_disjoint_arrays(array_dict: Dict[str, cupy.array]):
series_dict = {}
for k in list(array_dict.keys()):
array_dict[k] = cudf.Series(array_dict[k], name=k)
if array_dict[k] is not None:
series_dict[k] = cudf.Series(array_dict[k], name=k)

return create_df_from_disjoint_series(list(array_dict.values()))
return create_df_from_disjoint_series(list(series_dict.values()))


def _write_samples_to_parquet_csr(
Expand Down
3 changes: 1 addition & 2 deletions python/cugraph/cugraph/gnn/data_loading/dist_io/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ def __write_minibatches_coo(self, minibatch_dict):
fanout_length = (len(minibatch_dict["label_hop_offsets"]) - 1) // len(
minibatch_dict["batch_id"]
)
rank_batch_offset = minibatch_dict["batch_id"][0]

for p in range(
0, int(ceil(len(minibatch_dict["batch_id"]) / self.__batches_per_partition))
Expand All @@ -108,7 +107,7 @@ def __write_minibatches_coo(self, minibatch_dict):
]

batch_id_array_p = minibatch_dict["batch_id"][partition_start:partition_end]
start_batch_id = batch_id_array_p[0] - rank_batch_offset
start_batch_id = batch_id_array_p[0]

start_ix, end_ix = label_hop_offsets_array_p[[0, -1]]
majors_array_p = minibatch_dict["majors"][start_ix:end_ix]
Expand Down
4 changes: 4 additions & 0 deletions python/cugraph/cugraph/gnn/data_loading/dist_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,10 @@ def __sample_from_nodes_func(
)

if self.__writer is None:
# rename renumber_map -> map to match unbuffered format
minibatch_dict["map"] = minibatch_dict["renumber_map"]
del minibatch_dict["renumber_map"]

return iter([(minibatch_dict, current_batches[0], current_batches[-1])])
else:
self.__writer.write_minibatches(minibatch_dict)
Expand Down
60 changes: 59 additions & 1 deletion python/cugraph/cugraph/tests/sampling/test_dist_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from cugraph.datasets import karate
from cugraph.gnn import UniformNeighborSampler, DistSampleWriter
from cugraph.gnn.data_loading.bulk_sampler_io import create_df_from_disjoint_arrays

from pylibcugraph import SGGraph, ResourceHandle, GraphProperties

Expand All @@ -41,7 +42,7 @@


@pytest.fixture
def karate_graph():
def karate_graph() -> SGGraph:
el = karate.get_edgelist().reset_index().rename(columns={"index": "eid"})
G = SGGraph(
ResourceHandle(),
Expand Down Expand Up @@ -101,3 +102,60 @@ def test_dist_sampler_simple(
assert original_el.dst.iloc[edge_id.iloc[i]] == dst.iloc[i]

shutil.rmtree(samples_path)


@pytest.mark.sg
@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available")
@pytest.mark.parametrize("seeds_per_call", [4, 5, 10])
@pytest.mark.parametrize("compression", ["COO", "CSR"])
def test_dist_sampler_buffered_in_memory(
scratch_dir: str, karate_graph: SGGraph, seeds_per_call: int, compression: str
):
G = karate_graph

samples_path = os.path.join(scratch_dir, "test_bulk_sampler_buffered_in_memory")
create_directory_with_overwrite(samples_path)

seeds = cupy.arange(10, dtype="int64")

unbuffered_sampler = UniformNeighborSampler(
G,
writer=DistSampleWriter(samples_path),
local_seeds_per_call=seeds_per_call,
compression=compression,
)

buffered_sampler = UniformNeighborSampler(
G,
writer=None,
local_seeds_per_call=seeds_per_call,
compression=compression,
)

unbuffered_results = unbuffered_sampler.sample_from_nodes(
seeds,
batch_size=4,
)

unbuffered_results = [
(create_df_from_disjoint_arrays(r[0]), r[1], r[2]) for r in unbuffered_results
]

buffered_results = buffered_sampler.sample_from_nodes(seeds, batch_size=4)
buffered_results = [
(create_df_from_disjoint_arrays(r[0]), r[1], r[2]) for r in buffered_results
]

assert len(buffered_results) == len(unbuffered_results)

for k in range(len(buffered_results)):
br, bs, be = buffered_results[k]
ur, us, ue = unbuffered_results[k]

assert bs == us
assert be == ue

for col in ur.columns:
assert (br[col].dropna() == ur[col].dropna()).all()

shutil.rmtree(samples_path)

0 comments on commit 55d8ad1

Please sign in to comment.