From ea67d08cde4060ba53eb31281a943b23029e83dd Mon Sep 17 00:00:00 2001 From: JalenCato Date: Mon, 11 Dec 2023 20:16:29 +0000 Subject: [PATCH] add optimization for gconstruct --- python/graphstorm/gconstruct/construct_graph.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/python/graphstorm/gconstruct/construct_graph.py b/python/graphstorm/gconstruct/construct_graph.py index 0065da5403..bf3cb6bcb6 100644 --- a/python/graphstorm/gconstruct/construct_graph.py +++ b/python/graphstorm/gconstruct/construct_graph.py @@ -582,7 +582,7 @@ def process_edge_data(process_confs, node_id_map, arr_merger, return (edges, edge_data, label_stats) -def verify_confs(confs): +def verify_confs(confs, args=None): """ Verify the configuration of the input data. """ if "version" not in confs: @@ -591,6 +591,16 @@ def verify_confs(confs): "The config file does not have a 'version' entry. Assuming gconstruct-v0.1") ntypes = {conf['node_type'] for conf in confs["nodes"]} etypes = [conf['relation'] for conf in confs["edges"]] + # Adjust input to DGL requirement if it is a honogeneous graph + if len(ntypes) == 1 and len(etypes) == 1 and not args.add_reverse_edges: + assert etypes[0][0] in ntypes, \ + f"source node type {etypes[0][0]} does not exist. Please check your input data." + assert etypes[0][2] in ntypes, \ + f"dest node type {etypes[0][2]} does not exist. Please check your input data." + logging.warning("Generated Graph is a homogeneous graph, so the node type will be " + "changed to _N and edge type should be changed to [_N, _E, _N]") + confs['nodes'][0]['node_type'] = "_N" + confs['edges'][0]['relation'] = ["_N", "_E", "_N"] for etype in etypes: assert len(etype) == 3, \ "The edge type must be (source node type, relation type, dest node type)." @@ -668,7 +678,7 @@ def process_graph(args): if args.num_processes_for_nodes is not None else args.num_processes num_processes_for_edges = args.num_processes_for_edges \ if args.num_processes_for_edges is not None else args.num_processes - verify_confs(process_confs) + verify_confs(process_confs, args) output_format = args.output_format for out_format in output_format: assert out_format in ["DGL", "DistDGL"], \