From b8726ade76d8e4deaef390e2d058c426c0b61b3f Mon Sep 17 00:00:00 2001 From: arch Date: Sat, 27 Nov 2021 19:26:57 +0100 Subject: [PATCH] add random shift to training --- config.yaml | 6 +++--- utils/dataset.py | 10 ++++++++++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/config.yaml b/config.yaml index 400f330..5ae70e0 100644 --- a/config.yaml +++ b/config.yaml @@ -4,7 +4,7 @@ general: checkpoint_dir: 'checkpoint' train_dir: './data/train' batch_size: 1 # data loader is only implemented for case == 1! - select: 'model3' + select: 'model1' test_file: './data/test/example1.mkv' model1: @@ -17,8 +17,8 @@ model1: convlstm_hidden_dim: 64 seq_len: 8 lr: 0.0001 - lr_milestones: [7, 12, 15] - epochs: 25 + lr_milestones: [4, 7, 9] + epochs: 10 model2: name: 'funpos2' diff --git a/utils/dataset.py b/utils/dataset.py index 205596b..a1f46c8 100644 --- a/utils/dataset.py +++ b/utils/dataset.py @@ -86,6 +86,10 @@ def load_next_video_frames(self): self.read_next_frame() + def get_uniform_random_int(self, lower, upper): + return torch.randint(lower, upper+1, (1,)).numpy()[0] + + def open_next_video(self): self.inc_video_idx() with open("".join(self.videos[self.video_idx][:-4]) + '.param', "r") as f: @@ -93,6 +97,12 @@ def open_next_video(self): with open("".join(self.videos[self.video_idx][:-4]) + '.labels', "r") as f: l = json.load(f) self.labels = {int(k):l[k] for k in l.keys()} + self.param['zoom'] = [ + self.param['zoom'][0] + self.get_uniform_random_int(-10, 10), + self.param['zoom'][1] + self.get_uniform_random_int(-10, 10), + self.param['zoom'][2] + self.get_uniform_random_int(-10, 10), + self.param['zoom'][3] + self.get_uniform_random_int(-10, 10) + ] self.param['resize'] = (self.img_width, self.img_height) if self.stream is not None: self.stream.stop()