diff --git a/python/cugraph-pyg/cugraph_pyg/loader/cugraph_node_loader.py b/python/cugraph-pyg/cugraph_pyg/loader/cugraph_node_loader.py index ad8d22e255e..200a82b460b 100644 --- a/python/cugraph-pyg/cugraph_pyg/loader/cugraph_node_loader.py +++ b/python/cugraph-pyg/cugraph_pyg/loader/cugraph_node_loader.py @@ -52,7 +52,9 @@ def __init__( graph_store: CuGraphStore, input_nodes: InputNodes = None, batch_size: int = 0, + *, shuffle: bool = False, + drop_last: bool = True, edge_types: Sequence[Tuple[str]] = None, directory: Union[str, tempfile.TemporaryDirectory] = None, input_files: List[str] = None, @@ -209,26 +211,31 @@ def __init__( # Truncate if we can't evenly divide the input array stop = (len(input_nodes) // batch_size) * batch_size - input_nodes = input_nodes[:stop] + input_nodes, remainder = cupy.array_split(input_nodes, [stop]) # Split into batches - input_nodes = cupy.split(input_nodes, len(input_nodes) // batch_size) + input_nodes = cupy.split(input_nodes, max(len(input_nodes) // batch_size, 1)) + + if not drop_last: + input_nodes.append(remainder) self.__num_batches = 0 for batch_num, batch_i in enumerate(input_nodes): - self.__num_batches += 1 - bulk_sampler.add_batches( - cudf.DataFrame( - { - "start": batch_i, - "batch": cupy.full( - batch_size, batch_num + starting_batch_id, dtype="int32" - ), - } - ), - start_col_name="start", - batch_col_name="batch", - ) + batch_len = len(batch_i) + if batch_len > 0: + self.__num_batches += 1 + bulk_sampler.add_batches( + cudf.DataFrame( + { + "start": batch_i, + "batch": cupy.full( + batch_len, batch_num + starting_batch_id, dtype="int32" + ), + } + ), + start_col_name="start", + batch_col_name="batch", + ) bulk_sampler.flush() self.__input_files = iter( diff --git a/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_loader.py b/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_loader.py index 27b73bf7d35..9813fa933ee 100644 --- a/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_loader.py +++ b/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_loader.py @@ -456,6 +456,44 @@ def test_cugraph_loader_e2e_csc(framework: str): @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.parametrize("drop_last", [True, False]) +def test_drop_last(drop_last): + N = {"N": 10} + G = { + ("N", "e", "N"): torch.stack( + [torch.tensor([0, 1, 2, 3, 4]), torch.tensor([5, 6, 7, 8, 9])] + ) + } + F = FeatureStore(backend="torch") + F.add_data(torch.arange(10), "N", "z") + + store = CuGraphStore(F, G, N) + with tempfile.TemporaryDirectory() as dir: + loader = CuGraphNeighborLoader( + (store, store), + input_nodes=torch.tensor([0, 1, 2, 3, 4]), + num_neighbors=[1], + batch_size=2, + shuffle=False, + drop_last=drop_last, + batches_per_partition=1, + directory=dir, + ) + + t = torch.tensor([]) + for batch in loader: + t = torch.concat([t, batch["N"].z]) + + t = t.tolist() + + files = os.listdir(dir) + assert len(files) == 2 if drop_last else 3 + assert "batch=0-0.parquet" in files + assert "batch=1-1.parquet" in files + if not drop_last: + assert "batch=2-2.parquet" in files + + @pytest.mark.parametrize("directory", ["local", "temp"]) def test_load_directory( karate_gnn: Tuple[