Skip to content

Commit

Permalink
revert none return, add check
Browse files Browse the repository at this point in the history
  • Loading branch information
alexbarghi-nv committed Jul 10, 2024
1 parent 4b60d8d commit e1fa6e0
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
cugraph_comms_shutdown,
)

from utils import init_pytorch_worker
from cugraph_dgl.tests.utils import init_pytorch_worker

torch = import_optional("torch")
dgl = import_optional("dgl")
Expand Down
2 changes: 2 additions & 0 deletions python/cugraph-dgl/cugraph_dgl/tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def test_graph_make_homogeneous_graph(direction):
assert (
graph.nodes() == torch.arange(num_nodes, dtype=torch.int64, device="cuda")
).all()

assert graph.nodes[None]["x"] is not None
assert (graph.nodes[None]["x"] == torch.as_tensor(node_x, device="cuda")).all()
assert (
graph.nodes[None]["num"]
Expand Down
3 changes: 2 additions & 1 deletion python/cugraph-dgl/cugraph_dgl/tests/test_graph_mg.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
cugraph_comms_get_raft_handle,
)

from utils import init_pytorch_worker
from .utils import init_pytorch_worker

pylibwholegraph = import_optional("pylibwholegraph")
torch = import_optional("torch")
Expand Down Expand Up @@ -75,6 +75,7 @@ def run_test_graph_make_homogeneous_graph_mg(rank, uid, world_size, direction):
== torch.arange(global_num_nodes, dtype=torch.int64, device="cuda")
).all()
ix = torch.arange(len(node_x) * rank, len(node_x) * (rank + 1), dtype=torch.int64)
assert graph.nodes[ix]["x"] is not None
assert (graph.nodes[ix]["x"] == torch.as_tensor(node_x, device="cuda")).all()

assert (
Expand Down
12 changes: 2 additions & 10 deletions python/cugraph-dgl/cugraph_dgl/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,7 @@ def __getitem__(self, key: str):
if self._graph._has_e_emb(t, key)
}

return (
self._graph._get_e_emb(self._etype, key, self._edges)
if self._graph._has_e_emb(self._etype, key)
else None
)
return self._graph._get_e_emb(self._etype, key, self._edges)

def __setitem__(self, key: str, val: Union[TensorType, Dict[str, TensorType]]):
if isinstance(self._etype, list):
Expand Down Expand Up @@ -166,11 +162,7 @@ def __getitem__(self, key: str):
if self._graph._has_n_emb(t, key)
}
else:
return (
self._graph._get_n_emb(self._ntype, key, self._nodes)
if self._graph._has_n_emb(self._ntype, key)
else None
)
return self._graph._get_n_emb(self._ntype, key, self._nodes)

def __setitem__(self, key: str, val: Union[TensorType, Dict[str, TensorType]]):
if isinstance(self._ntype, list):
Expand Down

0 comments on commit e1fa6e0

Please sign in to comment.