Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Only noise as a result #253

Open
mpastewski opened this issue Nov 22, 2023 · 4 comments
Open

Only noise as a result #253

mpastewski opened this issue Nov 22, 2023 · 4 comments

Comments

@mpastewski
Copy link

mpastewski commented Nov 22, 2023

I'm new to the TTS topic. I'm trying to replicate your work, and when using Encodec, I am getting just noise instead of voice. I didn't run it with SoundStream, as I don't know where I can find the checkpoint file available on the internet.

Please help me get started. Below there is the code that I'm testing on.

Also, if you have some working examples of scripts that use audiolm, that would be a great help for people new to the topic.

import numpy as np
import soundfile as sf
import torchaudio
from audiolm_pytorch import (
    AudioLM, HubertWithKmeans, SemanticTransformer, CoarseTransformer, FineTransformer,
    SemanticTransformerTrainer, CoarseTransformerTrainer, FineTransformerTrainer,
    SoundStream, SoundStreamTrainer, EncodecWrapper
)

# Step 1: Setup the neural codec (SoundStream or Encodec)
# Option 1: Using Encodec
encodec = EncodecWrapper()

# Option 2: Using SoundStream
#soundstream = SoundStream(
#    codebook_size = 4096,
#    rq_num_quantizers = 8,
#    rq_groups = 2,
#    use_lookup_free_quantizer = False,
#    use_finite_scalar_quantizer = True,
#    attn_window_size = 128,
#    attn_depth = 2
#)

# Assuming you have trained SoundStream, load it as follows:
#soundstream = SoundStream.init_and_load_from('./path/to/checkpoint.pt')

# Step 2: Setup and train the Hierarchical Transformers
# Semantic Transformer
wav2vec = HubertWithKmeans(
    checkpoint_path = './hubert/hubert_base_ls960.pt',
    kmeans_path = './hubert/hubert_base_ls960_L9_km500.bin'
)

semantic_transformer = SemanticTransformer(
    num_semantic_tokens = wav2vec.codebook_size,
    dim = 1024,
    depth = 6,
    flash_attn = True
).cuda()

# Coarse Transformer
coarse_transformer = CoarseTransformer(
    num_semantic_tokens = wav2vec.codebook_size,
    codebook_size = 1024,
    num_coarse_quantizers = 3,
    dim = 512,
    depth = 6,
    flash_attn = True
)

# Fine Transformer
fine_transformer = FineTransformer(
    num_coarse_quantizers = 3,
    num_fine_quantizers = 5,
    codebook_size = 1024,
    dim = 512,
    depth = 6,
    flash_attn = True
)

# Training steps for each transformer are omitted for brevity
# You need to train each transformer using their respective trainers as per your dataset and requirements

audiolm = AudioLM(
    wav2vec = wav2vec,
    codec = encodec,
    semantic_transformer = semantic_transformer,
    coarse_transformer = coarse_transformer,
    fine_transformer = fine_transformer
).cuda()

# Generate audio with text condition
text_to_generate = "My name is Bond, James Bond"
generated_wav_with_text_condition = audiolm(text = [text_to_generate]).cuda()

# Convert the tensor to numpy array and move it to CPU from GPU if necessary
audio_data = generated_wav_with_text_condition.cpu().numpy()

# Ensure the audio data is in the correct shape and range
# If the audio is mono, ensure it has two dimensions (n_frames, 1)
if len(audio_data.shape) == 1:
    audio_data = audio_data.reshape(-1, 1)

# Normalize audio to the range [-1.0, 1.0] if it's not already
audio_data = audio_data / np.max(np.abs(audio_data))

# Convert the numpy array back to a PyTorch tensor
tensor_audio_data = torch.from_numpy(audio_data)

# Save the audio data to a file using torchaudio
output_file_path = 'bond_audio.wav'
torchaudio.save(output_file_path, tensor_audio_data, 16000)```
@HanJu-Chen
Copy link

Hi, I'm new to AudioLM. I am trying to run the code. May I ask how you put three transformers together? I tried to train three transformers separately to get the pt file. But I am not quite sure how to load it.

@my-yy
Copy link

my-yy commented Nov 27, 2023

Similar issue:
I trained on the LibriTTS training dataset with a batch size of 64 and obtained the following models:

semantic.transformer.405000.pt
coarse.transformer.683000.pt
fine.transformer.828000.pt
However, the testing results do not sound like human speech. It seems that the speech rate is very fast.

import argparse
import ast
from audiolm_pytorch import SemanticTransformer, SemanticTransformerTrainer
from audiolm_pytorch import HubertWithKmeans, CoarseTransformer, CoarseTransformerTrainer
from audiolm_pytorch import FineTransformer, FineTransformerTrainer
from utils import seed_util, my_parser
from audiolm_pytorch import EncodecWrapper
import torch
import soundfile
from torchaudio.transforms import Resample
from audiolm_pytorch import AudioLM
import torchaudio

the_path = "/zhangpai21/webdataset/audio/splitspeechdatasets/LibriTTS-R/train/tar"

wav2vec = HubertWithKmeans(
    checkpoint_path='./hubert/hubert_base_ls960.pt',
    kmeans_path='./hubert/hubert_base_ls960_L9_km500.bin'
)

encodec = EncodecWrapper()

semantic_transformer = SemanticTransformer(
    num_semantic_tokens=wav2vec.codebook_size,
    dim=1024,
    depth=6,
    flash_attn=False
)

coarse_transformer = CoarseTransformer(
    num_semantic_tokens=wav2vec.codebook_size,
    codebook_size=1024,
    num_coarse_quantizers=3,
    dim=512,
    depth=6,
    flash_attn=False
)
fine_transformer = FineTransformer(
    num_coarse_quantizers=3,
    num_fine_quantizers=5,
    codebook_size=1024,
    dim=512,
    depth=6,
    flash_attn=True
)


def train_semantic():
    trainer = SemanticTransformerTrainer(
        transformer=semantic_transformer,
        wav2vec=wav2vec,
        folder=the_path,
        batch_size=64,
        data_max_length=320 * 32,
        num_train_steps=1_000_000,
        results_folder='./results',
    )

    trainer.train()


def train_coarse():
    trainer = CoarseTransformerTrainer(
        transformer=coarse_transformer,
        codec=encodec,
        wav2vec=wav2vec,
        folder=the_path,
        batch_size=64,
        data_max_length=320 * 32,
        num_train_steps=1_000_000,
        results_folder='./results_coarse',
    )

    trainer.train()


def train_fine():
    trainer = FineTransformerTrainer(
        transformer=fine_transformer,
        codec=encodec,
        folder=the_path,
        batch_size=64,
        data_max_length=320 * 32,
        num_train_steps=1_000_000,
        results_folder='./results_fine',
    )

    trainer.train()


def get_real_semantic_ids():
    audio_path = "/zhangpai21/workspace/cgy/1_projects/1_valle/p1/examples/libritts/prompts/8463_294825_000043_000000.wav"
    data, sample_hz = torchaudio.load(audio_path)
    target_sample_rate = 16000
    resample_transform = Resample(orig_freq=sample_hz, new_freq=target_sample_rate)
    data = resample_transform(data)
    sematic_token_ids = wav2vec(data)
    return sematic_token_ids


def do_inference():
    semantic_transformer.load("results/semantic.transformer.405000.pt")
    coarse_transformer.load("results_coarse/coarse.transformer.683000.pt")
    fine_transformer.load("results_fine/fine.transformer.828000.pt")

    audiolm = AudioLM(
        wav2vec=wav2vec,
        codec=encodec,
        semantic_transformer=semantic_transformer,
        coarse_transformer=coarse_transformer,
        fine_transformer=fine_transformer
    ).cuda()

    generated_wav = audiolm(batch_size=1)
    # tensor([[-0.0101, -0.0111, -0.0039,  ..., -0.0439, -0.1621,  0.0146]])
    wav = generated_wav[0].detach().cpu().numpy()
    # soundfile.write("output.wav", wav, samplerate=16000)
    soundfile.write("output.wav", wav, samplerate=24000)




if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--type", default="coarse")
    args = parser.parse_args()

    if args.type == "semantic":
        train_semantic()
    elif args.type == "coarse":
        train_coarse()
    elif args.type == "fine":
        train_fine()
    elif args.type == "infer":
        do_inference()

@drishyakarki
Copy link

drishyakarki commented Dec 7, 2023

I'm new to the TTS topic. I'm trying to replicate your work, and when using Encodec, I am getting just noise instead of voice. I didn't run it with SoundStream, as I don't know where I can find the checkpoint file available on the internet.

Please help me get started. Below there is the code that I'm testing on.

Also, if you have some working examples of scripts that use audiolm, that would be a great help for people new to the topic.

import numpy as np
import soundfile as sf
import torchaudio
from audiolm_pytorch import (
    AudioLM, HubertWithKmeans, SemanticTransformer, CoarseTransformer, FineTransformer,
    SemanticTransformerTrainer, CoarseTransformerTrainer, FineTransformerTrainer,
    SoundStream, SoundStreamTrainer, EncodecWrapper
)

# Step 1: Setup the neural codec (SoundStream or Encodec)
# Option 1: Using Encodec
encodec = EncodecWrapper()

# Option 2: Using SoundStream
#soundstream = SoundStream(
#    codebook_size = 4096,
#    rq_num_quantizers = 8,
#    rq_groups = 2,
#    use_lookup_free_quantizer = False,
#    use_finite_scalar_quantizer = True,
#    attn_window_size = 128,
#    attn_depth = 2
#)

# Assuming you have trained SoundStream, load it as follows:
#soundstream = SoundStream.init_and_load_from('./path/to/checkpoint.pt')

# Step 2: Setup and train the Hierarchical Transformers
# Semantic Transformer
wav2vec = HubertWithKmeans(
    checkpoint_path = './hubert/hubert_base_ls960.pt',
    kmeans_path = './hubert/hubert_base_ls960_L9_km500.bin'
)

semantic_transformer = SemanticTransformer(
    num_semantic_tokens = wav2vec.codebook_size,
    dim = 1024,
    depth = 6,
    flash_attn = True
).cuda()

# Coarse Transformer
coarse_transformer = CoarseTransformer(
    num_semantic_tokens = wav2vec.codebook_size,
    codebook_size = 1024,
    num_coarse_quantizers = 3,
    dim = 512,
    depth = 6,
    flash_attn = True
)

# Fine Transformer
fine_transformer = FineTransformer(
    num_coarse_quantizers = 3,
    num_fine_quantizers = 5,
    codebook_size = 1024,
    dim = 512,
    depth = 6,
    flash_attn = True
)

# Training steps for each transformer are omitted for brevity
# You need to train each transformer using their respective trainers as per your dataset and requirements

audiolm = AudioLM(
    wav2vec = wav2vec,
    codec = encodec,
    semantic_transformer = semantic_transformer,
    coarse_transformer = coarse_transformer,
    fine_transformer = fine_transformer
).cuda()

# Generate audio with text condition
text_to_generate = "My name is Bond, James Bond"
generated_wav_with_text_condition = audiolm(text = [text_to_generate]).cuda()

# Convert the tensor to numpy array and move it to CPU from GPU if necessary
audio_data = generated_wav_with_text_condition.cpu().numpy()

# Ensure the audio data is in the correct shape and range
# If the audio is mono, ensure it has two dimensions (n_frames, 1)
if len(audio_data.shape) == 1:
    audio_data = audio_data.reshape(-1, 1)

# Normalize audio to the range [-1.0, 1.0] if it's not already
audio_data = audio_data / np.max(np.abs(audio_data))

# Convert the numpy array back to a PyTorch tensor
tensor_audio_data = torch.from_numpy(audio_data)

# Save the audio data to a file using torchaudio
output_file_path = 'bond_audio.wav'
torchaudio.save(output_file_path, tensor_audio_data, 16000)```

I think you need to train soundstream seperately. Here you can try the notebook

@lcc-404
Copy link

lcc-404 commented Mar 21, 2024

Similar issue: I trained on the LibriTTS training dataset with a batch size of 64 and obtained the following models:

semantic.transformer.405000.pt coarse.transformer.683000.pt fine.transformer.828000.pt However, the testing results do not sound like human speech. It seems that the speech rate is very fast.

import argparse
import ast
from audiolm_pytorch import SemanticTransformer, SemanticTransformerTrainer
from audiolm_pytorch import HubertWithKmeans, CoarseTransformer, CoarseTransformerTrainer
from audiolm_pytorch import FineTransformer, FineTransformerTrainer
from utils import seed_util, my_parser
from audiolm_pytorch import EncodecWrapper
import torch
import soundfile
from torchaudio.transforms import Resample
from audiolm_pytorch import AudioLM
import torchaudio

the_path = "/zhangpai21/webdataset/audio/splitspeechdatasets/LibriTTS-R/train/tar"

wav2vec = HubertWithKmeans(
    checkpoint_path='./hubert/hubert_base_ls960.pt',
    kmeans_path='./hubert/hubert_base_ls960_L9_km500.bin'
)

encodec = EncodecWrapper()

semantic_transformer = SemanticTransformer(
    num_semantic_tokens=wav2vec.codebook_size,
    dim=1024,
    depth=6,
    flash_attn=False
)

coarse_transformer = CoarseTransformer(
    num_semantic_tokens=wav2vec.codebook_size,
    codebook_size=1024,
    num_coarse_quantizers=3,
    dim=512,
    depth=6,
    flash_attn=False
)
fine_transformer = FineTransformer(
    num_coarse_quantizers=3,
    num_fine_quantizers=5,
    codebook_size=1024,
    dim=512,
    depth=6,
    flash_attn=True
)


def train_semantic():
    trainer = SemanticTransformerTrainer(
        transformer=semantic_transformer,
        wav2vec=wav2vec,
        folder=the_path,
        batch_size=64,
        data_max_length=320 * 32,
        num_train_steps=1_000_000,
        results_folder='./results',
    )

    trainer.train()


def train_coarse():
    trainer = CoarseTransformerTrainer(
        transformer=coarse_transformer,
        codec=encodec,
        wav2vec=wav2vec,
        folder=the_path,
        batch_size=64,
        data_max_length=320 * 32,
        num_train_steps=1_000_000,
        results_folder='./results_coarse',
    )

    trainer.train()


def train_fine():
    trainer = FineTransformerTrainer(
        transformer=fine_transformer,
        codec=encodec,
        folder=the_path,
        batch_size=64,
        data_max_length=320 * 32,
        num_train_steps=1_000_000,
        results_folder='./results_fine',
    )

    trainer.train()


def get_real_semantic_ids():
    audio_path = "/zhangpai21/workspace/cgy/1_projects/1_valle/p1/examples/libritts/prompts/8463_294825_000043_000000.wav"
    data, sample_hz = torchaudio.load(audio_path)
    target_sample_rate = 16000
    resample_transform = Resample(orig_freq=sample_hz, new_freq=target_sample_rate)
    data = resample_transform(data)
    sematic_token_ids = wav2vec(data)
    return sematic_token_ids


def do_inference():
    semantic_transformer.load("results/semantic.transformer.405000.pt")
    coarse_transformer.load("results_coarse/coarse.transformer.683000.pt")
    fine_transformer.load("results_fine/fine.transformer.828000.pt")

    audiolm = AudioLM(
        wav2vec=wav2vec,
        codec=encodec,
        semantic_transformer=semantic_transformer,
        coarse_transformer=coarse_transformer,
        fine_transformer=fine_transformer
    ).cuda()

    generated_wav = audiolm(batch_size=1)
    # tensor([[-0.0101, -0.0111, -0.0039,  ..., -0.0439, -0.1621,  0.0146]])
    wav = generated_wav[0].detach().cpu().numpy()
    # soundfile.write("output.wav", wav, samplerate=16000)
    soundfile.write("output.wav", wav, samplerate=24000)




if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--type", default="coarse")
    args = parser.parse_args()

    if args.type == "semantic":
        train_semantic()
    elif args.type == "coarse":
        train_coarse()
    elif args.type == "fine":
        train_fine()
    elif args.type == "infer":
        do_inference()

Hi! I meet the same problem as you,the generated wav sounds like sped up...Have you solved this problem?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants