From 6e8ac75f4b9b15f3c46ae10a1494ea1f7fbd5b76 Mon Sep 17 00:00:00 2001 From: Stefan Kahl Date: Fri, 23 Feb 2024 09:13:29 -0500 Subject: [PATCH 1/2] fix segments extraction --- segments.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/segments.py b/segments.py index ffb0465d..e49c68f2 100644 --- a/segments.py +++ b/segments.py @@ -152,9 +152,10 @@ def findSegments(afile: str, rfile: str): for i, line in enumerate(lines): if rtype == "table" and i > 0: + # TODO: Use header columns to get the right indices d = line.split("\t") - start = float(d[4]) - end = float(d[5]) + start = float(d[5]) + end = float(d[6]) species = d[-2] confidence = float(d[-1]) @@ -186,8 +187,8 @@ def findSegments(afile: str, rfile: str): species = d[3] confidence = float(d[4]) - # Check if confidence is high enough - if confidence >= cfg.MIN_CONFIDENCE: + # Check if confidence is high enough and label is not "nocall" + if confidence >= cfg.MIN_CONFIDENCE and species.lower() != "nocall": segments.append({"audio": afile, "start": start, "end": end, "species": species, "confidence": confidence}) return segments @@ -239,8 +240,8 @@ def extractSegments(item: tuple[tuple[str, list[dict]], float, dict[str]]): os.makedirs(outpath, exist_ok=True) # Save segment - seg_name = "{:.3f}_{}_{}.wav".format( - seg["confidence"], seg_cnt, seg["audio"].rsplit(os.sep, 1)[-1].rsplit(".", 1)[0] + seg_name = "{:.3f}_{}_{}_{:.1f}s_{:.1f}s.wav".format( + seg["confidence"], seg_cnt, seg["audio"].rsplit(os.sep, 1)[-1].rsplit(".", 1)[0], seg["start"], seg["end"] ) seg_path = os.path.join(outpath, seg_name) audio.saveSignal(seg_sig, seg_path) From ffb783702bf437c811cc3aac52a396cc4679d203 Mon Sep 17 00:00:00 2001 From: Stefan Kahl Date: Fri, 23 Feb 2024 09:21:27 -0500 Subject: [PATCH 2/2] fix threads --- embeddings.py | 4 ++-- segments.py | 3 ++- train.py | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/embeddings.py b/embeddings.py index 4ad70f46..f7e3acbb 100644 --- a/embeddings.py +++ b/embeddings.py @@ -183,11 +183,11 @@ def analyzeFile(item): # Set number of threads if os.path.isdir(cfg.INPUT_PATH): - cfg.CPU_THREADS = int(args.threads) + cfg.CPU_THREADS = max(1, int(args.threads)) cfg.TFLITE_THREADS = 1 else: cfg.CPU_THREADS = 1 - cfg.TFLITE_THREADS = int(args.threads) + cfg.TFLITE_THREADS = max(1, int(args.threads)) # Set batch size cfg.BATCH_SIZE = max(1, int(args.batchsize)) diff --git a/segments.py b/segments.py index e49c68f2..cfb8db73 100644 --- a/segments.py +++ b/segments.py @@ -4,6 +4,7 @@ """ import argparse import os +import multiprocessing from multiprocessing import Pool import numpy as np @@ -268,7 +269,7 @@ def extractSegments(item: tuple[tuple[str, list[dict]], float, dict[str]]): parser.add_argument( "--seg_length", type=float, default=3.0, help="Length of extracted segments in seconds. Defaults to 3.0." ) - parser.add_argument("--threads", type=int, default=4, help="Number of CPU threads.") + parser.add_argument("--threads", type=int, default=min(8, max(1, multiprocessing.cpu_count() // 2)), help="Number of CPU threads.") args = parser.parse_args() diff --git a/train.py b/train.py index d7ef0932..33e99d1b 100644 --- a/train.py +++ b/train.py @@ -398,7 +398,7 @@ def run_trial(self, trial, *args, **kwargs): cfg.TRAIN_CACHE_MODE = args.cache_mode.lower() cfg.TRAIN_CACHE_FILE = args.cache_file cfg.TFLITE_THREADS = 1 - cfg.CPU_THREADS = cfg.CPU_THREADS = max(1, int(args.threads)) + cfg.CPU_THREADS = max(1, int(args.threads)) cfg.BANDPASS_FMIN = max(0, min(cfg.SIG_FMAX, int(args.fmin))) cfg.BANDPASS_FMAX = max(cfg.SIG_FMIN, min(cfg.SIG_FMAX, int(args.fmax)))