Skip to content

Commit

Permalink
Change _N and _E to dgl.distributed.constants DEFAULT_NTYPE and DEFAU…
Browse files Browse the repository at this point in the history
…LT_ETYPE (#626)

*Issue #, if available:*

*Description of changes:*


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.

---------

Co-authored-by: Xiang Song <[email protected]>
  • Loading branch information
classicsong and Xiang Song authored Nov 8, 2023
1 parent 0e37f2c commit f82a89d
Show file tree
Hide file tree
Showing 7 changed files with 43 additions and 31 deletions.
4 changes: 2 additions & 2 deletions python/graphstorm/data/ogbn_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def __init__(self, raw_dir, dataset, edge_pct=1,
self.is_homo = is_homo

if self.is_homo:
self.node_type = '_N'
self.edge_type = '_E'
self.node_type = dgl.distributed.constants.DEFAULT_NTYPE
self.edge_type = dgl.distributed.constants.DEFAULT_ETYPE
else:
self.node_type = 'node'
self.edge_type, self.rev_edge_type = 'interacts', 'rev-interacts'
Expand Down
4 changes: 3 additions & 1 deletion python/graphstorm/gsf.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import torch.nn.functional as F
from dataclasses import dataclass
from dgl.distributed import role
from dgl.distributed.constants import DEFAULT_NTYPE
from dgl.distributed.constants import DEFAULT_ETYPE

from .utils import sys_tracker, get_rank, get_world_size, use_wholegraph
from .config import BUILTIN_TASK_NODE_CLASSIFICATION
Expand Down Expand Up @@ -632,7 +634,7 @@ def check_homo(g):
g: DGLGraph
The graph used in training and testing
"""
if g.ntypes == ['_N'] and g.etypes == ['_E']:
if g.ntypes == [DEFAULT_NTYPE] and g.etypes == [DEFAULT_ETYPE]:
return True
return False

Expand Down
6 changes: 3 additions & 3 deletions python/graphstorm/model/gat_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,11 @@ def forward(self, g, inputs):
----------
g : DGLHeteroGraph
Input graph.
inputs : dict["_N", torch.Tensor]
inputs : dict[DEFAULT_NTYPE, torch.Tensor]
Node feature for each node type.
Returns
-------
dict{"_N", torch.Tensor}
dict{DEFAULT_NTYPE, torch.Tensor}
New node features for each node type.
"""
# add self-loop during computation.
Expand Down Expand Up @@ -201,7 +201,7 @@ def forward(self, blocks, h):
----------
blocks: DGL MFGs
Sampled subgraph in DGL MFG
h: dict["_N", torch.Tensor]
h: dict[DEFAULT_NTYPE, torch.Tensor]
Input node feature for each node type.
"""
for layer, block in zip(self.layers, blocks):
Expand Down
13 changes: 7 additions & 6 deletions python/graphstorm/model/sage_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from torch import nn
import torch.nn.functional as F
import dgl.nn as dglnn
from dgl.distributed.constants import DEFAULT_NTYPE

from .ngnn_mlp import NGNNMLP
from .gnn_encoder_base import GraphConvEncoder
Expand Down Expand Up @@ -124,16 +125,16 @@ def forward(self, g, inputs):
----------
g : DGLHeteroGraph
Input graph.
inputs : dict["_N", torch.Tensor]
inputs : dict[DEFAULT_NTYPE, torch.Tensor]
Node feature for each node type.
Returns
-------
dict{"_N", torch.Tensor}
dict{DEFAULT_NTYPE, torch.Tensor}
New node features for each node type.
"""
g = g.local_var()

inputs = inputs['_N']
inputs = inputs[DEFAULT_NTYPE]
h_conv = self.conv(g, inputs)
if self.norm:
h_conv = self.norm(h_conv)
Expand All @@ -142,7 +143,7 @@ def forward(self, g, inputs):
if self.num_ffn_layers_in_gnn > 0:
h_conv = self.ngnn_mlp(h_conv)

return {'_N': h_conv}
return {DEFAULT_NTYPE: h_conv}


class SAGEEncoder(GraphConvEncoder):
Expand All @@ -168,7 +169,7 @@ class SAGEEncoder(GraphConvEncoder):
Examples:
----------
.. code:: python
# Build model and do full-graph inference on SAGEEncoder
Expand Down Expand Up @@ -228,7 +229,7 @@ def forward(self, blocks, h):
----------
blocks: DGL MFGs
Sampled subgraph in DGL MFG
h: dict["_N", torch.Tensor]
h: dict[DEFAULT_NTYPE, torch.Tensor]
Input node feature for each node type.
"""
for layer, block in zip(self.layers, blocks):
Expand Down
24 changes: 13 additions & 11 deletions tests/unit-tests/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import pandas as pd
import dgl.distributed as dist
import tempfile
from dgl.distributed.constants import (DEFAULT_NTYPE,
DEFAULT_ETYPE)

from transformers import AutoTokenizer
from graphstorm import get_feat_size
Expand Down Expand Up @@ -292,31 +294,31 @@ def generate_dummy_homo_graph(size='tiny', gen_mask=True):
data_size = int(size_dict[size])

num_nodes_dict = {
"_N": data_size,
DEFAULT_NTYPE: data_size,
}

edges = {
("_N", "_E", "_N"): (th.randint(data_size, (2 * data_size,)),
DEFAULT_ETYPE: (th.randint(data_size, (2 * data_size,)),
th.randint(data_size, (2 * data_size,)))
}

hetero_graph = dgl.heterograph(edges, num_nodes_dict=num_nodes_dict)

# set node and edge features
node_feat = {'_N': th.randn(data_size, 2)}
node_feat = {DEFAULT_NTYPE: th.randn(data_size, 2)}

edge_feat = {'_E': th.randn(2 * data_size, 2)}
edge_feat = {DEFAULT_ETYPE: th.randn(2 * data_size, 2)}

hetero_graph.nodes['_N'].data['feat'] = node_feat['_N']
hetero_graph.nodes['_N'].data['label'] = th.randint(10, (hetero_graph.number_of_nodes('_N'), ))
hetero_graph.nodes[DEFAULT_NTYPE].data['feat'] = node_feat[DEFAULT_NTYPE]
hetero_graph.nodes[DEFAULT_NTYPE].data['label'] = th.randint(10, (hetero_graph.number_of_nodes(DEFAULT_NTYPE), ))

hetero_graph.edges['_E'].data['feat'] = edge_feat['_E']
hetero_graph.edges['_E'].data['label'] = th.randint(10, (hetero_graph.number_of_edges('_E'), ))
hetero_graph.edges[DEFAULT_ETYPE].data['feat'] = edge_feat[DEFAULT_ETYPE]
hetero_graph.edges[DEFAULT_ETYPE].data['label'] = th.randint(10, (hetero_graph.number_of_edges(DEFAULT_ETYPE), ))

# set train/val/test masks for nodes and edges
if gen_mask:
target_ntype = ['_N']
target_etype = [("_N", "_E", "_N")]
target_ntype = [DEFAULT_NTYPE]
target_etype = [DEFAULT_ETYPE]

node_train_mask = generate_mask([0,1], data_size)
node_val_mask = generate_mask([2,3], data_size)
Expand Down Expand Up @@ -532,7 +534,7 @@ def create_distill_data(tmpdirname, num_files):

textual_embed_pddf = pd.DataFrame({
"ids": id_col,
"textual_feats": textual_col,
"textual_feats": textual_col,
"embeddings": embeddings_col
}).set_index("ids")
textual_embed_pddf.to_parquet(os.path.join(tmpdirname, f"part-{part_i}.parquet"))
Expand Down
9 changes: 6 additions & 3 deletions tests/unit-tests/test_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from argparse import Namespace
from types import MethodType
from unittest.mock import patch
from dgl.distributed.constants import DEFAULT_NTYPE

import torch as th
from torch import nn
Expand Down Expand Up @@ -238,7 +239,7 @@ def require_cache_embed(self):
embs4[ntype][0:len(embs4[ntype])].numpy())

target_nidx = {"n1": th.arange(g.number_of_nodes("n0"))} \
if not is_homo else {"_N": th.arange(g.number_of_nodes("_N"))}
if not is_homo else {DEFAULT_NTYPE: th.arange(g.number_of_nodes(DEFAULT_NTYPE))}
dataloader1 = GSgnnNodeDataLoader(data, target_nidx, fanout=[],
batch_size=10, device="cuda:0", train_task=False)
pred1, labels1 = node_mini_batch_predict(model, embs, dataloader1, return_label=True)
Expand Down Expand Up @@ -554,7 +555,8 @@ def test_sage_node_prediction(norm):
# get the test dummy distributed graph
_, part_config = generate_dummy_dist_graph(tmpdirname, is_homo=True)
np_data = GSgnnNodeTrainData(graph_name='dummy', part_config=part_config,
train_ntypes=['_N'], label_field='label',
train_ntypes=[DEFAULT_NTYPE],
label_field='label',
node_feat_field='feat')
model = create_sage_node_model(np_data.g, norm)
check_node_prediction(model, np_data, is_homo=True)
Expand All @@ -578,7 +580,8 @@ def test_gat_node_prediction(device):
# get the test dummy distributed graph
_, part_config = generate_dummy_dist_graph(tmpdirname, is_homo=True)
np_data = GSgnnNodeTrainData(graph_name='dummy', part_config=part_config,
train_ntypes=['_N'], label_field='label',
train_ntypes=[DEFAULT_NTYPE],
label_field='label',
node_feat_field='feat')
model = create_gat_node_model(np_data.g)
model = model.to(device)
Expand Down
14 changes: 9 additions & 5 deletions tools/partition_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,15 @@
node regression, edge classification and edge regression.
"""

import dgl
import numpy as np
import torch as th
import argparse
import time
import numpy as np
import torch as th

import dgl
from dgl.distributed.constants import DEFAULT_NTYPE
from dgl.distributed.constants import DEFAULT_ETYPE

from graphstorm.data import OGBTextFeatDataset
from graphstorm.data import MovieLens100kNCDataset
from graphstorm.data import ConstructedGraphDataset
Expand Down Expand Up @@ -140,13 +144,13 @@
pred_ntypes = args.target_ntype.split(',') if args.target_ntype is not None else None
if pred_ntypes is None:
try:
pred_ntypes = [dataset.predict_category] if not args.is_homo else ['_N']
pred_ntypes = [dataset.predict_category] if not args.is_homo else [DEFAULT_NTYPE]
except:
pass
pred_etypes = [tuple(args.target_etype.split(','))] if args.target_etype is not None else None
if pred_etypes is None:
try:
pred_etypes = [dataset.target_etype] if not args.is_homo else ['_E']
pred_etypes = [dataset.target_etype] if not args.is_homo else [DEFAULT_ETYPE]
except:
pass
assert pred_ntypes is not None or pred_etypes is not None, \
Expand Down

0 comments on commit f82a89d

Please sign in to comment.