Skip to content

Commit

Permalink
allow for passing in custom mel spec module (#200)
Browse files Browse the repository at this point in the history
  • Loading branch information
chenht2021 authored Oct 21, 2024
1 parent 25cdc51 commit 795cb19
Showing 1 changed file with 17 additions and 3 deletions.
20 changes: 17 additions & 3 deletions model/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
import torchaudio
from datasets import load_from_disk
from datasets import Dataset as Dataset_
from torch import nn

from model.modules import MelSpec
from model.utils import default


class HFDataset(Dataset):
Expand Down Expand Up @@ -77,15 +79,22 @@ def __init__(
hop_length=256,
n_mel_channels=100,
preprocessed_mel=False,
mel_spec_module: nn.Module | None = None,
):
self.data = custom_dataset
self.durations = durations
self.target_sample_rate = target_sample_rate
self.hop_length = hop_length
self.preprocessed_mel = preprocessed_mel

if not preprocessed_mel:
self.mel_spectrogram = MelSpec(
target_sample_rate=target_sample_rate, hop_length=hop_length, n_mel_channels=n_mel_channels
self.mel_spectrogram = default(
mel_spec_module,
MelSpec(
target_sample_rate=target_sample_rate,
hop_length=hop_length,
n_mel_channels=n_mel_channels,
),
)

def get_frame_len(self, index):
Expand Down Expand Up @@ -201,6 +210,7 @@ def load_dataset(
tokenizer: str = "pinyin",
dataset_type: str = "CustomDataset",
audio_type: str = "raw",
mel_spec_module: nn.Module | None = None,
mel_spec_kwargs: dict = dict(),
) -> CustomDataset | HFDataset:
"""
Expand All @@ -224,7 +234,11 @@ def load_dataset(
data_dict = json.load(f)
durations = data_dict["duration"]
train_dataset = CustomDataset(
train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs
train_dataset,
durations=durations,
preprocessed_mel=preprocessed_mel,
mel_spec_module=mel_spec_module,
**mel_spec_kwargs,
)

elif dataset_type == "CustomDatasetPath":
Expand Down

0 comments on commit 795cb19

Please sign in to comment.