Skip to content

Commit

Permalink
add test for drop_last
Browse files Browse the repository at this point in the history
  • Loading branch information
alexbarghi-nv committed Nov 12, 2023
1 parent a6c43ae commit 9218d9c
Showing 1 changed file with 32 additions and 24 deletions.
56 changes: 32 additions & 24 deletions python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,35 +456,43 @@ def test_cugraph_loader_e2e_csc(framework: str):


@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available")
@pytest.mark.parametrize('shuffle', [True, False])
def test_shuffle(shuffle):
N = {'N': 10}
@pytest.mark.parametrize("drop_last", [True, False])
def test_drop_last(drop_last):
N = {"N": 10}
G = {
('N', 'e', 'N'): torch.stack([
torch.tensor([0, 1, 2, 3, 4]),
torch.tensor([5, 6, 7, 8, 9])
])
("N", "e", "N"): torch.stack(
[torch.tensor([0, 1, 2, 3, 4]), torch.tensor([5, 6, 7, 8, 9])]
)
}
F = FeatureStore(backend='torch')
F.add_data(torch.arange(10), 'N', 'z')
F = FeatureStore(backend="torch")
F.add_data(torch.arange(10), "N", "z")

store = CuGraphStore(F, G, N)
loader = CuGraphNeighborLoader(
(store, store),
input_nodes=torch.tensor([0, 1, 2, 3, 4]),
num_neighbors=[1],
batch_size=1,
shuffle=shuffle,
batches_per_partition=1,
)
with tempfile.TemporaryDirectory() as dir:
loader = CuGraphNeighborLoader(
(store, store),
input_nodes=torch.tensor([0, 1, 2, 3, 4]),
num_neighbors=[1],
batch_size=2,
shuffle=False,
drop_last=drop_last,
batches_per_partition=1,
directory=dir,
)

t = torch.tensor([])
for batch in loader:
t = torch.concat([t, batch["N"].z])

t = t.tolist()

files = os.listdir(dir)
assert len(files) == 2 if drop_last else 3
assert "batch=0-0.parquet" in files
assert "batch=1-1.parquet" in files
if not drop_last:
assert "batch=2-2.parquet" in files

t = torch.tensor([])
for batch in loader:
t = torch.concat([t, batch['N'].z])

t = t.tolist()
if not shuffle:
assert sorted(t) == t

@pytest.mark.parametrize("directory", ["local", "temp"])
def test_load_directory(
Expand Down

0 comments on commit 9218d9c

Please sign in to comment.