-
Notifications
You must be signed in to change notification settings - Fork 4
/
options.py
399 lines (305 loc) · 14.3 KB
/
options.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
import argparse
import dataclasses
import os
from dataclasses import dataclass
import yaml
@dataclass
class Options:
""" "Dataclass for housing experiment flags."""
random_seed: int = 0
################################### logs ###################################
# experiment name
name: str = "debug"
# log directory for training.
log_dir: str = os.path.join(os.path.expanduser("~"), "tmp/tensorboard")
# want to note something about the experiment?
notes: str = ""
# interval in number of training steps to log.
log_interval: int = 100
# interval in number of training steps to validate.
val_interval: int = 1000
# number of validation batches per validation step
val_batches: int = 100
################################### data ###################################
# which dataset should we use?
dataset: str = "scannet"
# base dataset path.
dataset_path: str = "/mnt/scannet-data-png"
# number of dataloader workers to use.
num_workers: int = 12
# where to look for a tuple file.
tuple_info_file_location: str = "/mnt/res_nas/mohameds/implicit_recon/multi_view_data/scannet/"
# the suffix of a tuple filename, which is concatenated to the split to get
# the final tuple filename.
mv_tuple_file_suffix: str = "_eight_view_deepvmvs.txt"
# the type of frame tuple to use. default is DVMVS keyframes. dense yields
# an optimal online tuple for each frame in the scan. dense_offline will
# create a tuple for every frame using frames in the past and future.
frame_tuple_type: str = "default"
# number of views the model should expect in a tuple.
model_num_views: int = 8
# similar to model_num_views, but used exclusively for data
# loading/processing.
num_images_in_tuple: int = None
# file listing scans to use, since we use tuple files for dataloading, this
# is only relevant for generating tuple files and certain script guidance.
dataset_scan_split_file: str = "/mnt/scannet-data-png2/scannetv2_train.txt"
# the split to use, script dependant.
split: str = "train"
# image size input to the network. Used in dataloaders and projectors.
image_width: int = 512
image_height: int = 384
# used to shuffle tuple order for ablation.
shuffle_tuple: bool = False
# number of keyframes to keep around in the buffer for DVMVS tuples.
test_keyframe_buffer_size: int = 30
# full res supervision for implicit samples
full_depth_supervision: bool = True
############################## hyperparameters #############################
# learning rate
lr: float = 1e-4
# weight decay
wd: float = 1e-4
# number of sanity validation steps before training
num_sanity_val_steps: int = 0
# max number of iterations for training
max_steps: int = 110000
# batch size
batch_size: int = 16
# validation batch size during training
val_batch_size: int = 16
# number of GPUs to use for training.
gpus: int = 2
# precision to use for training.
precision: int = 16
# stepped learning rate schedule. LR will drop by 10 at both steps
lr_steps: list = dataclasses.field(default_factory=lambda: [70000, 80000])
# sampled depth pixels on the surface
near_surface_ratio: float = 0.25
surface_noise_type = "additive"
# regularisation weight encouraging predictions far from 0.5
bd_regularisation_weight: float = 0.5
# whether to apply regularisation only near depth edges
bd_edge_regularision: bool = True
################################## models ##################################
# resumes with training state
resume: str = None
# loads model weights
load_weights_from_checkpoint: str = None
# attempt to load weights from a pretrained model for BD
lazy_load_weights_from_checkpoint: str = None
# image prior encoder
image_encoder_name: str = "efficientnet"
# final depth decoder.
depth_decoder_name: str = "unet_pp"
# loss
loss_type: str = "log_l1"
# matching encoder. resnet or fpn
matching_encoder_type: str = "resnet"
# number of channels for matching features
matching_feature_dims: int = 16
# scale to match features at. 1 means half the final depth output size, or a
# quarter of image resolution.
matching_scale: int = 1
# number of depth bins to use in the cost volume.
matching_num_depth_bins: int = 64
# min and max depth planes in the cost volume
min_matching_depth: float = 0.25
max_matching_depth: float = 5.0
# type of cost volume encoder.
cv_encoder_type: str = "multi_scale_encoder"
# type of cost volume to use. SimpleRecon's metadata model uses the
# 'mlp_feature_volume' model. Also available in this repo is a simple dot
# reduction model 'simple_cost_volume'
feature_volume_type: str = "mlp_feature_volume"
# whether to use temporal stability
use_prior: bool = False
################################# Inference ################################
# base paths for all outputs.
output_base_path: str = "/mnt/res_nas/mohameds/simple_recon_output/"
# Where to load in rendered depth maps from.
# If set to 'None' will use the depth of a plane fixed at 2.0m from the camera.
rendered_depth_map_load_dir: str = None
# only run whatever it is this script is doing on a single frame.
single_debug_scan_id: str = None
# skip every skip_frames tuple when inferring depths. Useful for dense
# tuples
skip_frames: int = None
max_frames: int = None
# mask the predicted depth map using the a mask from the cost volume where
# true indicates available projected source view information. NOT used by
# default or for scores.
mask_pred_depth: bool = False
# cache predicted depths to disk when inferring
cache_depths: bool = False
# if true, will load in depth maps at the highest resolution available in
# the dataset and using those when computing metrics against upscaled
# predictions.
high_res_validation: bool = False
# fast cost volume for inference.
fast_cost_volume: bool = False
# if false will eval depth maps. If true will eval IOU for binary depth planes
binary_eval_depth: bool = False
# use valiation thresholds for evaluation
use_validation_thresholds: bool = False
# if true will eval depth planes for regression methods
regression_plane_eval: bool = False
# save out fewer outputs
skinny_cache_dump: bool = False
# temporal evaluation
temporal_eval: bool = False
eval_length: int = 15
eval_frame_multiplier: int = 8
warmup: int = 2
# multiplier inside the sigmoid operation - higher values give outputs nearer to 0 and 1
bd_sigmoid_multiplier: float = 1.0
############################### Visualization ##############################
# dump a quick depth visualization in test.py.
# visualization_scripts/visualize_scene_depth_output.py produces nicer
# visualizations than this though.
dump_depth_visualization: bool = False
class OptionsHandler:
"""A class for handling experiment options.
This class handles options files and optional CLI arguments for
experimentation.
The intended use looks like this:
optionsHandler = options.OptionsHandler()
# uses a config filename from args or populates flags from CLI
optionsHandler.parse_and_merge_options()
# optionally print
optionsHandler.pretty_print_options()
You could also load from a config file you choose and ignore one that
may be supplied in args.
optionsHandler.parse_and_merge_options(config_filepath =
os.path.join("configs", "test_config.yaml"))
Options will be populated by an optional supplied config files first,
then overwritten by any changes provided in command line args. If a
required attribute is not defined in either, then an Exception is thrown.
I want to add a new arg! What should I do? Well, easy. Add an entry in
the Options class and specify a type and default value. If this needs to
be a required arg, set None for a default value and also add its name
as a string to the required_flags list in the OptionsHandler class's
initializer.
There are two config files allowed. --config_file, then
--data_config_file. Order of overriding (last overrides above):
- config_file
- data_config_file
- CLI arguments
"""
def __init__(self, required_flags=[]):
"""Sets up the class and stores required flags."""
if required_flags is None:
required_flags = []
self.options = Options()
self.required_flags = required_flags
self.parser = argparse.ArgumentParser(description="SimpleRecon Options")
self.parser.add_argument("--config_file", type=str, default=None)
self.parser.add_argument("--data_config_file", type=str, default=None)
self.populate_argparse()
def parse_and_merge_options(self, config_filepaths=None, ignore_cl_args=False):
"""Parses flags from a config file and CL arguments.
Args:
config_filepaths: str filepath to a .yaml or list of filepaths
to config files ignore_cl_args: optionally ignore CLI
altogether, useful for debugging with a hardcoded config
filepath and in python notebooks.
Raises:
Exception: raised when required arguments aren't satisfied.
"""
# parse args
if not ignore_cl_args:
cl_args = self.parser.parse_args()
# load config file
if config_filepaths is not None:
# use config_filepath(s) provided here if available
if isinstance(config_filepaths, list):
for config_filepath in config_filepaths:
config_options = OptionsHandler.load_options_from_yaml(config_filepath)
self.merge_config_options(config_options)
else:
config_options = OptionsHandler.load_options_from_yaml(config_filepaths)
self.merge_config_options(config_options)
self.config_filepaths = config_filepaths
elif not ignore_cl_args and (
cl_args.config_file is not None or cl_args.data_config_file is not None
):
# if args tells us we should load from a file, then let's do that.
self.config_filepaths = []
# add from standard config first
if cl_args.config_file is not None:
config_options = OptionsHandler.load_options_from_yaml(cl_args.config_file)
self.merge_config_options(config_options)
self.config_filepaths.append(cl_args.config_file)
# then merge from a data config
if cl_args.data_config_file is not None:
config_options = OptionsHandler.load_options_from_yaml(cl_args.data_config_file)
self.merge_config_options(config_options)
self.config_filepaths.append(cl_args.data_config_file)
else:
# no config has been supplied. Let's hope that we have required
# arguments through command line.
print("Not reading from a config_file.")
config_options = None
self.config_filepaths = None
if not ignore_cl_args:
# merge args second and overwrite everything that's come before
self.merge_cl_args(cl_args)
# now check that all required arguments are satisfied
self.check_required_items()
def populate_argparse(self):
"""Populates argparse arguments using Options attributes."""
for field_name in self.options.__dataclass_fields__.keys():
field_info = self.options.__dataclass_fields__[field_name]
if field_info.type == bool:
self.parser.add_argument(f"--{field_name}", action="store_true")
else:
self.parser.add_argument(
f"--{field_name}",
type=field_info.type,
default=None,
)
def check_required_items(self):
"""Raises a flag if options isn't defined."""
for required_flag in self.required_flags:
if self.options.__getattribute__(required_flag) is None:
raise Exception(f"Error! Missing required config argument '{required_flag}'")
def merge_config_options(self, config_options):
""""""
# loop over loaded config and update those in self.options.
for field_name in config_options.__dict__.keys():
value = config_options.__getattribute__(field_name)
self.options.__setattr__(field_name, value)
def merge_cl_args(self, cl_args):
# loop over loaded args and update those in self.options.
for arg_pair in cl_args._get_kwargs():
# this should be the only argument that doesn't match here.
if arg_pair[0] == "config_file":
continue
if arg_pair[1] is not None:
# check if type bool and if false, in that case ignore
if isinstance(arg_pair[1], bool) and not arg_pair[1]:
continue
if arg_pair[0] == "prediction_mlp_channels":
array = "".join(arg_pair[1]).split("_")
array = [int(dim) for dim in array]
self.options.__setattr__(arg_pair[0], array)
else:
self.options.__setattr__(arg_pair[0], arg_pair[1])
def pretty_print_options(self):
print("########################### Options ###########################")
print("")
for field_name in self.options.__dataclass_fields__.keys():
print(" ", field_name + ":", self.options.__getattribute__(field_name))
print("")
print("###############################################################")
@staticmethod
def load_options_from_yaml(config_filepath):
stream = open(config_filepath, "r")
return yaml.load(stream, Loader=yaml.Loader)
@staticmethod
def save_options_as_yaml(config_filepath, options):
with open(config_filepath, "w") as outfile:
yaml.dump(options, outfile, default_flow_style=False)
def handle_backwards_compat(opts):
# modify older experiment configs if needed
return opts