From 67e67fdd2da4631159eade4ed979ce24847f23aa Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Thu, 7 Nov 2024 16:27:51 +0800 Subject: [PATCH] fix: avoid nil in ToolProviderDeclaration --- cmd/commandline/init/templates/python/tts.py | 211 ++++++++++++++++++ .../plugin_entities/tool_declaration.go | 28 +++ 2 files changed, 239 insertions(+) diff --git a/cmd/commandline/init/templates/python/tts.py b/cmd/commandline/init/templates/python/tts.py index e69de29..2e60ccc 100644 --- a/cmd/commandline/init/templates/python/tts.py +++ b/cmd/commandline/init/templates/python/tts.py @@ -0,0 +1,211 @@ +from collections.abc import Generator +import concurrent.futures +from functools import reduce +from io import BytesIO +from typing import Optional + +from openai import OpenAI +from pydub import AudioSegment + +from dify_plugin import TTSModel +from dify_plugin.errors.model import ( + CredentialsValidateFailedError, + InvokeBadRequestError, +) +from ..common_openai import _CommonOpenAI + + +class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel): + """ + Model class for OpenAI Speech to text model. + """ + + def _invoke( + self, + model: str, + credentials: dict, + content_text: str, + voice: str, + user: Optional[str] = None, + ) -> bytes | Generator[bytes, None, None]: + """ + _invoke text2speech model + + :param model: model name + :param tenant_id: user tenant id + :param credentials: model credentials + :param content_text: text content to be translated + :param voice: model timbre + :param user: unique user id + :return: text translated to audio file + """ + + voices = self.get_tts_model_voices(model=model, credentials=credentials) + if not voices: + raise InvokeBadRequestError("No voices found for the model") + + if not voice or voice not in [d["value"] for d in voices]: + voice = self._get_model_default_voice(model, credentials) + + # if streaming: + return self._tts_invoke_streaming( + model=model, credentials=credentials, content_text=content_text, voice=voice + ) + + def validate_credentials( + self, model: str, credentials: dict, user: Optional[str] = None + ) -> None: + """ + validate credentials text2speech model + + :param model: model name + :param credentials: model credentials + :param user: unique user id + :return: text translated to audio file + """ + try: + self._tts_invoke( + model=model, + credentials=credentials, + content_text="Hello Dify!", + voice=self._get_model_default_voice(model, credentials), + ) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + def _tts_invoke( + self, model: str, credentials: dict, content_text: str, voice: str + ) -> bytes: + """ + _tts_invoke text2speech model + + :param model: model name + :param credentials: model credentials + :param content_text: text content to be translated + :param voice: model timbre + :return: text translated to audio file + """ + audio_type = self._get_model_audio_type(model, credentials) + word_limit = self._get_model_word_limit(model, credentials) or 500 + max_workers = self._get_model_workers_limit(model, credentials) + try: + sentences = list( + self._split_text_into_sentences( + org_text=content_text, max_length=word_limit + ) + ) + audio_bytes_list = [] + + # Create a thread pool and map the function to the list of sentences + with concurrent.futures.ThreadPoolExecutor( + max_workers=max_workers + ) as executor: + futures = [ + executor.submit( + self._process_sentence, + sentence=sentence, + model=model, + voice=voice, + credentials=credentials, + ) + for sentence in sentences + ] + for future in futures: + try: + if future.result(): + audio_bytes_list.append(future.result()) + except Exception as ex: + raise InvokeBadRequestError(str(ex)) + + if len(audio_bytes_list) > 0: + audio_segments = [ + AudioSegment.from_file(BytesIO(audio_bytes), format=audio_type) + for audio_bytes in audio_bytes_list + if audio_bytes + ] + combined_segment = reduce(lambda x, y: x + y, audio_segments) + buffer: BytesIO = BytesIO() + combined_segment.export(buffer, format=audio_type) + buffer.seek(0) + + return buffer.read() + else: + raise InvokeBadRequestError("No audio bytes found") + except Exception as ex: + raise InvokeBadRequestError(str(ex)) + + def _tts_invoke_streaming( + self, model: str, credentials: dict, content_text: str, voice: str + ) -> Generator[bytes, None, None]: + """ + _tts_invoke_streaming text2speech model + + :param model: model name + :param credentials: model credentials + :param content_text: text content to be translated + :param voice: model timbre + :return: text translated to audio file + """ + try: + # doc: https://platform.openai.com/docs/guides/text-to-speech + credentials_kwargs = self._to_credential_kwargs(credentials) + client = OpenAI(**credentials_kwargs) + + voices = self.get_tts_model_voices(model=model, credentials=credentials) + if not voices: + raise InvokeBadRequestError("No voices found for the model") + + if not voice or voice not in voices: + voice = self._get_model_default_voice(model, credentials) + + word_limit = self._get_model_word_limit(model, credentials) or 500 + if len(content_text) > word_limit: + sentences = self._split_text_into_sentences( + content_text, max_length=word_limit + ) + executor = concurrent.futures.ThreadPoolExecutor( + max_workers=min(3, len(sentences)) + ) + futures = [ + executor.submit( + client.audio.speech.with_streaming_response.create, + model=model, + response_format="mp3", + input=sentences[i], + voice=voice, # type: ignore + ) + for i in range(len(sentences)) + ] + for index, future in enumerate(futures): + yield from future.result().__enter__().iter_bytes(1024) + + else: + response = client.audio.speech.with_streaming_response.create( + model=model, + voice=voice, # type: ignore + response_format="mp3", + input=content_text.strip(), + ) + + yield from response.__enter__().iter_bytes(1024) + except Exception as ex: + raise InvokeBadRequestError(str(ex)) + + def _process_sentence(self, sentence: str, model: str, voice, credentials: dict): + """ + _tts_invoke openai text2speech model api + + :param model: model name + :param credentials: model credentials + :param voice: model timbre + :param sentence: text content to be translated + :return: text translated to audio file + """ + # transform credentials to kwargs for model instance + credentials_kwargs = self._to_credential_kwargs(credentials) + client = OpenAI(**credentials_kwargs) + response = client.audio.speech.create( + model=model, voice=voice, input=sentence.strip() + ) + if isinstance(response.read(), bytes): + return response.read() diff --git a/internal/types/entities/plugin_entities/tool_declaration.go b/internal/types/entities/plugin_entities/tool_declaration.go index 6e70c19..3611370 100644 --- a/internal/types/entities/plugin_entities/tool_declaration.go +++ b/internal/types/entities/plugin_entities/tool_declaration.go @@ -130,6 +130,18 @@ type ToolProviderDeclaration struct { ToolFiles []string `json:"-" yaml:"-"` } +func (t *ToolProviderDeclaration) MarshalJSON() ([]byte, error) { + type alias ToolProviderDeclaration + p := alias(*t) + if p.CredentialsSchema == nil { + p.CredentialsSchema = []ProviderConfig{} + } + if p.Tools == nil { + p.Tools = []ToolDeclaration{} + } + return json.Marshal(p) +} + func (t *ToolProviderDeclaration) UnmarshalYAML(value *yaml.Node) error { type alias struct { Identity ToolProviderIdentity `yaml:"identity"` @@ -196,6 +208,14 @@ func (t *ToolProviderDeclaration) UnmarshalYAML(value *yaml.Node) error { } } + if t.CredentialsSchema == nil { + t.CredentialsSchema = []ProviderConfig{} + } + + if t.Tools == nil { + t.Tools = []ToolDeclaration{} + } + return nil } @@ -250,6 +270,14 @@ func (t *ToolProviderDeclaration) UnmarshalJSON(data []byte) error { } } + if t.CredentialsSchema == nil { + t.CredentialsSchema = []ProviderConfig{} + } + + if t.Tools == nil { + t.Tools = []ToolDeclaration{} + } + return nil }