diff --git a/datasets/__init__.py b/datasets/__init__.py index 31c6a87..13e294b 100644 --- a/datasets/__init__.py +++ b/datasets/__init__.py @@ -62,7 +62,8 @@ def get_data_generator(dataset, data_root, classes = None): if dataset == 'cifar-10': - return CifarGenerator(data_root, classes, reenumerate = True, cifar10 = True, randzoom_range = 0.25) + return CifarGenerator(data_root, classes, reenumerate = True, cifar10 = True, + train_generator_kwargs = { 'horizontal_flip' : True, 'width_shift_range' : 0.15, 'height_shift_range' : 0.15, 'zoom_range' : 0.25 }) elif dataset == 'cifar-100':