forked from ma-shangao/rl_waypoint_mrta
-
Notifications
You must be signed in to change notification settings - Fork 2
/
dataset_preparation.py
94 lines (70 loc) · 3.29 KB
/
dataset_preparation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import pickle
import os
import numpy as np
import torch
from sklearn.datasets import make_blobs
from torch.utils.data import Dataset
from torch.utils.data.dataset import T_co
def prepare_blob_dataset(city_num: int = 50,
feature_dim: int = 2,
sample_num: int = 100000
) -> (np.ndarray, np.ndarray):
samples = np.zeros((sample_num, city_num, feature_dim))
labels = np.zeros((sample_num, city_num))
for sample in range(sample_num):
samples[sample, :, :], labels[sample, :] = make_blobs(city_num,
feature_dim,
cluster_std=0.07,
center_box=(0.0, 1.0))
return samples, labels
class BlobDataset(Dataset):
def __init__(self, city_num: int = 50, feature_dim: int = 2, sample_num: int = 100000):
super(BlobDataset, self).__init__()
self.city_num = city_num
self.feature_dim = feature_dim
self.sample_num = sample_num
self.samples, self.labels = self._generate_dataset()
def __getitem__(self, index) -> T_co:
sample = self.samples[index]
label = self.labels[index]
data_pair = {'sample': sample, 'label': label}
return data_pair
def __len__(self):
return len(self.samples)
def _generate_dataset(self):
samples, labels = prepare_blob_dataset(self.city_num,
self.feature_dim,
self.sample_num)
return torch.from_numpy(samples).float(), torch.from_numpy(labels)
# TSP dataset wrapper from https://github.com/wouterkool/attention-learn-to-route
class TSPDataset(Dataset):
def __init__(self, filename=None, size=20, num_samples=1000000, offset=0, distribution=None):
super(TSPDataset, self).__init__()
self.data_set = []
if filename is not None:
if os.path.splitext(filename)[1] == '.npy':
data = np.load(filename)
assert data.ndim == (2 or 3), "data.ndim should either be 2 or 3"
if data.ndim == 2:
data = np.expand_dims(data, axis=0)
self.data = [torch.FloatTensor(row) for row in (data[offset:offset + num_samples])]
else:
assert os.path.splitext(filename)[1] == '.pkl'
with open(filename, 'rb') as f:
data = pickle.load(f)
self.data = [torch.FloatTensor(row) for row in (data[offset:offset + num_samples])]
else:
# Sample points randomly in [0, 1] square
self.data = [torch.FloatTensor(size, 2).uniform_(0, 1) for _ in range(num_samples)]
self.size = len(self.data)
def __len__(self):
return self.size
def __getitem__(self, idx):
return self.data[idx]
def data_normalisation(self):
self.data = [(self.data[row] - self.data[row].min()) / (self.data[row].max() - self.data[row].min())
for row in range(self.size)]
if __name__ == '__main__':
test = TSPDataset(filename='tmp/platforms.npy')
test2 = TSPDataset(size=20, num_samples=50)
print(len(test))