From a50b657d0cc707f3c18424affa788605946fbbc5 Mon Sep 17 00:00:00 2001 From: JalenCato Date: Fri, 1 Nov 2024 00:06:58 +0000 Subject: [PATCH] hard negative for gspartition --- python/graphstorm/gpartition/__init__.py | 1 + .../gpartition/dist_partition_graph.py | 22 +++++- .../gpartition/post_hard_negative.py | 73 +++++++++++++++++++ python/graphstorm/model/utils.py | 3 + 4 files changed, 98 insertions(+), 1 deletion(-) create mode 100644 python/graphstorm/gpartition/post_hard_negative.py diff --git a/python/graphstorm/gpartition/__init__.py b/python/graphstorm/gpartition/__init__.py index c7957002c2..818d6ca79a 100644 --- a/python/graphstorm/gpartition/__init__.py +++ b/python/graphstorm/gpartition/__init__.py @@ -19,3 +19,4 @@ from .metis_partition import (ParMetisPartitionAlgorithm) from .partition_config import (ParMETISConfig) from .partition_algo_base import LocalPartitionAlgorithm +from .post_hard_negative import shuffle_hard_negative_nids \ No newline at end of file diff --git a/python/graphstorm/gpartition/dist_partition_graph.py b/python/graphstorm/gpartition/dist_partition_graph.py index da50ce8ca6..8f26def57f 100644 --- a/python/graphstorm/gpartition/dist_partition_graph.py +++ b/python/graphstorm/gpartition/dist_partition_graph.py @@ -38,6 +38,7 @@ ParMetisPartitionAlgorithm, ParMETISConfig, RandomPartitionAlgorithm, + shuffle_hard_negative_nids, ) from graphstorm.utils import get_log_level @@ -189,12 +190,31 @@ def main(): dirs_exist_ok=True, ) + # Hard Negative Mapping + if args.gsprocessing_config: + gsprocessing_config = args.gsprocessing_config + shuffle_hard_negative_nids(f"{args.input_path}/{gsprocessing_config}", args.output_path) + else: + for filename in os.listdir(args.input_path): + if filename.endswith("_with_transformations.json"): + gsprocessing_config = filename + shuffle_hard_negative_nids(f"{args.input_path}/{gsprocessing_config}", + args.num_parts, args.output_path) + break + else: + # Did not raise error here for not introducing the break change, + # but will raise warning here to warn customers. + logging.info("Skip the hard negative node ID mapping, " + "upgrade the latest GSProcessing to solve the warning here.") + def parse_args() -> argparse.Namespace: """Parses arguments for the script""" argparser = argparse.ArgumentParser("Partition DGL graphs for node and edge classification " + "or regression tasks") argparser.add_argument("--input-path", type=str, required=True, help="Path to input DGL chunked data.") + argparser.add_argument("--gsprocessing-config", type=str, + help="Path to the input GSProcessing config data.") argparser.add_argument("--metadata-filename", type=str, default="metadata.json", help="Name for the chunked DGL data metadata file.") argparser.add_argument("--output-path", type=str, required=True, @@ -224,4 +244,4 @@ def parse_args() -> argparse.Namespace: if __name__ == '__main__': - main() + main() \ No newline at end of file diff --git a/python/graphstorm/gpartition/post_hard_negative.py b/python/graphstorm/gpartition/post_hard_negative.py new file mode 100644 index 0000000000..32f4adf634 --- /dev/null +++ b/python/graphstorm/gpartition/post_hard_negative.py @@ -0,0 +1,73 @@ +import json +import os + +import torch as th +from dgl.data.utils import load_tensors, save_tensors +from graphstorm.model.utils import load_dist_nid_map + +def load_hard_negative_config(gsprocessing_config): + with open(gsprocessing_config, 'r') as file: + config = json.load(file) + + # Hard Negative only supports link prediction + edges_config = config['graph']['edges'] + mapping_edge_list = [] + for single_edge_config in edges_config: + if "features" not in single_edge_config: + continue + feature_dict = single_edge_config["features"] + for single_feature in feature_dict: + if single_feature["transformation"]["name"] \ + == "edge_dst_hard_negative": + edge_type = ":".join([single_edge_config["source"]["type"], + single_edge_config["relation"]["type"], + single_edge_config["dest"]["type"]]) + hard_neg_feat_name = single_feature['name'] + mapping_edge_list.append({"dst_node_type": single_edge_config["dest"]["type"], + "edge_type": edge_type, + "hard_neg_feat_name": hard_neg_feat_name}) + return mapping_edge_list + + +def shuffle_hard_negative_nids(gsprocessing_config, num_parts, output_path): + shuffled_edge_config = load_hard_negative_config(gsprocessing_config) + + node_type_list = [] + for single_shuffled_edge_config in shuffled_edge_config: + node_type = single_shuffled_edge_config["dst_node_type"] + node_type_list.append(node_type) + node_mapping = load_dist_nid_map(f"{output_path}/dist_graph", node_type_list) + gnid2pnid_mapping = {} + + def get_gnid2pnid_map(ntype): + if ntype in gnid2pnid_mapping: + return gnid2pnid_mapping[ntype] + else: + pnid2gnid_map = node_mapping[ntype] + gnid2pnid_map = th.argsort(pnid2gnid_map) + gnid2pnid_mapping[ntype] = gnid2pnid_map + # del ntype in node_mapping to save memory + del node_mapping[ntype] + return gnid2pnid_mapping[ntype] + + # iterate all the partitions to convert hard negative node ids. + for i in range(num_parts): + part_path = os.path.join(f"{output_path}/dist_graph", f"part{i}") + edge_feat_path = os.path.join(part_path, "edge_feat.dgl") + + # load edge features first + edge_feats = load_tensors(edge_feat_path) + for single_shuffled_edge_config in shuffled_edge_config: + etype = single_shuffled_edge_config["edge_type"] + neg_feat = single_shuffled_edge_config["hard_neg_feat_name"] + neg_ntype = single_shuffled_edge_config["dst_node_type"] + efeat_name = f"{etype}/{neg_feat}" + hard_nids = edge_feats[efeat_name].long() + hard_nid_idx = hard_nids > -1 + gnid2pnid_map = get_gnid2pnid_map(neg_ntype) + hard_nids[hard_nid_idx] = gnid2pnid_map[hard_nids[hard_nid_idx]] + + # replace the edge_feat.dgl with the updated one. + os.remove(edge_feat_path) + save_tensors(edge_feat_path, edge_feats) + diff --git a/python/graphstorm/model/utils.py b/python/graphstorm/model/utils.py index 0969c95f5d..49fe896787 100644 --- a/python/graphstorm/model/utils.py +++ b/python/graphstorm/model/utils.py @@ -393,6 +393,9 @@ def _exchange_node_id_mapping(rank, world_size, device, # move mapping into CPU return gather_list[0].to(th.device("cpu")) +def load_dist_nid_map(node_id_mapping_file, ntypes): + return _load_dist_nid_map(node_id_mapping_file, ntypes) + def _load_dist_nid_map(node_id_mapping_file, ntypes): """ Load id mapping files in dist partition format. """