Skip to content

Commit

Permalink
hard negative for gspartition
Browse files Browse the repository at this point in the history
  • Loading branch information
jalencato committed Nov 1, 2024
1 parent 8b702e5 commit a50b657
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 1 deletion.
1 change: 1 addition & 0 deletions python/graphstorm/gpartition/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@
from .metis_partition import (ParMetisPartitionAlgorithm)
from .partition_config import (ParMETISConfig)
from .partition_algo_base import LocalPartitionAlgorithm
from .post_hard_negative import shuffle_hard_negative_nids
22 changes: 21 additions & 1 deletion python/graphstorm/gpartition/dist_partition_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
ParMetisPartitionAlgorithm,
ParMETISConfig,
RandomPartitionAlgorithm,
shuffle_hard_negative_nids,
)
from graphstorm.utils import get_log_level

Expand Down Expand Up @@ -189,12 +190,31 @@ def main():
dirs_exist_ok=True,
)

# Hard Negative Mapping
if args.gsprocessing_config:
gsprocessing_config = args.gsprocessing_config
shuffle_hard_negative_nids(f"{args.input_path}/{gsprocessing_config}", args.output_path)
else:
for filename in os.listdir(args.input_path):
if filename.endswith("_with_transformations.json"):
gsprocessing_config = filename
shuffle_hard_negative_nids(f"{args.input_path}/{gsprocessing_config}",
args.num_parts, args.output_path)
break
else:
# Did not raise error here for not introducing the break change,
# but will raise warning here to warn customers.
logging.info("Skip the hard negative node ID mapping, "
"upgrade the latest GSProcessing to solve the warning here.")

def parse_args() -> argparse.Namespace:
"""Parses arguments for the script"""
argparser = argparse.ArgumentParser("Partition DGL graphs for node and edge classification "
+ "or regression tasks")
argparser.add_argument("--input-path", type=str, required=True,
help="Path to input DGL chunked data.")
argparser.add_argument("--gsprocessing-config", type=str,
help="Path to the input GSProcessing config data.")
argparser.add_argument("--metadata-filename", type=str, default="metadata.json",
help="Name for the chunked DGL data metadata file.")
argparser.add_argument("--output-path", type=str, required=True,
Expand Down Expand Up @@ -224,4 +244,4 @@ def parse_args() -> argparse.Namespace:


if __name__ == '__main__':
main()
main()
73 changes: 73 additions & 0 deletions python/graphstorm/gpartition/post_hard_negative.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import json
import os

import torch as th
from dgl.data.utils import load_tensors, save_tensors
from graphstorm.model.utils import load_dist_nid_map

def load_hard_negative_config(gsprocessing_config):
with open(gsprocessing_config, 'r') as file:
config = json.load(file)

# Hard Negative only supports link prediction
edges_config = config['graph']['edges']
mapping_edge_list = []
for single_edge_config in edges_config:
if "features" not in single_edge_config:
continue
feature_dict = single_edge_config["features"]
for single_feature in feature_dict:
if single_feature["transformation"]["name"] \
== "edge_dst_hard_negative":
edge_type = ":".join([single_edge_config["source"]["type"],
single_edge_config["relation"]["type"],
single_edge_config["dest"]["type"]])
hard_neg_feat_name = single_feature['name']
mapping_edge_list.append({"dst_node_type": single_edge_config["dest"]["type"],
"edge_type": edge_type,
"hard_neg_feat_name": hard_neg_feat_name})
return mapping_edge_list


def shuffle_hard_negative_nids(gsprocessing_config, num_parts, output_path):
shuffled_edge_config = load_hard_negative_config(gsprocessing_config)

node_type_list = []
for single_shuffled_edge_config in shuffled_edge_config:
node_type = single_shuffled_edge_config["dst_node_type"]
node_type_list.append(node_type)
node_mapping = load_dist_nid_map(f"{output_path}/dist_graph", node_type_list)
gnid2pnid_mapping = {}

def get_gnid2pnid_map(ntype):
if ntype in gnid2pnid_mapping:
return gnid2pnid_mapping[ntype]
else:
pnid2gnid_map = node_mapping[ntype]
gnid2pnid_map = th.argsort(pnid2gnid_map)
gnid2pnid_mapping[ntype] = gnid2pnid_map
# del ntype in node_mapping to save memory
del node_mapping[ntype]
return gnid2pnid_mapping[ntype]

# iterate all the partitions to convert hard negative node ids.
for i in range(num_parts):
part_path = os.path.join(f"{output_path}/dist_graph", f"part{i}")
edge_feat_path = os.path.join(part_path, "edge_feat.dgl")

# load edge features first
edge_feats = load_tensors(edge_feat_path)
for single_shuffled_edge_config in shuffled_edge_config:
etype = single_shuffled_edge_config["edge_type"]
neg_feat = single_shuffled_edge_config["hard_neg_feat_name"]
neg_ntype = single_shuffled_edge_config["dst_node_type"]
efeat_name = f"{etype}/{neg_feat}"
hard_nids = edge_feats[efeat_name].long()
hard_nid_idx = hard_nids > -1
gnid2pnid_map = get_gnid2pnid_map(neg_ntype)
hard_nids[hard_nid_idx] = gnid2pnid_map[hard_nids[hard_nid_idx]]

# replace the edge_feat.dgl with the updated one.
os.remove(edge_feat_path)
save_tensors(edge_feat_path, edge_feats)

3 changes: 3 additions & 0 deletions python/graphstorm/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,9 @@ def _exchange_node_id_mapping(rank, world_size, device,
# move mapping into CPU
return gather_list[0].to(th.device("cpu"))

def load_dist_nid_map(node_id_mapping_file, ntypes):
return _load_dist_nid_map(node_id_mapping_file, ntypes)

def _load_dist_nid_map(node_id_mapping_file, ntypes):
""" Load id mapping files in dist partition format.
"""
Expand Down

0 comments on commit a50b657

Please sign in to comment.