Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
Xiang Song committed Jun 9, 2024
1 parent c62c3e3 commit 1b6d9f6
Showing 1 changed file with 108 additions and 39 deletions.
147 changes: 108 additions & 39 deletions python/graphstorm/gconstruct/remap_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,20 +655,44 @@ def _parse_gs_config(config):
node_id_mapping = os.path.join(os.path.dirname(part_config), "raw_id_mappings")
predict_dir = config.save_prediction_path
emb_dir = config.save_embed_path
task_type = config.task_type

pred_ntypes = []
pred_etypes = []
if task_type in (BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION):
pred_etypes = config.target_etype
pred_etypes = pred_etypes \
if isinstance(pred_etypes, list) else [pred_etypes]
pred_etypes = [list(pred_etype) for pred_etype in pred_etypes]
elif task_type in (BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION):
pred_ntypes = config.target_ntype
pred_ntypes = pred_ntypes \
if isinstance(pred_ntypes, list) else [pred_ntypes]

return node_id_mapping, predict_dir, emb_dir, pred_ntypes, pred_etypes
if config.multi_tasks is not None:
node_predict_dirs = []
edge_predict_dirs = []
# multi-task setting
tasks = config.multi_tasks
for task in tasks:
task_config = task.task_config
task_id = task.task_id
pred_path = os.path.join(predict_dir, task_id)
if task.task_type in [BUILTIN_TASK_NODE_CLASSIFICATION,
BUILTIN_TASK_NODE_REGRESSION]:
pred_ntype = task_config.target_ntype
pred_ntypes.append(pred_ntype)
node_predict_dirs.append(pred_path)
elif task_type in (BUILTIN_TASK_EDGE_CLASSIFICATION,
BUILTIN_TASK_EDGE_REGRESSION):
pred_etype = config.target_etype
pred_etypes.append(pred_etype)
edge_predict_dirs.append(pred_path)

predict_dir = (node_predict_dirs, edge_predict_dirs)
return node_id_mapping, predict_dir, emb_dir, pred_ntypes, pred_etypes
else:
task_type = config.task_type
if task_type in (BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION):
pred_etypes = config.target_etype
pred_etypes = pred_etypes \
if isinstance(pred_etypes, list) else [pred_etypes]
pred_etypes = [list(pred_etype) for pred_etype in pred_etypes]
elif task_type in (BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION):
pred_ntypes = config.target_ntype
pred_ntypes = pred_ntypes \
if isinstance(pred_ntypes, list) else [pred_ntypes]

return node_id_mapping, predict_dir, emb_dir, pred_ntypes, pred_etypes

def main(args, gs_config_args):
""" main function
Expand Down Expand Up @@ -755,7 +779,9 @@ def main(args, gs_config_args):
"Skip remapping node embeddings.")

################## remap prediction #############
if predict_dir is not None:
if predict_dir is not None and isinstance(predict_dir, str):
# predict_dir is a string
# There is only one prediction task.
assert os.path.exists(predict_dir), \
f"Prediction dir {predict_dir} does not exist."
# if pred_etypes (edges with prediction results)
Expand Down Expand Up @@ -816,6 +842,33 @@ def main(args, gs_config_args):
if "etypes" in info else []
if len(pred_ntypes) == 0:
pred_ntypes = info["ntypes"] if "ntypes" in info else []
elif predict_dir is not None and isinstance(predict_dir, tuple):
# This is multi-task learning.
# we only get predict_dir with type list
# from yaml config
node_predict_dirs, edge_predict_dirs = predict_dir

if len(node_predict_dirs) == 0 and \
len(edge_predict_dirs) == 0:
logging.info("Prediction results are empty."
"Skip remapping prediction result.")
pred_etypes = []
pred_ntypes = []
else:
# check the prediciton result paths
for pred_dir, pred_ntype in zip(node_predict_dirs, pred_ntypes):
assert os.path.exists(pred_dir), \
f"Prediction dir {pred_dir} does not exist."
assert os.path.exists(os.path.join(pred_dir, pred_ntype)), \
f"Prediction dir {os.path.join(pred_dir, pred_ntype)}" \
f"for {pred_ntype} does not exist."
for pred_dir, pred_etype in zip(edge_predict_dirs, pred_etypes):
assert os.path.exists(pred_dir), \
f"Prediction dir {pred_dir} does not exist."
assert os.path.exists(os.path.join(pred_dir, "_".join(pred_etype))), \
f"Prediction dir {os.path.join(pred_dir, "_".join(pred_etype))}" \
f"for {pred_etype} does not exist."

else:
pred_etypes = []
pred_ntypes = []
Expand Down Expand Up @@ -890,34 +943,50 @@ def main(args, gs_config_args):
files_to_remove += emb_files_to_remove

if len(pred_etypes) > 0:
pred_output = predict_dir
# We need to do ID remapping for edge prediction result
pred_files_to_remove = \
remap_edge_pred(pred_etypes,
predict_dir,
pred_output,
out_chunk_size,
num_proc,
rank,
world_size,
with_shared_fs,
output_func)
files_to_remove += pred_files_to_remove
if isinstance(predict_dir, tuple):
_, edge_predict_dirs = predict_dir
edge_pred_etypes = pred_etypes
else:
edge_predict_dirs = [predict_dir]
edge_pred_etypes = [pred_etypes]

for pred_dir, pred_et in zip(edge_predict_dirs, edge_pred_etypes):
pred_output = pred_dir
# We need to do ID remapping for edge prediction result
pred_files_to_remove = \
remap_edge_pred(pred_et,
pred_dir,
pred_output,
out_chunk_size,
num_proc,
rank,
world_size,
with_shared_fs,
output_func)
files_to_remove += pred_files_to_remove

if len(pred_ntypes) > 0:
pred_output = predict_dir
# We need to do ID remapping for node prediction result
pred_files_to_remove = \
remap_node_pred(pred_ntypes,
predict_dir,
pred_output,
out_chunk_size,
num_proc,
rank,
world_size,
with_shared_fs,
output_func)
files_to_remove += pred_files_to_remove
if isinstance(predict_dir, tuple):
node_predict_dirs, _ = predict_dir
node_pred_ntypes = pred_ntypes
else:
node_predict_dirs = [predict_dir]
node_pred_ntypes = [pred_ntypes]

for pred_dir, pred_nt in zip(node_predict_dirs, node_pred_ntypes):
pred_output = pred_dir
# We need to do ID remapping for node prediction result
pred_files_to_remove = \
remap_node_pred(pred_nt,
pred_dir,
pred_output,
out_chunk_size,
num_proc,
rank,
world_size,
with_shared_fs,
output_func)
files_to_remove += pred_files_to_remove

if args.preserve_input is False and len(files_to_remove) > 0:
# If files_to_remove is not empty, at least node_emb_dir or
Expand Down

0 comments on commit 1b6d9f6

Please sign in to comment.