Skip to content

Commit

Permalink
fix: avoid nil in ToolProviderDeclaration
Browse files Browse the repository at this point in the history
  • Loading branch information
Yeuoly committed Nov 7, 2024
1 parent 115a5ce commit 67e67fd
Show file tree
Hide file tree
Showing 2 changed files with 239 additions and 0 deletions.
211 changes: 211 additions & 0 deletions cmd/commandline/init/templates/python/tts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
from collections.abc import Generator
import concurrent.futures
from functools import reduce
from io import BytesIO
from typing import Optional

from openai import OpenAI
from pydub import AudioSegment

from dify_plugin import TTSModel
from dify_plugin.errors.model import (
CredentialsValidateFailedError,
InvokeBadRequestError,
)
from ..common_openai import _CommonOpenAI


class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
"""
Model class for OpenAI Speech to text model.
"""

def _invoke(
self,
model: str,
credentials: dict,
content_text: str,
voice: str,
user: Optional[str] = None,
) -> bytes | Generator[bytes, None, None]:
"""
_invoke text2speech model
:param model: model name
:param tenant_id: user tenant id
:param credentials: model credentials
:param content_text: text content to be translated
:param voice: model timbre
:param user: unique user id
:return: text translated to audio file
"""

voices = self.get_tts_model_voices(model=model, credentials=credentials)
if not voices:
raise InvokeBadRequestError("No voices found for the model")

if not voice or voice not in [d["value"] for d in voices]:
voice = self._get_model_default_voice(model, credentials)

# if streaming:
return self._tts_invoke_streaming(
model=model, credentials=credentials, content_text=content_text, voice=voice
)

def validate_credentials(
self, model: str, credentials: dict, user: Optional[str] = None
) -> None:
"""
validate credentials text2speech model
:param model: model name
:param credentials: model credentials
:param user: unique user id
:return: text translated to audio file
"""
try:
self._tts_invoke(
model=model,
credentials=credentials,
content_text="Hello Dify!",
voice=self._get_model_default_voice(model, credentials),
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))

def _tts_invoke(
self, model: str, credentials: dict, content_text: str, voice: str
) -> bytes:
"""
_tts_invoke text2speech model
:param model: model name
:param credentials: model credentials
:param content_text: text content to be translated
:param voice: model timbre
:return: text translated to audio file
"""
audio_type = self._get_model_audio_type(model, credentials)
word_limit = self._get_model_word_limit(model, credentials) or 500
max_workers = self._get_model_workers_limit(model, credentials)
try:
sentences = list(
self._split_text_into_sentences(
org_text=content_text, max_length=word_limit
)
)
audio_bytes_list = []

# Create a thread pool and map the function to the list of sentences
with concurrent.futures.ThreadPoolExecutor(
max_workers=max_workers
) as executor:
futures = [
executor.submit(
self._process_sentence,
sentence=sentence,
model=model,
voice=voice,
credentials=credentials,
)
for sentence in sentences
]
for future in futures:
try:
if future.result():
audio_bytes_list.append(future.result())
except Exception as ex:
raise InvokeBadRequestError(str(ex))

if len(audio_bytes_list) > 0:
audio_segments = [
AudioSegment.from_file(BytesIO(audio_bytes), format=audio_type)
for audio_bytes in audio_bytes_list
if audio_bytes
]
combined_segment = reduce(lambda x, y: x + y, audio_segments)
buffer: BytesIO = BytesIO()
combined_segment.export(buffer, format=audio_type)
buffer.seek(0)

return buffer.read()
else:
raise InvokeBadRequestError("No audio bytes found")
except Exception as ex:
raise InvokeBadRequestError(str(ex))

def _tts_invoke_streaming(
self, model: str, credentials: dict, content_text: str, voice: str
) -> Generator[bytes, None, None]:
"""
_tts_invoke_streaming text2speech model
:param model: model name
:param credentials: model credentials
:param content_text: text content to be translated
:param voice: model timbre
:return: text translated to audio file
"""
try:
# doc: https://platform.openai.com/docs/guides/text-to-speech
credentials_kwargs = self._to_credential_kwargs(credentials)
client = OpenAI(**credentials_kwargs)

voices = self.get_tts_model_voices(model=model, credentials=credentials)
if not voices:
raise InvokeBadRequestError("No voices found for the model")

if not voice or voice not in voices:
voice = self._get_model_default_voice(model, credentials)

word_limit = self._get_model_word_limit(model, credentials) or 500
if len(content_text) > word_limit:
sentences = self._split_text_into_sentences(
content_text, max_length=word_limit
)
executor = concurrent.futures.ThreadPoolExecutor(
max_workers=min(3, len(sentences))
)
futures = [
executor.submit(
client.audio.speech.with_streaming_response.create,
model=model,
response_format="mp3",
input=sentences[i],
voice=voice, # type: ignore
)
for i in range(len(sentences))
]
for index, future in enumerate(futures):
yield from future.result().__enter__().iter_bytes(1024)

else:
response = client.audio.speech.with_streaming_response.create(
model=model,
voice=voice, # type: ignore
response_format="mp3",
input=content_text.strip(),
)

yield from response.__enter__().iter_bytes(1024)
except Exception as ex:
raise InvokeBadRequestError(str(ex))

def _process_sentence(self, sentence: str, model: str, voice, credentials: dict):
"""
_tts_invoke openai text2speech model api
:param model: model name
:param credentials: model credentials
:param voice: model timbre
:param sentence: text content to be translated
:return: text translated to audio file
"""
# transform credentials to kwargs for model instance
credentials_kwargs = self._to_credential_kwargs(credentials)
client = OpenAI(**credentials_kwargs)
response = client.audio.speech.create(
model=model, voice=voice, input=sentence.strip()
)
if isinstance(response.read(), bytes):
return response.read()
28 changes: 28 additions & 0 deletions internal/types/entities/plugin_entities/tool_declaration.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,18 @@ type ToolProviderDeclaration struct {
ToolFiles []string `json:"-" yaml:"-"`
}

func (t *ToolProviderDeclaration) MarshalJSON() ([]byte, error) {
type alias ToolProviderDeclaration
p := alias(*t)
if p.CredentialsSchema == nil {
p.CredentialsSchema = []ProviderConfig{}
}
if p.Tools == nil {
p.Tools = []ToolDeclaration{}
}
return json.Marshal(p)
}

func (t *ToolProviderDeclaration) UnmarshalYAML(value *yaml.Node) error {
type alias struct {
Identity ToolProviderIdentity `yaml:"identity"`
Expand Down Expand Up @@ -196,6 +208,14 @@ func (t *ToolProviderDeclaration) UnmarshalYAML(value *yaml.Node) error {
}
}

if t.CredentialsSchema == nil {
t.CredentialsSchema = []ProviderConfig{}
}

if t.Tools == nil {
t.Tools = []ToolDeclaration{}
}

return nil
}

Expand Down Expand Up @@ -250,6 +270,14 @@ func (t *ToolProviderDeclaration) UnmarshalJSON(data []byte) error {
}
}

if t.CredentialsSchema == nil {
t.CredentialsSchema = []ProviderConfig{}
}

if t.Tools == nil {
t.Tools = []ToolDeclaration{}
}

return nil
}

Expand Down

0 comments on commit 67e67fd

Please sign in to comment.