Skip to content

Commit

Permalink
update view.py
Browse files Browse the repository at this point in the history
  • Loading branch information
alexbarghi-nv committed Oct 22, 2024
1 parent b86bdd3 commit 7897f2f
Showing 1 changed file with 43 additions and 0 deletions.
43 changes: 43 additions & 0 deletions python/cugraph-dgl/cugraph_dgl/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings

from collections import defaultdict
from collections.abc import MutableMapping
Expand All @@ -20,11 +21,53 @@

import cugraph_dgl
from cugraph_dgl.typing import TensorType
from cugraph_dgl.utils.cugraph_conversion_utils import _cast_to_torch_tensor

torch = import_optional("torch")
dgl = import_optional("dgl")


class EmbeddingView:
def __init__(self, storage: "dgl.storages.base.FeatureStorage", ld: int):
self.__ld = ld
self.__storage = storage

def __getitem__(self, u: TensorType) -> "torch.Tensor":
u = _cast_to_torch_tensor(u)
try:
return self.__storage.fetch(
u,
"cuda",
)
except RuntimeError as ex:
warnings.warn(
"Got error accessing data, trying again with index on device: "
+ str(ex)
)
return self.__storage.fetch(
u.cuda(),
"cuda",
)

def __call__(self):
warnings.warn(
"Getting an entire embedding tensor is not recommended "
" as it wastes memory. Consider indexing to get only the "
"required elements of the embedding tensor."
)
return self[torch.arange(self.__ld, dtype=torch.int64)]

@property
def shape(self) -> "torch.Size":
try:
f = self.__storage.fetch(torch.tensor([0]), "cpu")
except RuntimeError:
f = self.__storage.fetch(torch.tensor([0], device="cuda"), "cuda")
sz = [s for s in f.shape]
sz[0] = self.__ld
return torch.Size(tuple(sz))


class HeteroEdgeDataView(MutableMapping):
"""
Duck-typed version of DGL's HeteroEdgeDataView.
Expand Down

0 comments on commit 7897f2f

Please sign in to comment.