Skip to content

Commit

Permalink
[Homo Optimization] Optimization on GSF (#686)
Browse files Browse the repository at this point in the history
*Issue #, if available:*

*Description of changes:*

Allow users to start training/inferring job without specifying
target_ntype/target_etype on homogeneous graph.

By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.

---------

Co-authored-by: xiang song(charlie.song) <[email protected]>
Co-authored-by: Xiang Song <[email protected]>
  • Loading branch information
3 people committed Dec 20, 2023
1 parent 61aada6 commit fa9643e
Show file tree
Hide file tree
Showing 9 changed files with 359 additions and 15 deletions.
8 changes: 4 additions & 4 deletions docs/source/configuration/configuration-run.rst
Original file line number Diff line number Diff line change
Expand Up @@ -381,20 +381,20 @@ Classification and Regression Task

Node Classification/Regression Specific
.........................................
- **target_ntype**: (**Required**) The node type for prediction.
- **target_ntype**: The node type for prediction.

- Yaml: ``target_ntype: movie``
- Argument: ``--target-ntype movie``
- Default value: This parameter must be provided by user.
- Default value: For heterogeneous input graph, this parameter must be provided by the user. If not provided, GraphStorm will assume the input graph is a homogeneous graph and set ``target_ntype`` to "_N".

Edge Classification/Regression Specific
..........................................
- **target_etype**: (**Required**) The list of canonical edge types that will be added as a training target in edge classification/regression tasks, for example ``--train-etype query,clicks,asin`` or ``--train-etype query,clicks,asin query,search,asin``. A canonical edge type should be formatted as `src_node_type,relation_type,dst_node_type`. Currently, GraphStorm only supports single task edge classification/regression, i.e., it only accepts one canonical edge type.
- **target_etype**: The list of canonical edge types that will be added as training targets in edge classification/regression tasks, for example ``--train-etype query,clicks,asin`` or ``--train-etype query,clicks,asin query,search,asin``. A canonical edge type should be formatted as `src_node_type,relation_type,dst_node_type`. Currently, GraphStorm only supports single task edge classification/regression, i.e., it only accepts one canonical edge type.

- Yaml: ``target_etype:``
| ``- query,clicks,asin``
- Argument: ``--target-etype query,clicks,asin``
- Default value: This parameter must be provided by user.
- Default value: For heterogeneous input graph, this parameter must be provided by the user. If not provided, GraphStorm will assume the input graph is a homogeneous graph and set ``target_etype`` to ("_N", "_E", "_N").
- **remove_target_edge_type**: When set to true, GraphStorm removes target_etype in message passing, i.e., any edge with target_etype will not be sampled during training and inference.

- Yaml: ``remove_target_edge_type: false``
Expand Down
16 changes: 11 additions & 5 deletions python/graphstorm/config/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import yaml
import torch as th
import torch.nn.functional as F
from dgl.distributed.constants import DEFAULT_NTYPE, DEFAULT_ETYPE

from .config import BUILTIN_GNN_ENCODER
from .config import BUILTIN_ENCODER
Expand Down Expand Up @@ -1573,9 +1574,12 @@ def target_ntype(self):
""" The node type for prediction
"""
# pylint: disable=no-member
assert hasattr(self, "_target_ntype"), \
"Must provide the target ntype through target_ntype"
return self._target_ntype
if hasattr(self, "_target_ntype"):
return self._target_ntype
else:
logging.warning("There is not target ntype provided, "
"will treat the input graph as a homogeneous graph")
return DEFAULT_NTYPE

@property
def eval_target_ntype(self):
Expand Down Expand Up @@ -1648,8 +1652,10 @@ def target_etype(self):
classification/regression. Support multiple tasks when needed.
"""
# pylint: disable=no-member
assert hasattr(self, "_target_etype"), \
"Edge classification task needs a target etype"
if not hasattr(self, "_target_etype"):
logging.warning("There is not target etype provided, "
"will treat the input graph as a homogeneous graph")
return [DEFAULT_ETYPE]
assert isinstance(self._target_etype, list), \
"target_etype must be a list in format: " \
"[\"query,clicks,asin\", \"query,search,asin\"]."
Expand Down
36 changes: 35 additions & 1 deletion python/graphstorm/dataloading/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import torch as th
import dgl
from dgl.distributed.constants import DEFAULT_NTYPE, DEFAULT_ETYPE
from torch.utils.data import Dataset
import pandas as pd

Expand Down Expand Up @@ -554,6 +555,15 @@ def __init__(self, graph_name, part_config, train_etypes, eval_etypes=None,
lm_feat_ntypes=lm_feat_ntypes,
lm_feat_etypes=lm_feat_etypes)

if self._train_etypes == [DEFAULT_ETYPE]:
# DGL Graph edge type is not canonical. It is just list[str].
assert self._g.ntypes == [DEFAULT_NTYPE] and \
self._g.etypes == [DEFAULT_ETYPE[1]], \
f"It is required to be a homogeneous graph when target_etype is not provided " \
f"or is set to {DEFAULT_ETYPE} on edge tasks, expect node type " \
f"to be {[DEFAULT_NTYPE]} and edge type to be {[DEFAULT_ETYPE[1]]}, " \
f"but get {self._g.ntypes} and {self._g.etypes}"

def prepare_data(self, g):
"""
Prepare the training, validation and testing edge set.
Expand Down Expand Up @@ -731,6 +741,14 @@ def __init__(self, graph_name, part_config, eval_etypes,
decoder_edge_feat,
lm_feat_ntypes=lm_feat_ntypes,
lm_feat_etypes=lm_feat_etypes)
if self._eval_etypes == [DEFAULT_ETYPE]:
# DGL Graph edge type is not canonical. It is just list[str].
assert self._g.ntypes == [DEFAULT_NTYPE] and \
self._g.etypes == [DEFAULT_ETYPE[1]], \
f"It is required to be a homogeneous graph when target_etype is not provided " \
f"or is set to {DEFAULT_ETYPE} on edge tasks, expect node type " \
f"to be {[DEFAULT_NTYPE]} and edge type to be {[DEFAULT_ETYPE[1]]}, " \
f"but get {self._g.ntypes} and {self._g.etypes}"

def prepare_data(self, g):
""" Prepare the testing edge set if any
Expand Down Expand Up @@ -916,7 +934,6 @@ def __init__(self, graph_name, part_config, train_ntypes, eval_ntypes=None,
assert isinstance(train_ntypes, list), \
"prediction ntypes for training has to be a string or a list of strings."
self._train_ntypes = train_ntypes

if eval_ntypes is not None:
if isinstance(eval_ntypes, str):
eval_ntypes = [eval_ntypes]
Expand All @@ -932,6 +949,14 @@ def __init__(self, graph_name, part_config, train_ntypes, eval_ntypes=None,
edge_feat_field=edge_feat_field,
lm_feat_ntypes=lm_feat_ntypes,
lm_feat_etypes=lm_feat_etypes)
if self._train_ntypes == [DEFAULT_NTYPE]:
# DGL Graph edge type is not canonical. It is just list[str].
assert self._g.ntypes == [DEFAULT_NTYPE] and \
self._g.etypes == [DEFAULT_ETYPE[1]], \
f"It is required to be a homogeneous graph when target_ntype is not provided " \
f"or is set to {DEFAULT_NTYPE} on node tasks, expect node type " \
f"to be {[DEFAULT_NTYPE]} and edge type to be {[DEFAULT_ETYPE[1]]}, " \
f"but get {self._g.ntypes} and {self._g.etypes}"

def prepare_data(self, g):
pb = g.get_partition_book()
Expand Down Expand Up @@ -1072,6 +1097,15 @@ def __init__(self, graph_name, part_config, eval_ntypes,
lm_feat_ntypes=lm_feat_ntypes,
lm_feat_etypes=lm_feat_etypes)

if self._eval_ntypes == [DEFAULT_NTYPE]:
# DGL Graph edge type is not canonical. It is just list[str].
assert self._g.ntypes == [DEFAULT_NTYPE] and \
self._g.etypes == [DEFAULT_ETYPE[1]], \
f"It is required to be a homogeneous graph when target_ntype is not provided " \
f"or is set to {DEFAULT_NTYPE} on node tasks, expect node type " \
f"to be {[DEFAULT_NTYPE]} and edge type to be {[DEFAULT_ETYPE[1]]}, " \
f"but get {self._g.ntypes} and {self._g.etypes}"

def prepare_data(self, g):
"""
Prepare the testing node set if any
Expand Down
23 changes: 21 additions & 2 deletions tests/end2end-tests/data_process/homogeneous_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,25 @@ echo "********* Test Node Classification on GConstruct Homogeneous Graph with re
python3 -m graphstorm.run.gs_node_classification --workspace $GS_HOME/training_scripts/gsgnn_np/ --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /tmp/movielen_100k_train_val_1p_4t_homogeneous_rev/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nc.yaml --target-ntype _N
error_and_exit $?

echo "********* Test Edge Classification on GConstruct Homogeneous Graph with reverse edge ********"
echo "********* Test Edge Classification on GConstruct Homogeneous Graph with reverse edge********"
python3 -m graphstorm.run.gs_edge_classification --workspace $GS_HOME/training_scripts/gsgnn_ep/ --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /tmp/movielen_100k_train_val_1p_4t_homogeneous_rev/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_ec.yaml --target-etype _N,_E,_N
error_and_exit $?
error_and_exit $?

echo "********* Test Node Classification with homogeneous graph optimization********"
python3 -m graphstorm.run.gs_node_classification --workspace $GS_HOME/training_scripts/gsgnn_np/ --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /tmp/movielen_100k_train_val_1p_4t_homogeneous_rev/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nc_homogeneous.yaml --save-model-path /tmp/homogeneous_node_model
error_and_exit $?

echo "********* Test Node Classification with homogeneous graph optimization doing inference********"
python3 -m graphstorm.run.gs_node_classification --inference --workspace $GS_HOME/training_scripts/gsgnn_np/ --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /tmp/movielen_100k_train_val_1p_4t_homogeneous_rev/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nc_homogeneous.yaml --restore-model-path /tmp/homogeneous_node_model/epoch-2
error_and_exit $?

echo "********* Test Edge Classification with homogeneous graph optimization********"
python3 -m graphstorm.run.gs_edge_classification --workspace $GS_HOME/training_scripts/gsgnn_ep/ --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /tmp/movielen_100k_train_val_1p_4t_homogeneous_rev/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_ec_homogeneous.yaml --save-model-path /tmp/homogeneous_edge_model
error_and_exit $?

echo "********* Test Edge Classification with homogeneous graph optimization doing inference********"
python3 -m graphstorm.run.gs_edge_classification --inference --workspace $GS_HOME/training_scripts/gsgnn_ep/ --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /tmp/movielen_100k_train_val_1p_4t_homogeneous_rev/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_ec_homogeneous.yaml --restore-model-path /tmp/homogeneous_edge_model/epoch-2
error_and_exit $?

rm -rf /tmp/homogeneous_node_model
rm -rf /tmp/homogeneous_edge_model
108 changes: 108 additions & 0 deletions tests/unit-tests/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,92 @@ def generate_dummy_homo_graph(size='tiny', gen_mask=True):

return hetero_graph

def generate_dummy_homogeneous_failure_graph(size='tiny', gen_mask=True, type='node'):
"""
generate a dummy homogeneous graph for failure case.
In a homogeneous graph, the correct node type is defined as ["_N"], and the correct edge type is [("_N", "_E", "_N")].
Any deviation from this specification implies an invalid input for a homogeneous graph. This function is designed
to create test cases that intentionally fail for homogeneous graph inputs. For type="node", it will produce a graph
with the correct node type ["_N"] but with an altered edge type set as [("_N", "_E", "_N"), ("_N", "fake_E", "_N")].
Conversely, for type="edge", the function generates a graph with an incorrect node type ["_N", "fake_N"] while
maintaining the correct edge type [("_N", "_E", "_N")]. The unit test is expected to identify and flag errors
in both these scenarios.
Parameters
----------
size: the size of dummy graph data, could be one of tiny, small, medium, large, and largest
type: task type to generate failure case
:return:
hg: a homogeneous graph in one node type and one edge type.
"""
size_dict = {
'tiny': 1e+2,
'small': 1e+4,
'medium': 1e+6,
'large': 1e+8,
'largest': 1e+10
}

data_size = int(size_dict[size])

if type == 'node':
ntype = "_N"
etype = ("_N", "fake_E", "_N")

num_nodes_dict = {
ntype: data_size,
}
else:
ntype = "_N"
etype = ("_N", "_E", "_N")
num_nodes_dict = {
ntype: data_size,
"fake_N": data_size
}

edges = {
etype: (th.randint(data_size, (2 * data_size,)),
th.randint(data_size, (2 * data_size,)))
}

hetero_graph = dgl.heterograph(edges, num_nodes_dict=num_nodes_dict)

# set node and edge features
node_feat = {ntype: th.randn(data_size, 2)}

edge_feat = {etype: th.randn(2 * data_size, 2)}

hetero_graph.nodes[ntype].data['feat'] = node_feat[ntype]
hetero_graph.nodes[ntype].data['label'] = th.randint(10, (hetero_graph.number_of_nodes(ntype), ))

hetero_graph.edges[etype].data['feat'] = edge_feat[etype]
hetero_graph.edges[etype].data['label'] = th.randint(10, (hetero_graph.number_of_edges(etype), ))

# set train/val/test masks for nodes and edges
if gen_mask:
target_ntype = [ntype]
target_etype = [etype]

node_train_mask = generate_mask([0,1], data_size)
node_val_mask = generate_mask([2,3], data_size)
node_test_mask = generate_mask([4,5], data_size)

edge_train_mask = generate_mask([0,1], 2 * data_size)
edge_val_mask = generate_mask([2,3], 2 * data_size)
edge_test_mask = generate_mask([4,5], 2 * data_size)

hetero_graph.nodes[target_ntype[0]].data['train_mask'] = node_train_mask
hetero_graph.nodes[target_ntype[0]].data['val_mask'] = node_val_mask
hetero_graph.nodes[target_ntype[0]].data['test_mask'] = node_test_mask

hetero_graph.edges[target_etype[0]].data['train_mask'] = edge_train_mask
hetero_graph.edges[target_etype[0]].data['val_mask'] = edge_val_mask
hetero_graph.edges[target_etype[0]].data['test_mask'] = edge_test_mask

return hetero_graph


def partion_and_load_distributed_graph(hetero_graph, dirname, graph_name='dummy'):
"""
Expand Down Expand Up @@ -431,6 +517,28 @@ def generate_dummy_dist_graph_multi_target_ntypes(dirname, size='tiny', graph_na
return partion_and_load_distributed_graph(hetero_graph=hetero_graph, dirname=dirname,
graph_name=graph_name)


def generate_dummy_dist_graph_homogeneous_failure_graph(dirname, size='tiny', graph_name='dummy',
gen_mask=True, type='node'):
"""
Generate a dummy DGL distributed graph with the given size
Parameters
----------
dirname : the directory where the graph will be partitioned and stored.
size: the size of dummy graph data, could be one of tiny, small, medium, large, and largest
graph_name: string as a name
type: task type to generate failure case
Returns
-------
dist_graph: a DGL distributed graph
part_config : the path of the partition configuration file.
type:
"""
hetero_graph = generate_dummy_homogeneous_failure_graph(size=size, gen_mask=gen_mask, type=type)
return partion_and_load_distributed_graph(hetero_graph=hetero_graph, dirname=dirname,
graph_name=graph_name)

def load_lm_graph(part_config):
with open(part_config) as f:
part_metadata = json.load(f)
Expand Down
7 changes: 4 additions & 3 deletions tests/unit-tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import math
import tempfile
from argparse import Namespace
from dgl.distributed.constants import DEFAULT_NTYPE, DEFAULT_ETYPE

import dgl
import torch as th
Expand Down Expand Up @@ -613,7 +614,7 @@ def test_node_class_info():
create_node_class_config(Path(tmpdirname), 'node_class_test')
args = Namespace(yaml_config_file=os.path.join(Path(tmpdirname), 'node_class_test_default.yaml'), local_rank=0)
config = GSConfig(args)
check_failure(config, "target_ntype")
assert config.target_ntype == DEFAULT_NTYPE
check_failure(config, "label_field")
assert config.multilabel == False
assert config.multilabel_weights == None
Expand Down Expand Up @@ -748,7 +749,7 @@ def test_node_regress_info():
create_node_regress_config(Path(tmpdirname), 'node_regress_test')
args = Namespace(yaml_config_file=os.path.join(Path(tmpdirname), 'node_regress_test_default.yaml'), local_rank=0)
config = GSConfig(args)
check_failure(config, "target_ntype")
assert config.target_ntype == DEFAULT_NTYPE
check_failure(config, "label_field")
assert len(config.eval_metric) == 1
assert config.eval_metric[0] == "rmse"
Expand Down Expand Up @@ -840,7 +841,7 @@ def test_edge_class_info():
create_edge_class_config(Path(tmpdirname), 'edge_class_test')
args = Namespace(yaml_config_file=os.path.join(Path(tmpdirname), 'edge_class_test_default.yaml'), local_rank=0)
config = GSConfig(args)
check_failure(config, "target_etype")
assert config.target_etype == [DEFAULT_ETYPE]
assert config.decoder_type == "DenseBiDecoder"
assert config.num_decoder_basis == 2
assert config.remove_target_edge_type == True
Expand Down
Loading

0 comments on commit fa9643e

Please sign in to comment.