Skip to content

Commit

Permalink
add optimization for gconstruct
Browse files Browse the repository at this point in the history
  • Loading branch information
jalencato committed Dec 11, 2023
1 parent 530b2de commit ea67d08
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions python/graphstorm/gconstruct/construct_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ def process_edge_data(process_confs, node_id_map, arr_merger,

return (edges, edge_data, label_stats)

def verify_confs(confs):
def verify_confs(confs, args=None):
""" Verify the configuration of the input data.
"""
if "version" not in confs:
Expand All @@ -591,6 +591,16 @@ def verify_confs(confs):
"The config file does not have a 'version' entry. Assuming gconstruct-v0.1")
ntypes = {conf['node_type'] for conf in confs["nodes"]}
etypes = [conf['relation'] for conf in confs["edges"]]
# Adjust input to DGL requirement if it is a honogeneous graph
if len(ntypes) == 1 and len(etypes) == 1 and not args.add_reverse_edges:
assert etypes[0][0] in ntypes, \
f"source node type {etypes[0][0]} does not exist. Please check your input data."
assert etypes[0][2] in ntypes, \
f"dest node type {etypes[0][2]} does not exist. Please check your input data."
logging.warning("Generated Graph is a homogeneous graph, so the node type will be "
"changed to _N and edge type should be changed to [_N, _E, _N]")
confs['nodes'][0]['node_type'] = "_N"
confs['edges'][0]['relation'] = ["_N", "_E", "_N"]
for etype in etypes:
assert len(etype) == 3, \
"The edge type must be (source node type, relation type, dest node type)."
Expand Down Expand Up @@ -668,7 +678,7 @@ def process_graph(args):
if args.num_processes_for_nodes is not None else args.num_processes
num_processes_for_edges = args.num_processes_for_edges \
if args.num_processes_for_edges is not None else args.num_processes
verify_confs(process_confs)
verify_confs(process_confs, args)
output_format = args.output_format
for out_format in output_format:
assert out_format in ["DGL", "DistDGL"], \
Expand Down

0 comments on commit ea67d08

Please sign in to comment.