diff --git a/docs/source/cli/graph-construction/single-machine-gconstruct.rst b/docs/source/cli/graph-construction/single-machine-gconstruct.rst index 45026d1f4e..a0b2c465d7 100644 --- a/docs/source/cli/graph-construction/single-machine-gconstruct.rst +++ b/docs/source/cli/graph-construction/single-machine-gconstruct.rst @@ -33,7 +33,7 @@ Full argument list of the ``gconstruct.construct_graph`` command * **-\-num-processes-for-nodes**: the number of processes to process node data simultaneously. Increase this number can speed up node data processing. * **-\-num-processes-for-edges**: the number of processes to process edge data simultaneously. Increase this number can speed up edge data processing. * **-\-output-dir**: (**Required**) the path of the output data files. -* **-\-graph-name**: (**Required**) the name assigned for the graph. +* **-\-graph-name**: (**Required**) the name assigned for the graph. The graph name must adhere to the `Python identifier naming rules`_ with the exception that hyphens (``-``) are permitted and the name can start with numbers. * **-\-remap-node-id**: boolean value to decide whether to rename node IDs or not. Adding this argument will set it to be true, otherwise false. * **-\-add-reverse-edges**: boolean value to decide whether to add reverse edges for the given graph. Adding this argument sets it to true; otherwise, it defaults to false. It is **strongly** suggested to include this argument for graph construction, as some nodes in the original data may not have in-degrees, and thus cannot update their presentations by aggregating messages from their neighbors. Adding this arugment helps prevent this issue. * **-\-output-format**: the format of constructed graph, options are ``DGL``, ``DistDGL``. Default is ``DistDGL``. It also accepts multiple graph formats at the same time separated by an space, for example ``--output-format "DGL DistDGL"``. The output format is explained in the :ref:`Output ` section above. diff --git a/docs/source/cli/model-training-inference/distributed/sagemaker.rst b/docs/source/cli/model-training-inference/distributed/sagemaker.rst index 5b111726cb..3acb78ee61 100644 --- a/docs/source/cli/model-training-inference/distributed/sagemaker.rst +++ b/docs/source/cli/model-training-inference/distributed/sagemaker.rst @@ -388,7 +388,7 @@ The rest of the arguments are passed on to ``sagemaker_train.py`` or ``sagemaker * **--task-type**: Task type. * **--graph-data-s3**: S3 location of the input graph. -* **--graph-name**: Name of the input graph. +* **--graph-name**: Name of the input graph. The graph name must adhere to the `Python identifier naming rules`_ with the exception that hyphens (``-``) are permitted and the name can start with numbers. * **--yaml-s3**: S3 location of yaml file for training and inference. * **--custom-script**: Custom training script provided by customers to run customer training logic. This should be a path to the Python script within the Docker image. * **--output-emb-s3**: S3 location to store GraphStorm generated node embeddings. This is an inference only argument. diff --git a/graphstorm-processing/graphstorm_processing/distributed_executor.py b/graphstorm-processing/graphstorm_processing/distributed_executor.py index c374056f56..3bc1bc4936 100644 --- a/graphstorm-processing/graphstorm_processing/distributed_executor.py +++ b/graphstorm-processing/graphstorm_processing/distributed_executor.py @@ -54,6 +54,7 @@ import json import logging import os +import re from pathlib import Path import tempfile import time @@ -540,7 +541,11 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--graph-name", type=str, - help="Name for the graph being processed.", + help="Name for the graph being processed." + "The graph name must adhere to the Python " + "identifier naming rules with the exception " + "that hyphens (-) are permitted and the name " + "can start with numbers", required=False, default=None, ) @@ -564,6 +569,33 @@ def parse_args() -> argparse.Namespace: return parser.parse_args() +def check_graph_name(graph_name): + """Check whether the graph name is a valid graph name. + + We enforce that the graph name adheres to the Python + identifier naming rules as in + https://docs.python.org/3/reference/lexical_analysis.html#identifiers, + with the exception that hyphens (-) are permitted + and the name can start with numbers. + This helps avoid the cases when an invalid graph name, + such as `/graph`, causes unexpected errors. + + Note: Same as graphstorm.utils.check_graph_name. + + Parameter + --------- + graph_name: str + Graph Name. + """ + gname = re.sub(r"^\d+", "", graph_name) + assert gname.replace("-", "_").isidentifier(), ( + "GraphStorm expects the graph name adheres to the Python" + "identifier naming rules with the exception that hyphens " + "(-) are permitted and the name can start with numbers. " + f"Got: {graph_name}" + ) + + def main(): """Main entry point for GSProcessing""" # Allows us to get typed arguments from the command line @@ -572,6 +604,7 @@ def main(): level=gsprocessing_args.log_level, format="[GSPROCESSING] %(asctime)s %(levelname)-8s %(message)s", ) + check_graph_name(gsprocessing_args.graph_name) # Determine execution environment if os.path.exists("/opt/ml/config/processingjobconfig.json"): diff --git a/python/graphstorm/gconstruct/construct_graph.py b/python/graphstorm/gconstruct/construct_graph.py index 9308c2003e..ceaee5adfd 100644 --- a/python/graphstorm/gconstruct/construct_graph.py +++ b/python/graphstorm/gconstruct/construct_graph.py @@ -30,7 +30,7 @@ import dgl from dgl.distributed.constants import DEFAULT_NTYPE, DEFAULT_ETYPE -from ..utils import sys_tracker, get_log_level +from ..utils import sys_tracker, get_log_level, check_graph_name from .file_io import parse_node_file_format, parse_edge_file_format from .file_io import get_in_files from .transform import parse_feat_ops, process_features, preprocess_features @@ -742,6 +742,7 @@ def print_graph_info(g, node_data, edge_data, node_label_stats, edge_label_stats def process_graph(args): """ Process the graph. """ + check_graph_name(args.graph_name) logging.basicConfig(level=get_log_level(args.logging_level)) with open(args.conf_file, 'r', encoding="utf8") as json_file: process_confs = json.load(json_file) @@ -909,7 +910,11 @@ def process_graph(args): argparser.add_argument("--output-dir", type=str, required=True, help="The path of the output data folder.") argparser.add_argument("--graph-name", type=str, required=True, - help="The graph name") + help="Name for the graph being processed." + "The graph name must adhere to the Python " + "identifier naming rules with the exception " + "that hyphens (-) are permitted " + "and the name can start with numbers",) argparser.add_argument("--remap-node-id", action='store_true', help="Whether or not to remap node IDs.") argparser.add_argument("--add-reverse-edges", action='store_true', diff --git a/python/graphstorm/utils.py b/python/graphstorm/utils.py index e54027178f..8f1511f406 100644 --- a/python/graphstorm/utils.py +++ b/python/graphstorm/utils.py @@ -20,6 +20,7 @@ import time import resource import logging +import re import psutil import pandas as pd @@ -31,6 +32,29 @@ USE_WHOLEGRAPH = False GS_DEVICE = th.device('cpu') +def check_graph_name(graph_name): + """ Check whether the graph name is a valid graph name. + + We enforce that the graph name adheres to the Python + identifier naming rules as in + https://docs.python.org/3/reference/lexical_analysis.html#identifiers, + with the exception that hyphens (-) are permitted + and the name can start with numbers. + This helps avoid the cases when an invalid graph name, + such as `/graph`, causes unexpected errors. + + Parameter + --------- + graph_name: str + Graph Name. + """ + gname = re.sub(r'^\d+', '', graph_name) + assert gname.replace('-', '_').isidentifier(), \ + "GraphStorm expects the graph name adheres to the Python" \ + "identifier naming rules with the exception that hyphens " \ + "(-) are permitted and the name can start with numbers. " \ + f"Got: {graph_name}." + def get_graph_name(part_config): """ Get graph name from graph partition config file @@ -45,7 +69,9 @@ def get_graph_name(part_config): """ with open(part_config, "r", encoding='utf-8') as f: config = json.load(f) - return config["graph_name"] + graph_name = config["graph_name"] + check_graph_name(graph_name) + return graph_name def setup_device(local_rank): r"""Setup computation device. diff --git a/tests/unit-tests/test_gsf.py b/tests/unit-tests/test_gsf.py index 338f53128d..031017b886 100644 --- a/tests/unit-tests/test_gsf.py +++ b/tests/unit-tests/test_gsf.py @@ -12,10 +12,12 @@ Unit tests for gsf.py """ +import pytest from graphstorm.gsf import (create_builtin_node_decoder, create_builtin_edge_decoder, create_builtin_lp_decoder) +from graphstorm.utils import check_graph_name from graphstorm.config import (BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION, BUILTIN_TASK_EDGE_CLASSIFICATION, @@ -449,7 +451,33 @@ def test_create_builtin_lp_decoder(): assert decoder.gamma == 6. +def test_check_graph_name(): + graph_name = "a" + check_graph_name(graph_name) + graph_name = "graph_name" + check_graph_name(graph_name) + graph_name = "graph-name" + check_graph_name(graph_name) + graph_name = "123-graph-name" + check_graph_name(graph_name) + graph_name = "_Graph-name" + check_graph_name(graph_name) + + # test with invalid graph name + graph_name = "/graph_name" + with pytest.raises(AssertionError): + check_graph_name(graph_name) + + graph_name = "|graph_name" + with pytest.raises(AssertionError): + check_graph_name(graph_name) + + graph_name = "\graph_name" + with pytest.raises(AssertionError): + check_graph_name(graph_name) + if __name__ == '__main__': + test_check_graph_name() test_create_builtin_node_decoder() test_create_builtin_edge_decoder() test_create_builtin_lp_decoder() \ No newline at end of file