diff --git a/clearvoice/clearvoice.py b/clearvoice/clearvoice.py index ddb36ab..cf8a602 100644 --- a/clearvoice/clearvoice.py +++ b/clearvoice/clearvoice.py @@ -35,16 +35,19 @@ def __init__(self, task, model_names): model = self.network_wrapper(task, model_name) self.models += [model] - def __call__(self, input_path, online_write=False, output_path=None): + def __call__(self, input_path, online_write=False, output_path=None, extract_noise=False): results = {} for model in self.models: - result = model.process(input_path, online_write, output_path) + result = model.process(input_path, online_write, output_path, extract_noise) if not online_write: - results[model.name] = result + if extract_noise: + results[model.name] = result # result 现在是 (enhanced, noise) 元组 + else: + results[model.name] = result if not online_write: if len(results) == 1: - return results[model.name] + return next(iter(results.values())) else: return results diff --git a/clearvoice/networks.py b/clearvoice/networks.py index 9cdce3b..663e5ea 100644 --- a/clearvoice/networks.py +++ b/clearvoice/networks.py @@ -150,133 +150,160 @@ def load_model(self): self.model.load_state_dict(state) #print(f'Successfully loaded {model_name} for decoding') - def decode(self): - """ - Decodes the input audio data using the loaded model and ensures the output matches the original audio length. - + def decode(self, extract_noise=False): + """Decode the audio data using the model. + This method processes the audio through a speech model (e.g., for enhancement, separation, etc.), and truncates the resulting audio to match the original input's length. The method supports multiple speakers if the model handles multi-speaker audio. + Args: + extract_noise (bool): Whether to extract noise signal. + Returns: - output_audio: The decoded audio after processing, truncated to the input audio length. - If multi-speaker audio is processed, a list of truncated audio outputs per speaker is returned. + If extract_noise is True: + tuple: (enhanced_audio, noise_audio), both truncated to input audio length + else: + output_audio: The decoded audio after processing, truncated to the input audio length. + If multi-speaker audio is processed, a list of truncated audio outputs per speaker is returned. """ - # Decode the audio using the loaded model on the given device (e.g., CPU or GPU) - output_audio = decode_one_audio(self.model, self.device, self.data['audio'], self.args) - - # Ensure the decoded output matches the length of the input audio - if isinstance(output_audio, list): - # If multi-speaker audio (a list of outputs), truncate each speaker's audio to input length - for spk in range(self.args.num_spks): - output_audio[spk] = output_audio[spk][:self.data['audio_len']] + # Decode the audio using the loaded model + output = decode_one_audio(self.model, self.device, self.data['audio'], self.args, extract_noise) + + if extract_noise: + enhanced_audio, noise_audio = output + # Ensure the decoded outputs match the length of the input audio + if isinstance(enhanced_audio, list): + # For multi-speaker audio + for spk in range(self.args.num_spks): + enhanced_audio[spk] = enhanced_audio[spk][:self.data['audio_len']] + noise_audio[spk] = noise_audio[spk][:self.data['audio_len']] + else: + # Single output + enhanced_audio = enhanced_audio[:self.data['audio_len']] + noise_audio = noise_audio[:self.data['audio_len']] + return enhanced_audio, noise_audio else: - # Single output, truncate to input audio length - output_audio = output_audio[:self.data['audio_len']] - - return output_audio + # Original functionality for non-noise extraction + output_audio = output + if isinstance(output_audio, list): + for spk in range(self.args.num_spks): + output_audio[spk] = output_audio[spk][:self.data['audio_len']] + else: + output_audio = output_audio[:self.data['audio_len']] + return output_audio + + def process(self, input_path, online_write=False, output_path=None, extract_noise=False): + """Process audio files using the model. - def process(self, input_path, online_write=False, output_path=None): - """ - Load and process audio files from the specified input path. Optionally, - write the output audio files to the specified output directory. - Args: - input_path (str): Path to the input audio files or folder. - online_write (bool): Whether to write the processed audio to disk in real-time. - output_path (str): Optional path for writing output files. If None, output - will be stored in self.result. - + input_path (str): Path to input audio file or directory + online_write (bool): Whether to write output files during processing + output_path (str): Path for output files (if online_write is True) + extract_noise (bool): Whether to extract noise signal + Returns: - dict or ndarray: Processed audio results either as a dictionary or as a single array, - depending on the number of audio files processed. - Returns None if online_write is enabled. + If not online_write: + If single file: + enhanced_audio or (enhanced_audio, noise_audio) + If multiple files: + dict of enhanced_audio or dict of (enhanced_audio, noise_audio) """ - self.result = {} self.args.input_path = input_path - data_reader = DataReader(self.args) # Initialize a data reader to load the audio files - + data_reader = DataReader(self.args) - # Check if online writing is enabled if online_write: - output_wave_dir = self.args.output_dir # Set the default output directory - if isinstance(output_path, str): # If a specific output path is provided, use it + output_wave_dir = self.args.output_dir + if isinstance(output_path, str): output_wave_dir = os.path.join(output_path, self.name) - # Create the output directory if it does not exist if not os.path.isdir(output_wave_dir): os.makedirs(output_wave_dir) - - num_samples = len(data_reader) # Get the total number of samples to process - print(f'Running {self.name} ...') # Display the model being used + + num_samples = len(data_reader) + print(f'Running {self.name} ...') if self.args.task == 'target_speaker_extraction': from utils.video_process import process_tse assert online_write == True process_tse(self.args, self.model, self.device, data_reader, output_wave_dir) else: - # Disable gradient calculation for better efficiency during inference with torch.no_grad(): - for idx in tqdm(range(num_samples)): # Loop over all audio samples + for idx in tqdm(range(num_samples)): self.data = {} - # Read the audio, waveform ID, and audio length from the data reader input_audio, wav_id, input_len, scalar = data_reader[idx] - # Store the input audio and metadata in self.data self.data['audio'] = input_audio self.data['id'] = wav_id self.data['audio_len'] = input_len - - # Perform the audio decoding/processing - output_audio = self.decode() - # Perform audio renormalization - if not isinstance(output_audio, list): - output_audio = output_audio * scalar + if extract_noise: + # Get enhanced audio and noise + enhanced_audio, noise_audio = self.decode(extract_noise=True) - if online_write: - # If online writing is enabled, save the output audio to files - if isinstance(output_audio, list): - # In case of multi-speaker output, save each speaker's output separately - for spk in range(self.args.num_spks): - output_file = os.path.join(output_wave_dir, wav_id.replace('.wav', f'_s{spk+1}.wav')) - sf.write(output_file, output_audio[spk], self.args.sampling_rate) + if not isinstance(enhanced_audio, list): + enhanced_audio = enhanced_audio * scalar + noise_audio = noise_audio * scalar + + if online_write: + if isinstance(enhanced_audio, list): + # Handle multi-speaker case + for spk in range(self.args.num_spks): + # Save enhanced audio + output_file = os.path.join(output_wave_dir, + wav_id.replace('.wav', f'_s{spk+1}.wav')) + sf.write(output_file, enhanced_audio[spk], self.args.sampling_rate) + + # Save corresponding noise + noise_file = os.path.join(output_wave_dir, + wav_id.replace('.wav', f'_s{spk+1}_noise.wav')) + sf.write(noise_file, noise_audio[spk], self.args.sampling_rate) + else: + # Save enhanced audio + output_file = os.path.join(output_wave_dir, wav_id) + sf.write(output_file, enhanced_audio, self.args.sampling_rate) + + # Save noise + noise_file = os.path.join(output_wave_dir, + wav_id.replace('.wav', '_noise.wav')) + sf.write(noise_file, noise_audio, self.args.sampling_rate) else: - # Single-speaker or standard output - output_file = os.path.join(output_wave_dir, wav_id) - sf.write(output_file, output_audio, self.args.sampling_rate) + self.result[wav_id] = (enhanced_audio, noise_audio) else: - # If not writing to disk, store the output in the result dictionary - self.result[wav_id] = output_audio - - # Return the processed results if not writing to disk - if not online_write: - if len(self.result) == 1: - # If there is only one result, return it directly - return next(iter(self.result.values())) - else: - # Otherwise, return the entire result dictionary - return self.result - + # Original processing logic + output_audio = self.decode() + + if not isinstance(output_audio, list): + output_audio = output_audio * scalar + + if online_write: + if isinstance(output_audio, list): + for spk in range(self.args.num_spks): + output_file = os.path.join(output_wave_dir, + wav_id.replace('.wav', f'_s{spk+1}.wav')) + sf.write(output_file, output_audio[spk], self.args.sampling_rate) + else: + output_file = os.path.join(output_wave_dir, wav_id) + sf.write(output_file, output_audio, self.args.sampling_rate) + else: + self.result[wav_id] = output_audio + + if not online_write: + if len(self.result) == 1: + return next(iter(self.result.values())) + else: + return self.result def write(self, output_path, add_subdir=False, use_key=False): - """ - Write the processed audio results to the specified output path. + """Write the processed audio results to the specified output path. Args: - output_path (str): The directory or file path where processed audio will be saved. If not - provided, defaults to self.args.output_dir. - add_subdir (bool): If True, appends the model name as a subdirectory to the output path. - use_key (bool): If True, uses the result dictionary's keys (audio file IDs) for filenames. - - Returns: - None: Outputs are written to disk, no data is returned. + output_path (str): The directory or file path where processed audio will be saved. + add_subdir (bool): If True, appends the model name as a subdirectory. + use_key (bool): If True, uses the result dictionary's keys for filenames. """ - - # Ensure the output path is a string. If not provided, use the default output directory if not isinstance(output_path, str): output_path = self.args.output_dir - # If add_subdir is enabled, create a subdirectory for the model name if add_subdir: if os.path.isfile(output_path): print(f'File exists: {output_path}, remove it and try again!') @@ -285,35 +312,67 @@ def write(self, output_path, add_subdir=False, use_key=False): if not os.path.isdir(output_path): os.makedirs(output_path) - # Ensure proper directory setup when using keys for filenames if use_key and not os.path.isdir(output_path): if os.path.exists(output_path): print(f'File exists: {output_path}, remove it and try again!') return os.makedirs(output_path) - # If not using keys and output path is a directory, check for conflicts + if not use_key and os.path.isdir(output_path): print(f'Directory exists: {output_path}, remove it and try again!') return - # Iterate over the results dictionary to write the processed audio to disk for key in self.result: + result_value = self.result[key] + is_tuple = isinstance(result_value, tuple) + if use_key: - # If using keys, format filenames based on the result dictionary's keys (audio IDs) - if isinstance(self.result[key], list): # For multi-speaker outputs + if isinstance(result_value[0] if is_tuple else result_value, list): for spk in range(self.args.num_spks): - sf.write(os.path.join(output_path, key.replace('.wav', f'_s{spk+1}.wav')), - self.result[key][spk], self.args.sampling_rate) + # Save enhanced audio + enhanced_path = os.path.join(output_path, key.replace('.wav', f'_s{spk+1}.wav')) + sf.write(enhanced_path, + result_value[0][spk] if is_tuple else result_value[spk], + self.args.sampling_rate) + + # Save noise if available + if is_tuple: + noise_path = os.path.join(output_path, key.replace('.wav', f'_s{spk+1}_noise.wav')) + sf.write(noise_path, result_value[1][spk], self.args.sampling_rate) else: - sf.write(os.path.join(output_path, key), self.result[key], self.args.sampling_rate) + # Save enhanced audio + enhanced_path = os.path.join(output_path, key) + sf.write(enhanced_path, + result_value[0] if is_tuple else result_value, + self.args.sampling_rate) + + # Save noise if available + if is_tuple: + noise_path = os.path.join(output_path, key.replace('.wav', '_noise.wav')) + sf.write(noise_path, result_value[1], self.args.sampling_rate) else: - # If not using keys, write audio to the specified output path directly - if isinstance(self.result[key], list): # For multi-speaker outputs + if isinstance(result_value[0] if is_tuple else result_value, list): for spk in range(self.args.num_spks): - sf.write(output_path.replace('.wav', f'_s{spk+1}.wav'), - self.result[key][spk], self.args.sampling_rate) + # Save enhanced audio + enhanced_path = output_path.replace('.wav', f'_s{spk+1}.wav') + sf.write(enhanced_path, + result_value[0][spk] if is_tuple else result_value[spk], + self.args.sampling_rate) + + # Save noise if available + if is_tuple: + noise_path = output_path.replace('.wav', f'_s{spk+1}_noise.wav') + sf.write(noise_path, result_value[1][spk], self.args.sampling_rate) else: - sf.write(output_path, self.result[key], self.args.sampling_rate) + # Save enhanced audio + sf.write(output_path, + result_value[0] if is_tuple else result_value, + self.args.sampling_rate) + + # Save noise if available + if is_tuple: + noise_path = output_path.replace('.wav', '_noise.wav') + sf.write(noise_path, result_value[1], self.args.sampling_rate) # The model classes for specific sub-tasks diff --git a/clearvoice/streamlit_app.py b/clearvoice/streamlit_app.py index bbd8345..b36d9aa 100644 --- a/clearvoice/streamlit_app.py +++ b/clearvoice/streamlit_app.py @@ -2,19 +2,53 @@ from clearvoice import ClearVoice import os import tempfile +import soundfile as sf +import subprocess +import shutil st.set_page_config(page_title="ClearerVoice Studio", layout="wide") temp_dir = 'temp' + +def convert_video_to_wav(video_path): + """将视频文件转换为WAV音频文件 + + Args: + video_path: 视频文件路径 + + Returns: + str: 转换后的WAV文件路径 + """ + wav_path = video_path.rsplit('.', 1)[0] + '.wav' + try: + # 使用ffmpeg提取音频并转换为wav格式 + cmd = f'ffmpeg -i "{video_path}" -vn -acodec pcm_s16le -ar 16000 -ac 1 "{wav_path}" -y' + subprocess.run(cmd, shell=True, check=True) + return wav_path + except subprocess.CalledProcessError as e: + st.error(f"Error converting video to audio: {str(e)}") + return None + def save_uploaded_file(uploaded_file): + """保存上传的文件 + + Args: + uploaded_file: Streamlit上传的文件对象 + + Returns: + str: 保存的文件路径 + """ if uploaded_file is not None: - # Check if temp directory exists, create if not + # 确保临时目录存在 if not os.path.exists(temp_dir): os.makedirs(temp_dir) - # Save to temp directory, overwrite if file exists + # 保存文件 temp_path = os.path.join(temp_dir, uploaded_file.name) with open(temp_path, 'wb') as f: - f.write(uploaded_file.getvalue()) + # 使用缓冲区分块读取大文件 + CHUNK_SIZE = 1024 * 1024 # 1MB chunks + for chunk in iter(lambda: uploaded_file.read(CHUNK_SIZE), b''): + f.write(chunk) return temp_path return None @@ -26,37 +60,91 @@ def main(): with tabs[0]: st.header("Speech Enhancement") - # Model selection + # 模型选择 se_models = ['MossFormer2_SE_48K', 'FRCRN_SE_16K', 'MossFormerGAN_SE_16K'] selected_model = st.selectbox("Select Model", se_models) - # File upload - uploaded_file = st.file_uploader("Upload Audio File", type=['wav'], key='se') + # 文件上传 - 支持wav和mp4 + uploaded_file = st.file_uploader("Upload Audio/Video File", + type=['wav', 'mp4'], + key='se', + help="Support WAV and MP4 files with no size limit") if st.button("Start Processing", key='se_process'): if uploaded_file is not None: with st.spinner('Processing...'): - # Save uploaded file - input_path = save_uploaded_file(uploaded_file) - - # Initialize ClearVoice - myClearVoice = ClearVoice(task='speech_enhancement', - model_names=[selected_model]) - - # Process audio - output_wav = myClearVoice(input_path=input_path, - online_write=False) - - # Save processed audio - output_dir = os.path.join(temp_dir, "speech_enhancement_output") - os.makedirs(output_dir, exist_ok=True) - output_path = os.path.join(output_dir, f"output_{selected_model}.wav") - myClearVoice.write(output_wav, output_path=output_path) - - # Display audio - st.audio(output_path) + try: + # 保存上传的文件 + input_path = save_uploaded_file(uploaded_file) + + # 如果是视频文件,转换为WAV + if input_path.endswith('.mp4'): + st.info("Converting video to audio...") + wav_path = convert_video_to_wav(input_path) + if wav_path is None: + st.error("Failed to convert video to audio") + return + input_path = wav_path + + # 初始化ClearVoice + myClearVoice = ClearVoice(task='speech_enhancement', + model_names=[selected_model]) + + # 处理音频 + enhanced_audio, noise_audio = myClearVoice(input_path=input_path, + online_write=False, + extract_noise=True) + + # 保存处理后的音频 + output_dir = os.path.join(temp_dir, "speech_enhancement_output") + os.makedirs(output_dir, exist_ok=True) + + # 设置采样率 + sampling_rate = 48000 if selected_model == 'MossFormer2_SE_48K' else 16000 + + # 保存增强后的语音 + output_path = os.path.join(output_dir, f"enhanced_{selected_model}.wav") + sf.write(output_path, enhanced_audio, sampling_rate) + + # 保存提取的噪音 + noise_path = os.path.join(output_dir, f"noise_{selected_model}.wav") + sf.write(noise_path, noise_audio, sampling_rate) + + # 显示原始音频(如果是WAV文件) + if uploaded_file.name.endswith('.wav'): + st.subheader("Original Audio:") + st.audio(input_path) + + # 显示处理后的音频 + st.subheader("Enhanced Speech:") + st.audio(output_path) + + st.subheader("Extracted Noise:") + st.audio(noise_path) + + # 提供下载链接 + st.download_button( + label="Download Enhanced Audio", + data=open(output_path, 'rb'), + file_name=f"enhanced_{uploaded_file.name.rsplit('.', 1)[0]}.wav", + mime="audio/wav" + ) + + st.download_button( + label="Download Noise Audio", + data=open(noise_path, 'rb'), + file_name=f"noise_{uploaded_file.name.rsplit('.', 1)[0]}.wav", + mime="audio/wav" + ) + + except Exception as e: + st.error(f"An error occurred: {str(e)}") + finally: + # 清理临时文件 + if os.path.exists(input_path): + os.remove(input_path) else: - st.error("Please upload an audio file first") + st.error("Please upload an audio/video file first") with tabs[1]: st.header("Speech Separation") diff --git a/clearvoice/utils/decode.py b/clearvoice/utils/decode.py index dfe0083..682f389 100644 --- a/clearvoice/utils/decode.py +++ b/clearvoice/utils/decode.py @@ -17,35 +17,29 @@ # Constant for normalizing audio values MAX_WAV_VALUE = 32768.0 -def decode_one_audio(model, device, inputs, args): - """Decodes audio using the specified model based on the provided network type. - - This function selects the appropriate decoding function based on the specified - network in the arguments and processes the input audio data accordingly. - +def decode_one_audio(model, device, inputs, args, extract_noise=False): + """Select and call the appropriate decoding function based on network type. + Args: - model (nn.Module): The trained model used for decoding. - device (torch.device): The device (CPU or GPU) to perform computations on. - inputs (torch.Tensor): Input audio tensor. - args (Namespace): Contains arguments for network configuration. - - Returns: - list: A list of decoded audio outputs for each speaker. + model: The model to use for decoding + device: The device to run inference on + inputs: Input audio data + args: Arguments containing model configuration + extract_noise: Whether to extract noise signal """ - # Select decoding function based on the network type specified in args if args.network == 'FRCRN_SE_16K': - return decode_one_audio_frcrn_se_16k(model, device, inputs, args) + return decode_one_audio_frcrn_se_16k(model, device, inputs, args, extract_noise) elif args.network == 'MossFormer2_SE_48K': - return decode_one_audio_mossformer2_se_48k(model, device, inputs, args) + return decode_one_audio_mossformer2_se_48k(model, device, inputs, args, extract_noise) elif args.network == 'MossFormerGAN_SE_16K': - return decode_one_audio_mossformergan_se_16k(model, device, inputs, args) + return decode_one_audio_mossformergan_se_16k(model, device, inputs, args, extract_noise) elif args.network == 'MossFormer2_SS_16K': - return decode_one_audio_mossformer2_ss_16k(model, device, inputs, args) + return decode_one_audio_mossformer2_ss_16k(model, device, inputs, args, extract_noise) else: - print("No network found!") # Print error message if no valid network is specified - return + print("No network found!") + return -def decode_one_audio_mossformer2_ss_16k(model, device, inputs, args): +def decode_one_audio_mossformer2_ss_16k(model, device, inputs, args, extract_noise=False): """Decodes audio using the MossFormer2 model for speech separation at 16kHz. This function handles the audio decoding process by processing the input tensor @@ -57,6 +51,7 @@ def decode_one_audio_mossformer2_ss_16k(model, device, inputs, args): inputs (torch.Tensor): Input audio tensor of shape (B, T), where B is the batch size and T is the number of time steps. args (Namespace): Contains arguments for decoding configuration. + extract_noise (bool): Whether to extract noise signal Returns: list: A list of decoded audio outputs for each speaker. @@ -119,135 +114,149 @@ def decode_one_audio_mossformer2_ss_16k(model, device, inputs, args): out[spk] = out[spk] / rms_out * rms_input return out # Return the list of normalized outputs -def decode_one_audio_frcrn_se_16k(model, device, inputs, args): +def decode_one_audio_frcrn_se_16k(model, device, inputs, args, extract_noise=False): """Decodes audio using the FRCRN model for speech enhancement at 16kHz. - This function processes the input audio tensor either in segments or as a whole, - depending on the length of the input. The model's inference method is applied - to obtain the enhanced audio output. - Args: - model (nn.Module): The trained FRCRN model used for decoding. - device (torch.device): The device (CPU or GPU) to perform computations on. - inputs (torch.Tensor): Input audio tensor of shape (B, T), where B is the batch size - and T is the number of time steps. - args (Namespace): Contains arguments for decoding configuration. + model: The trained FRCRN model used for decoding + device: The device to perform computations on + inputs: Input audio tensor of shape (B, T) + args: Arguments containing model configuration + extract_noise: Whether to extract noise signal Returns: - numpy.ndarray: The decoded audio output, which has been enhanced by the model. + If extract_noise is True: + tuple: (enhanced_audio, noise_audio) + else: + ndarray: enhanced_audio """ - decode_do_segment = False # Flag to determine if segmentation is needed + decode_do_segment = False + window = args.sampling_rate * args.decode_window + stride = int(window * 0.75) + b, t = inputs.shape - window = args.sampling_rate * args.decode_window # Decoding window length - stride = int(window * 0.75) # Decoding stride for segmenting the input - b, t = inputs.shape # Get batch size (b) and input length (t) - - # Check if input length exceeds one-time decode length to decide on segmentation + # 检查是否需要分段处理 if t > args.sampling_rate * args.one_time_decode_length: - decode_do_segment = True # Enable segment decoding for long sequences - - # Pad the inputs to meet the decoding window length requirements - if t < window: - # Pad with zeros if the input length is less than the window size - inputs = np.concatenate([inputs, np.zeros((inputs.shape[0], window - t))], axis=1) - elif t < window + stride: - # Pad the input if its length is less than the window plus stride - padding = window + stride - t - inputs = np.concatenate([inputs, np.zeros((inputs.shape[0], padding))], axis=1) - else: - # Ensure the input length is a multiple of the stride - if (t - window) % stride != 0: - padding = t - (t - window) // stride * stride - inputs = np.concatenate([inputs, np.zeros((inputs.shape[0], padding))], axis=1) + decode_do_segment = True - # Convert inputs to a PyTorch tensor and move to the specified device - inputs = torch.from_numpy(np.float32(inputs)).to(device) - b, t = inputs.shape # Update batch size and input length after conversion + # 转换输入为PyTorch张量并保存原始输入 + original_inputs = torch.from_numpy(np.float32(inputs)).to(device) + inputs = original_inputs.clone() - # Process the inputs in segments if necessary if decode_do_segment: - outputs = np.zeros(t) # Initialize the output array - give_up_length = (window - stride) // 2 # Calculate length to give up at each segment - current_idx = 0 # Initialize current index for segmentation + outputs = np.zeros(t) + if extract_noise: + noise_outputs = np.zeros(t) + give_up_length = (window - stride) // 2 + current_idx = 0 while current_idx + window <= t: - tmp_input = inputs[:, current_idx:current_idx + window] # Get segment input - tmp_output = model.inference(tmp_input).detach().cpu().numpy() # Inference on segment + tmp_input = inputs[:, current_idx:current_idx + window] + # 获取增强后的语音 + tmp_output = model.inference(tmp_input).detach().cpu().numpy() - # For the first segment, use the whole segment minus the give-up length if current_idx == 0: outputs[current_idx:current_idx + window - give_up_length] = tmp_output[:-give_up_length] else: - # For subsequent segments, account for the give-up length - outputs[current_idx + give_up_length:current_idx + window - give_up_length] = tmp_output[give_up_length:-give_up_length] + outputs[current_idx + give_up_length:current_idx + window - give_up_length] = \ + tmp_output[give_up_length:-give_up_length] - current_idx += stride # Move to the next segment - else: - # If no segmentation is required, process the entire input - outputs = model.inference(inputs).detach().cpu().numpy() # Inference on full input + current_idx += stride - return outputs # Return the decoded audio output - -def decode_one_audio_mossformergan_se_16k(model, device, inputs, args): + if extract_noise: + # 计算噪音信号 + original = original_inputs.cpu().numpy()[0, :t] + noise_outputs = original - outputs + return outputs, noise_outputs + return outputs + else: + # 处理完整音频 + enhanced = model.inference(inputs).detach().cpu().numpy() + + if extract_noise: + # 计算噪音信号 + original = original_inputs.cpu().numpy()[0, :t] + noise = original - enhanced + return enhanced, noise + return enhanced + +def decode_one_audio_mossformergan_se_16k(model, device, inputs, args, extract_noise=False): """Decodes audio using the MossFormerGAN model for speech enhancement at 16kHz. - This function processes the input audio tensor either in segments or as a whole, - depending on the length of the input. The `_decode_one_audio_mossformergan_se_16k` + This function processes the input audio tensor either in segments or as a whole, + depending on the length of the input. The `_decode_one_audio_mossformergan_se_16k` function is called to perform the model inference and return the enhanced audio output. Args: model (nn.Module): The trained MossFormerGAN model used for decoding. device (torch.device): The device (CPU or GPU) for computation. - inputs (torch.Tensor): Input audio tensor of shape (B, T), where B is the batch size - and T is the number of time steps. + inputs (torch.Tensor): Input audio tensor of shape (B, T). args (Namespace): Contains arguments for decoding configuration. + extract_noise (bool): Whether to extract noise signal (default: False) Returns: - numpy.ndarray: The decoded audio output, which has been enhanced by the model. + If extract_noise is True: + tuple: (enhanced_audio, noise_audio) + else: + ndarray: enhanced_audio """ - decode_do_segment = False # Flag to determine if segmentation is needed - window = args.sampling_rate * args.decode_window # Decoding window length - stride = int(window * 0.75) # Decoding stride for segmenting the input - b, t = inputs.shape # Get batch size (b) and input length (t) + decode_do_segment = False + window = args.sampling_rate * args.decode_window + stride = int(window * 0.75) + b, t = inputs.shape # Check if input length exceeds one-time decode length to decide on segmentation if t > args.sampling_rate * args.one_time_decode_length: - decode_do_segment = True # Enable segment decoding for long sequences + decode_do_segment = True - # Convert inputs to a PyTorch tensor and move to the specified device + # Convert inputs to PyTorch tensor and compute normalization factor inputs = torch.from_numpy(np.float32(inputs)).to(device) - - # Compute normalization factor based on the input norm_factor = torch.sqrt(inputs.size(-1) / torch.sum((inputs ** 2.0), dim=-1)) + b, t = inputs.shape - b, t = inputs.shape # Update batch size and input length after conversion - - # Process the inputs in segments if necessary if decode_do_segment: - outputs = np.zeros(t) # Initialize the output array - give_up_length = (window - stride) // 2 # Calculate length to give up at each segment - current_idx = 0 # Initialize current index for segmentation + outputs = np.zeros(t) + if extract_noise: + noise_outputs = np.zeros(t) + give_up_length = (window - stride) // 2 + current_idx = 0 while current_idx + window <= t: - tmp_input = inputs[:, current_idx:current_idx + window] # Get segment input - tmp_output = _decode_one_audio_mossformergan_se_16k(model, device, tmp_input, norm_factor, args) # Inference on segment + tmp_input = inputs[:, current_idx:current_idx + window] + if extract_noise: + tmp_output, tmp_noise = _decode_one_audio_mossformergan_se_16k( + model, device, tmp_input, norm_factor, args, extract_noise) + else: + tmp_output = _decode_one_audio_mossformergan_se_16k( + model, device, tmp_input, norm_factor, args, extract_noise) - # For the first segment, use the whole segment minus the give-up length if current_idx == 0: outputs[current_idx:current_idx + window - give_up_length] = tmp_output[:-give_up_length] + if extract_noise: + noise_outputs[current_idx:current_idx + window - give_up_length] = tmp_noise[:-give_up_length] else: - # For subsequent segments, account for the give-up length - outputs[current_idx + give_up_length:current_idx + window - give_up_length] = tmp_output[give_up_length:-give_up_length] + outputs[current_idx + give_up_length:current_idx + window - give_up_length] = \ + tmp_output[give_up_length:-give_up_length] + if extract_noise: + noise_outputs[current_idx + give_up_length:current_idx + window - give_up_length] = \ + tmp_noise[give_up_length:-give_up_length] - current_idx += stride # Move to the next segment + current_idx += stride - return outputs # Return the accumulated outputs from segments + if extract_noise: + return outputs, noise_outputs + return outputs else: - # If no segmentation is required, process the entire input - return _decode_one_audio_mossformergan_se_16k(model, device, inputs, norm_factor, args) # Inference on full input + if extract_noise: + enhanced, noise = _decode_one_audio_mossformergan_se_16k( + model, device, inputs, norm_factor, args, extract_noise) + return enhanced, noise + else: + return _decode_one_audio_mossformergan_se_16k( + model, device, inputs, norm_factor, args, extract_noise) @torch.no_grad() -def _decode_one_audio_mossformergan_se_16k(model, device, inputs, norm_factor, args): +def _decode_one_audio_mossformergan_se_16k(model, device, inputs, norm_factor, args, extract_noise=False): """Processes audio inputs through the MossFormerGAN model for speech enhancement. This function performs the following steps: @@ -261,48 +270,58 @@ def _decode_one_audio_mossformergan_se_16k(model, device, inputs, norm_factor, a Args: model (nn.Module): The trained MossFormerGAN model used for decoding. device (torch.device): The device (CPU or GPU) for computation. - inputs (torch.Tensor): Input audio tensor of shape (B, T), where B is the batch size and T is the number of time steps. - norm_factor (torch.Tensor): A norm tensor to regularize input amplitude + inputs (torch.Tensor): Input audio tensor of shape (B, T). + norm_factor (torch.Tensor): A norm tensor to regularize input amplitude. args (Namespace): Contains arguments for STFT parameters and normalization. + extract_noise (bool): Whether to extract noise signal (default: False) Returns: - numpy.ndarray: The decoded audio output, which has been enhanced by the model. + If extract_noise is True: + tuple: (enhanced_audio, noise_audio) + else: + ndarray: enhanced_audio """ - input_len = inputs.size(-1) # Get the length of the input audio - nframe = int(np.ceil(input_len / args.win_inc)) # Calculate the number of frames based on window increment - padded_len = nframe * args.win_inc # Calculate the padded length to fit the model - padding_len = padded_len - input_len # Determine how much padding is needed + input_len = inputs.size(-1) + nframe = int(np.ceil(input_len / args.win_inc)) + padded_len = nframe * args.win_inc + padding_len = padded_len - input_len - # Pad the input audio with the beginning of the input - inputs = torch.cat([inputs, inputs[:, :padding_len]], dim=-1) + # Save original input for noise calculation if needed + original_inputs = inputs.clone() - # Prepare inputs for STFT by transposing and normalizing - inputs = torch.transpose(inputs, 0, 1) # Change shape for STFT - inputs = torch.transpose(inputs * norm_factor, 0, 1) # Apply normalization factor and transpose back + # Pad inputs + inputs = torch.cat([inputs, inputs[:, :padding_len]], dim=-1) + inputs = torch.transpose(inputs, 0, 1) + inputs = torch.transpose(inputs * norm_factor, 0, 1) - # Perform Short-Time Fourier Transform (STFT) on the normalized inputs + # Compute STFT inputs_spec = stft(inputs, args, center=True, periodic=True, onesided=True) - inputs_spec = inputs_spec.to(torch.float32) # Ensure the spectrogram is in float32 format + inputs_spec = inputs_spec.to(torch.float32) - # Compress the power of the spectrogram to improve model performance + # Compress power spectrum inputs_spec = power_compress(inputs_spec).permute(0, 1, 3, 2) - # Pass the compressed spectrogram through the model to get predicted real and imaginary parts + # Get model predictions out_list = model(inputs_spec) pred_real, pred_imag = out_list[0].permute(0, 1, 3, 2), out_list[1].permute(0, 1, 3, 2) - # Uncompress the predicted spectrogram to get the magnitude and phase + # Uncompress predicted spectrum pred_spec_uncompress = power_uncompress(pred_real, pred_imag).squeeze(1) - # Perform Inverse STFT (iSTFT) to convert back to time domain audio + # Reconstruct enhanced audio outputs = istft(pred_spec_uncompress, args, center=True, periodic=True, onesided=True) - - # Normalize the output audio by dividing by the normalization factor outputs = outputs.squeeze(0) / norm_factor + enhanced = outputs[:input_len] + + if extract_noise: + # Calculate noise signal + original = original_inputs.squeeze(0)[:input_len] + noise = original - enhanced + return enhanced.detach().cpu().numpy(), noise.detach().cpu().numpy() - return outputs[:input_len].detach().cpu().numpy() # Return the output as a numpy array + return enhanced.detach().cpu().numpy() -def decode_one_audio_mossformer2_se_48k(model, device, inputs, args): +def decode_one_audio_mossformer2_se_48k(model, device, inputs, args, extract_noise=False): """Processes audio inputs through the MossFormer2 model for speech enhancement at 48kHz. This function decodes audio input using the following steps: @@ -313,27 +332,30 @@ def decode_one_audio_mossformer2_se_48k(model, device, inputs, args): 5. Passes the filter banks through the model to get a predicted mask. 6. Applies the mask to the spectrogram of the audio segment and reconstructs the audio. 7. For shorter inputs, processes them in one go without segmentation. - + Args: model (nn.Module): The trained MossFormer2 model used for decoding. device (torch.device): The device (CPU or GPU) for computation. - inputs (torch.Tensor): Input audio tensor of shape (B, T), where B is the batch size and T is the number of time steps. + inputs (torch.Tensor): Input audio tensor of shape (B, T). args (Namespace): Contains arguments for sampling rate, window size, and other parameters. + extract_noise (bool): Whether to extract noise signal (default: False) Returns: - numpy.ndarray: The decoded audio output, normalized to the range [-1, 1]. + If extract_noise is True: + tuple: (enhanced_audio, noise_audio) + else: + ndarray: enhanced_audio normalized to [-1, 1] """ - inputs = inputs[0, :] # Extract the first element from the input tensor - input_len = inputs.shape[0] # Get the length of the input audio - inputs = inputs * MAX_WAV_VALUE # Normalize the input to the maximum WAV value + inputs = inputs[0, :] + input_len = inputs.shape[0] + inputs = inputs * MAX_WAV_VALUE - # Check if input length exceeds the defined threshold for online decoding if input_len > args.sampling_rate * args.one_time_decode_length: # 20 seconds online_decoding = True if online_decoding: - window = int(args.sampling_rate * args.decode_window) # Define window length (e.g., 4s for 48kHz) - stride = int(window * 0.75) # Define stride length (e.g., 3s for 48kHz) - t = inputs.shape[0] # Update length after potential padding + window = int(args.sampling_rate * args.decode_window) # Define window length + stride = int(window * 0.75) # Define stride length + t = inputs.shape[0] # Pad input if necessary to match window size if t < window: @@ -346,14 +368,16 @@ def decode_one_audio_mossformer2_se_48k(model, device, inputs, args): padding = t - (t - window) // stride * stride inputs = np.concatenate([inputs, np.zeros(padding)], 0) - audio = torch.from_numpy(inputs).type(torch.FloatTensor) # Convert to Torch tensor - t = audio.shape[0] # Update length after conversion - outputs = torch.from_numpy(np.zeros(t)) # Initialize output tensor - give_up_length = (window - stride) // 2 # Determine length to ignore at the edges - dfsmn_memory_length = 0 # Placeholder for potential memory length - current_idx = 0 # Initialize current index for sliding window + audio = torch.from_numpy(inputs).type(torch.FloatTensor) + t = audio.shape[0] + outputs = torch.from_numpy(np.zeros(t)) + if extract_noise: + noise_outputs = torch.from_numpy(np.zeros(t)) + + give_up_length = (window - stride) // 2 + dfsmn_memory_length = 0 + current_idx = 0 - # Process audio in sliding window segments while current_idx + window <= t: # Select appropriate segment of audio for processing if current_idx < dfsmn_memory_length: @@ -361,72 +385,91 @@ def decode_one_audio_mossformer2_se_48k(model, device, inputs, args): else: audio_segment = audio[current_idx - dfsmn_memory_length:current_idx + window] - # Compute filter banks for the audio segment + # Compute filter banks and their deltas fbanks = compute_fbank(audio_segment.unsqueeze(0), args) - - # Compute deltas for filter banks - fbank_tr = torch.transpose(fbanks, 0, 1) # Transpose for delta computation - fbank_delta = torchaudio.functional.compute_deltas(fbank_tr) # First-order delta - fbank_delta_delta = torchaudio.functional.compute_deltas(fbank_delta) # Second-order delta - - # Transpose back to original shape + fbank_tr = torch.transpose(fbanks, 0, 1) + fbank_delta = torchaudio.functional.compute_deltas(fbank_tr) + fbank_delta_delta = torchaudio.functional.compute_deltas(fbank_delta) fbank_delta = torch.transpose(fbank_delta, 0, 1) fbank_delta_delta = torch.transpose(fbank_delta_delta, 0, 1) - - # Concatenate the original filter banks with their deltas + fbanks = torch.cat([fbanks, fbank_delta, fbank_delta_delta], dim=1) - fbanks = fbanks.unsqueeze(0).to(device) # Add batch dimension and move to device + fbanks = fbanks.unsqueeze(0).to(device) - # Pass filter banks through the model + # Model inference Out_List = model(fbanks) - pred_mask = Out_List[-1] # Get the predicted mask from the output - - # Apply STFT to the audio segment + pred_mask = Out_List[-1] spectrum = stft(audio_segment, args) - pred_mask = pred_mask.permute(2, 1, 0) # Permute dimensions for masking - masked_spec = spectrum.cpu() * pred_mask.detach().cpu() # Apply mask to the spectrum - masked_spec_complex = masked_spec[:, :, 0] + 1j * masked_spec[:, :, 1] # Convert to complex form + pred_mask = pred_mask.permute(2, 1, 0) - # Reconstruct audio from the masked spectrogram + # Process enhanced audio + masked_spec = spectrum.cpu() * pred_mask.detach().cpu() + masked_spec_complex = masked_spec[:, :, 0] + 1j * masked_spec[:, :, 1] output_segment = istft(masked_spec_complex, args, len(audio_segment)) - # Store the output segment in the output tensor + # Process noise if requested + if extract_noise: + noise_mask = 1 - pred_mask.detach().cpu() + noise_spec = spectrum.cpu() * noise_mask + noise_spec_complex = noise_spec[:, :, 0] + 1j * noise_spec[:, :, 1] + noise_segment = istft(noise_spec_complex, args, len(audio_segment)) + + # Store results if current_idx == 0: outputs[current_idx:current_idx + window - give_up_length] = output_segment[:-give_up_length] + if extract_noise: + noise_outputs[current_idx:current_idx + window - give_up_length] = noise_segment[:-give_up_length] else: - output_segment = output_segment[-window:] # Get the latest window of output - outputs[current_idx + give_up_length:current_idx + window - give_up_length] = output_segment[give_up_length:-give_up_length] - - current_idx += stride # Move to the next segment + output_segment = output_segment[-window:] + outputs[current_idx + give_up_length:current_idx + window - give_up_length] = \ + output_segment[give_up_length:-give_up_length] + if extract_noise: + noise_segment = noise_segment[-window:] + noise_outputs[current_idx + give_up_length:current_idx + window - give_up_length] = \ + noise_segment[give_up_length:-give_up_length] + + current_idx += stride + + if extract_noise: + return outputs.numpy() / MAX_WAV_VALUE, noise_outputs.numpy() / MAX_WAV_VALUE + return outputs.numpy() / MAX_WAV_VALUE else: - # Process the entire audio at once if it is shorter than the threshold + # Process shorter audio in one go audio = torch.from_numpy(inputs).type(torch.FloatTensor) + + # Compute filter banks and their deltas fbanks = compute_fbank(audio.unsqueeze(0), args) - - # Compute deltas for filter banks fbank_tr = torch.transpose(fbanks, 0, 1) fbank_delta = torchaudio.functional.compute_deltas(fbank_tr) fbank_delta_delta = torchaudio.functional.compute_deltas(fbank_delta) fbank_delta = torch.transpose(fbank_delta, 0, 1) fbank_delta_delta = torch.transpose(fbank_delta_delta, 0, 1) - - # Concatenate the original filter banks with their deltas + fbanks = torch.cat([fbanks, fbank_delta, fbank_delta_delta], dim=1) - fbanks = fbanks.unsqueeze(0).to(device) # Add batch dimension and move to device + fbanks = fbanks.unsqueeze(0).to(device) - # Pass filter banks through the model + # Model inference Out_List = model(fbanks) - pred_mask = Out_List[-1] # Get the predicted mask - spectrum = stft(audio, args) # Apply STFT to the audio - pred_mask = pred_mask.permute(2, 1, 0) # Permute dimensions for masking - masked_spec = spectrum * pred_mask.detach().cpu() # Apply mask to the spectrum - masked_spec_complex = masked_spec[:, :, 0] + 1j * masked_spec[:, :, 1] # Convert to complex form - - # Reconstruct audio from the masked spectrogram - outputs = istft(masked_spec_complex, args, len(audio)) + pred_mask = Out_List[-1] + spectrum = stft(audio, args) + pred_mask = pred_mask.permute(2, 1, 0) + + # Process enhanced audio + masked_spec = spectrum * pred_mask.detach().cpu() + masked_spec_complex = masked_spec[:, :, 0] + 1j * masked_spec[:, :, 1] + enhanced = istft(masked_spec_complex, args, len(audio)) + + if extract_noise: + # Calculate noise signal + noise_mask = 1 - pred_mask.detach().cpu() + noise_spec = spectrum * noise_mask + noise_spec_complex = noise_spec[:, :, 0] + 1j * noise_spec[:, :, 1] + noise = istft(noise_spec_complex, args, len(audio)) + + return enhanced.numpy() / MAX_WAV_VALUE, noise.numpy() / MAX_WAV_VALUE - return outputs.numpy() / MAX_WAV_VALUE # Return the output normalized to [-1, 1] + return enhanced.numpy() / MAX_WAV_VALUE def decode_one_audio_AV_MossFormer2_TSE_16K(model, inputs, args): """Processes video inputs through the AV mossformer2 model with Target speaker extraction (TSE) for decoding at 16kHz.