-
Notifications
You must be signed in to change notification settings - Fork 25
/
data_loader.py
98 lines (77 loc) · 3.22 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import warnings
import pickle as pkl
import sys, os
import scipy.sparse as sp
import networkx as nx
import torch
import numpy as np
# from sklearn import datasets
# from sklearn.preprocessing import LabelBinarizer, scale
# from sklearn.model_selection import train_test_split
# from ogb.nodeproppred import DglNodePropPredDataset
# import copy
from utils import sparse_mx_to_torch_sparse_tensor #, dgl_graph_to_torch_sparse
warnings.simplefilter("ignore")
def parse_index_file(filename):
"""Parse index file."""
index = []
for line in open(filename):
index.append(int(line.strip()))
return index
def sample_mask(idx, l):
"""Create mask."""
mask = np.zeros(l)
mask[idx] = 1
return np.array(mask, dtype=np.bool)
def load_citation_network(dataset_str, sparse=None):
names = ['x', 'y', 'tx', 'ty', 'allx', 'ally', 'graph']
objects = []
for i in range(len(names)):
with open("data/ind.{}.{}".format(dataset_str, names[i]), 'rb') as f:
if sys.version_info > (3, 0):
objects.append(pkl.load(f, encoding='latin1'))
else:
objects.append(pkl.load(f))
x, y, tx, ty, allx, ally, graph = tuple(objects)
test_idx_reorder = parse_index_file("data/ind.{}.test.index".format(dataset_str))
test_idx_range = np.sort(test_idx_reorder)
if dataset_str == 'citeseer':
# Fix citeseer dataset (there are some isolated nodes in the graph)
# Find isolated nodes, add them as zero-vecs into the right position
test_idx_range_full = range(min(test_idx_reorder), max(test_idx_reorder) + 1)
tx_extended = sp.lil_matrix((len(test_idx_range_full), x.shape[1]))
tx_extended[test_idx_range - min(test_idx_range), :] = tx
tx = tx_extended
ty_extended = np.zeros((len(test_idx_range_full), y.shape[1]))
ty_extended[test_idx_range - min(test_idx_range), :] = ty
ty = ty_extended
features = sp.vstack((allx, tx)).tolil()
features[test_idx_reorder, :] = features[test_idx_range, :]
adj = nx.adjacency_matrix(nx.from_dict_of_lists(graph))
if not sparse:
adj = np.array(adj.todense(),dtype='float32')
else:
adj = sparse_mx_to_torch_sparse_tensor(adj)
labels = np.vstack((ally, ty))
labels[test_idx_reorder, :] = labels[test_idx_range, :]
idx_test = test_idx_range.tolist()
idx_train = range(len(y))
idx_val = range(len(y), len(y) + 500)
train_mask = sample_mask(idx_train, labels.shape[0])
val_mask = sample_mask(idx_val, labels.shape[0])
test_mask = sample_mask(idx_test, labels.shape[0])
features = torch.FloatTensor(features.todense())
labels = torch.LongTensor(labels)
train_mask = torch.BoolTensor(train_mask)
val_mask = torch.BoolTensor(val_mask)
test_mask = torch.BoolTensor(test_mask)
nfeats = features.shape[1]
for i in range(labels.shape[0]):
sum_ = torch.sum(labels[i])
if sum_ != 1:
labels[i] = torch.tensor([1, 0, 0, 0, 0, 0])
labels = (labels == 1).nonzero()[:, 1]
nclasses = torch.max(labels).item() + 1
return features, nfeats, labels, nclasses, train_mask, val_mask, test_mask, adj
def load_data(args):
return load_citation_network(args.dataset, args.sparse)