diff --git a/.github/workflow_scripts/e2e_check.sh b/.github/workflow_scripts/e2e_check.sh index 9851a35529..8c122c9f9d 100644 --- a/.github/workflow_scripts/e2e_check.sh +++ b/.github/workflow_scripts/e2e_check.sh @@ -8,6 +8,7 @@ sh ./tests/end2end-tests/create_data.sh sh ./tests/end2end-tests/tools/test_mem_est.sh sh ./tests/end2end-tests/data_process/test.sh sh ./tests/end2end-tests/data_process/movielens_test.sh +sh ./tests/end2end-tests/data_process/homogeneous_test.sh sh ./tests/end2end-tests/custom-gnn/run_test.sh bash ./tests/end2end-tests/graphstorm-nc/test.sh bash ./tests/end2end-tests/graphstorm-lp/test.sh diff --git a/docs/source/configuration/configuration-run.rst b/docs/source/configuration/configuration-run.rst index 1f4b71d12c..f0c1afa854 100644 --- a/docs/source/configuration/configuration-run.rst +++ b/docs/source/configuration/configuration-run.rst @@ -381,20 +381,20 @@ Classification and Regression Task Node Classification/Regression Specific ......................................... -- **target_ntype**: (**Required**) The node type for prediction. +- **target_ntype**: The node type for prediction. - Yaml: ``target_ntype: movie`` - Argument: ``--target-ntype movie`` - - Default value: This parameter must be provided by user. + - Default value: For heterogeneous input graph, this parameter must be provided by the user. If not provided, GraphStorm will assume the input graph is a homogeneous graph and set ``target_ntype`` to "_N". Edge Classification/Regression Specific .......................................... -- **target_etype**: (**Required**) The list of canonical edge types that will be added as a training target in edge classification/regression tasks, for example ``--train-etype query,clicks,asin`` or ``--train-etype query,clicks,asin query,search,asin``. A canonical edge type should be formatted as `src_node_type,relation_type,dst_node_type`. Currently, GraphStorm only supports single task edge classification/regression, i.e., it only accepts one canonical edge type. +- **target_etype**: The list of canonical edge types that will be added as training targets in edge classification/regression tasks, for example ``--train-etype query,clicks,asin`` or ``--train-etype query,clicks,asin query,search,asin``. A canonical edge type should be formatted as `src_node_type,relation_type,dst_node_type`. Currently, GraphStorm only supports single task edge classification/regression, i.e., it only accepts one canonical edge type. - Yaml: ``target_etype:`` | ``- query,clicks,asin`` - Argument: ``--target-etype query,clicks,asin`` - - Default value: This parameter must be provided by user. + - Default value: For heterogeneous input graph, this parameter must be provided by the user. If not provided, GraphStorm will assume the input graph is a homogeneous graph and set ``target_etype`` to ("_N", "_E", "_N"). - **remove_target_edge_type**: When set to true, GraphStorm removes target_etype in message passing, i.e., any edge with target_etype will not be sampled during training and inference. - Yaml: ``remove_target_edge_type: false`` diff --git a/python/graphstorm/config/argument.py b/python/graphstorm/config/argument.py index a5d58c88ce..bbada03934 100644 --- a/python/graphstorm/config/argument.py +++ b/python/graphstorm/config/argument.py @@ -25,6 +25,7 @@ import yaml import torch as th import torch.nn.functional as F +from dgl.distributed.constants import DEFAULT_NTYPE, DEFAULT_ETYPE from .config import BUILTIN_GNN_ENCODER from .config import BUILTIN_ENCODER @@ -1573,9 +1574,12 @@ def target_ntype(self): """ The node type for prediction """ # pylint: disable=no-member - assert hasattr(self, "_target_ntype"), \ - "Must provide the target ntype through target_ntype" - return self._target_ntype + if hasattr(self, "_target_ntype"): + return self._target_ntype + else: + logging.warning("There is not target ntype provided, " + "will treat the input graph as a homogeneous graph") + return DEFAULT_NTYPE @property def eval_target_ntype(self): @@ -1648,8 +1652,10 @@ def target_etype(self): classification/regression. Support multiple tasks when needed. """ # pylint: disable=no-member - assert hasattr(self, "_target_etype"), \ - "Edge classification task needs a target etype" + if not hasattr(self, "_target_etype"): + logging.warning("There is not target etype provided, " + "will treat the input graph as a homogeneous graph") + return [DEFAULT_ETYPE] assert isinstance(self._target_etype, list), \ "target_etype must be a list in format: " \ "[\"query,clicks,asin\", \"query,search,asin\"]." diff --git a/python/graphstorm/dataloading/dataset.py b/python/graphstorm/dataloading/dataset.py index 90fd4ee67b..227ceff8bc 100644 --- a/python/graphstorm/dataloading/dataset.py +++ b/python/graphstorm/dataloading/dataset.py @@ -23,6 +23,7 @@ import torch as th import dgl +from dgl.distributed.constants import DEFAULT_NTYPE, DEFAULT_ETYPE from torch.utils.data import Dataset import pandas as pd @@ -554,6 +555,15 @@ def __init__(self, graph_name, part_config, train_etypes, eval_etypes=None, lm_feat_ntypes=lm_feat_ntypes, lm_feat_etypes=lm_feat_etypes) + if self._train_etypes == [DEFAULT_ETYPE]: + # DGL Graph edge type is not canonical. It is just list[str]. + assert self._g.ntypes == [DEFAULT_NTYPE] and \ + self._g.etypes == [DEFAULT_ETYPE[1]], \ + f"It is required to be a homogeneous graph when target_etype is not provided " \ + f"or is set to {DEFAULT_ETYPE} on edge tasks, expect node type " \ + f"to be {[DEFAULT_NTYPE]} and edge type to be {[DEFAULT_ETYPE[1]]}, " \ + f"but get {self._g.ntypes} and {self._g.etypes}" + def prepare_data(self, g): """ Prepare the training, validation and testing edge set. @@ -731,6 +741,14 @@ def __init__(self, graph_name, part_config, eval_etypes, decoder_edge_feat, lm_feat_ntypes=lm_feat_ntypes, lm_feat_etypes=lm_feat_etypes) + if self._eval_etypes == [DEFAULT_ETYPE]: + # DGL Graph edge type is not canonical. It is just list[str]. + assert self._g.ntypes == [DEFAULT_NTYPE] and \ + self._g.etypes == [DEFAULT_ETYPE[1]], \ + f"It is required to be a homogeneous graph when target_etype is not provided " \ + f"or is set to {DEFAULT_ETYPE} on edge tasks, expect node type " \ + f"to be {[DEFAULT_NTYPE]} and edge type to be {[DEFAULT_ETYPE[1]]}, " \ + f"but get {self._g.ntypes} and {self._g.etypes}" def prepare_data(self, g): """ Prepare the testing edge set if any @@ -916,7 +934,6 @@ def __init__(self, graph_name, part_config, train_ntypes, eval_ntypes=None, assert isinstance(train_ntypes, list), \ "prediction ntypes for training has to be a string or a list of strings." self._train_ntypes = train_ntypes - if eval_ntypes is not None: if isinstance(eval_ntypes, str): eval_ntypes = [eval_ntypes] @@ -932,6 +949,14 @@ def __init__(self, graph_name, part_config, train_ntypes, eval_ntypes=None, edge_feat_field=edge_feat_field, lm_feat_ntypes=lm_feat_ntypes, lm_feat_etypes=lm_feat_etypes) + if self._train_ntypes == [DEFAULT_NTYPE]: + # DGL Graph edge type is not canonical. It is just list[str]. + assert self._g.ntypes == [DEFAULT_NTYPE] and \ + self._g.etypes == [DEFAULT_ETYPE[1]], \ + f"It is required to be a homogeneous graph when target_ntype is not provided " \ + f"or is set to {DEFAULT_NTYPE} on node tasks, expect node type " \ + f"to be {[DEFAULT_NTYPE]} and edge type to be {[DEFAULT_ETYPE[1]]}, " \ + f"but get {self._g.ntypes} and {self._g.etypes}" def prepare_data(self, g): pb = g.get_partition_book() @@ -1072,6 +1097,15 @@ def __init__(self, graph_name, part_config, eval_ntypes, lm_feat_ntypes=lm_feat_ntypes, lm_feat_etypes=lm_feat_etypes) + if self._eval_ntypes == [DEFAULT_NTYPE]: + # DGL Graph edge type is not canonical. It is just list[str]. + assert self._g.ntypes == [DEFAULT_NTYPE] and \ + self._g.etypes == [DEFAULT_ETYPE[1]], \ + f"It is required to be a homogeneous graph when target_ntype is not provided " \ + f"or is set to {DEFAULT_NTYPE} on node tasks, expect node type " \ + f"to be {[DEFAULT_NTYPE]} and edge type to be {[DEFAULT_ETYPE[1]]}, " \ + f"but get {self._g.ntypes} and {self._g.etypes}" + def prepare_data(self, g): """ Prepare the testing node set if any diff --git a/python/graphstorm/gconstruct/construct_graph.py b/python/graphstorm/gconstruct/construct_graph.py index 0065da5403..259399328b 100644 --- a/python/graphstorm/gconstruct/construct_graph.py +++ b/python/graphstorm/gconstruct/construct_graph.py @@ -28,6 +28,7 @@ import numpy as np import torch as th import dgl +from dgl.distributed.constants import DEFAULT_NTYPE, DEFAULT_ETYPE from ..utils import sys_tracker, get_log_level from .file_io import parse_node_file_format, parse_edge_file_format @@ -582,8 +583,23 @@ def process_edge_data(process_confs, node_id_map, arr_merger, return (edges, edge_data, label_stats) +def is_homogeneous(confs): + """ Verify if it is a homogeneous graph + Parameter + --------- + confs: dict + A dict containing all user input config + """ + ntypes = {conf['node_type'] for conf in confs["nodes"]} + etypes = set(tuple(conf['relation']) for conf in confs["edges"]) + return len(ntypes) == 1 and len(etypes) == 1 + def verify_confs(confs): """ Verify the configuration of the input data. + Parameter + --------- + confs: dict + A dict containing all user input config """ if "version" not in confs: # TODO: Make a requirement with v1.0 launch @@ -599,6 +615,14 @@ def verify_confs(confs): f"source node type {src_type} does not exist. Please check your input data." assert dst_type in ntypes, \ f"dest node type {dst_type} does not exist. Please check your input data." + # Adjust input to DGL homogeneous graph format if it is a homogeneous graph + if is_homogeneous(confs): + logging.warning("Generated Graph is a homogeneous graph, so the node type will be " + "changed to _N and edge type will be changed to [_N, _E, _N]") + for node in confs['nodes']: + node['node_type'] = DEFAULT_NTYPE + for edge in confs['edges']: + edge['relation'] = list(DEFAULT_ETYPE) def print_graph_info(g, node_data, edge_data, node_label_stats, edge_label_stats): """ Print graph information. @@ -698,12 +722,35 @@ def process_graph(args): if args.add_reverse_edges: edges1 = {} - for etype in edges: - e = edges[etype] + if is_homogeneous(process_confs): + logging.warning("For homogeneous graph, the generated reverse edge will " + "be the same edge type as the original graph. Instead for " + "heterogeneous graph, the generated reverse edge type will " + "add -rev as a suffix") + e = edges[DEFAULT_ETYPE] assert isinstance(e, tuple) and len(e) == 2 - assert isinstance(etype, tuple) and len(etype) == 3 - edges1[etype] = e - edges1[etype[2], etype[1] + "-rev", etype[0]] = (e[1], e[0]) + edges1[DEFAULT_ETYPE] = (np.concatenate([e[0], e[1]]), + np.concatenate([e[1], e[0]])) + # Double edge feature as it is necessary to match tensor size in generated graph + # Only generate mask on original graph + if edge_data: + data = edge_data[DEFAULT_ETYPE] + logging.warning("Reverse edge for homogeneous graph will have same feature as " + "what we have in the original edges") + for key, value in data.items(): + if key not in ["train_mask", "test_mask", "val_mask"]: + data[key] = np.concatenate([value, value]) + else: + data[key] = np.concatenate([value, np.zeros(value.shape, + dtype=value.dtype)]) + + else: + for etype in edges: + e = edges[etype] + assert isinstance(e, tuple) and len(e) == 2 + assert isinstance(etype, tuple) and len(etype) == 3 + edges1[etype] = e + edges1[etype[2], etype[1] + "-rev", etype[0]] = (e[1], e[0]) edges = edges1 sys_tracker.check('Add reverse edges') g = dgl.heterograph(edges, num_nodes_dict=num_nodes) diff --git a/tests/end2end-tests/data_gen/movielens_homogeneous.json b/tests/end2end-tests/data_gen/movielens_homogeneous.json new file mode 100644 index 0000000000..018776e82e --- /dev/null +++ b/tests/end2end-tests/data_gen/movielens_homogeneous.json @@ -0,0 +1,63 @@ +{ + "version": "gconstruct-v0.1", + "nodes": [ + { + "node_id_col": "id", + "node_type": "movie", + "format": {"name": "parquet"}, + "files": "/data/ml-100k/movie.parquet", + "features": [ + { + "feature_col": "title", + "transform": { + "name": "bert_hf", + "bert_model": "bert-base-uncased", + "max_seq_length": 16 + } + } + ], + "labels": [ + { + "label_col": "label", + "task_type": "classification", + "split_pct": [0.8, 0.1, 0.1] + } + ] + }, + { + "node_type": "movie", + "format": {"name": "parquet"}, + "files": "/data/ml-100k/movie.parquet", + "features": [ + { + "feature_col": "id" + } + ] + } + ], + "edges": [ + { + "source_id_col": "src_id", + "dest_id_col": "dst_id", + "relation": ["movie", "rating", "movie"], + "format": {"name": "parquet"}, + "files": "/data/ml-100k/edges_homogeneous.parquet", + "features": [ + { + "feature_col": "rate" + }], + "labels": [ + { + "label_col": "rate", + "task_type": "classification", + "split_pct": [0.1, 0.1, 0.1] + } + ] + }, + { + "relation": ["movie", "rating", "movie"], + "format": {"name": "parquet"}, + "files": "/data/ml-100k/edges_homogeneous.parquet" + } + ] +} \ No newline at end of file diff --git a/tests/end2end-tests/data_gen/process_movielens.py b/tests/end2end-tests/data_gen/process_movielens.py index 90fdcd1702..a9ca90873e 100644 --- a/tests/end2end-tests/data_gen/process_movielens.py +++ b/tests/end2end-tests/data_gen/process_movielens.py @@ -90,6 +90,11 @@ def write_data_parquet(data, data_file): edge_data = {'src_id': edges[0], 'dst_id': edges[1], 'rate': edges[2]} write_data_parquet(edge_data, '/data/ml-100k/edges.parquet') +# generate data for homogeneous optimization test +edges = pandas.read_csv('/data/ml-100k/u.data', delimiter='\t', header=None) +edge_data = {'src_id': edges[1], 'dst_id': edges[1], 'rate': edges[2]} +write_data_parquet(edge_data, '/data/ml-100k/edges_homogeneous.parquet') + # generate synthetic user data with label user_labels = np.random.randint(11, size=feat.shape[0]) user_data = {'id': user['id'].values, 'feat': feat, 'occupation': user['occupation'], 'label': user_labels} diff --git a/tests/end2end-tests/data_process/check_homogeneous.py b/tests/end2end-tests/data_process/check_homogeneous.py new file mode 100644 index 0000000000..daeb6f0ada --- /dev/null +++ b/tests/end2end-tests/data_process/check_homogeneous.py @@ -0,0 +1,60 @@ +""" + Copyright 2023 Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +""" +import os +import argparse +import dgl +from dgl.distributed.constants import DEFAULT_NTYPE, DEFAULT_ETYPE +from numpy.testing import assert_almost_equal + + +def check_reverse_edge(args): + + g_orig = dgl.load_graphs(os.path.join(args.orig_graph_path, "graph.dgl"))[0][0] + g_rev = dgl.load_graphs(os.path.join(args.rev_graph_path, "graph.dgl"))[0][0] + assert g_orig.ntypes == g_rev.ntypes + assert g_orig.etypes == g_rev.etypes + assert g_orig.number_of_nodes(DEFAULT_NTYPE) == g_rev.number_of_nodes(DEFAULT_NTYPE) + assert 2 * g_orig.number_of_edges(DEFAULT_ETYPE) == g_rev.number_of_edges(DEFAULT_ETYPE) + for ntype in g_orig.ntypes: + assert g_orig.number_of_nodes(ntype) == g_rev.number_of_nodes(ntype) + for name in g_orig.nodes[ntype].data: + # We should skip '*_mask' because data split is split randomly. + if 'mask' not in name: + assert_almost_equal(g_orig.nodes[ntype].data[name].numpy(), + g_rev.nodes[ntype].data[name].numpy()) + + # Check edge feature + g_orig_feat = dgl.data.load_tensors(os.path.join(args.orig_graph_path, "edge_feat.dgl")) + g_rev_feat = dgl.data.load_tensors(os.path.join(args.rev_graph_path, "edge_feat.dgl")) + for feat_type in g_orig_feat.keys(): + if "mask" not in feat_type: + assert_almost_equal(g_orig_feat[feat_type].numpy(), + g_rev_feat[feat_type].numpy()[:g_orig.number_of_edges(DEFAULT_ETYPE)]) + else: + assert_almost_equal(g_rev_feat[feat_type].numpy()[g_orig.number_of_edges(DEFAULT_ETYPE):], + [0] * g_orig.number_of_edges(DEFAULT_ETYPE)) + +if __name__ == '__main__': + argparser = argparse.ArgumentParser("Check edge prediction remapping") + argparser.add_argument("--orig-graph-path", type=str, default="/tmp/movielen_100k_train_val_1p_4t_homogeneous/part0/", + help="Path to save the generated data") + argparser.add_argument("--rev-graph-path", type=str, default="/tmp/movielen_100k_train_val_1p_4t_homogeneous_rev/part0/", + help="Path to save the generated data") + + args = argparser.parse_args() + + check_reverse_edge(args) \ No newline at end of file diff --git a/tests/end2end-tests/data_process/homogeneous_test.sh b/tests/end2end-tests/data_process/homogeneous_test.sh new file mode 100644 index 0000000000..ea2cc197e3 --- /dev/null +++ b/tests/end2end-tests/data_process/homogeneous_test.sh @@ -0,0 +1,69 @@ +#!/bin/bash + +service ssh restart + +GS_HOME=$(pwd) +NUM_TRAINERS=4 +export PYTHONPATH=$GS_HOME/python/ +cd $GS_HOME/training_scripts/gsgnn_np +echo "127.0.0.1" > ip_list.txt +cd $GS_HOME/training_scripts/gsgnn_ep +echo "127.0.0.1" > ip_list.txt + +error_and_exit () { + # check exec status of launch.py + status=$1 + echo $status + + if test $status -ne 0 + then + exit -1 + fi +} + + +echo "********* Test Homogeneous Graph Optimization ********" +python3 -m graphstorm.gconstruct.construct_graph --conf-file $GS_HOME/tests/end2end-tests/data_gen/movielens_homogeneous.json --num-processes 1 --output-dir /tmp/movielen_100k_train_val_1p_4t_homogeneous --graph-name movie-lens-100k +error_and_exit $? + +echo "********* Test Node Classification on GConstruct Homogeneous Graph ********" +python3 -m graphstorm.run.gs_node_classification --workspace $GS_HOME/training_scripts/gsgnn_np/ --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /tmp/movielen_100k_train_val_1p_4t_homogeneous/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nc.yaml --target-ntype _N +error_and_exit $? + +echo "********* Test Edge Classification on GConstruct Homogeneous Graph ********" +python3 -m graphstorm.run.gs_edge_classification --workspace $GS_HOME/training_scripts/gsgnn_ep/ --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /tmp/movielen_100k_train_val_1p_4t_homogeneous/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_ec.yaml --target-etype _N,_E,_N +error_and_exit $? + +echo "********* Test Homogeneous Graph Optimization on reverse edge********" +python3 -m graphstorm.gconstruct.construct_graph --conf-file $GS_HOME/tests/end2end-tests/data_gen/movielens_homogeneous.json --num-processes 1 --output-dir /tmp/movielen_100k_train_val_1p_4t_homogeneous_rev --graph-name movie-lens-100k --add-reverse-edges +error_and_exit $? + +python3 $GS_HOME/tests/end2end-tests/data_process/check_homogeneous.py +error_and_exit $? + +echo "********* Test Node Classification on GConstruct Homogeneous Graph with reverse edge********" +python3 -m graphstorm.run.gs_node_classification --workspace $GS_HOME/training_scripts/gsgnn_np/ --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /tmp/movielen_100k_train_val_1p_4t_homogeneous_rev/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nc.yaml --target-ntype _N +error_and_exit $? + +echo "********* Test Edge Classification on GConstruct Homogeneous Graph with reverse edge********" +python3 -m graphstorm.run.gs_edge_classification --workspace $GS_HOME/training_scripts/gsgnn_ep/ --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /tmp/movielen_100k_train_val_1p_4t_homogeneous_rev/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_ec.yaml --target-etype _N,_E,_N +error_and_exit $? + +echo "********* Test Node Classification with homogeneous graph optimization********" +python3 -m graphstorm.run.gs_node_classification --workspace $GS_HOME/training_scripts/gsgnn_np/ --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /tmp/movielen_100k_train_val_1p_4t_homogeneous_rev/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nc_homogeneous.yaml --save-model-path /tmp/homogeneous_node_model +error_and_exit $? + +echo "********* Test Node Classification with homogeneous graph optimization doing inference********" +python3 -m graphstorm.run.gs_node_classification --inference --workspace $GS_HOME/training_scripts/gsgnn_np/ --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /tmp/movielen_100k_train_val_1p_4t_homogeneous_rev/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nc_homogeneous.yaml --restore-model-path /tmp/homogeneous_node_model/epoch-2 +error_and_exit $? + +echo "********* Test Edge Classification with homogeneous graph optimization********" +python3 -m graphstorm.run.gs_edge_classification --workspace $GS_HOME/training_scripts/gsgnn_ep/ --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /tmp/movielen_100k_train_val_1p_4t_homogeneous_rev/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_ec_homogeneous.yaml --save-model-path /tmp/homogeneous_edge_model +error_and_exit $? + +echo "********* Test Edge Classification with homogeneous graph optimization doing inference********" +python3 -m graphstorm.run.gs_edge_classification --inference --workspace $GS_HOME/training_scripts/gsgnn_ep/ --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /tmp/movielen_100k_train_val_1p_4t_homogeneous_rev/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_ec_homogeneous.yaml --restore-model-path /tmp/homogeneous_edge_model/epoch-2 +error_and_exit $? + +rm -rf /tmp/homogeneous_node_model +rm -rf /tmp/homogeneous_edge_model diff --git a/tests/unit-tests/data_utils.py b/tests/unit-tests/data_utils.py index e6e9f094b8..8ce83c7779 100644 --- a/tests/unit-tests/data_utils.py +++ b/tests/unit-tests/data_utils.py @@ -343,6 +343,92 @@ def generate_dummy_homo_graph(size='tiny', gen_mask=True): return hetero_graph +def generate_dummy_homogeneous_failure_graph(size='tiny', gen_mask=True, type='node'): + """ + generate a dummy homogeneous graph for failure case. + + In a homogeneous graph, the correct node type is defined as ["_N"], and the correct edge type is [("_N", "_E", "_N")]. + Any deviation from this specification implies an invalid input for a homogeneous graph. This function is designed + to create test cases that intentionally fail for homogeneous graph inputs. For type="node", it will produce a graph + with the correct node type ["_N"] but with an altered edge type set as [("_N", "_E", "_N"), ("_N", "fake_E", "_N")]. + Conversely, for type="edge", the function generates a graph with an incorrect node type ["_N", "fake_N"] while + maintaining the correct edge type [("_N", "_E", "_N")]. The unit test is expected to identify and flag errors + in both these scenarios. + + Parameters + ---------- + size: the size of dummy graph data, could be one of tiny, small, medium, large, and largest + type: task type to generate failure case + + :return: + hg: a homogeneous graph in one node type and one edge type. + """ + size_dict = { + 'tiny': 1e+2, + 'small': 1e+4, + 'medium': 1e+6, + 'large': 1e+8, + 'largest': 1e+10 + } + + data_size = int(size_dict[size]) + + if type == 'node': + ntype = "_N" + etype = ("_N", "fake_E", "_N") + + num_nodes_dict = { + ntype: data_size, + } + else: + ntype = "_N" + etype = ("_N", "_E", "_N") + num_nodes_dict = { + ntype: data_size, + "fake_N": data_size + } + + edges = { + etype: (th.randint(data_size, (2 * data_size,)), + th.randint(data_size, (2 * data_size,))) + } + + hetero_graph = dgl.heterograph(edges, num_nodes_dict=num_nodes_dict) + + # set node and edge features + node_feat = {ntype: th.randn(data_size, 2)} + + edge_feat = {etype: th.randn(2 * data_size, 2)} + + hetero_graph.nodes[ntype].data['feat'] = node_feat[ntype] + hetero_graph.nodes[ntype].data['label'] = th.randint(10, (hetero_graph.number_of_nodes(ntype), )) + + hetero_graph.edges[etype].data['feat'] = edge_feat[etype] + hetero_graph.edges[etype].data['label'] = th.randint(10, (hetero_graph.number_of_edges(etype), )) + + # set train/val/test masks for nodes and edges + if gen_mask: + target_ntype = [ntype] + target_etype = [etype] + + node_train_mask = generate_mask([0,1], data_size) + node_val_mask = generate_mask([2,3], data_size) + node_test_mask = generate_mask([4,5], data_size) + + edge_train_mask = generate_mask([0,1], 2 * data_size) + edge_val_mask = generate_mask([2,3], 2 * data_size) + edge_test_mask = generate_mask([4,5], 2 * data_size) + + hetero_graph.nodes[target_ntype[0]].data['train_mask'] = node_train_mask + hetero_graph.nodes[target_ntype[0]].data['val_mask'] = node_val_mask + hetero_graph.nodes[target_ntype[0]].data['test_mask'] = node_test_mask + + hetero_graph.edges[target_etype[0]].data['train_mask'] = edge_train_mask + hetero_graph.edges[target_etype[0]].data['val_mask'] = edge_val_mask + hetero_graph.edges[target_etype[0]].data['test_mask'] = edge_test_mask + + return hetero_graph + def partion_and_load_distributed_graph(hetero_graph, dirname, graph_name='dummy'): """ @@ -431,6 +517,28 @@ def generate_dummy_dist_graph_multi_target_ntypes(dirname, size='tiny', graph_na return partion_and_load_distributed_graph(hetero_graph=hetero_graph, dirname=dirname, graph_name=graph_name) + +def generate_dummy_dist_graph_homogeneous_failure_graph(dirname, size='tiny', graph_name='dummy', + gen_mask=True, type='node'): + """ + Generate a dummy DGL distributed graph with the given size + Parameters + ---------- + dirname : the directory where the graph will be partitioned and stored. + size: the size of dummy graph data, could be one of tiny, small, medium, large, and largest + graph_name: string as a name + type: task type to generate failure case + + Returns + ------- + dist_graph: a DGL distributed graph + part_config : the path of the partition configuration file. + type: + """ + hetero_graph = generate_dummy_homogeneous_failure_graph(size=size, gen_mask=gen_mask, type=type) + return partion_and_load_distributed_graph(hetero_graph=hetero_graph, dirname=dirname, + graph_name=graph_name) + def load_lm_graph(part_config): with open(part_config) as f: part_metadata = json.load(f) diff --git a/tests/unit-tests/gconstruct/test_construct_graph.py b/tests/unit-tests/gconstruct/test_construct_graph.py index a03a7cbec7..d7c9ae6650 100644 --- a/tests/unit-tests/gconstruct/test_construct_graph.py +++ b/tests/unit-tests/gconstruct/test_construct_graph.py @@ -22,11 +22,12 @@ import numpy as np import dgl import torch as th +import copy from functools import partial from numpy.testing import assert_equal, assert_almost_equal -from graphstorm.gconstruct.construct_graph import parse_edge_data +from graphstorm.gconstruct.construct_graph import parse_edge_data, verify_confs, is_homogeneous from graphstorm.gconstruct.file_io import write_data_parquet, read_data_parquet from graphstorm.gconstruct.file_io import write_data_json, read_data_json from graphstorm.gconstruct.file_io import write_data_csv, read_data_csv @@ -1705,6 +1706,59 @@ def test_gc(): assert not os.path.isdir("/tmp_featurewrapper2"), \ "Directory /tmp_featurewrapper2 should not exist after gc" + +def test_homogeneous(): + # single node type and edge type input + conf = { + "version": "gconstruct-v0.1", "nodes": [ + {"node_id_col": "id", "node_type": "movie", "format": {"name": "parquet"}, + "files": "/data/ml-100k/movie.parquet", "features": [ + {"feature_col": "title", "transform": { + "name": "bert_hf", "bert_model": "bert-base-uncased", "max_seq_length": 16}}], + "labels": [{"label_col": "label", "task_type": "classification", "split_pct": [0.8, 0.1, 0.1]}]}], + "edges": [ + {"source_id_col": "src_id", "dest_id_col": "dst_id", "relation": ["movie", "rating", "movie"], + "format": {"name": "parquet"}, "files": "/data/ml-100k/edges_homo.parquet", "labels": [ + {"label_col": "rate", "task_type": "classification", "split_pct": [0.1, 0.1, 0.1]}]}] + } + assert is_homogeneous(conf) + verify_confs(conf) + assert conf['nodes'][0]["node_type"] == "_N" + assert conf['edges'][0]['relation'] == ["_N", "_E", "_N"] + conf["edges"][0]["relation"] = ["movie_fake", "rating", "movie"] + conf["nodes"].append(copy.deepcopy(conf["nodes"][0])) + conf["nodes"][0]["node_type"] = "movie" + conf["nodes"][1]["node_type"] = "movie_fake" + assert not is_homogeneous(conf) + + + # multiple node types and edge types input + conf = { + "version": "gconstruct-v0.1", "nodes": [ + {"node_id_col": "id", "node_type": "movie", "format": {"name": "parquet"}, + "files": "/data/ml-100k/movie.parquet", "features": [ + {"feature_col": "title", "transform": { + "name": "bert_hf", "bert_model": "bert-base-uncased", "max_seq_length": 16}}], + "labels": [{"label_col": "label", "task_type": "classification", "split_pct": [0.8, 0.1, 0.1]}]}, + {"node_type": "movie", "format": {"name": "parquet"}, "files": "/data/ml-100k/movie.parquet", + "features": [{"feature_col": "id"}]}], + "edges": [ + {"source_id_col": "src_id", "dest_id_col": "dst_id", "relation": ["movie", "rating", "movie"], + "format": {"name": "parquet"}, "files": "/data/ml-100k/edges_homo.parquet", "labels": [ + {"label_col": "rate", "task_type": "classification", "split_pct": [0.1, 0.1, 0.1]}]}, + {"relation": ["movie", "rating", "movie"], "format": {"name": "parquet"}, + "files": "/data/ml-100k/edges_homo.parquet"}] + } + assert is_homogeneous(conf) + verify_confs(conf) + assert conf['nodes'][0]["node_type"] == "_N" + assert conf['edges'][0]['relation'] == ["_N", "_E", "_N"] + conf["edges"][0]["relation"] = ["movie_fake", "rating", "movie"] + conf["nodes"].append(copy.deepcopy(conf["nodes"][0])) + conf["nodes"][0]["node_type"] = "movie" + conf["nodes"][1]["node_type"] = "movie_fake" + assert not is_homogeneous(conf) + if __name__ == '__main__': test_parse_edge_data() test_multiprocessing_checks() @@ -1723,4 +1777,5 @@ def test_gc(): test_label() test_multicolumn(None) test_multicolumn("/") - test_feature_wrapper() \ No newline at end of file + test_feature_wrapper() + test_homogeneous() diff --git a/tests/unit-tests/test_config.py b/tests/unit-tests/test_config.py index 48cc4c2146..0b97aa4dc0 100644 --- a/tests/unit-tests/test_config.py +++ b/tests/unit-tests/test_config.py @@ -21,6 +21,7 @@ import math import tempfile from argparse import Namespace +from dgl.distributed.constants import DEFAULT_NTYPE, DEFAULT_ETYPE import dgl import torch as th @@ -613,7 +614,7 @@ def test_node_class_info(): create_node_class_config(Path(tmpdirname), 'node_class_test') args = Namespace(yaml_config_file=os.path.join(Path(tmpdirname), 'node_class_test_default.yaml'), local_rank=0) config = GSConfig(args) - check_failure(config, "target_ntype") + assert config.target_ntype == DEFAULT_NTYPE check_failure(config, "label_field") assert config.multilabel == False assert config.multilabel_weights == None @@ -748,7 +749,7 @@ def test_node_regress_info(): create_node_regress_config(Path(tmpdirname), 'node_regress_test') args = Namespace(yaml_config_file=os.path.join(Path(tmpdirname), 'node_regress_test_default.yaml'), local_rank=0) config = GSConfig(args) - check_failure(config, "target_ntype") + assert config.target_ntype == DEFAULT_NTYPE check_failure(config, "label_field") assert len(config.eval_metric) == 1 assert config.eval_metric[0] == "rmse" @@ -840,7 +841,7 @@ def test_edge_class_info(): create_edge_class_config(Path(tmpdirname), 'edge_class_test') args = Namespace(yaml_config_file=os.path.join(Path(tmpdirname), 'edge_class_test_default.yaml'), local_rank=0) config = GSConfig(args) - check_failure(config, "target_etype") + assert config.target_etype == [DEFAULT_ETYPE] assert config.decoder_type == "DenseBiDecoder" assert config.num_decoder_basis == 2 assert config.remove_target_edge_type == True diff --git a/tests/unit-tests/test_dataloading.py b/tests/unit-tests/test_dataloading.py index 18613e2314..0c688d586e 100644 --- a/tests/unit-tests/test_dataloading.py +++ b/tests/unit-tests/test_dataloading.py @@ -29,6 +29,7 @@ from data_utils import ( generate_dummy_dist_graph, generate_dummy_dist_graph_reconstruct, + generate_dummy_dist_graph_homogeneous_failure_graph, create_distill_data, ) @@ -1354,6 +1355,106 @@ def test_inbatch_joint_neg_sampler(num_pos, num_neg): assert_equal(in_batch_dst[i*(num_pos-1):(i+1)*(num_pos-1)].numpy(), np.arange(num_pos)[tmp_idx]) +def test_GSgnnTrainData_homogeneous(): + # initialize the torch distributed environment + th.distributed.init_process_group(backend='gloo', + init_method='tcp://127.0.0.1:23456', + rank=0, + world_size=1) + tr_ntypes = ["_N"] + va_ntypes = ["_N"] + + with tempfile.TemporaryDirectory() as tmpdirname: + # generate the test dummy homogeneous distributed graph and + # test if it is possible to create GSgnnNodeTrainData on homogeneous graph + dist_graph, part_config = generate_dummy_dist_graph(graph_name='dummy', + dirname=tmpdirname, + is_homo=True) + _ = GSgnnNodeTrainData(graph_name='dummy', part_config=part_config, + train_ntypes=tr_ntypes, eval_ntypes=va_ntypes, + label_field='label') + + # generate the test dummy distributed graph with "_N" node type. As it is expected to be + # a homogeneous graph with "_N" as node type and ("_N", "_E", "_N") as edge type. + # It should throw an error to clarify that. + dist_graph, part_config = generate_dummy_dist_graph_homogeneous_failure_graph(graph_name='dummy', + dirname=tmpdirname) + try: + _ = GSgnnNodeTrainData(graph_name='dummy', part_config=part_config, + train_ntypes=tr_ntypes, eval_ntypes=va_ntypes, + label_field='label') + assert False, "expected Error raised for non-homogeneous graph input" + except AssertionError as _: + pass + + # generate the test dummy homogeneous distributed graph and + # test if it is possible to create GSgnnNodeInferData on homogeneous graph + dist_graph, part_config = generate_dummy_dist_graph(graph_name='dummy', + dirname=tmpdirname, + is_homo=True) + _ = GSgnnNodeInferData(graph_name='dummy', part_config=part_config, + eval_ntypes=va_ntypes) + + # generate the test dummy distributed graph with "_N" node type. As it is expected to be + # a homogeneous graph with "_N" as node type and ("_N", "_E", "_N") as edge type. + # It should throw an error to clarify that. + dist_graph, part_config = generate_dummy_dist_graph_homogeneous_failure_graph(graph_name='dummy', + dirname=tmpdirname) + try: + _ = GSgnnNodeInferData(graph_name='dummy', part_config=part_config, + eval_ntypes=va_ntypes) + assert False, "expected Error raised for non-homogeneous graph input" + except AssertionError as _: + pass + + tr_etypes = [("_N", "_E", "_N")] + va_etypes = [("_N", "_E", "_N")] + + with tempfile.TemporaryDirectory() as tmpdirname: + # generate the test dummy homogeneous distributed graph and + # test if it is possible to create GSgnnEdgeTrainData on homogeneous graph + dist_graph, part_config = generate_dummy_dist_graph(graph_name='dummy', + dirname=os.path.join(tmpdirname, 'dummy'), + is_homo=True) + _ = GSgnnEdgeTrainData(graph_name='dummy', part_config=part_config, + train_etypes=tr_etypes, eval_etypes=va_etypes, + label_field='label') + + # generate the test dummy distributed graph with "_N" node type. As it is expected to be + # a homogeneous graph with "_N" as node type and ("_N", "_E", "_N") as edge type. + # It should throw an error to clarify that. + dist_graph, part_config = generate_dummy_dist_graph_homogeneous_failure_graph(graph_name='dummy', + dirname=os.path.join(tmpdirname, 'dummy')) + try: + _ = GSgnnEdgeTrainData(graph_name='dummy', part_config=part_config, + train_etypes=tr_etypes, eval_etypes=va_etypes, + label_field='label') + assert False, "expected Error raised for non-homogeneous graph input" + except AssertionError as _: + pass + + # generate the test dummy homogeneous distributed graph and + # test if it is possible to create GSgnnEdgeInferData on homogeneous graph + dist_graph, part_config = generate_dummy_dist_graph(graph_name='dummy', + dirname=os.path.join(tmpdirname, 'dummy'), + is_homo=True) + _ = GSgnnEdgeInferData(graph_name='dummy', part_config=part_config, + eval_etypes=va_etypes) + + # generate the test dummy distributed graph with "_N" node type. As it is expected to be + # a homogeneous graph with "_N" as node type and ("_N", "_E", "_N") as edge type. + # It should throw an error to clarify that. + dist_graph, part_config = generate_dummy_dist_graph_homogeneous_failure_graph(graph_name='dummy', + dirname=os.path.join(tmpdirname, 'dummy')) + try: + _ = GSgnnEdgeInferData(graph_name='dummy', part_config=part_config, + eval_etypes=va_etypes) + assert False, "expected Error raised for non-homogeneous graph input" + except AssertionError as _: + pass + + # after test pass, destroy all process group + th.distributed.destroy_process_group() if __name__ == '__main__': test_inbatch_joint_neg_sampler(10, 20) @@ -1387,3 +1488,5 @@ def test_inbatch_joint_neg_sampler(num_pos, num_neg): test_DistillDistributedFileSampler(num_files=7, is_train=True, \ infinite=False, shuffle=True) test_DistillDataloaderGenerator("gloo", 7, True) + + test_GSgnnTrainData_homogeneous() diff --git a/training_scripts/gsgnn_ep/ml_ec_homogeneous.yaml b/training_scripts/gsgnn_ep/ml_ec_homogeneous.yaml new file mode 100644 index 0000000000..7e519f9276 --- /dev/null +++ b/training_scripts/gsgnn_ep/ml_ec_homogeneous.yaml @@ -0,0 +1,38 @@ +--- +version: 1.0 +gsf: + basic: + backend: gloo + verbose: false + save_perf_results_path: null + gnn: + model_encoder_type: rgcn + fanout: "4" + num_layers: 1 + hidden_size: 128 + use_mini_batch_infer: true + input: + restore_model_path: null + output: + save_model_path: null + save_embed_path: null + hyperparam: + dropout: 0. + lr: 0.001 + lm_tune_lr: 0.0001 + num_epochs: 3 + batch_size: 64 + wd_l2norm: 0 + no_validation: false + eval_frequency: 1000 + rgcn: + num_bases: -1 + use_self_loop: true + sparse_optimizer_lr: 1e-2 + use_node_embeddings: false + edge_classification: + label_field: "rate" + multilabel: false + num_classes: 6 + num_decoder_basis: 32 + exclude_training_targets: false diff --git a/training_scripts/gsgnn_np/ml_nc_homogeneous.yaml b/training_scripts/gsgnn_np/ml_nc_homogeneous.yaml new file mode 100644 index 0000000000..7f12b9e011 --- /dev/null +++ b/training_scripts/gsgnn_np/ml_nc_homogeneous.yaml @@ -0,0 +1,35 @@ +--- +version: 1.0 +gsf: + basic: + backend: gloo + verbose: false + save_perf_results_path: null + gnn: + model_encoder_type: rgcn + fanout: "4" + num_layers: 1 + hidden_size: 128 + use_mini_batch_infer: true + input: + restore_model_path: null + output: + save_model_path: null + save_embed_path: null + hyperparam: + dropout: 0. + lr: 0.001 + lm_tune_lr: 0.0001 + num_epochs: 3 + batch_size: 128 + wd_l2norm: 0 + no_validation: false + rgcn: + num_bases: -1 + use_self_loop: true + sparse_optimizer_lr: 1e-2 + use_node_embeddings: false + node_classification: + label_field: "label" + multilabel: false + num_classes: 19