Skip to content

Commit

Permalink
Target dirs and imports (#419)
Browse files Browse the repository at this point in the history
* Prevent exception when emeddings result is empty

* Make target dir for result file

* add keras-tuner

* move save results to utils
  • Loading branch information
kahst authored Aug 26, 2024
1 parent 8b4c4f5 commit 2ca2b3a
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 14 deletions.
17 changes: 6 additions & 11 deletions analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,8 @@ def generate_raven_table(timestamps: list[str], result: dict[str, list], afile_p
out_string += (
f"{selection_id}\tSpectrogram 1\t1\t0\t3\t{low_freq}\t{high_freq}\tnocall\tnocall\t1.0\t{afile_path}\t0\n"
)

with open(result_path, "w", encoding="utf-8") as rfile:
rfile.write(out_string)

utils.save_result_file(result_path, out_string)


def generate_audacity(timestamps: list[str], result: dict[str, list], result_path: str) -> str:
Expand All @@ -97,8 +96,7 @@ def generate_audacity(timestamps: list[str], result: dict[str, list], result_pat
# Write result string to file
out_string += rstring

with open(result_path, "w", encoding="utf-8") as rfile:
rfile.write(out_string)
utils.save_result_file(result_path, out_string)


def generate_rtable(timestamps: list[str], result: dict[str, list], afile_path: str, result_path: str) -> str:
Expand Down Expand Up @@ -131,8 +129,7 @@ def generate_rtable(timestamps: list[str], result: dict[str, list], afile_path:
# Write result string to file
out_string += rstring

with open(result_path, "w", encoding="utf-8") as rfile:
rfile.write(out_string)
utils.save_result_file(result_path, out_string)


def generate_kaleidoscope(timestamps: list[str], result: dict[str, list], afile_path: str, result_path: str) -> str:
Expand Down Expand Up @@ -167,8 +164,7 @@ def generate_kaleidoscope(timestamps: list[str], result: dict[str, list], afile_
# Write result string to file
out_string += rstring

with open(result_path, "w", encoding="utf-8") as rfile:
rfile.write(out_string)
utils.save_result_file(result_path, out_string)


def generate_csv(timestamps: list[str], result: dict[str, list], afile_path: str, result_path: str) -> str:
Expand All @@ -187,8 +183,7 @@ def generate_csv(timestamps: list[str], result: dict[str, list], afile_path: str
# Write result string to file
out_string += rstring

with open(result_path, "w", encoding="utf-8") as rfile:
rfile.write(out_string)
utils.save_result_file(result_path, out_string)


def saveResultFiles(r: dict[str, list], result_files: dict[str, str], afile_path: str):
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ gradio
pywebview
tqdm
bottle
requests
requests
keras-tuner
9 changes: 7 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def _loadAudioFile(f, label_vector, config):
except Exception as e:
# Print Error
print(f"\t Error when loading file {f}", flush=True)
print(f"\t {e}", flush=True)
return np.array([]), np.array([])

# Crop training samples
Expand Down Expand Up @@ -169,8 +170,12 @@ def _loadTrainingData(cache_mode="none", cache_file="", progress_callback=None):
with tqdm.tqdm(total=len(tasks), desc=f" - loading '{folder}'", unit='f') as progress_bar:
for task in tasks:
result = task.get()
x_train += result[0]
y_train += result[1]
# Make sure result is not empty
# Empty results might be caused by errors when loading the audio file
# TODO: We should check for embeddings size in result, otherwise we can't add them to the training data
if len(result[0]) > 0:
x_train += result[0]
y_train += result[1]
num_files_processed += 1
progress_bar.update(1)
if progress_callback:
Expand Down
15 changes: 15 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,3 +640,18 @@ def save_model_params(file_path):
cfg.TRAIN_WITH_LABEL_SMOOTHING,
)
)

def save_result_file(result_path: str, out_string: str):
"""Saves the result to a file.
Args:
result_path: The path to the result file.
out_string: The string to be written to the file.
"""

# Make directory if it doesn't exist
os.makedirs(os.path.dirname(result_path), exist_ok=True)

# Write the result to the file
with open(result_path, "w", encoding="utf-8") as rfile:
rfile.write(out_string)

0 comments on commit 2ca2b3a

Please sign in to comment.