From 9218d9c148670351871e5dddbbbeec1efd33bb0b Mon Sep 17 00:00:00 2001 From: Alexandria Barghi Date: Sun, 12 Nov 2023 10:19:45 -0800 Subject: [PATCH] add test for drop_last --- .../cugraph_pyg/tests/test_cugraph_loader.py | 56 +++++++++++-------- 1 file changed, 32 insertions(+), 24 deletions(-) diff --git a/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_loader.py b/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_loader.py index d0e8907157c..9813fa933ee 100644 --- a/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_loader.py +++ b/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_loader.py @@ -456,35 +456,43 @@ def test_cugraph_loader_e2e_csc(framework: str): @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") -@pytest.mark.parametrize('shuffle', [True, False]) -def test_shuffle(shuffle): - N = {'N': 10} +@pytest.mark.parametrize("drop_last", [True, False]) +def test_drop_last(drop_last): + N = {"N": 10} G = { - ('N', 'e', 'N'): torch.stack([ - torch.tensor([0, 1, 2, 3, 4]), - torch.tensor([5, 6, 7, 8, 9]) - ]) + ("N", "e", "N"): torch.stack( + [torch.tensor([0, 1, 2, 3, 4]), torch.tensor([5, 6, 7, 8, 9])] + ) } - F = FeatureStore(backend='torch') - F.add_data(torch.arange(10), 'N', 'z') + F = FeatureStore(backend="torch") + F.add_data(torch.arange(10), "N", "z") store = CuGraphStore(F, G, N) - loader = CuGraphNeighborLoader( - (store, store), - input_nodes=torch.tensor([0, 1, 2, 3, 4]), - num_neighbors=[1], - batch_size=1, - shuffle=shuffle, - batches_per_partition=1, - ) + with tempfile.TemporaryDirectory() as dir: + loader = CuGraphNeighborLoader( + (store, store), + input_nodes=torch.tensor([0, 1, 2, 3, 4]), + num_neighbors=[1], + batch_size=2, + shuffle=False, + drop_last=drop_last, + batches_per_partition=1, + directory=dir, + ) + + t = torch.tensor([]) + for batch in loader: + t = torch.concat([t, batch["N"].z]) + + t = t.tolist() + + files = os.listdir(dir) + assert len(files) == 2 if drop_last else 3 + assert "batch=0-0.parquet" in files + assert "batch=1-1.parquet" in files + if not drop_last: + assert "batch=2-2.parquet" in files - t = torch.tensor([]) - for batch in loader: - t = torch.concat([t, batch['N'].z]) - - t = t.tolist() - if not shuffle: - assert sorted(t) == t @pytest.mark.parametrize("directory", ["local", "temp"]) def test_load_directory(