diff --git a/graphstorm-processing/graphstorm_processing/distributed_executor.py b/graphstorm-processing/graphstorm_processing/distributed_executor.py index c374056f56..580e07c8ce 100644 --- a/graphstorm-processing/graphstorm_processing/distributed_executor.py +++ b/graphstorm-processing/graphstorm_processing/distributed_executor.py @@ -563,6 +563,27 @@ def parse_args() -> argparse.Namespace: return parser.parse_args() +def check_graph_name(graph_name): + """ Check whether the graph name is a valid graph name. + + We enforce that the graph name adheres to the Python + identifier naming rules as in + https://docs.python.org/3/reference/lexical_analysis.html#identifiers, + with the exception that hyphens (-) are permitted. + This helps avoid the cases when an invalid graph name, + such as `/graph`, causes unexpected errors. + + Note: Same as graphstorm.utils.check_graph_name. + + Parameter + --------- + graph_name: str + Graph Name. + """ + assert graph_name.replace('-', '_').isidentifier(), \ + "GraphStorm expects the graph name adheres to the Python" \ + "identifier naming rules with the exception that hyphens " \ + f"(-) are permitted. But we get {graph_name}" def main(): """Main entry point for GSProcessing""" @@ -572,6 +593,7 @@ def main(): level=gsprocessing_args.log_level, format="[GSPROCESSING] %(asctime)s %(levelname)-8s %(message)s", ) + check_graph_name(gsprocessing_args.graph_name) # Determine execution environment if os.path.exists("/opt/ml/config/processingjobconfig.json"): diff --git a/python/graphstorm/gconstruct/construct_graph.py b/python/graphstorm/gconstruct/construct_graph.py index 9308c2003e..bf21c3f7dd 100644 --- a/python/graphstorm/gconstruct/construct_graph.py +++ b/python/graphstorm/gconstruct/construct_graph.py @@ -30,7 +30,7 @@ import dgl from dgl.distributed.constants import DEFAULT_NTYPE, DEFAULT_ETYPE -from ..utils import sys_tracker, get_log_level +from ..utils import sys_tracker, get_log_level, check_graph_name from .file_io import parse_node_file_format, parse_edge_file_format from .file_io import get_in_files from .transform import parse_feat_ops, process_features, preprocess_features @@ -742,6 +742,7 @@ def print_graph_info(g, node_data, edge_data, node_label_stats, edge_label_stats def process_graph(args): """ Process the graph. """ + check_graph_name(args.graph_name) logging.basicConfig(level=get_log_level(args.logging_level)) with open(args.conf_file, 'r', encoding="utf8") as json_file: process_confs = json.load(json_file) diff --git a/python/graphstorm/utils.py b/python/graphstorm/utils.py index e54027178f..a58c9daafd 100644 --- a/python/graphstorm/utils.py +++ b/python/graphstorm/utils.py @@ -31,6 +31,26 @@ USE_WHOLEGRAPH = False GS_DEVICE = th.device('cpu') +def check_graph_name(graph_name): + """ Check whether the graph name is a valid graph name. + + We enforce that the graph name adheres to the Python + identifier naming rules as in + https://docs.python.org/3/reference/lexical_analysis.html#identifiers, + with the exception that hyphens (-) are permitted. + This helps avoid the cases when an invalid graph name, + such as `/graph`, causes unexpected errors. + + Parameter + --------- + graph_name: str + Graph Name. + """ + assert graph_name.replace('-', '_').isidentifier(), \ + "GraphStorm expects the graph name adheres to the Python" \ + "identifier naming rules with the exception that hyphens " \ + f"(-) are permitted. But we get {graph_name}" + def get_graph_name(part_config): """ Get graph name from graph partition config file @@ -45,7 +65,9 @@ def get_graph_name(part_config): """ with open(part_config, "r", encoding='utf-8') as f: config = json.load(f) - return config["graph_name"] + graph_name = config["graph_name"] + check_graph_name(graph_name) + return graph_name def setup_device(local_rank): r"""Setup computation device. diff --git a/tests/unit-tests/test_gsf.py b/tests/unit-tests/test_gsf.py index 338f53128d..6be5958788 100644 --- a/tests/unit-tests/test_gsf.py +++ b/tests/unit-tests/test_gsf.py @@ -16,6 +16,7 @@ from graphstorm.gsf import (create_builtin_node_decoder, create_builtin_edge_decoder, create_builtin_lp_decoder) +from graphstorm.utils import check_graph_name from graphstorm.config import (BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION, BUILTIN_TASK_EDGE_CLASSIFICATION, @@ -449,7 +450,41 @@ def test_create_builtin_lp_decoder(): assert decoder.gamma == 6. +def test_check_graph_name(): + graph_name = "a" + check_graph_name(graph_name) + graph_name = "graph_name" + check_graph_name(graph_name) + graph_name = "graph-name" + check_graph_name(graph_name) + + # test with invalid graph name + graph_name = "/graph_name" + invalid_name = False + try: + check_graph_name(graph_name) + except: + invalid_name = True + assert invalid_name + + graph_name = "|graph_name" + invalid_name = False + try: + check_graph_name(graph_name) + except: + invalid_name = True + assert invalid_name + + graph_name = "\graph_name" + invalid_name = False + try: + check_graph_name(graph_name) + except: + invalid_name = True + assert invalid_name + if __name__ == '__main__': + test_check_graph_name() test_create_builtin_node_decoder() test_create_builtin_edge_decoder() test_create_builtin_lp_decoder() \ No newline at end of file