Skip to content

Commit

Permalink
Remove tail nodes outside of training set
Browse files Browse the repository at this point in the history
  • Loading branch information
wangz10 committed Nov 24, 2023
1 parent 77bc16c commit 3d41849
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions tools/partition_graph_lp.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,13 @@
int(num_nodes * (args.train_pct + args.val_pct))]
test_idx = shuffled_index[int(num_nodes * (args.train_pct + args.val_pct)): ]
# 2. find all out-edges for the sets of head nodes:
train_eids = g.out_edges(train_idx, form='eid', etype=target_e)
val_eids = g.out_edges(val_idx, form='eid', etype=target_e)
test_eids = g.out_edges(test_idx, form='eid', etype=target_e)
# we remove edges with tail nodes outside of the training set
_, train_v, train_eids = g.out_edges(train_idx, form='all', etype=target_e)
train_eids = train_eids[np.in1d(train_v, train_idx)]
_, val_v, val_eids = g.out_edges(val_idx, form='all', etype=target_e)
val_eids = val_eids[np.in1d(val_v, train_idx)]
_, test_v, test_eids = g.out_edges(test_idx, form='all', etype=target_e)
test_eids = test_eids[np.in1d(test_v, train_idx)]
# 3. build boolean edge masks
g.edges[target_e].data['train_mask'][train_eids] = True
g.edges[target_e].data['val_mask'][val_eids] = True
Expand Down

0 comments on commit 3d41849

Please sign in to comment.