diff --git a/src/f5_tts/model/trainer.py b/src/f5_tts/model/trainer.py index 96f5d74b..bd96da7c 100644 --- a/src/f5_tts/model/trainer.py +++ b/src/f5_tts/model/trainer.py @@ -61,7 +61,7 @@ def __init__( gradient_accumulation_steps=grad_accumulation_steps, **accelerate_kwargs, ) - self.device = self.accelerator.device + self.logger = logger if self.logger == "wandb": if exists(wandb_resume_id): @@ -325,7 +325,9 @@ def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int if self.log_samples and self.accelerator.is_local_main_process: ref_audio, ref_audio_len = vocoder.decode(batch["mel"][0].unsqueeze(0)), mel_lengths[0] - torchaudio.save(f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio.cpu(), target_sample_rate) + torchaudio.save( + f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio.cpu(), target_sample_rate + ) with torch.inference_mode(): generated, _ = self.accelerator.unwrap_model(self.model).sample( cond=mel_spec[0][:ref_audio_len].unsqueeze(0), @@ -336,8 +338,12 @@ def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int sway_sampling_coef=sway_sampling_coef, ) generated = generated.to(torch.float32) - gen_audio = vocoder.decode(generated[:, ref_audio_len:, :].permute(0, 2, 1).to(self.device)) - torchaudio.save(f"{log_samples_path}/step_{global_step}_gen.wav", gen_audio.cpu(), target_sample_rate) + gen_audio = vocoder.decode( + generated[:, ref_audio_len:, :].permute(0, 2, 1).to(self.accelerator.device) + ) + torchaudio.save( + f"{log_samples_path}/step_{global_step}_gen.wav", gen_audio.cpu(), target_sample_rate + ) if global_step % self.last_per_steps == 0: self.save_checkpoint(global_step, last=True)