-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset_transform.py
81 lines (71 loc) · 3.29 KB
/
dataset_transform.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from graph_partition import SubgraphsData, metis_subgraph, random_subgraph, to_sparse, combine_subgraphs, cal_coarsen_adj
import torch
from pe import random_walk, RWSE, LapPE
class PositionalEncodingTransform(object):
def __init__(self, rw_dim=0, lap_dim=0):
super().__init__()
self.rw_dim = rw_dim
self.lap_dim = lap_dim
def __call__(self, data):
if self.rw_dim > 0:
data.rw_pos_enc = RWSE(
data.edge_index, self.rw_dim, data.num_nodes)
if self.lap_dim > 0:
data.lap_pos_enc = LapPE(
data.edge_index, self.lap_dim, data.num_nodes)
return data
class GraphPartitionTransform(object):
def __init__(self, n_patches, metis=True, drop_rate=0.0, num_hops=1, is_directed=False, patch_rw_dim=0, patch_num_diff=0):
super().__init__()
self.n_patches = n_patches
self.drop_rate = drop_rate
self.num_hops = num_hops
self.is_directed = is_directed
self.patch_rw_dim = patch_rw_dim
self.patch_num_diff = patch_num_diff
self.metis = metis
def _diffuse(self, A):
if self.patch_num_diff == 0:
return A
Dinv = A.sum(dim=-1).clamp(min=1).pow(-1).unsqueeze(-1) # D^-1
RW = A * Dinv
M = RW
M_power = M
# Iterate
for _ in range(self.patch_num_diff-1):
M_power = torch.matmul(M_power, M)
return M_power
def __call__(self, data):
data = SubgraphsData(**{k: v for k, v in data})
if self.metis:
node_masks, edge_masks = metis_subgraph(
data, n_patches=self.n_patches, drop_rate=self.drop_rate, num_hops=self.num_hops, is_directed=self.is_directed)
else:
node_masks, edge_masks = random_subgraph(
data, n_patches=self.n_patches, num_hops=self.num_hops)
subgraphs_nodes, subgraphs_edges = to_sparse(node_masks, edge_masks)
combined_subgraphs = combine_subgraphs(
data.edge_index, subgraphs_nodes, subgraphs_edges, num_selected=self.n_patches, num_nodes=data.num_nodes)
coarsen_adj = cal_coarsen_adj(node_masks)
coarsen_rows_batch, coarsen_cols_batch = torch.nonzero(
coarsen_adj, as_tuple=True)
data.coarsen_edge_attr = coarsen_adj[coarsen_rows_batch,
coarsen_cols_batch]
data.subgraphs_batch_row = coarsen_rows_batch
data.subgraphs_batch_col = coarsen_cols_batch
if self.patch_num_diff > -1 or self.patch_rw_dim > 0:
if self.patch_rw_dim > 0:
data.patch_pe = random_walk(coarsen_adj, self.patch_rw_dim)
if self.patch_num_diff > -1:
data.coarsen_adj = self._diffuse(coarsen_adj).unsqueeze(0)
subgraphs_batch = subgraphs_nodes[0]
mask = torch.zeros(self.n_patches).bool()
mask[subgraphs_batch] = True
data.subgraphs_batch = subgraphs_batch
data.subgraphs_batch_edge = subgraphs_edges[0]
data.subgraphs_nodes_mapper = subgraphs_nodes[1]
data.subgraphs_edges_mapper = subgraphs_edges[1]
data.combined_subgraphs = combined_subgraphs
data.mask = mask.unsqueeze(0)
data.__num_nodes__ = data.num_nodes # set number of nodes of the current graph
return data