Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
SWivid committed Nov 2, 2024
1 parent dc67a68 commit f7e248e
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions src/f5_tts/model/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(
gradient_accumulation_steps=grad_accumulation_steps,
**accelerate_kwargs,
)
self.device = self.accelerator.device

self.logger = logger
if self.logger == "wandb":
if exists(wandb_resume_id):
Expand Down Expand Up @@ -325,7 +325,9 @@ def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int

if self.log_samples and self.accelerator.is_local_main_process:
ref_audio, ref_audio_len = vocoder.decode(batch["mel"][0].unsqueeze(0)), mel_lengths[0]
torchaudio.save(f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio.cpu(), target_sample_rate)
torchaudio.save(
f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio.cpu(), target_sample_rate
)
with torch.inference_mode():
generated, _ = self.accelerator.unwrap_model(self.model).sample(
cond=mel_spec[0][:ref_audio_len].unsqueeze(0),
Expand All @@ -336,8 +338,12 @@ def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int
sway_sampling_coef=sway_sampling_coef,
)
generated = generated.to(torch.float32)
gen_audio = vocoder.decode(generated[:, ref_audio_len:, :].permute(0, 2, 1).to(self.device))
torchaudio.save(f"{log_samples_path}/step_{global_step}_gen.wav", gen_audio.cpu(), target_sample_rate)
gen_audio = vocoder.decode(
generated[:, ref_audio_len:, :].permute(0, 2, 1).to(self.accelerator.device)
)
torchaudio.save(
f"{log_samples_path}/step_{global_step}_gen.wav", gen_audio.cpu(), target_sample_rate
)

if global_step % self.last_per_steps == 0:
self.save_checkpoint(global_step, last=True)
Expand Down

0 comments on commit f7e248e

Please sign in to comment.