-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathinput_data.py
53 lines (45 loc) · 1.7 KB
/
input_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
"""
****************NOTE*****************
CREDITS : Thomas Kipf
since datasets are the same as those in kipf's implementation,
Their preprocessing source was used as-is.
*************************************
"""
import pickle as pkl
import sys
import networkx as nx
import numpy as np
import scipy.sparse as sp
def parse_index_file(filename):
index = []
for line in open(filename):
index.append(int(line.strip()))
return index
def load_data(dataset):
# load the data: x, tx, allx, graph
names = ["x", "tx", "allx", "graph"]
objects = []
for i in range(len(names)):
with open("data/ind.{}.{}".format(dataset, names[i]), "rb") as f:
if sys.version_info > (3, 0):
objects.append(pkl.load(f, encoding="latin1"))
else:
objects.append(pkl.load(f))
x, tx, allx, graph = tuple(objects)
test_idx_reorder = parse_index_file(
"data/ind.{}.test.index".format(dataset)
)
test_idx_range = np.sort(test_idx_reorder)
if dataset == "citeseer":
# Fix citeseer dataset (there are some isolated nodes in the graph)
# Find isolated nodes, add them as zero-vecs into the right position
test_idx_range_full = range(
min(test_idx_reorder), max(test_idx_reorder) + 1
)
tx_extended = sp.lil_matrix((len(test_idx_range_full), x.shape[1]))
tx_extended[test_idx_range - min(test_idx_range), :] = tx
tx = tx_extended
features = sp.vstack((allx, tx)).tolil()
features[test_idx_reorder, :] = features[test_idx_range, :]
adj = nx.adjacency_matrix(nx.from_dict_of_lists(graph))
return adj, features