diff --git a/python/cugraph-dgl/cugraph_dgl/view.py b/python/cugraph-dgl/cugraph_dgl/view.py index dbc53e7..7c4d95f 100644 --- a/python/cugraph-dgl/cugraph_dgl/view.py +++ b/python/cugraph-dgl/cugraph_dgl/view.py @@ -11,6 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from collections import defaultdict from collections.abc import MutableMapping @@ -20,11 +21,53 @@ import cugraph_dgl from cugraph_dgl.typing import TensorType +from cugraph_dgl.utils.cugraph_conversion_utils import _cast_to_torch_tensor torch = import_optional("torch") dgl = import_optional("dgl") +class EmbeddingView: + def __init__(self, storage: "dgl.storages.base.FeatureStorage", ld: int): + self.__ld = ld + self.__storage = storage + + def __getitem__(self, u: TensorType) -> "torch.Tensor": + u = _cast_to_torch_tensor(u) + try: + return self.__storage.fetch( + u, + "cuda", + ) + except RuntimeError as ex: + warnings.warn( + "Got error accessing data, trying again with index on device: " + + str(ex) + ) + return self.__storage.fetch( + u.cuda(), + "cuda", + ) + + def __call__(self): + warnings.warn( + "Getting an entire embedding tensor is not recommended " + " as it wastes memory. Consider indexing to get only the " + "required elements of the embedding tensor." + ) + return self[torch.arange(self.__ld, dtype=torch.int64)] + + @property + def shape(self) -> "torch.Size": + try: + f = self.__storage.fetch(torch.tensor([0]), "cpu") + except RuntimeError: + f = self.__storage.fetch(torch.tensor([0], device="cuda"), "cuda") + sz = [s for s in f.shape] + sz[0] = self.__ld + return torch.Size(tuple(sz)) + + class HeteroEdgeDataView(MutableMapping): """ Duck-typed version of DGL's HeteroEdgeDataView.