From 0ca2c9344b9a29d05a366cc0456310eeea0e8468 Mon Sep 17 00:00:00 2001 From: gferraro Date: Tue, 12 Nov 2024 09:52:15 +0100 Subject: [PATCH] added mask segment type as default --- src/ml_tools/dataset.py | 2 +- src/ml_tools/datasetstructures.py | 34 +++++++++++++++++++++---------- 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/src/ml_tools/dataset.py b/src/ml_tools/dataset.py index 7e633b32..87e8ccaa 100644 --- a/src/ml_tools/dataset.py +++ b/src/ml_tools/dataset.py @@ -83,7 +83,7 @@ def __init__( self.excluded_tags = config.build.excluded_tags self.min_frame_mass = config.build.min_frame_mass self.filter_by_lq = config.build.filter_by_lq - self.segment_types = [SegmentType.ALL_RANDOM] + self.segment_types = [SegmentType.ALL_RANDOM_MASKED] self.max_segments = config.build.max_segments self.country = config.build.country self.max_frames = config.build.max_frames diff --git a/src/ml_tools/datasetstructures.py b/src/ml_tools/datasetstructures.py index ad4b1d1b..57d04dcc 100644 --- a/src/ml_tools/datasetstructures.py +++ b/src/ml_tools/datasetstructures.py @@ -30,6 +30,7 @@ class SegmentType(Enum): ALL_SECTIONS = 5 TOP_RANDOM = 6 ALL_RANDOM_NOMIN = 7 + ALL_RANDOM_MASKED = 8 class BaseSample(ABC): @@ -1071,9 +1072,13 @@ def get_segments( frame_indices = np.array(frame_indices) segment_count = max(1, len(frame_indices) // segment_frame_spacing) segment_count = int(segment_count) + mask_length = 25 + # probably only counts for all random if max_segments is not None and segment_type not in [SegmentType.ALL_SECTIONS]: segment_count = min(max_segments, segment_count) + # adjust size of mask if we take less segments + mask_length = max(mask_length, len(frame_indices) // segment_count) # take any segment_width frames, this could be done each epoch whole_indices = frame_indices random_frames = segment_type in [ @@ -1081,12 +1086,13 @@ def get_segments( SegmentType.ALL_RANDOM, SegmentType.ALL_RANDOM_NOMIN, SegmentType.TOP_RANDOM, + SegmentType.ALL_RANDOM_MASKED, None, ] - random_mask = True + for _ in range(repeats): used_indices = [] - if not random_mask: + if segment_type != SegmentType.ALL_RANDOM_MASKED or len(whole_indices) < 40: frame_indices = whole_indices.copy() if random_frames: @@ -1094,21 +1100,27 @@ def get_segments( np.random.shuffle(frame_indices) for i in range(segment_count): - if random_mask: - if len(whole_indices) < 40: - frame_indices = whole_indices.copy() - else: - mask_start = i * 25 - frame_indices = list(whole_indices[0:mask_start].copy()) - frame_indices.extend(whole_indices[mask_start + 25 :].copy()) + if segment_type == SegmentType.ALL_RANDOM_MASKED: + if len(whole_indices) > 40: + mask_start = i * mask_length + frame_indices = whole_indices[0:mask_start] + frame_indices = np.concatenate( + [frame_indices, whole_indices[mask_start + mask_length :]], + axis=0, + ) + # maybe some faster way of doing this... frame_indices = [ f for f in frame_indices if f not in used_indices ] frame_indices = np.uint32(frame_indices) - np.random.shuffle(frame_indices) + np.random.shuffle(frame_indices) # always get atleast one segment, not doing annymore - if len(frame_indices) == 0 or len(segments) >= min_segments: + if ( + len(frame_indices) == 0 + or min_segments is None + or len(segments) >= min_segments + ): if ( len(frame_indices) < segment_width / 2.0 and len(segments) > 1 ) or len(frame_indices) < segment_width / 4: