Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
Xiang Song committed Jun 7, 2024
1 parent 1f907ca commit 19f2bd9
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
6 changes: 6 additions & 0 deletions python/graphstorm/dataloading/dataloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -1279,6 +1279,12 @@ def __next__(self):
# return pos, neg pairs
return cur_iter, self._neg_sample_type

def __len__(self):
num_samples = 0
for _, test_size in self._fixed_test_size.items():
num_samples += math.ceil(test_size / self._batch_size)
return num_samples

@property
def fanout(self):
""" Get eval fanout
Expand Down
8 changes: 6 additions & 2 deletions tests/unit-tests/test_dataloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,7 +765,9 @@ def test_GSgnnLinkPredictionTestDataLoader(batch_size, num_negative_edges):

total_edges = {etype: len(train_idxs[etype]) for etype in test_etypes}
num_pos_edges = {etype: 0 for etype in test_etypes}
num_samples = 0
for pos_neg_tuple, sample_type in dataloader:
num_samples += 1
assert sample_type == BUILTIN_LP_UNIFORM_NEG_SAMPLER
assert isinstance(pos_neg_tuple, dict)
assert len(pos_neg_tuple) == 1
Expand Down Expand Up @@ -793,6 +795,7 @@ def test_GSgnnLinkPredictionTestDataLoader(batch_size, num_negative_edges):
assert neg_src.shape[0] == pos_src.shape[0]
assert neg_src.shape[1] == num_negative_edges
assert th.all(neg_src < g.number_of_nodes(canonical_etype[0]))
assert len(dataloader) == num_samples

# The target idx size for ("n0", "r1", "n1") is 2
# The target idx size for ("n0", "r0", "n1") is 50
Expand Down Expand Up @@ -822,6 +825,7 @@ def test_GSgnnLinkPredictionTestDataLoader(batch_size, num_negative_edges):
expected_pos_pairs += expected_idx_len

assert num_samples == expected_samples
assert len(dataloader) == num_samples
assert num_pos_pairs == expected_pos_pairs

# after test pass, destroy all process group
Expand Down Expand Up @@ -2396,6 +2400,8 @@ def test_GSgnnMultiTaskDataLoader():
assert np.any(edge0_seeds_cnt.numpy() >= 0)

if __name__ == '__main__':
test_GSgnnLinkPredictionTestDataLoader(1, 1)
test_GSgnnLinkPredictionTestDataLoader(10, 20)
test_GSgnnMultiTaskDataLoader()
test_GSgnnLinkPredictionPredefinedTestDataLoader(1)
test_GSgnnLinkPredictionPredefinedTestDataLoader(10)
Expand All @@ -2419,8 +2425,6 @@ def test_GSgnnMultiTaskDataLoader():
test_node_dataloader_reconstruct()
test_GSgnnAllEtypeLinkPredictionDataLoader(10)
test_GSgnnAllEtypeLinkPredictionDataLoader(1)
test_GSgnnLinkPredictionTestDataLoader(1, 1)
test_GSgnnLinkPredictionTestDataLoader(10, 20)
test_GSgnnLinkPredictionJointTestDataLoader(1, 1)
test_GSgnnLinkPredictionJointTestDataLoader(10, 20)

Expand Down

0 comments on commit 19f2bd9

Please sign in to comment.