Skip to content

Commit

Permalink
Merge pull request #516 from SWivid/bugfix/fix-vocoder-bug
Browse files Browse the repository at this point in the history
fix vocoder generate sample and #467
  • Loading branch information
SWivid authored Nov 23, 2024
2 parents 8ee20d3 + c65b9d3 commit 934592a
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 18 deletions.
6 changes: 3 additions & 3 deletions src/f5_tts/eval/eval_infer_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,13 +189,13 @@ def main():
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32)
if mel_spec_type == "vocos":
generated_wave = vocoder.decode(gen_mel_spec)
generated_wave = vocoder.decode(gen_mel_spec).cpu()
elif mel_spec_type == "bigvgan":
generated_wave = vocoder(gen_mel_spec)
generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu()

if ref_rms_list[i] < target_rms:
generated_wave = generated_wave * ref_rms_list[i] / target_rms
torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave.cpu(), target_sample_rate)
torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave, target_sample_rate)

accelerator.wait_for_everyone()
if accelerator.is_main_process:
Expand Down
6 changes: 3 additions & 3 deletions src/f5_tts/infer/speech_edit.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,13 +181,13 @@
generated = generated[:, ref_audio_len:, :]
gen_mel_spec = generated.permute(0, 2, 1)
if mel_spec_type == "vocos":
generated_wave = vocoder.decode(gen_mel_spec)
generated_wave = vocoder.decode(gen_mel_spec).cpu()
elif mel_spec_type == "bigvgan":
generated_wave = vocoder(gen_mel_spec)
generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu()

if rms < target_rms:
generated_wave = generated_wave * rms / target_rms

save_spectrogram(gen_mel_spec[0].cpu().numpy(), f"{output_dir}/speech_edit_out.png")
torchaudio.save(f"{output_dir}/speech_edit_out.wav", generated_wave.cpu(), target_sample_rate)
torchaudio.save(f"{output_dir}/speech_edit_out.wav", generated_wave, target_sample_rate)
print(f"Generated wav: {generated_wave.shape}")
29 changes: 17 additions & 12 deletions src/f5_tts/model/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,26 +324,31 @@ def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int
self.save_checkpoint(global_step)

if self.log_samples and self.accelerator.is_local_main_process:
ref_audio, ref_audio_len = vocoder.decode(batch["mel"][0].unsqueeze(0)), mel_lengths[0]
torchaudio.save(
f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio.cpu(), target_sample_rate
)
ref_audio_len = mel_lengths[0]
infer_text = [
text_inputs[0] + ([" "] if isinstance(text_inputs[0], list) else " ") + text_inputs[0]
]
with torch.inference_mode():
generated, _ = self.accelerator.unwrap_model(self.model).sample(
cond=mel_spec[0][:ref_audio_len].unsqueeze(0),
text=[text_inputs[0] + [" "] + text_inputs[0]],
text=infer_text,
duration=ref_audio_len * 2,
steps=nfe_step,
cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef,
)
generated = generated.to(torch.float32)
gen_audio = vocoder.decode(
generated[:, ref_audio_len:, :].permute(0, 2, 1).to(self.accelerator.device)
)
torchaudio.save(
f"{log_samples_path}/step_{global_step}_gen.wav", gen_audio.cpu(), target_sample_rate
)
generated = generated.to(torch.float32)
gen_mel_spec = generated[:, ref_audio_len:, :].permute(0, 2, 1).to(self.accelerator.device)
ref_mel_spec = batch["mel"][0].unsqueeze(0)
if self.vocoder_name == "vocos":
gen_audio = vocoder.decode(gen_mel_spec).cpu()
ref_audio = vocoder.decode(ref_mel_spec).cpu()
elif self.vocoder_name == "bigvgan":
gen_audio = vocoder(gen_mel_spec).squeeze(0).cpu()
ref_audio = vocoder(ref_mel_spec).squeeze(0).cpu()

torchaudio.save(f"{log_samples_path}/step_{global_step}_gen.wav", gen_audio, target_sample_rate)
torchaudio.save(f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio, target_sample_rate)

if global_step % self.last_per_steps == 0:
self.save_checkpoint(global_step, last=True)
Expand Down

0 comments on commit 934592a

Please sign in to comment.