-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbuild.py
193 lines (167 loc) · 9.45 KB
/
build.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import torch
import os
import logging
import argparse
import tensorflow as tf
from torchvision import datasets, transforms
from src.utils.experiments import set_seed
from src.torch_dataset.sequantial_image import SequentialImage2Classify
from src.utils.saving import save_data
from lra_benchmarks.data.pathfinder import Pathfinder32, Pathfinder64, Pathfinder128, Pathfinder256
from src.torch_dataset.torch_pathfinder import PathfinderDataset
from src.torch_dataset.torch_listops import ListOpsDataset
from lra_benchmarks.data.listops import listops
from lra_benchmarks.listops import input_pipeline
import tensorflow_datasets as tfds
from torchtext.datasets import IMDB
from src.torch_dataset.torch_text import TextDataset
from src.utils.experiments import read_yaml_to_dict
tasks = ['smnist', 'pmnist', 'scifar10gs', 'scifar10', 'pathfinder', 'pathx', 'listops', 'imdb']
parser = argparse.ArgumentParser(description='Build Classification task.')
parser.add_argument('--seed', type=int, default=42, help='Random seed.')
parser.add_argument('--task', required=True, choices=tasks, help='Name of classification task.')
args, unknown = parser.parse_known_args()
logging.basicConfig(level=logging.INFO)
args = parser.parse_args()
logging.info(f"Setting seed: {args.seed}")
set_seed(args.seed)
logging.info(f"Building task: {args.task}")
if args.task == 'smnist':
transform = transforms.Compose([
transforms.ToTensor(), # Convert image to pytorch tensor with values in [0, 1] and shape (C, H, W)
# transforms.Normalize((0.1307,), (0.3081,)),
])
develop_dataset = SequentialImage2Classify(datasets.MNIST(root='../data_storage/',
train=True,
transform=transform,
download=True))
test_dataset = SequentialImage2Classify(datasets.MNIST(root='../data_storage/',
train=False,
transform=transform,
download=True))
elif args.task == 'pmnist':
transform = transforms.Compose([
transforms.ToTensor(), # Convert image to pytorch tensor with values in [0, 1] and shape (C, H, W)
# transforms.Normalize((0.1307,), (0.3081,)),
])
permutation = torch.randperm(28 * 28)
develop_dataset = SequentialImage2Classify(dataset=datasets.MNIST(root='../data_storage/',
train=True,
transform=transform,
download=True),
permutation=permutation)
test_dataset = SequentialImage2Classify(dataset=datasets.MNIST(root='../data_storage/',
train=False,
transform=transform,
download=True),
permutation=permutation)
elif args.task == 'scifar10':
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
develop_dataset = SequentialImage2Classify(datasets.CIFAR10(root='../data_storage/',
train=True,
transform=transform,
download=True))
test_dataset = SequentialImage2Classify(datasets.CIFAR10(root='../data_storage/',
train=False,
transform=transform,
download=True))
elif args.task == 'scifar10gs':
transform = transforms.Compose([
transforms.Grayscale(),
transforms.ToTensor(),
# transforms.Normalize(mean=122.6 / 255.0, std=61.0 / 255.0),
])
develop_dataset = SequentialImage2Classify(datasets.CIFAR10(root='../data_storage/',
train=True,
transform=transform,
download=True))
test_dataset = SequentialImage2Classify(datasets.CIFAR10(root='../data_storage/',
train=False,
transform=transform,
download=True))
elif args.task in ['pathfinder', 'pathx']:
setting = read_yaml_to_dict(os.path.join('configs', args.task, 'setting.yaml'))
data = setting.get('data', {})
resolution = str(data.get('resolution'))
level = data.get('level')
learning = setting.get('learning')
test_split = learning.get('test_split')
train_split = str(int((1 - test_split) * 100)) + '%'
pathfinders = {
'32': Pathfinder32,
'64': Pathfinder64,
'128': Pathfinder128,
'256': Pathfinder256,
}
builder_class = pathfinders[resolution]
try:
builder_dataset = builder_class()
builder_dataset.download_and_prepare()
# Load the dataset with shuffled files
develop_dataset, test_dataset = builder_dataset.as_dataset(split=[level+'[:'+train_split+']', level+'['+train_split+':]'],
decoders={'image': tfds.decode.SkipDecoding()})
# Filter out examples with empty images
develop_dataset = develop_dataset.filter(lambda x: tf.strings.length((x['image'])) > 0)
test_dataset = test_dataset.filter(lambda x: tf.strings.length((x['image'])) > 0)
def decode(x):
decoded = {
'input': tf.cast(tf.image.decode_png(x['image']), dtype=tf.int32),
'label': x['label']
}
return decoded
develop_dataset = develop_dataset.map(decode, num_parallel_calls=tf.data.experimental.AUTOTUNE)
test_dataset = test_dataset.map(decode, num_parallel_calls=tf.data.experimental.AUTOTUNE)
# Convert TensorFlow dataset to PyTorch dataset
develop_dataset = PathfinderDataset(develop_dataset)
test_dataset = PathfinderDataset(test_dataset)
except FileNotFoundError:
logging.error(f"Pathfinder dataset {level} level not found. To download the datasets,"
f"please download it from gs://long-range-arena/lra_release."
f"If permissions fail, you may download the entire gziped file at"
f"https://storage.googleapis.com/long-range-arena/lra_release.gz. Put the data behind the"
f"project directory")
elif args.task == 'listops':
setting = read_yaml_to_dict(os.path.join('configs', args.task, 'setting.yaml'))
data = setting.get('data', {})
num_dev_samples = data['num_dev_samples']
num_test_samples = data['num_test_samples']
max_depth = data['max_depth']
max_args = data['max_args']
max_length = data['max_length']
min_length = data['min_length']
listops(task_name=args.task, num_develop_samples=num_dev_samples, num_test_samples=num_test_samples,
max_depth=max_depth, max_args=max_args,
max_length=max_length, min_length=min_length,
output_dir=os.path.join('..', 'data_storage'))
develop_dataset, test_dataset, encoder = input_pipeline.get_datasets(
n_devices=4,
task_name=args.task,
data_dir=os.path.join('..', 'data_storage'),
max_length=max_length)
# Convert TensorFlow dataset to PyTorch dataset
develop_dataset = ListOpsDataset(develop_dataset, padding_idx=0)
test_dataset = ListOpsDataset(test_dataset, padding_idx=0)
elif args.task == 'imdb':
setting = read_yaml_to_dict(os.path.join('configs', args.task, 'setting.yaml'))
data = setting.get('data', {})
max_length = data['max_length']
level = data['level']
min_freq = data['min_freq']
append_bos = data['append_bos']
append_eos = data['append_eos']
# Download and load IMDB dataset
develop_dataset, test_dataset = IMDB(root='../data_storage/')
# for label, text in develop_dataset:
# print(label, text)
develop_dataset = TextDataset(dataset=develop_dataset, max_length=max_length, level=level, min_freq=min_freq,
append_bos=append_bos, append_eos=append_eos, padding_idx=0)
test_dataset = TextDataset(dataset=test_dataset, max_length=max_length, level=level, min_freq=min_freq,
append_bos=append_bos, append_eos=append_eos, padding_idx=0)
else:
raise ValueError('Task not found')
logging.info('Saving datasets')
save_data(develop_dataset, os.path.join('..', 'datasets', args.task, 'develop_dataset'))
save_data(test_dataset, os.path.join('..', 'datasets', args.task, 'test_dataset'))