diff --git a/python/graphstorm/gconstruct/remap_result.py b/python/graphstorm/gconstruct/remap_result.py index 30fec2b3aa..35ae3e3127 100644 --- a/python/graphstorm/gconstruct/remap_result.py +++ b/python/graphstorm/gconstruct/remap_result.py @@ -655,20 +655,44 @@ def _parse_gs_config(config): node_id_mapping = os.path.join(os.path.dirname(part_config), "raw_id_mappings") predict_dir = config.save_prediction_path emb_dir = config.save_embed_path - task_type = config.task_type + pred_ntypes = [] pred_etypes = [] - if task_type in (BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION): - pred_etypes = config.target_etype - pred_etypes = pred_etypes \ - if isinstance(pred_etypes, list) else [pred_etypes] - pred_etypes = [list(pred_etype) for pred_etype in pred_etypes] - elif task_type in (BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION): - pred_ntypes = config.target_ntype - pred_ntypes = pred_ntypes \ - if isinstance(pred_ntypes, list) else [pred_ntypes] - - return node_id_mapping, predict_dir, emb_dir, pred_ntypes, pred_etypes + if config.multi_tasks is not None: + node_predict_dirs = [] + edge_predict_dirs = [] + # multi-task setting + tasks = config.multi_tasks + for task in tasks: + task_config = task.task_config + task_id = task.task_id + pred_path = os.path.join(predict_dir, task_id) + if task.task_type in [BUILTIN_TASK_NODE_CLASSIFICATION, + BUILTIN_TASK_NODE_REGRESSION]: + pred_ntype = task_config.target_ntype + pred_ntypes.append(pred_ntype) + node_predict_dirs.append(pred_path) + elif task_type in (BUILTIN_TASK_EDGE_CLASSIFICATION, + BUILTIN_TASK_EDGE_REGRESSION): + pred_etype = config.target_etype + pred_etypes.append(pred_etype) + edge_predict_dirs.append(pred_path) + + predict_dir = (node_predict_dirs, edge_predict_dirs) + return node_id_mapping, predict_dir, emb_dir, pred_ntypes, pred_etypes + else: + task_type = config.task_type + if task_type in (BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION): + pred_etypes = config.target_etype + pred_etypes = pred_etypes \ + if isinstance(pred_etypes, list) else [pred_etypes] + pred_etypes = [list(pred_etype) for pred_etype in pred_etypes] + elif task_type in (BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION): + pred_ntypes = config.target_ntype + pred_ntypes = pred_ntypes \ + if isinstance(pred_ntypes, list) else [pred_ntypes] + + return node_id_mapping, predict_dir, emb_dir, pred_ntypes, pred_etypes def main(args, gs_config_args): """ main function @@ -755,7 +779,9 @@ def main(args, gs_config_args): "Skip remapping node embeddings.") ################## remap prediction ############# - if predict_dir is not None: + if predict_dir is not None and isinstance(predict_dir, str): + # predict_dir is a string + # There is only one prediction task. assert os.path.exists(predict_dir), \ f"Prediction dir {predict_dir} does not exist." # if pred_etypes (edges with prediction results) @@ -816,6 +842,33 @@ def main(args, gs_config_args): if "etypes" in info else [] if len(pred_ntypes) == 0: pred_ntypes = info["ntypes"] if "ntypes" in info else [] + elif predict_dir is not None and isinstance(predict_dir, tuple): + # This is multi-task learning. + # we only get predict_dir with type list + # from yaml config + node_predict_dirs, edge_predict_dirs = predict_dir + + if len(node_predict_dirs) == 0 and \ + len(edge_predict_dirs) == 0: + logging.info("Prediction results are empty." + "Skip remapping prediction result.") + pred_etypes = [] + pred_ntypes = [] + else: + # check the prediciton result paths + for pred_dir, pred_ntype in zip(node_predict_dirs, pred_ntypes): + assert os.path.exists(pred_dir), \ + f"Prediction dir {pred_dir} does not exist." + assert os.path.exists(os.path.join(pred_dir, pred_ntype)), \ + f"Prediction dir {os.path.join(pred_dir, pred_ntype)}" \ + f"for {pred_ntype} does not exist." + for pred_dir, pred_etype in zip(edge_predict_dirs, pred_etypes): + assert os.path.exists(pred_dir), \ + f"Prediction dir {pred_dir} does not exist." + assert os.path.exists(os.path.join(pred_dir, "_".join(pred_etype))), \ + f"Prediction dir {os.path.join(pred_dir, "_".join(pred_etype))}" \ + f"for {pred_etype} does not exist." + else: pred_etypes = [] pred_ntypes = [] @@ -890,34 +943,50 @@ def main(args, gs_config_args): files_to_remove += emb_files_to_remove if len(pred_etypes) > 0: - pred_output = predict_dir - # We need to do ID remapping for edge prediction result - pred_files_to_remove = \ - remap_edge_pred(pred_etypes, - predict_dir, - pred_output, - out_chunk_size, - num_proc, - rank, - world_size, - with_shared_fs, - output_func) - files_to_remove += pred_files_to_remove + if isinstance(predict_dir, tuple): + _, edge_predict_dirs = predict_dir + edge_pred_etypes = pred_etypes + else: + edge_predict_dirs = [predict_dir] + edge_pred_etypes = [pred_etypes] + + for pred_dir, pred_et in zip(edge_predict_dirs, edge_pred_etypes): + pred_output = pred_dir + # We need to do ID remapping for edge prediction result + pred_files_to_remove = \ + remap_edge_pred(pred_et, + pred_dir, + pred_output, + out_chunk_size, + num_proc, + rank, + world_size, + with_shared_fs, + output_func) + files_to_remove += pred_files_to_remove if len(pred_ntypes) > 0: - pred_output = predict_dir - # We need to do ID remapping for node prediction result - pred_files_to_remove = \ - remap_node_pred(pred_ntypes, - predict_dir, - pred_output, - out_chunk_size, - num_proc, - rank, - world_size, - with_shared_fs, - output_func) - files_to_remove += pred_files_to_remove + if isinstance(predict_dir, tuple): + node_predict_dirs, _ = predict_dir + node_pred_ntypes = pred_ntypes + else: + node_predict_dirs = [predict_dir] + node_pred_ntypes = [pred_ntypes] + + for pred_dir, pred_nt in zip(node_predict_dirs, node_pred_ntypes): + pred_output = pred_dir + # We need to do ID remapping for node prediction result + pred_files_to_remove = \ + remap_node_pred(pred_nt, + pred_dir, + pred_output, + out_chunk_size, + num_proc, + rank, + world_size, + with_shared_fs, + output_func) + files_to_remove += pred_files_to_remove if args.preserve_input is False and len(files_to_remove) > 0: # If files_to_remove is not empty, at least node_emb_dir or