Skip to content

Commit

Permalink
Support drop_last Argument in cuGraph-PyG Loader (#3995)
Browse files Browse the repository at this point in the history
Supports the `drop_last` argument in cuGraph-PyG for better compatibility with native PyG workflows.

Closes #3949 

Merge after #3985

Authors:
  - Alex Barghi (https://github.com/alexbarghi-nv)
  - Naim (https://github.com/naimnv)

Approvers:
  - Tingyu Wang (https://github.com/tingyu66)
  - Brad Rees (https://github.com/BradReesWork)
  - Vibhu Jawa (https://github.com/VibhuJawa)

URL: #3995
  • Loading branch information
alexbarghi-nv authored Nov 20, 2023
1 parent 6e765bb commit f3eecda
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 15 deletions.
37 changes: 22 additions & 15 deletions python/cugraph-pyg/cugraph_pyg/loader/cugraph_node_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ def __init__(
graph_store: CuGraphStore,
input_nodes: InputNodes = None,
batch_size: int = 0,
*,
shuffle: bool = False,
drop_last: bool = True,
edge_types: Sequence[Tuple[str]] = None,
directory: Union[str, tempfile.TemporaryDirectory] = None,
input_files: List[str] = None,
Expand Down Expand Up @@ -209,26 +211,31 @@ def __init__(

# Truncate if we can't evenly divide the input array
stop = (len(input_nodes) // batch_size) * batch_size
input_nodes = input_nodes[:stop]
input_nodes, remainder = cupy.array_split(input_nodes, [stop])

# Split into batches
input_nodes = cupy.split(input_nodes, len(input_nodes) // batch_size)
input_nodes = cupy.split(input_nodes, max(len(input_nodes) // batch_size, 1))

if not drop_last:
input_nodes.append(remainder)

self.__num_batches = 0
for batch_num, batch_i in enumerate(input_nodes):
self.__num_batches += 1
bulk_sampler.add_batches(
cudf.DataFrame(
{
"start": batch_i,
"batch": cupy.full(
batch_size, batch_num + starting_batch_id, dtype="int32"
),
}
),
start_col_name="start",
batch_col_name="batch",
)
batch_len = len(batch_i)
if batch_len > 0:
self.__num_batches += 1
bulk_sampler.add_batches(
cudf.DataFrame(
{
"start": batch_i,
"batch": cupy.full(
batch_len, batch_num + starting_batch_id, dtype="int32"
),
}
),
start_col_name="start",
batch_col_name="batch",
)

bulk_sampler.flush()
self.__input_files = iter(
Expand Down
38 changes: 38 additions & 0 deletions python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,44 @@ def test_cugraph_loader_e2e_csc(framework: str):


@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available")
@pytest.mark.parametrize("drop_last", [True, False])
def test_drop_last(drop_last):
N = {"N": 10}
G = {
("N", "e", "N"): torch.stack(
[torch.tensor([0, 1, 2, 3, 4]), torch.tensor([5, 6, 7, 8, 9])]
)
}
F = FeatureStore(backend="torch")
F.add_data(torch.arange(10), "N", "z")

store = CuGraphStore(F, G, N)
with tempfile.TemporaryDirectory() as dir:
loader = CuGraphNeighborLoader(
(store, store),
input_nodes=torch.tensor([0, 1, 2, 3, 4]),
num_neighbors=[1],
batch_size=2,
shuffle=False,
drop_last=drop_last,
batches_per_partition=1,
directory=dir,
)

t = torch.tensor([])
for batch in loader:
t = torch.concat([t, batch["N"].z])

t = t.tolist()

files = os.listdir(dir)
assert len(files) == 2 if drop_last else 3
assert "batch=0-0.parquet" in files
assert "batch=1-1.parquet" in files
if not drop_last:
assert "batch=2-2.parquet" in files


@pytest.mark.parametrize("directory", ["local", "temp"])
def test_load_directory(
karate_gnn: Tuple[
Expand Down

0 comments on commit f3eecda

Please sign in to comment.