Skip to content

Commit

Permalink
Implemented inductive splits for link prediction
Browse files Browse the repository at this point in the history
  • Loading branch information
wangz10 committed Nov 15, 2023
1 parent a708401 commit 77bc16c
Showing 1 changed file with 28 additions and 4 deletions.
32 changes: 28 additions & 4 deletions tools/partition_graph_lp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import os
import dgl
import numpy as np
import torch as th
import argparse
import time
Expand All @@ -44,6 +45,9 @@
help='The pct of train nodes/edges. Should be > 0 and < 1.')
argparser.add_argument('--val-pct', type=float, default=0.1,
help='The pct of validation nodes/edges. Should be > 0 and < 1.')
argparser.add_argument('--inductive-split', action='store_true',
help='split links for inductive settings: no overlapping nodes across '
+ 'splits.')
# graph modification arguments
argparser.add_argument('--add-reverse-edges', action='store_true',
help='turn the graph into an undirected graph.')
Expand Down Expand Up @@ -154,10 +158,30 @@
g.edges[target_e].data['train_mask'] = th.full((num_edges,), False, dtype=th.bool)
g.edges[target_e].data['val_mask'] = th.full((num_edges,), False, dtype=th.bool)
g.edges[target_e].data['test_mask'] = th.full((num_edges,), False, dtype=th.bool)
g.edges[target_e].data['train_mask'][: int(num_edges * args.train_pct)] = True
g.edges[target_e].data['val_mask'][int(num_edges * args.train_pct): \
int(num_edges * (args.train_pct + args.val_pct))] = True
g.edges[target_e].data['test_mask'][int(num_edges * (args.train_pct + args.val_pct)): ] = True
if not args.inductive_split:
# Randomly split links
g.edges[target_e].data['train_mask'][: int(num_edges * args.train_pct)] = True
g.edges[target_e].data['val_mask'][int(num_edges * args.train_pct): \
int(num_edges * (args.train_pct + args.val_pct))] = True
g.edges[target_e].data['test_mask'][int(num_edges * (args.train_pct + args.val_pct)): ] = True
else:
# Inductive split for link prediction
# 1. split the head nodes
ntype = target_e[0]
num_nodes = g.number_of_nodes(ntype)
shuffled_index = np.random.permutation(np.arange(num_nodes))
train_idx = shuffled_index[: int(num_nodes * args.train_pct)]
val_idx = shuffled_index[int(num_nodes * args.train_pct): \
int(num_nodes * (args.train_pct + args.val_pct))]
test_idx = shuffled_index[int(num_nodes * (args.train_pct + args.val_pct)): ]
# 2. find all out-edges for the sets of head nodes:
train_eids = g.out_edges(train_idx, form='eid', etype=target_e)
val_eids = g.out_edges(val_idx, form='eid', etype=target_e)
test_eids = g.out_edges(test_idx, form='eid', etype=target_e)
# 3. build boolean edge masks
g.edges[target_e].data['train_mask'][train_eids] = True
g.edges[target_e].data['val_mask'][val_eids] = True
g.edges[target_e].data['test_mask'][test_eids] = True

print(f'load {args.dataset} takes {time.time() - start:.3f} seconds')
print(f'\n|V|={g.number_of_nodes()}, |E|={g.number_of_edges()}\n')
Expand Down

0 comments on commit 77bc16c

Please sign in to comment.