diff --git a/hallo/datasets/audio_processor.py b/hallo/datasets/audio_processor.py index 683282fb..50738970 100644 --- a/hallo/datasets/audio_processor.py +++ b/hallo/datasets/audio_processor.py @@ -73,7 +73,7 @@ def __init__( self.wav2vec_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec_model_path, local_files_only=True) - def preprocess(self, wav_file: str): + def preprocess(self, wav_file: str, clip_length: int): """ Preprocess a WAV audio file by separating the vocals from the background and resampling it to a 16 kHz sample rate. The separated vocal track is then converted into wav2vec2 for further processing or analysis. @@ -106,8 +106,12 @@ def preprocess(self, wav_file: str): speech_array, sampling_rate = librosa.load(vocal_audio_file, sr=self.sample_rate) audio_feature = np.squeeze(self.wav2vec_feature_extractor(speech_array, sampling_rate=sampling_rate).input_values) seq_len = math.ceil(len(audio_feature) / self.sample_rate * self.fps) + audio_length = seq_len audio_feature = torch.from_numpy(audio_feature).float().to(device=self.device) + if seq_len % clip_length != 0: + audio_feature = torch.nn.functional.pad(audio_feature, (0, (clip_length - seq_len % clip_length) * (self.sample_rate // self.fps)), 'constant', 0.0) + seq_len += clip_length - seq_len % clip_length audio_feature = audio_feature.unsqueeze(0) with torch.no_grad(): @@ -121,7 +125,7 @@ def preprocess(self, wav_file: str): audio_emb = audio_emb.cpu().detach() - return audio_emb + return audio_emb, audio_length def get_embedding(self, wav_file: str): """preprocess wav audio file convert to embeddings diff --git a/scripts/inference.py b/scripts/inference.py index c2ef0bbb..8dfefbf8 100644 --- a/scripts/inference.py +++ b/scripts/inference.py @@ -178,7 +178,7 @@ def inference_process(args: argparse.Namespace): os.path.basename(audio_separator_model_file), os.path.join(save_path, "audio_preprocess") ) as audio_processor: - audio_emb = audio_processor.preprocess(driving_audio_path) + audio_emb, audio_length = audio_processor.preprocess(driving_audio_path, clip_length) # 4. build modules sched_kwargs = OmegaConf.to_container(config.noise_scheduler_kwargs) @@ -339,6 +339,7 @@ def inference_process(args: argparse.Namespace): tensor_result = torch.cat(tensor_result, dim=2) tensor_result = tensor_result.squeeze(0) + tensor_result = tensor_result[:, :audio_length] output_file = config.output # save the result after all iteration