Skip to content

Commit

Permalink
Add a strict naming rule for graph name
Browse files Browse the repository at this point in the history
  • Loading branch information
Xiang Song committed Sep 26, 2024
1 parent ccc931b commit 937d1d6
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,27 @@ def parse_args() -> argparse.Namespace:

return parser.parse_args()

def check_graph_name(graph_name):
""" Check whether the graph name is a valid graph name.
We enforce that the graph name adheres to the Python
identifier naming rules as in
https://docs.python.org/3/reference/lexical_analysis.html#identifiers,
with the exception that hyphens (-) are permitted.
This helps avoid the cases when an invalid graph name,
such as `/graph`, causes unexpected errors.
Note: Same as graphstorm.utils.check_graph_name.
Parameter
---------
graph_name: str
Graph Name.
"""
assert graph_name.replace('-', '_').isidentifier(), \
"GraphStorm expects the graph name adheres to the Python" \
"identifier naming rules with the exception that hyphens " \
f"(-) are permitted. But we get {graph_name}"

def main():
"""Main entry point for GSProcessing"""
Expand All @@ -572,6 +593,7 @@ def main():
level=gsprocessing_args.log_level,
format="[GSPROCESSING] %(asctime)s %(levelname)-8s %(message)s",
)
check_graph_name(gsprocessing_args.graph_name)

# Determine execution environment
if os.path.exists("/opt/ml/config/processingjobconfig.json"):
Expand Down
3 changes: 2 additions & 1 deletion python/graphstorm/gconstruct/construct_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import dgl
from dgl.distributed.constants import DEFAULT_NTYPE, DEFAULT_ETYPE

from ..utils import sys_tracker, get_log_level
from ..utils import sys_tracker, get_log_level, check_graph_name
from .file_io import parse_node_file_format, parse_edge_file_format
from .file_io import get_in_files
from .transform import parse_feat_ops, process_features, preprocess_features
Expand Down Expand Up @@ -742,6 +742,7 @@ def print_graph_info(g, node_data, edge_data, node_label_stats, edge_label_stats
def process_graph(args):
""" Process the graph.
"""
check_graph_name(args.graph_name)
logging.basicConfig(level=get_log_level(args.logging_level))
with open(args.conf_file, 'r', encoding="utf8") as json_file:
process_confs = json.load(json_file)
Expand Down
24 changes: 23 additions & 1 deletion python/graphstorm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,26 @@
USE_WHOLEGRAPH = False
GS_DEVICE = th.device('cpu')

def check_graph_name(graph_name):
""" Check whether the graph name is a valid graph name.
We enforce that the graph name adheres to the Python
identifier naming rules as in
https://docs.python.org/3/reference/lexical_analysis.html#identifiers,
with the exception that hyphens (-) are permitted.
This helps avoid the cases when an invalid graph name,
such as `/graph`, causes unexpected errors.
Parameter
---------
graph_name: str
Graph Name.
"""
assert graph_name.replace('-', '_').isidentifier(), \
"GraphStorm expects the graph name adheres to the Python" \
"identifier naming rules with the exception that hyphens " \
f"(-) are permitted. But we get {graph_name}"

def get_graph_name(part_config):
""" Get graph name from graph partition config file
Expand All @@ -45,7 +65,9 @@ def get_graph_name(part_config):
"""
with open(part_config, "r", encoding='utf-8') as f:
config = json.load(f)
return config["graph_name"]
graph_name = config["graph_name"]
check_graph_name(graph_name)
return graph_name

def setup_device(local_rank):
r"""Setup computation device.
Expand Down
35 changes: 35 additions & 0 deletions tests/unit-tests/test_gsf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from graphstorm.gsf import (create_builtin_node_decoder,
create_builtin_edge_decoder,
create_builtin_lp_decoder)
from graphstorm.utils import check_graph_name
from graphstorm.config import (BUILTIN_TASK_NODE_CLASSIFICATION,
BUILTIN_TASK_NODE_REGRESSION,
BUILTIN_TASK_EDGE_CLASSIFICATION,
Expand Down Expand Up @@ -449,7 +450,41 @@ def test_create_builtin_lp_decoder():
assert decoder.gamma == 6.


def test_check_graph_name():
graph_name = "a"
check_graph_name(graph_name)
graph_name = "graph_name"
check_graph_name(graph_name)
graph_name = "graph-name"
check_graph_name(graph_name)

# test with invalid graph name
graph_name = "/graph_name"
invalid_name = False
try:
check_graph_name(graph_name)
except:
invalid_name = True
assert invalid_name

graph_name = "|graph_name"
invalid_name = False
try:
check_graph_name(graph_name)
except:
invalid_name = True
assert invalid_name

graph_name = "\graph_name"
invalid_name = False
try:
check_graph_name(graph_name)
except:
invalid_name = True
assert invalid_name

if __name__ == '__main__':
test_check_graph_name()
test_create_builtin_node_decoder()
test_create_builtin_edge_decoder()
test_create_builtin_lp_decoder()

0 comments on commit 937d1d6

Please sign in to comment.