Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve the TTS model loading process #22

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 30 additions & 17 deletions demo/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
import streamlit as st
from huggingface_hub import list_repo_files

from opennotebookllm.podcast_maker.config import SpeakerConfig, PodcastConfig
from opennotebookllm.podcast_maker.config import PodcastConfig, SpeakerConfig
from opennotebookllm.preprocessing import DATA_LOADERS, DATA_CLEANERS
from opennotebookllm.inference.model_loaders import load_llama_cpp_model
from opennotebookllm.inference.model_loaders import (
load_llama_cpp_model,
load_parler_tts_model_and_tokenizer,
)
from opennotebookllm.inference.text_to_text import text_to_text_stream
from opennotebookllm.podcast_maker.script_to_audio import (
parse_script_to_waveform,
Expand All @@ -28,19 +31,6 @@
"Jon's voice is calm with very clear audio and no background noise."
)

speaker_1 = SpeakerConfig(
model_id="parler-tts/parler-tts-mini-v1",
speaker_id="1",
speaker_description=speaker_1_description,
)
speaker_2 = SpeakerConfig(
model_id="parler-tts/parler-tts-mini-v1",
speaker_id="2",
speaker_description=speaker_2_description,
)
speakers = {s.speaker_id: s for s in [speaker_1, speaker_2]}
sample_pod_config = PodcastConfig(speakers=speakers)

CURATED_REPOS = [
"allenai/OLMoE-1B-7B-0924-Instruct-GGUF",
"MaziyarPanahi/SmolLM2-1.7B-Instruct-GGUF",
Expand Down Expand Up @@ -110,9 +100,32 @@

if st.button("Generate Audio"):
filename = "demo_podcast.wav"

with st.spinner("Downloading and Loading TTS Model..."):
model, tokenizer = load_parler_tts_model_and_tokenizer(
"parler-tts/parler-tts-mini-v1", "cpu"
)
speaker_1 = SpeakerConfig(
model=model,
speaker_id="1",
tokenizer=tokenizer,
speaker_description=speaker_1_description,
)
speaker_2 = SpeakerConfig(
model=model,
speaker_id="2",
tokenizer=tokenizer,
speaker_description=speaker_2_description,
)
demo_podcast_config = PodcastConfig(
speakers={s.speaker_id: s for s in [speaker_1, speaker_2]}
)

with st.spinner("Generating Audio..."):
waveform = parse_script_to_waveform(final_script, sample_pod_config)
waveform = parse_script_to_waveform(
final_script, demo_podcast_config
)
save_waveform_as_file(
waveform, sample_pod_config.sampling_rate, filename
waveform, demo_podcast_config.sampling_rate, filename
)
st.audio(filename)
2 changes: 1 addition & 1 deletion src/opennotebookllm/inference/model_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def load_parler_tts_model_and_tokenizer(
Loads the given model_id using parler_tts.from_pretrained.

Examples:
>>> model = load_parler_tts_model_and_tokenizer("parler-tts/parler-tts-mini-v1", "cpu")
>>> model, tokenizer = load_parler_tts_model_and_tokenizer("parler-tts/parler-tts-mini-v1", "cpu")

Args:
model_id (str): The model id to load.
Expand Down
39 changes: 22 additions & 17 deletions src/opennotebookllm/inference/text_to_speech.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,44 @@
import numpy as np
from opennotebookllm.inference.model_loaders import (
load_parler_tts_model_and_tokenizer,
)
from opennotebookllm.podcast_maker.config import SpeakerConfig
from transformers import PreTrainedModel, PreTrainedTokenizerBase


def _speech_generation_parler(input_text: str, tts_config: SpeakerConfig) -> np.ndarray:
model, tokenizer = load_parler_tts_model_and_tokenizer(tts_config.model_id)

def _speech_generation_parler(
input_text: str,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizerBase,
speaker_description: str,
) -> np.ndarray:
input_ids = tokenizer(speaker_description, return_tensors="pt").input_ids
prompt_input_ids = tokenizer(input_text, return_tensors="pt").input_ids
input_ids = tokenizer(tts_config.speaker_description, return_tensors="pt").input_ids

generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
waveform = generation.cpu().numpy().squeeze()

return waveform


def text_to_speech(input_text: str, tts_config: SpeakerConfig) -> np.ndarray:
def text_to_speech(
input_text: str,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizerBase,
speaker_profile: str,
) -> np.ndarray:
"""
Generates a speech waveform using the input_text and a speaker configuration that defines which model to use and its parameters.
Generates a speech waveform using the input_text, a model and a speaker profile to define a distinct voice pattern.

Examples:
>>> waveform = text_to_speech("Welcome to our amazing podcast", "parler-tts/parler-tts-mini-v1", "Laura's voice is exciting and fast in delivery with very clear audio and no background noise.")

Args:
input_text (str): The text to convert to speech.
tts_config: Configuration parameters for TTS model.

model (PreTrainedModel): The model used for generating the waveform.
tokenizer (PreTrainedTokenizerBase): The tokenizer used for tokenizing the text in order to send to the model.
speaker_profile (str): A description used by the ParlerTTS model to configure the speaker profile.
Returns:
numpy array: The waveform of the speech as a 2D numpy array
"""
if "parler" in tts_config.model_id:
return _speech_generation_parler(input_text, tts_config)
model_id = model.config.name_or_path
if "parler" in model_id:
return _speech_generation_parler(input_text, model, tokenizer, speaker_profile)
else:
raise NotImplementedError(
f"Model {tts_config.model_id} not yet implemented for TTS"
)
raise NotImplementedError(f"Model {model_id} not yet implemented for TTS")
10 changes: 7 additions & 3 deletions src/opennotebookllm/podcast_maker/config.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from typing import Dict, Optional

from pydantic import BaseModel
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from pydantic import BaseModel, ConfigDict


class SpeakerConfig(BaseModel):
model_id: str
model_config = ConfigDict(arbitrary_types_allowed=True)

model: PreTrainedModel
speaker_id: str
# ParlerTTS specific configuration
tokenizer: Optional[PreTrainedTokenizerBase] = None
speaker_description: Optional[str] = (
None # This description is used by the ParlerTTS model to configure the speaker profile
)
Expand Down
34 changes: 29 additions & 5 deletions src/opennotebookllm/podcast_maker/script_to_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

import numpy as np

from demo.app import sample_pod_config
from opennotebookllm.inference.model_loaders import load_parler_tts_model_and_tokenizer
from opennotebookllm.inference.text_to_speech import text_to_speech

from opennotebookllm.podcast_maker.config import PodcastConfig
from opennotebookllm.podcast_maker.config import PodcastConfig, SpeakerConfig


def parse_script_to_waveform(script: str, podcast_config: PodcastConfig):
Expand All @@ -24,8 +24,14 @@ def parse_script_to_waveform(script: str, podcast_config: PodcastConfig):
for part in parts:
if ":" in part:
speaker_id, speaker_text = part.split(":")
speaker_model = podcast_config.speakers[speaker_id].model
speaker_tokenizer = podcast_config.speakers[speaker_id].tokenizer
speaker_description = podcast_config.speakers[
speaker_id
].speaker_description

speaker_waveform = text_to_speech(
speaker_text, podcast_config.speakers[speaker_id]
speaker_text, speaker_model, speaker_tokenizer, speaker_description
)
podcast_waveform.append(speaker_waveform)

Expand All @@ -48,12 +54,30 @@ def save_waveform_as_file(
"Speaker 1: Welcome to our podcast. Speaker 2: It's great to be here!"
)

model, tokenizer = load_parler_tts_model_and_tokenizer(
"parler-tts/parler-tts-mini-v1", "cpu"
)
speaker_1 = SpeakerConfig(
model=model,
speaker_id="1",
tokenizer=tokenizer,
speaker_description="Laura's voice is exciting and fast in delivery with very clear audio and no background noise.",
)
speaker_2 = SpeakerConfig(
model=model,
speaker_id="2",
tokenizer=tokenizer,
speaker_description="Jon's voice is calm with very clear audio and no background noise.",
)
demo_podcast_config = PodcastConfig(
speakers={s.speaker_id: s for s in [speaker_1, speaker_2]}
)
test_podcast_waveform = parse_script_to_waveform(
test_podcast_script, sample_pod_config
test_podcast_script, demo_podcast_config
)

save_waveform_as_file(
test_podcast_waveform,
sampling_rate=sample_pod_config.sampling_rate,
sampling_rate=demo_podcast_config.sampling_rate,
filename=test_filename,
)
17 changes: 13 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@

import pytest

from opennotebookllm.inference.model_loaders import load_parler_tts_model_and_tokenizer
from opennotebookllm.podcast_maker.config import (
PodcastConfig,
SpeakerConfig,
speaker_1_description,
speaker_2_description,
)


Expand All @@ -27,14 +26,24 @@ def podcast_script():

@pytest.fixture()
def podcast_config():
speaker_1_description = "Laura's voice is exciting and fast in delivery with very clear audio and no background noise."
speaker_2_description = (
"Jon's voice is calm with very clear audio and no background noise."
)

model, tokenizer = load_parler_tts_model_and_tokenizer(
"parler-tts/parler-tts-mini-v1", "cpu"
)
speaker_1 = SpeakerConfig(
model_id="parler-tts/parler-tts-mini-v1",
model=model,
speaker_id="1",
tokenizer=tokenizer,
speaker_description=speaker_1_description,
)
speaker_2 = SpeakerConfig(
model_id="parler-tts/parler-tts-mini-v1",
model=model,
speaker_id="2",
tokenizer=tokenizer,
speaker_description=speaker_2_description,
)
speakers = {s.speaker_id: s for s in [speaker_1, speaker_2]}
Expand Down
9 changes: 7 additions & 2 deletions tests/integration/test_text_to_text_to_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,13 @@ def test_text_to_text_to_speech(tmp_path: Path, podcast_config: PodcastConfig):
stop=".",
)

speaker_config = list(podcast_config.speakers.values())[0]
waveform = text_to_speech(input_text=result, tts_config=speaker_config)
speaker_cfg = list(podcast_config.speakers.values())[0]
waveform = text_to_speech(
input_text=result,
model=speaker_cfg.model,
tokenizer=speaker_cfg.tokenizer,
speaker_profile=speaker_cfg.speaker_description,
)

filename = str(tmp_path / "test_text_to_text_to_speech_parler.wav")
save_waveform_as_file(
Expand Down
7 changes: 6 additions & 1 deletion tests/unit/inference/test_text_to_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,13 @@
def test_text_to_speech_parler(
tmp_path: Path, tts_prompt: str, podcast_config: PodcastConfig
):
speaker_cfg = list(podcast_config.speakers.values())[0]

waveform = text_to_speech(
input_text=tts_prompt, tts_config=list(podcast_config.speakers.values())[0]
tts_prompt,
speaker_cfg.model,
speaker_cfg.tokenizer,
speaker_cfg.speaker_description,
)

save_waveform_as_file(
Expand Down