Skip to content

Commit

Permalink
simplify advanced demo
Browse files Browse the repository at this point in the history
  • Loading branch information
Flux9665 committed Oct 7, 2024
1 parent 335a8ba commit e18c227
Showing 1 changed file with 11 additions and 56 deletions.
67 changes: 11 additions & 56 deletions run_advanced_GUI_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,10 @@
from PyQt5.QtWidgets import QComboBox
from PyQt5.QtWidgets import QFileDialog
from PyQt5.QtWidgets import QHBoxLayout
from PyQt5.QtWidgets import QLabel
from PyQt5.QtWidgets import QLineEdit
from PyQt5.QtWidgets import QMainWindow
from PyQt5.QtWidgets import QMessageBox
from PyQt5.QtWidgets import QPushButton
from PyQt5.QtWidgets import QSlider
from PyQt5.QtWidgets import QVBoxLayout
from PyQt5.QtWidgets import QWidget
from huggingface_hub import hf_hub_download
Expand Down Expand Up @@ -144,9 +142,6 @@ def __init__(self, tts_interface: ToucanTTSInterface):
self.audio_file_path = None
self.result_audio = None
self.min_duration = 1
self.slider_val = 100
self.durations_are_scaled = False
self.prev_slider_val_for_denorm = 100

self.setWindowTitle("TTS Model Interface")
self.setGeometry(100, 100, 1200, 900)
Expand Down Expand Up @@ -177,7 +172,7 @@ def __init__(self, tts_interface: ToucanTTSInterface):
# Initialize plots
self.init_plots()

# Initialize slider and buttons
# Initialize buttons
self.init_controls()

# Initialize Timer for TTS Cooldown
Expand All @@ -189,10 +184,6 @@ def __init__(self, tts_interface: ToucanTTSInterface):
def clear_all_widgets(self):
self.spectrogram_view.setParent(None)
self.pitch_plot.setParent(None)
self.upper_row.setParent(None)
self.slider_label.setParent(None)
self.mod_slider.setParent(None)
self.slider_value_label.setParent(None)
self.generate_button.setParent(None)
self.load_audio_button.setParent(None)
self.save_audio_button.setParent(None)
Expand All @@ -218,6 +209,7 @@ def load_data(self, durations, pitch, spectrogram):

self.durations = durations
self.cumulative_durations = np.cumsum(self.durations)
self.pitch = pitch
self.spectrogram = spectrogram

# Display Spectrogram
Expand Down Expand Up @@ -245,7 +237,7 @@ def load_data(self, durations, pitch, spectrogram):
# Display Durations
self.duration_lines = []
for i, cum_dur in enumerate(self.cumulative_durations):
line = pg.InfiniteLine(pos=cum_dur, angle=90, pen=pg.mkPen('orange', width=4))
line = pg.InfiniteLine(pos=cum_dur, angle=90, pen=pg.mkPen('orange', width=2))
self.spectrogram_view.addItem(line)
line.setMovable(True)
# Use lambda with default argument to capture current i
Expand Down Expand Up @@ -274,28 +266,6 @@ def init_controls(self):
self.controls_layout = QVBoxLayout()
self.main_layout.addLayout(self.controls_layout)

# Upper row layout for slider
self.upper_row = QHBoxLayout()
self.controls_layout.addLayout(self.upper_row)

# Slider Label
self.slider_label = QLabel("Faster")
self.upper_row.addWidget(self.slider_label)

# Slider
self.mod_slider = QSlider(Qt.Horizontal)
self.mod_slider.setMinimum(70)
self.mod_slider.setMaximum(130)
self.mod_slider.setValue(self.slider_val)
self.mod_slider.setTickPosition(QSlider.TicksBelow)
self.mod_slider.setTickInterval(10)
self.mod_slider.valueChanged.connect(self.on_slider_changed)
self.upper_row.addWidget(self.mod_slider)

# Slider Value Display
self.slider_value_label = QLabel("Slower")
self.upper_row.addWidget(self.slider_value_label)

# Lower row layout for buttons
self.lower_row = QHBoxLayout()
self.controls_layout.addLayout(self.lower_row)
Expand Down Expand Up @@ -406,18 +376,12 @@ def on_user_input_changed(self, text):
# Mark that an update is required
self.mark_tts_update()

def on_slider_changed(self, value):
# Update the slider label
# self.slider_value_label.setText(f"Durations at {value}%")
self.slider_val = value
# print(f"Slider changed to {scaling_factor * 100}% speed")
# Mark that an update is required
self.mark_tts_update()

def generate_new_prosody(self):
"""
Generate new prosody.
"""
if self.text_input.text().strip() == "":
return
wave, mel, durations, pitch = self.tts_backend(text=self.text_input.text(),
view=False,
duration_scaling_factor=1.0,
Expand All @@ -433,9 +397,6 @@ def generate_new_prosody(self):
prosody_creativity=0.8,
return_everything=True)
# reset and clear everything
self.slider_val = 100
self.prev_slider_val_for_denorm = self.slider_val
self.durations_are_scaled = False
self.clear_all_widgets()
self.init_plots()
self.init_controls()
Expand Down Expand Up @@ -510,7 +471,8 @@ def save_audio_file(self):

def play_audio(self):
# print("playing current audio...")
sounddevice.play(self.result_audio, samplerate=24000)
if self.result_audio is not None:
sounddevice.play(self.result_audio, samplerate=24000)

def update_result_audio(self, audio_array):
"""
Expand All @@ -525,7 +487,7 @@ def mark_tts_update(self):
Marks that a TTS update is required and starts/resets the timer.
"""
self.tts_update_required = True
self.tts_timer.start(600) # 600 milliseconds
self.tts_timer.start(800) # 800 milliseconds delay before the model starts to compute something

def run_tts(self):
"""
Expand Down Expand Up @@ -553,16 +515,12 @@ def run_tts(self):
phonemes = self.tts_backend.text2phone.get_phone_string(text=text)
self.phonemes = phonemes.replace(" ", "")

forced_durations = None if self.durations is None or len(self.durations) != len(self.phonemes) else insert_zeros_at_indexes(self.durations, self.word_boundaries)
if forced_durations is not None and self.durations_are_scaled:
forced_durations = torch.LongTensor([forced_duration / (self.prev_slider_val_for_denorm / 100) for forced_duration in forced_durations]).unsqueeze(0) # revert scaling
elif forced_durations is not None:
forced_durations = torch.LongTensor(forced_durations).unsqueeze(0)
forced_durations = None if self.durations is None or len(self.durations) != len(self.phonemes) else torch.LongTensor(insert_zeros_at_indexes(self.durations, self.word_boundaries)).unsqueeze(0)
forced_pitch = None if self.pitch is None or len(self.pitch) != len(self.phonemes) else torch.tensor(insert_zeros_at_indexes(self.pitch, self.word_boundaries)).unsqueeze(0)

wave, mel, durations, pitch = self.tts_backend(text,
view=False,
duration_scaling_factor=self.slider_val / 100,
duration_scaling_factor=1.0,
pitch_variance_scale=1.0,
energy_variance_scale=1.0,
pause_duration_scaling_factor=1.0,
Expand All @@ -576,9 +534,6 @@ def run_tts(self):
return_everything=True)

self.word_boundaries = find_zero_indexes(durations)
self.prev_slider_val_for_denorm = self.slider_val
if self.slider_val != 100:
self.durations_are_scaled = True

self.load_data(durations=durations.cpu().numpy(), pitch=pitch.cpu().numpy(), spectrogram=mel.cpu().transpose(0, 1).numpy())

Expand All @@ -602,7 +557,7 @@ def main():
}
QPushButton {
background-color: #808000;
background-color: #b9770e;
border: 1px solid #ffffff;
color: #ffffff;
padding: 8px 16px;
Expand Down

0 comments on commit e18c227

Please sign in to comment.