diff --git a/python/cugraph/cugraph/tests/data_store/test_gnn_feat_storage.py b/python/cugraph/cugraph/tests/data_store/test_gnn_feat_storage.py index 2d1537d11e3..38b701ae8ae 100644 --- a/python/cugraph/cugraph/tests/data_store/test_gnn_feat_storage.py +++ b/python/cugraph/cugraph/tests/data_store/test_gnn_feat_storage.py @@ -107,3 +107,40 @@ def test_feature_storage_pytorch_backend(): expected = ar3[indices_to_fetch] assert isinstance(output_fs, torch.Tensor) np.testing.assert_array_equal(output_fs.numpy(), expected) + + +@pytest.mark.sg +def test_feature_storage_wholegraph_backend(): + try: + import torch + except ModuleNotFoundError: + pytest.skip('pytorch not installed') + + fs = FeatureStore(backend='wholegraph') + + ar1 = np.random.randint(low=0, high=100, size=100_000) + ar2 = np.random.randint(low=0, high=100, size=100_000) + ar3 = np.random.randint(low=0, high=100, size=100_000).reshape(-1, 10) + fs = FeatureStore(backend="torch") + fs.add_data(ar1, "type1", "feat1") + fs.add_data(ar2, "type1", "feat2") + fs.add_data(ar3, "type2", "feat1") + + + indices_to_fetch = np.random.randint(low=0, high=len(ar1), size=1024) + output_fs = fs.get_data(indices_to_fetch, type_name="type1", feat_name="feat1") + expected = ar1[indices_to_fetch] + assert isinstance(output_fs, torch.Tensor) + np.testing.assert_array_equal(output_fs.numpy(), expected) + + indices_to_fetch = np.random.randint(low=0, high=len(ar2), size=1024) + output_fs = fs.get_data(indices_to_fetch, type_name="type1", feat_name="feat2") + expected = ar2[indices_to_fetch] + assert isinstance(output_fs, torch.Tensor) + np.testing.assert_array_equal(output_fs.numpy(), expected) + + indices_to_fetch = np.random.randint(low=0, high=len(ar3), size=1024) + output_fs = fs.get_data(indices_to_fetch, type_name="type2", feat_name="feat1") + expected = ar3[indices_to_fetch] + assert isinstance(output_fs, torch.Tensor) + np.testing.assert_array_equal(output_fs.numpy(), expected) \ No newline at end of file