Skip to content

Commit

Permalink
add multiple segment type option
Browse files Browse the repository at this point in the history
  • Loading branch information
gferraro committed Nov 7, 2024
1 parent 7ca8817 commit 46b431b
Show file tree
Hide file tree
Showing 11 changed files with 226 additions and 269 deletions.
3 changes: 2 additions & 1 deletion src/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,7 +896,7 @@ def main():
{
"segment_frame_spacing": master_dataset.segment_spacing * 9,
"segment_width": master_dataset.segment_length,
"segment_type": master_dataset.segment_type,
"segment_types": master_dataset.segment_types,
"segment_min_avg_mass": master_dataset.segment_min_avg_mass,
"max_segments": master_dataset.max_segments,
"dont_filter_segment": True,
Expand Down Expand Up @@ -932,6 +932,7 @@ def main():
"counts": dataset_counts,
"by_label": False,
"config": attrs.asdict(config),
"segment_types": master_dataset.segment_types,
}

with open(meta_filename, "w") as f:
Expand Down
46 changes: 3 additions & 43 deletions src/ml_tools/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __init__(
self.excluded_tags = config.build.excluded_tags
self.min_frame_mass = config.build.min_frame_mass
self.filter_by_lq = config.build.filter_by_lq
self.segment_type = SegmentType.ALL_RANDOM
self.segment_types = [SegmentType.ALL_RANDOM]
self.max_segments = config.build.max_segments
self.country = config.build.country
self.max_frames = config.build.max_frames
Expand All @@ -100,7 +100,7 @@ def __init__(
self.segment_spacing = 1
self.segment_min_avg_mass = 10
self.min_frame_mass = 16
self.segment_type = SegmentType.ALL_RANDOM
self.segment_types = [SegmentType.ALL_RANDOM]
self.max_frames = 75

self.country_rectangle = BuildConfig.COUNTRY_LOCATIONS.get(self.country)
Expand Down Expand Up @@ -244,7 +244,7 @@ def load_clip(self, db_clip, dont_filter_segment=False):
track_header.get_segments(
segment_width,
segment_frame_spacing,
self.segment_type,
self.segment_types,
self.segment_min_avg_mass,
max_segments=self.max_segments,
dont_filter=dont_filter_segment,
Expand Down Expand Up @@ -504,46 +504,6 @@ def regroup(
def has_data(self):
return len(self.samples_by_id) > 0

#
# def recalculate_segments(self, segment_type=SegmentType.ALL_RANDOM):
# self.samples_by_bin.clear()
# self.samples_by_label.clear()
# del self.samples[:]
# del self.samples
# self.samples = []
# self.samples_by_label = {}
# self.samples_by_bin = {}
# logging.info("%s generating segments type %s", self.name, segment_type)
# start = time.time()
# empty_tracks = []
# filtered_stats = 0
#
# for track in self.tracks:
# segment_frame_spacing = int(
# round(self.segment_spacing * track.frames_per_second)
# )
# segment_width = self.segment_length
# track.calculate_segments(
# segment_frame_spacing,
# segment_width,
# segment_type,
# segment_min_mass=segment_min_avg_mass,
# )
# filtered_stats = filtered_stats + track.filtered_stats["segment_mass"]
# if len(track.segments) == 0:
# empty_tracks.append(track)
# continue
# for sample in track.segments:
# self.add_clip_sample_mappings(sample)
#
# self.rebuild_cdf()
# logging.info(
# "%s #segments %s filtered stats are %s took %s",
# self.name,
# len(self.samples),
# filtered_stats,
# time.time() - start,
# )
def remove_sample_by_id(self, id, bin_id):
del self.samples_by_id[id]
try:
Expand Down
Loading

0 comments on commit 46b431b

Please sign in to comment.