Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TTS component #19

Merged
merged 45 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
ae59d27
Add .idea (PyCharm files) to gitignore
Kostis-S-Z Nov 26, 2024
8f01d75
Add Audio generation section to demo
Kostis-S-Z Nov 26, 2024
fd84d63
Add parler TTS model loader
Kostis-S-Z Nov 26, 2024
1c8093e
[WIP] Add script to podcast parser
Kostis-S-Z Nov 26, 2024
d567ed5
[WIP] Add text to speech code
Kostis-S-Z Nov 26, 2024
12178c8
[WIP] Add simple unit tests
Kostis-S-Z Nov 26, 2024
12a622f
Add .wav files to gitignore
Kostis-S-Z Nov 26, 2024
4084db7
Fix tiny typo in docs
Kostis-S-Z Nov 26, 2024
d43334e
Update default sampling rate
Kostis-S-Z Nov 27, 2024
936f433
Remove outdated fixture
Kostis-S-Z Nov 27, 2024
3290a01
Add podcast config fixture
Kostis-S-Z Nov 27, 2024
62df818
Update return type in TTS model loader
Kostis-S-Z Nov 27, 2024
89fb11e
Update TTS code to use pydantic Config
Kostis-S-Z Nov 27, 2024
f91caf0
Update tests
Kostis-S-Z Nov 27, 2024
ebb261c
Add pydantic config for TTS component
Kostis-S-Z Nov 27, 2024
61c11cb
Update tests
Kostis-S-Z Nov 27, 2024
51afcf1
Use tmp_path in tests to autoremove generated wav files
Kostis-S-Z Nov 27, 2024
98a1f33
Update prompt fixture
Kostis-S-Z Nov 27, 2024
bf6abae
Fix package imports
Kostis-S-Z Nov 27, 2024
e64aa88
Update comment docs
Kostis-S-Z Nov 27, 2024
cf32993
Rewrite TTS model loading
Kostis-S-Z Nov 27, 2024
0fd01b8
Update imports in app & set sample constants podcast config
Kostis-S-Z Nov 28, 2024
a0eec86
Add pydantic in project requirements
Kostis-S-Z Nov 28, 2024
04033fc
Use Python's wave to save audio file instead of scipy
Kostis-S-Z Nov 28, 2024
610ac9a
Update from 6-audio-generation-component
Kostis-S-Z Nov 28, 2024
f213613
Update comment
Kostis-S-Z Nov 28, 2024
9d03672
Improve the TTS model loading process
Kostis-S-Z Nov 28, 2024
c0ef8b8
Add parler_tts to project dependencies
Kostis-S-Z Nov 28, 2024
2e6c8b2
Update TTS part of demo
Kostis-S-Z Nov 28, 2024
b8a5eeb
Fix wave module saving wav file
Kostis-S-Z Nov 28, 2024
3bdbf78
Use soundfile instead of wave for saving .wav file
Kostis-S-Z Nov 28, 2024
c35135c
Fix script format
Kostis-S-Z Nov 28, 2024
9ee2c91
fix(demo/app.py): Drop nested button. (#24)
daavoo Nov 28, 2024
efa5708
Merge branch 'main' into 6-audio-generation-component
daavoo Nov 28, 2024
dae21b3
fix(parse_script_to_waveform): Remove extra quote
daavoo Nov 28, 2024
be82900
Updates to demo to include audio part (#26)
daavoo Dec 3, 2024
0dd8b78
Merge from main
Kostis-S-Z Dec 3, 2024
42c1a44
Add minimum codespaces machine specifications
Kostis-S-Z Dec 3, 2024
96d88a8
Add text_to_speech reference in API.md docs
Kostis-S-Z Dec 3, 2024
76a92f8
Add note in README about cold start of demo taking long time
Kostis-S-Z Dec 3, 2024
43d9383
Add Troubleshooting section in README
Kostis-S-Z Dec 3, 2024
b3501c3
Use sample rate from model config instead of hardcoded
Kostis-S-Z Dec 3, 2024
fa2225a
Update install instructions readme and docs
Kostis-S-Z Dec 4, 2024
3b84710
Fix outdated docstring example
Kostis-S-Z Dec 4, 2024
35ab5c7
Fix imports and references of old repo name
Kostis-S-Z Dec 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -159,4 +159,7 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.idea/

# Generated audio files
*.wav
10 changes: 10 additions & 0 deletions demo/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from opennotebookllm.preprocessing import DATA_LOADERS, DATA_CLEANERS
from opennotebookllm.inference.model_loaders import load_llama_cpp_model
from opennotebookllm.inference.text_to_text import text_to_text_stream
from opennotebookllm.podcast_maker.script_to_audio import script_to_audio

PODCAST_PROMPT = """
You are a helpful podcast writer.
Expand Down Expand Up @@ -75,10 +76,19 @@
if st.button("Generate Podcast Script"):
with st.spinner("Generating Podcast Script..."):
text = ""
final_script = ""
for chunk in text_to_text_stream(
clean_text, model, system_prompt=system_prompt.strip()
):
text += chunk
final_script += chunk
if text.endswith("\n"):
st.write(text)
text = ""

if st.button("Generate Audio"):
filename = "demo_podcast.wav"
with st.spinner("Generating Audio..."):
script_to_audio(final_script, filename=filename)

st.audio(filename)
daavoo marked this conversation as resolved.
Show resolved Hide resolved
29 changes: 28 additions & 1 deletion src/opennotebookllm/inference/model_loaders.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from typing import Tuple

from llama_cpp import Llama
from parler_tts import ParlerTTSForConditionalGeneration
daavoo marked this conversation as resolved.
Show resolved Hide resolved
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedModel


def load_llama_cpp_model(
Expand All @@ -8,7 +12,7 @@ def load_llama_cpp_model(
Loads the given model_id using Llama.from_pretrained.

Examples:
>>> model = load_model(
>>> model = load_llama_cpp_model(
Kostis-S-Z marked this conversation as resolved.
Show resolved Hide resolved
"allenai/OLMoE-1B-7B-0924-Instruct-GGUF/olmoe-1b-7b-0924-instruct-q8_0.gguf")

Args:
Expand All @@ -26,3 +30,26 @@ def load_llama_cpp_model(
n_ctx=0,
)
return model


def load_parler_tts_model_and_tokenizer(
daavoo marked this conversation as resolved.
Show resolved Hide resolved
model_id: str, device: str = "cpu"
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
"""
Loads the given model_id using parler_tts.from_pretrained.

Examples:
>>> model = load_parler_tts_model_and_tokenizer("parler-tts/parler-tts-mini-v1", "cpu")

Args:
model_id (str): The model id to load.
Format is expected to be `{repo}/{filename}`.
device (str): The device to load the model on, such as "cuda:0" or "cpu".

Returns:
PreTrainedModel: The loaded model.
"""
model = ParlerTTSForConditionalGeneration.from_pretrained(model_id).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_id)

return model, tokenizer
49 changes: 49 additions & 0 deletions src/opennotebookllm/inference/text_to_speech.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import numpy as np
from src.opennotebookllm.inference.model_loaders import (
load_parler_tts_model_and_tokenizer,
)


default_speaker_1_description = "Laura's voice is exciting and fast in delivery with very clear audio and no background noise."
default_speaker_2_description = (
"Jon's voice is calm with very clear audio and no background noise."
daavoo marked this conversation as resolved.
Show resolved Hide resolved
)


def _speech_generation_parler(
input_text: str, model_id: str, speaker_description: str
) -> np.array:
model, tokenizer = load_parler_tts_model_and_tokenizer(model_id)

prompt_input_ids = tokenizer(input_text, return_tensors="pt").input_ids
input_ids = tokenizer(speaker_description, return_tensors="pt").input_ids

generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
waveform = generation.cpu().numpy().squeeze()

return waveform


def text_to_speech(
input_text: str,
model_id: str,
speaker_description: str = default_speaker_1_description,
) -> np.array:
"""
Generates a speech waveform using the input_text, a speaker description and a given model id.

Examples:
>>> waveform = text_to_speech("Welcome to our amazing podcast", "parler-tts/parler-tts-mini-v1", "Laura's voice is exciting and fast in delivery with very clear audio and no background noise.")

Args:
input_text (str): The text to convert to speech.
model_id (str): A model id from the registered models list.
speaker_description (str): A description in natural language of how we want the voice to sound.

Returns:
numpy array: The waveform of the speech as a 2D numpy array
"""
if "parler" in model_id:
return _speech_generation_parler(input_text, model_id, speaker_description)
else:
raise NotImplementedError(f"Model {model_id} not yet implemented for TTS")
Empty file.
34 changes: 34 additions & 0 deletions src/opennotebookllm/podcast_maker/script_to_audio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import numpy as np

from src.opennotebookllm.inference.text_to_speech import (
text_to_speech,
default_speaker_1_description,
default_speaker_2_description,
)
from scipy.io.wavfile import write


def script_to_audio(
script: str,
model_id: str = "parler-tts/parler-tts-mini-v1",
filename: str = "podcast.wav",
sampling_rate: int = 44_100,
):
parts = script.split("Speaker")
podcast_waveform = []
for part in parts:
if ":" in part:
speaker_id, speaker_text = part.split(":")
daavoo marked this conversation as resolved.
Show resolved Hide resolved
if int(speaker_id) == 1:
speaker_1 = text_to_speech(
speaker_text, model_id, default_speaker_1_description
)
podcast_waveform.append(speaker_1)
elif int(speaker_id) == 2:
speaker_2 = text_to_speech(
speaker_text, model_id, default_speaker_2_description
)
podcast_waveform.append(speaker_2)

podcast_waveform = np.concatenate(podcast_waveform)
write(filename, rate=sampling_rate, data=podcast_waveform)
15 changes: 15 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,18 @@
@pytest.fixture(scope="session")
def example_data():
return Path(__file__).parent.parent / "example_data"


@pytest.fixture()
def tts_prompt():
return "Wow you are really good at writing unit tests!"


@pytest.fixture()
def tts_speaker_description():
return "Laura's voice is enthusiastic and fast with a very close recording that has no background noise."


@pytest.fixture()
def podcast_script():
return "Speaker 1: Welcome to our podcast. Speaker 2: It's great to be here!"
12 changes: 12 additions & 0 deletions tests/unit/inference/test_text_to_speech.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from src.opennotebookllm.inference.text_to_speech import text_to_speech
from scipy.io.wavfile import write


def test_text_to_speech_parler(tts_prompt, tts_speaker_description):
daavoo marked this conversation as resolved.
Show resolved Hide resolved
waveform = text_to_speech(
input_text=tts_prompt,
speaker_description=tts_speaker_description,
model_id="parler-tts/parler-tts-mini-v1",
)

write("test_parler_tts.wav", rate=44_100, data=waveform)
9 changes: 9 additions & 0 deletions tests/unit/podcast_maker/test_script_to_audio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import os
from src.opennotebookllm.podcast_maker.script_to_audio import script_to_audio


def test_parse_script(podcast_script: str):
filename = "test_podcast.wav"
script_to_audio(podcast_script, filename=filename)

assert os.path.isfile(filename)