Skip to content

Commit

Permalink
Download review plots (#529)
Browse files Browse the repository at this point in the history
* huggingface buttons + allow download of review plots

* huggingface instead of primary button style
  • Loading branch information
Josef-Haupt authored Dec 17, 2024
1 parent 9bd2b77 commit be14661
Show file tree
Hide file tree
Showing 7 changed files with 39 additions and 7 deletions.
2 changes: 1 addition & 1 deletion birdnet_analyzer/gui/multi_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def select_directory_wrapper(): # Nishant - Function modified for For Folder se

locale_radio = gu.locale()

start_batch_analysis_btn = gr.Button(loc.localize("analyze-start-button-label"))
start_batch_analysis_btn = gr.Button(loc.localize("analyze-start-button-label"), variant="huggingface")

result_grid = gr.Matrix(
headers=[
Expand Down
33 changes: 32 additions & 1 deletion birdnet_analyzer/gui/review.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import base64
import io
import os
import random
from functools import partial

import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from scipy.special import expit
from sklearn import linear_model

Expand Down Expand Up @@ -43,6 +46,7 @@ def collect_files(directory):

def create_log_plot(positives, negatives, fig_num=None):
f = plt.figure(fig_num, figsize=(12, 6))
f.tight_layout()
f.set_dpi(300)
f.clf()

Expand Down Expand Up @@ -137,14 +141,19 @@ def create_log_plot(positives, negatives, fig_num=None):

with gr.Column() as review_item_col:
with gr.Row():
spectrogram_image = gr.Plot(label=loc.localize("review-tab-spectrogram-plot-label"))
with gr.Column():
spectrogram_image = gr.Plot(label=loc.localize("review-tab-spectrogram-plot-label"), show_label=False)
with gr.Row():
spectrogram_dl_btn = gr.Button("Download spectrogram", size="sm")
regression_dl_btn = gr.Button("Download regression", size="sm")

with gr.Column():
with gr.Row():
skip_btn = gr.Button(loc.localize("review-tab-skip-button-label"))
undo_btn = gr.Button(loc.localize("review-tab-undo-button-label"))
positive_btn = gr.Button(loc.localize("review-tab-pos-button-label"))
negative_btn = gr.Button(loc.localize("review-tab-neg-button-label"))

with gr.Group():
review_audio = gr.Audio(
type="filepath", sources=[], show_download_button=False, autoplay=True
Expand Down Expand Up @@ -340,6 +349,21 @@ def undo_review(next_review_state):
def toggle_autoplay(value):
return gr.Audio(autoplay=value)

def download_plot(plot, filename=""):
imgdata = base64.b64decode(plot.plot.split(",", 1)[1])
res = gu._WINDOW.create_file_dialog(
gu.webview.SAVE_DIALOG, file_types=("PNG (*.png)", "Webp (*.webp)", "JPG (*.jpg)"), save_filename=filename
)

if res:
if res.endswith(".webp"):
with open(res, "wb") as f:
f.write(imgdata)
else:
output_format = res.rsplit(".", 1)[-1].upper()
img = Image.open(io.BytesIO(imgdata))
img.save(res, output_format if output_format in ["PNG", "JPEG"] else "PNG")

autoplay_checkbox.change(toggle_autoplay, inputs=autoplay_checkbox, outputs=review_audio)

review_change_output = [
Expand All @@ -355,6 +379,13 @@ def toggle_autoplay(value):
undo_btn,
]

spectrogram_dl_btn.click(
partial(download_plot, filename="spectrogram"), show_progress=False, inputs=spectrogram_image
)
regression_dl_btn.click(
partial(download_plot, filename="regression"), show_progress=False, inputs=species_regression_plot
)

species_dropdown.change(
select_subdir,
show_progress=True,
Expand Down
2 changes: 1 addition & 1 deletion birdnet_analyzer/gui/segments.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def select_directory_to_state_and_tb(state_key):
minimum=1,
)

extract_segments_btn = gr.Button(loc.localize("segments-tab-extract-button-label"))
extract_segments_btn = gr.Button(loc.localize("segments-tab-extract-button-label"), variant="huggingface")

result_grid = gr.Matrix(
headers=[
Expand Down
4 changes: 2 additions & 2 deletions birdnet_analyzer/gui/single_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def build_single_analysis_tab():
with gr.Tab(loc.localize("single-tab-title")):
audio_input = gr.Audio(type="filepath", label=loc.localize("single-audio-label"), sources=["upload"])
with gr.Group():
spectogram_output = gr.Plot(label=loc.localize("review-tab-spectrogram-plot-label"), visible=False)
spectogram_output = gr.Plot(label=loc.localize("review-tab-spectrogram-plot-label"), visible=False, show_label=False)
generate_spectrogram_cb = gr.Checkbox(
value=True,
label=loc.localize("single-tab-spectrogram-checkbox-label"),
Expand Down Expand Up @@ -169,7 +169,7 @@ def try_generate_spectrogram(audio_path, generate_spectrogram):
elem_classes="matrix-mh-200",
elem_id="single-file-output",
)
single_file_analyze = gr.Button(loc.localize("analyze-start-button-label"))
single_file_analyze = gr.Button(loc.localize("analyze-start-button-label"), variant="huggingface")
hidden_segment_audio = gr.Audio(visible=False, autoplay=True, type="numpy")

def play_selected_audio(evt: gr.SelectData, audio_path):
Expand Down
2 changes: 1 addition & 1 deletion birdnet_analyzer/gui/species.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def select_directory_and_update_tb(name_tb):
info=loc.localize("species-tab-sort-radio-info"),
)

start_btn = gr.Button(loc.localize("species-tab-start-button-label"))
start_btn = gr.Button(loc.localize("species-tab-start-button-label"), variant="huggingface")
start_btn.click(
run_species_list,
inputs=[
Expand Down
2 changes: 1 addition & 1 deletion birdnet_analyzer/gui/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def on_cache_mode_change(value):
)

train_history_plot = gr.Plot()
start_training_button = gr.Button(loc.localize("training-tab-start-training-button-label"))
start_training_button = gr.Button(loc.localize("training-tab-start-training-button-label"), variant="huggingface")

start_training_button.click(
start_training,
Expand Down
1 change: 1 addition & 0 deletions birdnet_analyzer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def spectrogram_from_file(path, fig_num=None, fig_size=None):
f.clf()

ax = f.add_subplot(111)
f.tight_layout()
D = librosa.stft(s, n_fft=1024, hop_length=512) # STFT of y
S_db = librosa.amplitude_to_db(np.abs(D), ref=np.max)

Expand Down

0 comments on commit be14661

Please sign in to comment.