Skip to content

Commit

Permalink
add random shift to training
Browse files Browse the repository at this point in the history
  • Loading branch information
arch committed Nov 27, 2021
1 parent 34d2856 commit b8726ad
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
6 changes: 3 additions & 3 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ general:
checkpoint_dir: 'checkpoint'
train_dir: './data/train'
batch_size: 1 # data loader is only implemented for case == 1!
select: 'model3'
select: 'model1'
test_file: './data/test/example1.mkv'

model1:
Expand All @@ -17,8 +17,8 @@ model1:
convlstm_hidden_dim: 64
seq_len: 8
lr: 0.0001
lr_milestones: [7, 12, 15]
epochs: 25
lr_milestones: [4, 7, 9]
epochs: 10

model2:
name: 'funpos2'
Expand Down
10 changes: 10 additions & 0 deletions utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,23 @@ def load_next_video_frames(self):
self.read_next_frame()


def get_uniform_random_int(self, lower, upper):
return torch.randint(lower, upper+1, (1,)).numpy()[0]


def open_next_video(self):
self.inc_video_idx()
with open("".join(self.videos[self.video_idx][:-4]) + '.param', "r") as f:
self.param = json.load(f)
with open("".join(self.videos[self.video_idx][:-4]) + '.labels', "r") as f:
l = json.load(f)
self.labels = {int(k):l[k] for k in l.keys()}
self.param['zoom'] = [
self.param['zoom'][0] + self.get_uniform_random_int(-10, 10),
self.param['zoom'][1] + self.get_uniform_random_int(-10, 10),
self.param['zoom'][2] + self.get_uniform_random_int(-10, 10),
self.param['zoom'][3] + self.get_uniform_random_int(-10, 10)
]
self.param['resize'] = (self.img_width, self.img_height)
if self.stream is not None:
self.stream.stop()
Expand Down

0 comments on commit b8726ad

Please sign in to comment.