Skip to content

Commit

Permalink
Merge branch 'awslabs:main' into llmgnn
Browse files Browse the repository at this point in the history
  • Loading branch information
GentleZhu authored Dec 6, 2023
2 parents 0b19c67 + 8380f8c commit be4e42a
Showing 1 changed file with 44 additions and 4 deletions.
48 changes: 44 additions & 4 deletions tools/partition_graph_lp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import os
import dgl
import numpy as np
import torch as th
import argparse
import time
Expand All @@ -44,6 +45,11 @@
help='The pct of train nodes/edges. Should be > 0 and < 1.')
argparser.add_argument('--val-pct', type=float, default=0.1,
help='The pct of validation nodes/edges. Should be > 0 and < 1.')
argparser.add_argument('--inductive-split', action='store_true',
help='split links for inductive settings: no overlapping nodes across '
+ 'splits.')
argparser.add_argument('--seed', type=int, default=42,
help='random seed for splitting links')
# graph modification arguments
argparser.add_argument('--add-reverse-edges', action='store_true',
help='turn the graph into an undirected graph.')
Expand Down Expand Up @@ -81,6 +87,7 @@
args = argparser.parse_args()
print(args)
start = time.time()
np.random.seed(args.seed)

constructed_graph = False

Expand Down Expand Up @@ -149,15 +156,48 @@
target_etypes = [target_etypes]

if constructed_graph:
d_shuffled_nids = {} # to store shuffled nids by ntype to avoid different orders for the same ntype
for target_e in target_etypes:
num_edges = g.num_edges(target_e)
g.edges[target_e].data['train_mask'] = th.full((num_edges,), False, dtype=th.bool)
g.edges[target_e].data['val_mask'] = th.full((num_edges,), False, dtype=th.bool)
g.edges[target_e].data['test_mask'] = th.full((num_edges,), False, dtype=th.bool)
g.edges[target_e].data['train_mask'][: int(num_edges * args.train_pct)] = True
g.edges[target_e].data['val_mask'][int(num_edges * args.train_pct): \
int(num_edges * (args.train_pct + args.val_pct))] = True
g.edges[target_e].data['test_mask'][int(num_edges * (args.train_pct + args.val_pct)): ] = True
if not args.inductive_split:
# Randomly split links
g.edges[target_e].data['train_mask'][: int(num_edges * args.train_pct)] = True
g.edges[target_e].data['val_mask'][int(num_edges * args.train_pct): \
int(num_edges * (args.train_pct + args.val_pct))] = True
g.edges[target_e].data['test_mask'][int(num_edges * (args.train_pct + args.val_pct)): ] = True
else:
# Inductive split for link prediction
# 1. split the head nodes u into three disjoint sets (train/val/test)
# such that model will be evaluted to predict links for unseen nodes
utype, _, vtype = target_e
num_nodes = g.number_of_nodes(utype)
shuffled_index = d_shuffled_nids.get(utype,
np.random.permutation(np.arange(num_nodes)))
if utype not in d_shuffled_nids:
d_shuffled_nids[utype] = shuffled_index
train_u = shuffled_index[: int(num_nodes * args.train_pct)]
val_u = shuffled_index[int(num_nodes * args.train_pct): \
int(num_nodes * (args.train_pct + args.val_pct))]
test_u = shuffled_index[int(num_nodes * (args.train_pct + args.val_pct)): ]
# 2. find all out-edges for the 3 sets of head nodes:
_, train_v, train_eids = g.out_edges(train_u, form='all', etype=target_e)
_, val_v, val_eids = g.out_edges(val_u, form='all', etype=target_e)
_, test_v, test_eids = g.out_edges(test_u, form='all', etype=target_e)
if utype == vtype:
# we remove edges with tail nodes outside of the training set
# this isn't necessary if head and tail are different types
train_eids = train_eids[np.in1d(train_v, train_u)]
# remove overlaps between val and test
val_eids = val_eids[~np.in1d(val_v, test_u)]
test_eids = test_eids[~np.in1d(test_v, val_u)]
# 3. build boolean edge masks: the edge mask prevents message-passing
# flow graphs from fetching edges outside of the splits
g.edges[target_e].data['train_mask'][train_eids] = True
g.edges[target_e].data['val_mask'][val_eids] = True
g.edges[target_e].data['test_mask'][test_eids] = True

print(f'load {args.dataset} takes {time.time() - start:.3f} seconds')
print(f'\n|V|={g.number_of_nodes()}, |E|={g.number_of_edges()}\n')
Expand Down

0 comments on commit be4e42a

Please sign in to comment.