Skip to content

Commit

Permalink
wholegraph test
Browse files Browse the repository at this point in the history
  • Loading branch information
alexbarghi-nv committed Sep 25, 2023
1 parent e17b587 commit b64326d
Showing 1 changed file with 37 additions and 0 deletions.
37 changes: 37 additions & 0 deletions python/cugraph/cugraph/tests/data_store/test_gnn_feat_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,40 @@ def test_feature_storage_pytorch_backend():
expected = ar3[indices_to_fetch]
assert isinstance(output_fs, torch.Tensor)
np.testing.assert_array_equal(output_fs.numpy(), expected)


@pytest.mark.sg
def test_feature_storage_wholegraph_backend():
try:
import torch
except ModuleNotFoundError:
pytest.skip('pytorch not installed')

fs = FeatureStore(backend='wholegraph')

ar1 = np.random.randint(low=0, high=100, size=100_000)
ar2 = np.random.randint(low=0, high=100, size=100_000)
ar3 = np.random.randint(low=0, high=100, size=100_000).reshape(-1, 10)
fs = FeatureStore(backend="torch")
fs.add_data(ar1, "type1", "feat1")
fs.add_data(ar2, "type1", "feat2")
fs.add_data(ar3, "type2", "feat1")


indices_to_fetch = np.random.randint(low=0, high=len(ar1), size=1024)
output_fs = fs.get_data(indices_to_fetch, type_name="type1", feat_name="feat1")
expected = ar1[indices_to_fetch]
assert isinstance(output_fs, torch.Tensor)
np.testing.assert_array_equal(output_fs.numpy(), expected)

indices_to_fetch = np.random.randint(low=0, high=len(ar2), size=1024)
output_fs = fs.get_data(indices_to_fetch, type_name="type1", feat_name="feat2")
expected = ar2[indices_to_fetch]
assert isinstance(output_fs, torch.Tensor)
np.testing.assert_array_equal(output_fs.numpy(), expected)

indices_to_fetch = np.random.randint(low=0, high=len(ar3), size=1024)
output_fs = fs.get_data(indices_to_fetch, type_name="type2", feat_name="feat1")
expected = ar3[indices_to_fetch]
assert isinstance(output_fs, torch.Tensor)
np.testing.assert_array_equal(output_fs.numpy(), expected)

0 comments on commit b64326d

Please sign in to comment.