Skip to content

Commit

Permalink
fix mg bulk sampler tests
Browse files Browse the repository at this point in the history
  • Loading branch information
alexbarghi-nv committed Sep 21, 2023
1 parent 17e9013 commit c5543b2
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
4 changes: 0 additions & 4 deletions python/cugraph/cugraph/gnn/data_loading/bulk_sampler_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ def _write_samples_to_parquet_coo(
if partition_info != "sg" and (not isinstance(partition_info, dict)):
raise ValueError("Invalid value of partition_info")

print('offsets:', offsets)
offsets = offsets[:-1]

# Offsets is always in order, so the last batch id is always the highest
Expand All @@ -73,9 +72,6 @@ def _write_samples_to_parquet_coo(
start_batch_id = offsets_p.batch_id.iloc[0]
end_batch_id = offsets_p.batch_id.iloc[len(offsets_p) - 1]

print('partition:', start_batch_id, end_batch_id)
print('max batch id:', max_batch_id)

reached_end = end_batch_id == max_batch_id

start_ix = offsets_p.offsets.iloc[0]
Expand Down
12 changes: 8 additions & 4 deletions python/cugraph/cugraph/tests/sampling/test_bulk_sampler_io_mg.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,12 @@ def test_bulk_sampler_io(scratch_dir):
divisions=[0, 8, 11]
)

offsets = cudf.DataFrame({"offsets": [0, 0], "batch_id": [0, 1]})
offsets = dask_cudf.from_cudf(offsets, npartitions=2)
assert len(results) == 12

offsets = cudf.DataFrame({"offsets": [0, 8, 0, 4], "batch_id": [0, None, 1, None]})
offsets = dask_cudf.from_cudf(offsets, npartitions=1).repartition(
divisions=[0, 2, 3]
)

samples_path = os.path.join(scratch_dir, "mg_test_bulk_sampler_io")
create_directory_with_overwrite(samples_path)
Expand Down Expand Up @@ -149,9 +153,9 @@ def test_bulk_sampler_io_empty_batch(scratch_dir):
)

# some batches are missing
offsets = cudf.DataFrame({"offsets": [0, 8, 0, 4], "batch_id": [0, 3, 4, 10]})
offsets = cudf.DataFrame({"offsets": [0, 8, 12, 0, 4, 8], "batch_id": [0, 3, None, 4, 10, None]})
offsets = dask_cudf.from_cudf(offsets, npartitions=1).repartition(
divisions=[0, 2, 3]
divisions=[0, 3, 5]
)

samples_path = os.path.join(scratch_dir, "mg_test_bulk_sampler_io_empty_batch")
Expand Down

0 comments on commit c5543b2

Please sign in to comment.