Skip to content

Commit

Permalink
dataloader fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
damaggu committed Apr 5, 2023
1 parent dbd4ced commit db78ca0
Showing 1 changed file with 52 additions and 24 deletions.
76 changes: 52 additions & 24 deletions SwissKnife/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,24 +215,44 @@ def categorize_data(self, num_classes, recurrent=False):
self.y_test_recurrent, num_classes=num_classes, dtype="int"
)

def normalize_data(self):
# TODO: double check this here
# self.mean = self.x_train[1000:-1000].mean(axis=0)
# self.std = np.std(self.x_train[1000:-1000], axis=0)
self.mean = self.x_train.mean(axis=0)
self.std = np.std(self.x_train, axis=0)
self.x_train = self.x_train - self.mean
self.x_train /= self.std
self.x_test = self.x_test - self.mean
self.x_test /= self.std

if not self.dlc_train is None:
self.mean_dlc = self.dlc_train.mean(axis=0)
self.std_dlc = self.dlc_train.std(axis=0)
self.dlc_train -= self.mean_dlc
self.dlc_test -= self.mean
self.dlc_train /= self.std_dlc
self.dlc_test /= self.std_dlc
def normalize_data(self, mode="default"):
if mode == "default":
# TODO: double check this here
# self.mean = self.x_train[1000:-1000].mean(axis=0)
# self.std = np.std(self.x_train[1000:-1000], axis=0)
self.mean = self.x_train.mean(axis=0)
self.std = np.std(self.x_train, axis=0)
self.x_train = self.x_train - self.mean
self.x_train /= self.std
self.x_test = self.x_test - self.mean
self.x_test /= self.std

if not self.dlc_train is None:
self.mean_dlc = self.dlc_train.mean(axis=0)
self.std_dlc = self.dlc_train.std(axis=0)
self.dlc_train -= self.mean_dlc
self.dlc_test -= self.mean
self.dlc_train /= self.std_dlc
self.dlc_test /= self.std_dlc
elif mode == "xception":
self.x_train /= 127.5
self.x_train -= 1.0
self.x_test /= 127.5
self.x_test -= 1.0

if not self.dlc_train is None:
self.dlc_train /= 127.5
self.dlc_train -= 1.0
self.dlc_test /= 127.5
self.dlc_test -= 1.0

else:
self.x_train /= 255.0
self.x_test /= 255.0
if not self.dlc_train is None:
self.dlc_train /= 255.0
self.dlc_test /= 255.0


def create_dataset(dataset, oneD, look_back=5):
"""
Expand Down Expand Up @@ -438,9 +458,9 @@ def undersample_data(self):

# TODO: undersample recurrent

def change_dtype(self):
self.x_train = np.asarray(self.x_train, dtype="uint8")
self.x_test = np.asarray(self.x_test, dtype="uint8")
def change_dtype(self, dtype="uint8"):
self.x_train = np.asarray(self.x_train, dtype=dtype)
self.x_test = np.asarray(self.x_test, dtype=dtype)

def get_input_shape(self, recurrent=False):
"""
Expand Down Expand Up @@ -484,23 +504,30 @@ def downscale_frames(self, factor=0.5):
self.x_test = np.asarray(im_re)

def prepare_data(
self, downscale=0.5, remove_behaviors=[], flatten=False
self, downscale=0.5, remove_behaviors=[], flatten=False, recurrent=False, normalization_mode='default'
):
print("preparing data")
print("changing dtype")
self.change_dtype()

print("removing behaviors")
for behavior in remove_behaviors:
self.remove_behavior(behavior=behavior)
print("downscaling")
if downscale:
self.downscale_frames(factor=downscale)
print("normalizing data")
if self.config["normalize_data"]:
self.normalize_data()
print("doing flow")
if self.config["do_flow"]:
self.create_flow_data()
print("encoding labels")
if self.config["encode_labels"]:
print("test")
self.encode_labels()
print("labels encoded")
print("using class weights")
if self.config["use_class_weights"]:
print("calc class weights")
self.class_weights = class_weight.compute_class_weight(
Expand All @@ -509,16 +536,17 @@ def prepare_data(
if self.config["undersample_data"]:
print("undersampling data")
self.undersample_data()
print("using generator")
if self.config["use_generator"]:
self.categorize_data(self.num_classes, recurrent=False)
self.categorize_data(self.num_classes, recurrent=recurrent)
else:
print("preparing recurrent data")
self.create_recurrent_data()
print("preparing flattened data")
if flatten:
self.create_flattened_data()
print("categorize data")
self.categorize_data(self.num_classes, recurrent=True)
self.categorize_data(self.num_classes, recurrent=recurrent)

print("data ready")

Expand Down

0 comments on commit db78ca0

Please sign in to comment.