Skip to content

Commit

Permalink
fix bulk sampler tests, re-enable other tests
Browse files Browse the repository at this point in the history
  • Loading branch information
alexbarghi-nv committed Dec 11, 2024
1 parent d934b72 commit b9437ea
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 8 deletions.
6 changes: 3 additions & 3 deletions python/cugraph/cugraph/tests/sampling/test_bulk_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def test_bulk_sampler_empty_batches(scratch_dir):
# [7, 8, 9] have no outgoing edges. The previous implementation returned and
# offsets array omitting seeds with no outgoing edges from the
# edge_label_offsets which is no longer the case
df = cudf.read_parquet(os.path.join(samples_path, "batch=0-2.parquet"))
df = cudf.read_parquet(os.path.join(samples_path, "batch=0-1.parquet"))

assert df[
(df.batch_id == 0) & (df.hop_id == 0)
Expand All @@ -293,12 +293,12 @@ def test_bulk_sampler_empty_batches(scratch_dir):
].destinations.sort_values().values_host.tolist() == [2, 3, 7, 8]

assert df[
(df.batch_id == 2) & (df.hop_id == 0)
(df.batch_id == 1) & (df.hop_id == 0)
].destinations.sort_values().values_host.tolist() == [7, 8]

assert len(df[(df.batch_id == 1) & (df.hop_id == 1)]) == 0

assert df.batch_id.max() == 2
assert df.batch_id.max() == 1

shutil.rmtree(samples_path)

Expand Down
6 changes: 3 additions & 3 deletions python/cugraph/cugraph/tests/sampling/test_bulk_sampler_mg.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def test_bulk_sampler_empty_batches(dask_client, scratch_dir):
# [7, 8, 9] have no outgoing edges. The previous implementation returned and
# offsets array omitting seeds with no outgoing edges from the
# edge_label_offsets which is no longer the case
df = cudf.read_parquet(os.path.join(samples_path, "batch=0-2.parquet"))
df = cudf.read_parquet(os.path.join(samples_path, "batch=0-1.parquet"))

assert df[
(df.batch_id == 0) & (df.hop_id == 0)
Expand All @@ -243,12 +243,12 @@ def test_bulk_sampler_empty_batches(dask_client, scratch_dir):
].destinations.sort_values().values_host.tolist() == [2, 3, 7, 8]

assert df[
(df.batch_id == 2) & (df.hop_id == 0)
(df.batch_id == 1) & (df.hop_id == 0)
].destinations.sort_values().values_host.tolist() == [7, 8]

assert len(df[(df.batch_id == 1) & (df.hop_id == 1)]) == 0

assert df.batch_id.max() == 2
assert df.batch_id.max() == 1

shutil.rmtree(samples_path)

Expand Down
2 changes: 0 additions & 2 deletions python/cugraph/cugraph/tests/sampling/test_dist_sampler_mg.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ def run_test_dist_sampler_simple(
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("seeds_per_rank", [8, 1])
@pytest.mark.parametrize("seeds_per_call", [4, 8])
@pytest.mark.skip("bleh")
@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not installed")
def test_dist_sampler_simple(
scratch_dir, batch_size, seeds_per_rank, fanout, equal_input_size, seeds_per_call
Expand Down Expand Up @@ -304,7 +303,6 @@ def run_test_dist_sampler_buffered_in_memory(
@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available")
@pytest.mark.parametrize("seeds_per_call", [4, 5, 10])
@pytest.mark.parametrize("compression", ["COO", "CSR"])
@pytest.mark.skip(reason="bleh")
def test_dist_sampler_buffered_in_memory(scratch_dir, seeds_per_call, compression):
uid = cugraph_comms_create_unique_id()

Expand Down

0 comments on commit b9437ea

Please sign in to comment.