From 5c5527b5fa384eadff7d8cbecb9988d498214331 Mon Sep 17 00:00:00 2001 From: zhangqianze Date: Thu, 12 Dec 2024 23:40:05 +0800 Subject: [PATCH 1/3] feat: support gemini v2v --- .../extension/gemini_v2v_python/BUILD.gn | 21 + .../extension/gemini_v2v_python/README.md | 65 ++ .../extension/gemini_v2v_python/__init__.py | 8 + .../extension/gemini_v2v_python/addon.py | 21 + .../extension/gemini_v2v_python/extension.py | 555 ++++++++++++++++++ .../extension/gemini_v2v_python/manifest.json | 156 +++++ .../extension/gemini_v2v_python/property.json | 12 + .../gemini_v2v_python/requirements.txt | 2 + .../extension/openai_v2v_python/property.json | 2 +- playground/src/common/moduleConfig.ts | 6 + 10 files changed, 847 insertions(+), 1 deletion(-) create mode 100644 agents/ten_packages/extension/gemini_v2v_python/BUILD.gn create mode 100644 agents/ten_packages/extension/gemini_v2v_python/README.md create mode 100644 agents/ten_packages/extension/gemini_v2v_python/__init__.py create mode 100644 agents/ten_packages/extension/gemini_v2v_python/addon.py create mode 100644 agents/ten_packages/extension/gemini_v2v_python/extension.py create mode 100644 agents/ten_packages/extension/gemini_v2v_python/manifest.json create mode 100644 agents/ten_packages/extension/gemini_v2v_python/property.json create mode 100644 agents/ten_packages/extension/gemini_v2v_python/requirements.txt diff --git a/agents/ten_packages/extension/gemini_v2v_python/BUILD.gn b/agents/ten_packages/extension/gemini_v2v_python/BUILD.gn new file mode 100644 index 00000000..066a7ee4 --- /dev/null +++ b/agents/ten_packages/extension/gemini_v2v_python/BUILD.gn @@ -0,0 +1,21 @@ +# +# +# Agora Real Time Engagement +# Created by Wei Hu in 2022-11. +# Copyright (c) 2024 Agora IO. All rights reserved. +# +# +import("//build/feature/ten_package.gni") + +ten_package("gemini_v2v_python") { + package_kind = "extension" + + resources = [ + "__init__.py", + "addon.py", + "extension.py", + "log.py", + "manifest.json", + "property.json", + ] +} diff --git a/agents/ten_packages/extension/gemini_v2v_python/README.md b/agents/ten_packages/extension/gemini_v2v_python/README.md new file mode 100644 index 00000000..3cd294f3 --- /dev/null +++ b/agents/ten_packages/extension/gemini_v2v_python/README.md @@ -0,0 +1,65 @@ +# openai_v2v_python + +An extension for integrating OpenAI's Next Generation of **Multimodal** AI into your application, providing configurable AI-driven features such as conversational agents, task automation, and tool integration. + +## Features + + + +- OpenAI **Multimodal** Integration: Leverage GPT **Multimodal** models for voice to voice as well as text processing. +- Configurable: Easily customize API keys, model settings, prompts, temperature, etc. +- Async Queue Processing: Supports real-time message processing with task cancellation and prioritization. + + +## API + +Refer to `api` definition in [manifest.json] and default values in [property.json](property.json). + + + +| **Property** | **Type** | **Description** | +|----------------------------|------------|-------------------------------------------| +| `api_key` | `string` | API key for authenticating with OpenAI | +| `temperature` | `float64` | Sampling temperature, higher values mean more randomness | +| `model` | `string` | Model identifier (e.g., GPT-3.5, GPT-4) | +| `max_tokens` | `int64` | Maximum number of tokens to generate | +| `system_message` | `string` | Default system message to send to the model | +| `voice` | `string` | Voice that OpenAI model speeches, such as `alloy`, `echo`, `shimmer`, etc | +| `server_vad` | `bool` | Flag to enable or disable server vad of OpenAI | +| `language` | `string` | Language that OpenAO model reponds, such as `en-US`, `zh-CN`, etc | +| `dump` | `bool` | Flag to enable or disable audio dump for debugging purpose | + +### Data Out: +| **Name** | **Property** | **Type** | **Description** | +|----------------|--------------|------------|-------------------------------| +| `text_data` | `text` | `string` | Outgoing text data | + +### Command Out: +| **Name** | **Description** | +|----------------|---------------------------------------------| +| `flush` | Response after flushing the current state | + +### Audio Frame In: +| **Name** | **Description** | +|------------------|-------------------------------------------| +| `pcm_frame` | Audio frame input for voice processing | + +### Audio Frame Out: +| **Name** | **Description** | +|------------------|-------------------------------------------| +| `pcm_frame` | Audio frame output after voice processing | + + +### Azure Support + +This extension also support Azure OpenAI Service, the propoerty settings are as follow: + +``` json +{ + "base_uri": "wss://xxx.openai.azure.com", + "path": "/openai/realtime?api-version=xxx&deployment=xxx", + "api_key": "xxx", + "model": "gpt-4o-realtime-preview", + "vendor": "azure" +} +``` \ No newline at end of file diff --git a/agents/ten_packages/extension/gemini_v2v_python/__init__.py b/agents/ten_packages/extension/gemini_v2v_python/__init__.py new file mode 100644 index 00000000..8cd75dde --- /dev/null +++ b/agents/ten_packages/extension/gemini_v2v_python/__init__.py @@ -0,0 +1,8 @@ +# +# +# Agora Real Time Engagement +# Created by Wei Hu in 2024-08. +# Copyright (c) 2024 Agora IO. All rights reserved. +# +# +from . import addon diff --git a/agents/ten_packages/extension/gemini_v2v_python/addon.py b/agents/ten_packages/extension/gemini_v2v_python/addon.py new file mode 100644 index 00000000..13801742 --- /dev/null +++ b/agents/ten_packages/extension/gemini_v2v_python/addon.py @@ -0,0 +1,21 @@ +# +# +# Agora Real Time Engagement +# Created by Wei Hu in 2024-08. +# Copyright (c) 2024 Agora IO. All rights reserved. +# +# +from ten import ( + Addon, + register_addon_as_extension, + TenEnv, +) + + +@register_addon_as_extension("gemini_v2v_python") +class GeminiRealtimeExtensionAddon(Addon): + + def on_create_instance(self, ten_env: TenEnv, name: str, context) -> None: + from .extension import GeminiRealtimeExtension + ten_env.log_info("GeminiRealtimeExtensionAddon on_create_instance") + ten_env.on_create_instance_done(GeminiRealtimeExtension(name), context) diff --git a/agents/ten_packages/extension/gemini_v2v_python/extension.py b/agents/ten_packages/extension/gemini_v2v_python/extension.py new file mode 100644 index 00000000..af6e846c --- /dev/null +++ b/agents/ten_packages/extension/gemini_v2v_python/extension.py @@ -0,0 +1,555 @@ +# +# +# Agora Real Time Engagement +# Created by Wei Hu in 2024-08. +# Copyright (c) 2024 Agora IO. All rights reserved. +# +# +import asyncio +import base64 +from enum import Enum +import json +import traceback +import time +from google import genai +import numpy as np +from datetime import datetime +from typing import Iterable, cast + +import websockets + +from ten import ( + AudioFrame, + AsyncTenEnv, + Cmd, + StatusCode, + CmdResult, + Data, +) +from ten.audio_frame import AudioFrameDataFmt +from ten_ai_base.const import CMD_PROPERTY_RESULT, CMD_TOOL_CALL +from ten_ai_base.llm import AsyncLLMBaseExtension +from dataclasses import dataclass, field +from ten_ai_base import BaseConfig, ChatMemory, EVENT_MEMORY_EXPIRED, EVENT_MEMORY_APPENDED, LLMUsage, LLMCompletionTokensDetails, LLMPromptTokensDetails +from ten_ai_base.types import LLMToolMetadata, LLMToolResult, LLMChatCompletionContentPartParam, TTSPcmOptions +from google.genai.types import LiveServerMessage, LiveClientRealtimeInput, Blob, LiveConnectConfig, LiveConnectConfigDict, GenerationConfig, SpeechConfig, VoiceConfig, PrebuiltVoiceConfig, Content, Part, Tool, FunctionDeclaration, Schema, LiveClientToolResponse, FunctionCall, FunctionResponse +from google.genai.live import AsyncSession + +import urllib.parse +import google.genai._api_client + +google.genai._api_client.urllib = urllib + +CMD_IN_FLUSH = "flush" +CMD_IN_ON_USER_JOINED = "on_user_joined" +CMD_IN_ON_USER_LEFT = "on_user_left" +CMD_OUT_FLUSH = "flush" + +class Role(str, Enum): + User = "user" + Assistant = "assistant" + +@dataclass +class GeminiRealtimeConfig(BaseConfig): + base_uri: str = "generativelanguage.googleapis.com" + api_key: str = "" + api_version: str = "v1alpha" + model: str = "gemini-2.0-flash-exp" + language: str = "en-US" + prompt: str = "" + temperature: float = 0.5 + max_tokens: int = 1024 + voice: str = "Puck" + server_vad: bool = True + audio_out: bool = True + input_transcript: bool = True + sample_rate: int = 24000 + stream_id: int = 0 + dump: bool = False + greeting: str = "" + + def build_ctx(self) -> dict: + return { + "language": self.language, + "model": self.model, + } + +class GeminiRealtimeExtension(AsyncLLMBaseExtension): + def __init__(self, name): + super().__init__(name) + self.config: GeminiRealtimeConfig = None + self.stopped: bool = False + self.connected: bool = False + self.buffer: bytearray = b'' + self.memory: ChatMemory = None + self.total_usage: LLMUsage = LLMUsage() + self.users_count = 0 + + self.stream_id: int = 0 + self.remote_stream_id: int = 0 + self.channel_name: str = "" + self.audio_len_threshold: int = 5120 + + self.completion_times = [] + self.connect_times = [] + self.first_token_times = [] + + self.buff: bytearray = b'' + self.transcript: str = "" + self.ctx: dict = {} + self.input_end = time.time() + self.client = None + self.session:AsyncSession = None + self.leftover_bytes = b'' + + async def on_init(self, ten_env: AsyncTenEnv) -> None: + await super().on_init(ten_env) + ten_env.log_debug("on_init") + + async def on_start(self, ten_env: AsyncTenEnv) -> None: + await super().on_start(ten_env) + ten_env.log_debug("on_start") + + self.loop = asyncio.get_event_loop() + + self.config = GeminiRealtimeConfig.create(ten_env=ten_env) + ten_env.log_info(f"config: {self.config}") + + if not self.config.api_key: + ten_env.log_error("api_key is required") + return + + try: + self.ctx = self.config.build_ctx() + self.ctx["greeting"] = self.config.greeting + + + self.client = genai.Client( + api_key=self.config.api_key, + http_options={ + 'api_version': self.config.api_version, + 'url': self.config.base_uri, + } + + ) + self.loop.create_task(self._loop(ten_env)) + + + # self.loop.create_task(self._loop()) + except Exception as e: + traceback.print_exc() + self.ten_env.log_error(f"Failed to init client {e}") + + self.ten_env = ten_env + + async def _loop(self, ten_env: AsyncTenEnv) -> None: + while not self.stopped: + await asyncio.sleep(1) + try: + config:LiveConnectConfig = self._get_session_config() + ten_env.log_info(f"Start listen") + async with self.client.aio.live.connect(model=self.config.model, config=config) as session: + ten_env.log_info(f"Connected") + session = cast(AsyncSession, session) + self.session = session + self.connected = True + + await self._greeting() + + while True: + try: + async for response in session.receive(): + response = cast(LiveServerMessage, response) + try: + if response.server_content: + if response.server_content.interrupted: + ten_env.log_info(f"Interrupted") + await self._flush() + continue + elif not response.server_content.turn_complete and response.server_content.model_turn: + for part in response.server_content.model_turn.parts: + self.send_audio_out(ten_env, part.inline_data.data, sample_rate=24000, bytes_per_sample=2, number_of_channels=1) + elif response.server_content.turn_complete: + ten_env.log_info(f"Turn complete") + elif response.setup_complete: + ten_env.log_info(f"Setup complete") + elif response.tool_call: + func_calls = response.tool_call.function_calls + await self._handle_tool_call(func_calls) + except Exception as e: + traceback.print_exc() + ten_env.log_error(f"Failed to handle response") + + await self._flush() + ten_env.log_info(f"Finish listen") + except websockets.exceptions.ConnectionClosedOK: + ten_env.log_info("Connection closed") + break + except Exception as e: + self.ten_env.log_error(f"Failed to handle loop {e}") + + def send_audio_out(self, ten_env: AsyncTenEnv, audio_data: bytes, **args: TTSPcmOptions) -> None: + """End sending audio out.""" + sample_rate = args.get("sample_rate", 24000) + bytes_per_sample = args.get("bytes_per_sample", 2) + number_of_channels = args.get("number_of_channels", 1) + try: + # Combine leftover bytes with new audio data + combined_data = self.leftover_bytes + audio_data + + # Check if combined_data length is odd + if len(combined_data) % (bytes_per_sample * number_of_channels) != 0: + # Save the last incomplete frame + valid_length = len(combined_data) - (len(combined_data) % (bytes_per_sample * number_of_channels)) + self.leftover_bytes = combined_data[valid_length:] + combined_data = combined_data[:valid_length] + else: + self.leftover_bytes = b'' + + if combined_data: + f = AudioFrame.create("pcm_frame") + f.set_sample_rate(sample_rate) + f.set_bytes_per_sample(bytes_per_sample) + f.set_number_of_channels(number_of_channels) + f.set_data_fmt(AudioFrameDataFmt.INTERLEAVE) + f.set_samples_per_channel(len(combined_data) // (bytes_per_sample * number_of_channels)) + f.alloc_buf(len(combined_data)) + buff = f.lock_buf() + buff[:] = combined_data + f.unlock_buf(buff) + ten_env.send_audio_frame(f) + except Exception as e: + pass + # ten_env.log_error(f"error send audio frame, {traceback.format_exc()}") + + async def on_stop(self, ten_env: AsyncTenEnv) -> None: + await super().on_stop(ten_env) + ten_env.log_info("on_stop") + + self.stopped = True + if self.session: + await self.session.close() + + async def on_audio_frame(self, ten_env: AsyncTenEnv, audio_frame: AudioFrame) -> None: + try: + stream_id = audio_frame.get_property_int("stream_id") + if self.channel_name == "": + self.channel_name = audio_frame.get_property_string("channel") + + if self.remote_stream_id == 0: + self.remote_stream_id = stream_id + + frame_buf = audio_frame.get_buf() + self._dump_audio_if_need(frame_buf, Role.User) + + await self._on_audio(frame_buf) + if not self.config.server_vad: + self.input_end = time.time() + except Exception as e: + traceback.print_exc() + self.ten_env.log_error(f"on audio frame failed {e}") + + async def on_cmd(self, ten_env: AsyncTenEnv, cmd: Cmd) -> None: + cmd_name = cmd.get_name() + ten_env.log_debug("on_cmd name {}".format(cmd_name)) + + status = StatusCode.OK + detail = "success" + + if cmd_name == CMD_IN_FLUSH: + # Will only flush if it is client side vad + await self._flush() + await ten_env.send_cmd(Cmd.create(CMD_OUT_FLUSH)) + ten_env.log_info("on flush") + elif cmd_name == CMD_IN_ON_USER_JOINED: + self.users_count += 1 + # Send greeting when first user joined + if self.users_count == 1: + await self._greeting() + elif cmd_name == CMD_IN_ON_USER_LEFT: + self.users_count -= 1 + else: + # Register tool + await super().on_cmd(ten_env, cmd) + return + + cmd_result = CmdResult.create(status) + cmd_result.set_property_string("detail", detail) + ten_env.return_result(cmd_result, cmd) + + # Not support for now + async def on_data(self, ten_env: AsyncTenEnv, data: Data) -> None: + pass + + # Direction: IN + async def _on_audio(self, buff: bytearray): + self.buff += buff + # Buffer audio + if self.connected and len(self.buff) >= self.audio_len_threshold: + # await self.conn.send_audio_data(self.buff) + try: + await self.session.send(LiveClientRealtimeInput(media_chunks=[Blob(data=self.buff, mime_type="audio/pcm")])) + self.buff = b'' + except Exception as e: + pass + # self.ten_env.log_error(f"Failed to send audio {e}") + + def _get_session_config(self) -> LiveConnectConfigDict: + def tool_dict(tool: LLMToolMetadata): + required = [] + properties:dict[str, "Schema"] = {} + + for param in tool.parameters: + properties[param.name] = Schema( + type=param.type.upper(), + description=param.description + ) + if param.required: + required.append(param.name) + + + t = Tool( + function_declarations=[FunctionDeclaration( + name=tool.name, + description=tool.description, + parameters=Schema( + type="OBJECT", + properties=properties, + required=required + ) + )]) + + return t + + tools = [tool_dict(t) for t in self.available_tools] if len(self.available_tools) > 0 else None + config = LiveConnectConfig( + response_modalities=["AUDIO"], + system_instruction=Content(parts=[Part(text=self.config.prompt)]), + tools=tools, + # voice is currently not working + # speech_config=SpeechConfig( + # voice_config=VoiceConfig( + # prebuilt_voice_config=PrebuiltVoiceConfig( + # voice_name=self.config.voice + # ) + # ) + # ), + generation_config=GenerationConfig( + temperature=self.config.temperature, + max_output_tokens=self.config.max_tokens + ) + ) + + return config + + async def on_tools_update(self, ten_env: AsyncTenEnv, tool: LLMToolMetadata) -> None: + """Called when a new tool is registered. Implement this method to process the new tool.""" + self.ten_env.log_info(f"on tools update {tool}") + # await self._update_session() + + def _replace(self, prompt: str) -> str: + result = prompt + for token, value in self.ctx.items(): + result = result.replace("{"+token+"}", value) + return result + + # Direction: OUT + def _on_audio_delta(self, delta: bytes) -> None: + audio_data = base64.b64decode(delta) + self.ten_env.log_debug(f"on_audio_delta audio_data len {len(audio_data)} samples {len(audio_data) // 2}") + self._dump_audio_if_need(audio_data, Role.Assistant) + + f = AudioFrame.create("pcm_frame") + f.set_sample_rate(self.config.sample_rate) + f.set_bytes_per_sample(2) + f.set_number_of_channels(1) + f.set_data_fmt(AudioFrameDataFmt.INTERLEAVE) + f.set_samples_per_channel(len(audio_data) // 2) + f.alloc_buf(len(audio_data)) + buff = f.lock_buf() + buff[:] = audio_data + f.unlock_buf(buff) + self.ten_env.send_audio_frame(f) + + def _send_transcript(self, content: str, role: Role, is_final: bool) -> None: + def is_punctuation(char): + if char in [",", ",", ".", "。", "?", "?", "!", "!"]: + return True + return False + + def parse_sentences(sentence_fragment, content): + sentences = [] + current_sentence = sentence_fragment + for char in content: + current_sentence += char + if is_punctuation(char): + # Check if the current sentence contains non-punctuation characters + stripped_sentence = current_sentence + if any(c.isalnum() for c in stripped_sentence): + sentences.append(stripped_sentence) + current_sentence = "" # Reset for the next sentence + + remain = current_sentence # Any remaining characters form the incomplete sentence + return sentences, remain + + def send_data(ten_env: AsyncTenEnv, sentence: str, stream_id: int, role: str, is_final: bool): + try: + d = Data.create("text_data") + d.set_property_string("text", sentence) + d.set_property_bool("end_of_segment", is_final) + d.set_property_string("role", role) + d.set_property_int("stream_id", stream_id) + ten_env.log_info( + f"send transcript text [{sentence}] stream_id {stream_id} is_final {is_final} end_of_segment {is_final} role {role}") + ten_env.send_data(d) + except Exception as e: + ten_env.log_error(f"Error send text data {role}: {sentence} {is_final} {e}") + + stream_id = self.remote_stream_id if role == Role.User else 0 + try: + if role == Role.Assistant and not is_final: + sentences, self.transcript = parse_sentences(self.transcript, content) + for s in sentences: + send_data(self.ten_env, s, stream_id, role, is_final) + else: + send_data(self.ten_env, content, stream_id, role, is_final) + except Exception as e: + self.ten_env.log_error(f"Error send text data {role}: {content} {is_final} {e}") + + def _dump_audio_if_need(self, buf: bytearray, role: Role) -> None: + if not self.config.dump: + return + + with open("{}_{}.pcm".format(role, self.channel_name), "ab") as dump_file: + dump_file.write(buf) + + async def _handle_tool_call(self, func_calls:list[FunctionCall] ) -> None: + function_responses = [] + for call in func_calls: + tool_call_id = call.id + name = call.name + arguments = call.args + self.ten_env.log_info(f"_handle_tool_call {tool_call_id} {name} {arguments}") + cmd: Cmd = Cmd.create(CMD_TOOL_CALL) + cmd.set_property_string("name", name) + cmd.set_property_from_json("arguments", json.dumps(arguments)) + result: CmdResult = await self.ten_env.send_cmd(cmd) + + func_response = FunctionResponse( + id=tool_call_id, + name=name, + response={"error":"Failed to call tool"} + ) + if result.get_status_code() == StatusCode.OK: + tool_result: LLMToolResult = json.loads( + result.get_property_to_json(CMD_PROPERTY_RESULT)) + + result_content = tool_result["content"] + func_response = FunctionResponse( + id=tool_call_id, + name=name, + response={ + "output": result_content + } + ) + self.ten_env.log_info(f"tool_result: {tool_call_id} {tool_result}") + else: + self.ten_env.log_error(f"Tool call failed") + function_responses.append(func_response) + # await self.conn.send_request(tool_response) + # await self.conn.send_request(ResponseCreate()) + self.ten_env.log_info(f"_remote_tool_call finish {name} {arguments}") + await self.session.send(LiveClientToolResponse(function_responses=function_responses)) + + def _greeting_text(self) -> str: + text = "Hi, there." + if self.config.language == "zh-CN": + text = "你好。" + elif self.config.language == "ja-JP": + text = "こんにちは" + elif self.config.language == "ko-KR": + text = "안녕하세요" + return text + + + def _convert_tool_params_to_dict(self, tool: LLMToolMetadata): + json = { + "type": "object", + "properties": {}, + "required": [] + } + + for param in tool.parameters: + json["properties"][param.name] = { + "type": param.type, + "description": param.description + } + if param.required: + json["required"].append(param.name) + + return json + + + def _convert_to_content_parts(self, content: Iterable[LLMChatCompletionContentPartParam]): + content_parts = [] + + + if isinstance(content, str): + content_parts.append({ + "type": "text", + "text": content + }) + else: + for part in content: + # Only text content is supported currently for v2v model + if part["type"] == "text": + content_parts.append(part) + return content_parts + + async def _greeting(self) -> None: + if self.connected and self.users_count == 1: + text = self._greeting_text() + if self.config.greeting: + text = "Say '" + self.config.greeting + "' to me." + self.ten_env.log_info(f"send greeting {text}") + await self.session.send(text, end_of_turn=True) + + async def _flush(self) -> None: + try: + c = Cmd.create("flush") + await self.ten_env.send_cmd(c) + except: + self.ten_env.log_error(f"Error flush") + + async def _update_usage(self, usage: dict) -> None: + self.total_usage.completion_tokens += usage.get("output_tokens") + self.total_usage.prompt_tokens += usage.get("input_tokens") + self.total_usage.total_tokens += usage.get("total_tokens") + if not self.total_usage.completion_tokens_details: + self.total_usage.completion_tokens_details = LLMCompletionTokensDetails() + if not self.total_usage.prompt_tokens_details: + self.total_usage.prompt_tokens_details = LLMPromptTokensDetails() + + if usage.get("output_token_details"): + self.total_usage.completion_tokens_details.accepted_prediction_tokens += usage["output_token_details"].get("text_tokens") + self.total_usage.completion_tokens_details.audio_tokens += usage["output_token_details"].get("audio_tokens") + + if usage.get("input_token_details:"): + self.total_usage.prompt_tokens_details.audio_tokens += usage["input_token_details"].get("audio_tokens") + self.total_usage.prompt_tokens_details.cached_tokens += usage["input_token_details"].get("cached_tokens") + self.total_usage.prompt_tokens_details.text_tokens += usage["input_token_details"].get("text_tokens") + + self.ten_env.log_info(f"total usage: {self.total_usage}") + + data = Data.create("llm_stat") + data.set_property_from_json("usage", json.dumps(self.total_usage.model_dump())) + if self.connect_times and self.completion_times and self.first_token_times: + data.set_property_from_json("latency", json.dumps({ + "connection_latency_95": np.percentile(self.connect_times, 95), + "completion_latency_95": np.percentile(self.completion_times, 95), + "first_token_latency_95": np.percentile(self.first_token_times, 95), + "connection_latency_99": np.percentile(self.connect_times, 99), + "completion_latency_99": np.percentile(self.completion_times, 99), + "first_token_latency_99": np.percentile(self.first_token_times, 99) + })) + self.ten_env.send_data(data) diff --git a/agents/ten_packages/extension/gemini_v2v_python/manifest.json b/agents/ten_packages/extension/gemini_v2v_python/manifest.json new file mode 100644 index 00000000..a1bb9ea5 --- /dev/null +++ b/agents/ten_packages/extension/gemini_v2v_python/manifest.json @@ -0,0 +1,156 @@ +{ + "type": "extension", + "name": "gemini_v2v_python", + "version": "0.1.0", + "dependencies": [ + { + "type": "system", + "name": "ten_runtime_python", + "version": "0.4" + } + ], + "package": { + "include": [ + "manifest.json", + "property.json", + "BUILD.gn", + "**.tent", + "**.py", + "README.md", + "realtime/**.tent", + "realtime/**.py" + ] + }, + "api": { + "property": { + "base_uri": { + "type": "string" + }, + "api_key": { + "type": "string" + }, + "api_version": { + "type": "string" + }, + "model": { + "type": "string" + }, + "language": { + "type": "string" + }, + "prompt": { + "type": "string" + }, + "temperature": { + "type": "float32" + }, + "max_tokens": { + "type": "int32" + }, + "voice": { + "type": "string" + }, + "server_vad": { + "type": "bool" + }, + "audio_out": { + "type": "bool" + }, + "input_transcript": { + "type": "bool" + }, + "sample_rate": { + "type": "int32" + }, + "stream_id": { + "type": "int32" + }, + "dump": { + "type": "bool" + }, + "greeting": { + "type": "string" + } + }, + "audio_frame_in": [ + { + "name": "pcm_frame", + "property": { + "stream_id": { + "type": "int64" + } + } + } + ], + "data_out": [ + { + "name": "text_data", + "property": { + "text": { + "type": "string" + } + } + }, + { + "name": "append", + "property": { + "text": { + "type": "string" + } + } + } + ], + "cmd_in": [ + { + "name": "tool_register", + "property": { + "name": { + "type": "string" + }, + "description": { + "type": "string" + }, + "parameters": { + "type": "string" + } + }, + "required": [ + "name", + "description", + "parameters" + ], + "result": { + "property": { + "response": { + "type": "string" + } + } + } + } + ], + "cmd_out": [ + { + "name": "flush" + }, + { + "name": "tool_call", + "property": { + "name": { + "type": "string" + }, + "args": { + "type": "string" + } + }, + "required": [ + "name" + ] + } + ], + "audio_frame_out": [ + { + "name": "pcm_frame" + } + ] + } +} \ No newline at end of file diff --git a/agents/ten_packages/extension/gemini_v2v_python/property.json b/agents/ten_packages/extension/gemini_v2v_python/property.json new file mode 100644 index 00000000..25a81332 --- /dev/null +++ b/agents/ten_packages/extension/gemini_v2v_python/property.json @@ -0,0 +1,12 @@ +{ + "api_key": "${env:GEMINI_API_KEY}", + "temperature": 0.9, + "base_uri": "generativelanguage.googleapis.com", + "model": "gemini-2.0-flash-exp", + "api_version": "v1alpha", + "max_tokens": 2048, + "voice": "Puck", + "language": "en-US", + "server_vad": true, + "dump": true +} \ No newline at end of file diff --git a/agents/ten_packages/extension/gemini_v2v_python/requirements.txt b/agents/ten_packages/extension/gemini_v2v_python/requirements.txt new file mode 100644 index 00000000..6f9ab33f --- /dev/null +++ b/agents/ten_packages/extension/gemini_v2v_python/requirements.txt @@ -0,0 +1,2 @@ +asyncio +google-genai \ No newline at end of file diff --git a/agents/ten_packages/extension/openai_v2v_python/property.json b/agents/ten_packages/extension/openai_v2v_python/property.json index 90392f6a..29e689ce 100644 --- a/agents/ten_packages/extension/openai_v2v_python/property.json +++ b/agents/ten_packages/extension/openai_v2v_python/property.json @@ -8,5 +8,5 @@ "server_vad": true, "dump": true, "history": 10, - "enable_storage": true + "enable_storage": false } \ No newline at end of file diff --git a/playground/src/common/moduleConfig.ts b/playground/src/common/moduleConfig.ts index 0b21c173..e18ab995 100644 --- a/playground/src/common/moduleConfig.ts +++ b/playground/src/common/moduleConfig.ts @@ -101,6 +101,11 @@ export const v2vModuleRegistry: Record = { name: "openai_v2v_python", type: ModuleRegistry.ModuleType.V2V, label: "OpenAI Realtime", + }, + gemini_v2v_python: { + name: "gemini_v2v_python", + type: ModuleRegistry.ModuleType.V2V, + label: "Gemini Realtime", } } @@ -137,4 +142,5 @@ export const moduleRegistry: Record = { export const compatibleTools: Record = { openai_chatgpt_python: ["vision_tool_python", "weatherapi_tool_python", "bingsearch_tool_python"], openai_v2v_python: ["weatherapi_tool_python", "bingsearch_tool_python"], + gemini_v2v_python: ["weatherapi_tool_python", "bingsearch_tool_python"], } \ No newline at end of file From daec880d8d10ae2baae87d32caa2f9d173ba7161 Mon Sep 17 00:00:00 2001 From: zhangqianze Date: Fri, 13 Dec 2024 02:39:18 +0800 Subject: [PATCH 2/3] feat: update readme --- .../extension/gemini_v2v_python/README.md | 72 ++++++------ .../extension/gemini_v2v_python/extension.py | 108 +++++++++++++++++- .../extension/gemini_v2v_python/manifest.json | 6 + .../gemini_v2v_python/requirements.txt | 2 +- playground/src/common/graph.ts | 3 + playground/src/common/hooks.ts | 2 +- .../components/Chat/ChatCfgModuleSelect.tsx | 50 +++++--- playground/src/components/ui/dropdown.tsx | 26 ++++- playground/src/store/reducers/global.ts | 17 ++- 9 files changed, 224 insertions(+), 62 deletions(-) diff --git a/agents/ten_packages/extension/gemini_v2v_python/README.md b/agents/ten_packages/extension/gemini_v2v_python/README.md index 3cd294f3..e43f7041 100644 --- a/agents/ten_packages/extension/gemini_v2v_python/README.md +++ b/agents/ten_packages/extension/gemini_v2v_python/README.md @@ -1,65 +1,63 @@ -# openai_v2v_python +# gemini_v2v_python -An extension for integrating OpenAI's Next Generation of **Multimodal** AI into your application, providing configurable AI-driven features such as conversational agents, task automation, and tool integration. +An extension for integrating Gemini's Next Generation of **Multimodal** AI into your application, providing configurable AI-driven features such as conversational agents, task automation, and tool integration. ## Features - - -- OpenAI **Multimodal** Integration: Leverage GPT **Multimodal** models for voice to voice as well as text processing. +- Gemini **Multimodal** Integration: Leverage Gemini **Multimodal** models for voice-to-voice as well as text processing. - Configurable: Easily customize API keys, model settings, prompts, temperature, etc. - Async Queue Processing: Supports real-time message processing with task cancellation and prioritization. - ## API -Refer to `api` definition in [manifest.json] and default values in [property.json](property.json). - - +Refer to the `api` definition in [manifest.json] and default values in [property.json](property.json). | **Property** | **Type** | **Description** | |----------------------------|------------|-------------------------------------------| -| `api_key` | `string` | API key for authenticating with OpenAI | -| `temperature` | `float64` | Sampling temperature, higher values mean more randomness | -| `model` | `string` | Model identifier (e.g., GPT-3.5, GPT-4) | -| `max_tokens` | `int64` | Maximum number of tokens to generate | -| `system_message` | `string` | Default system message to send to the model | -| `voice` | `string` | Voice that OpenAI model speeches, such as `alloy`, `echo`, `shimmer`, etc | -| `server_vad` | `bool` | Flag to enable or disable server vad of OpenAI | -| `language` | `string` | Language that OpenAO model reponds, such as `en-US`, `zh-CN`, etc | -| `dump` | `bool` | Flag to enable or disable audio dump for debugging purpose | - -### Data Out: +| `api_key` | `string` | API key for authenticating with Gemini | +| `temperature` | `float32` | Sampling temperature, higher values mean more randomness | +| `model` | `string` | Model identifier (e.g., GPT-4, Gemini-1) | +| `max_tokens` | `int32` | Maximum number of tokens to generate | +| `system_message` | `string` | Default system message to send to the model | +| `voice` | `string` | Voice that Gemini model uses, such as `alloy`, `echo`, `shimmer`, etc. | +| `server_vad` | `bool` | Flag to enable or disable server VAD for Gemini | +| `language` | `string` | Language that Gemini model responds in, such as `en-US`, `zh-CN`, etc. | +| `dump` | `bool` | Flag to enable or disable audio dump for debugging purposes | +| `base_uri` | `string` | Base URI for connecting to the Gemini service | +| `audio_out` | `bool` | Flag to enable or disable audio output | +| `input_transcript` | `bool` | Flag to enable input transcript processing | +| `sample_rate` | `int32` | Sample rate for audio processing | +| `stream_id` | `int32` | Stream ID for identifying audio streams | +| `greeting` | `string` | Greeting message for initial interaction | + +### Data Out + | **Name** | **Property** | **Type** | **Description** | |----------------|--------------|------------|-------------------------------| | `text_data` | `text` | `string` | Outgoing text data | +| `append` | `text` | `string` | Additional text appended to the output | + +### Command Out -### Command Out: | **Name** | **Description** | |----------------|---------------------------------------------| | `flush` | Response after flushing the current state | +| `tool_call` | Invokes a tool with specific arguments | + +### Audio Frame In -### Audio Frame In: | **Name** | **Description** | |------------------|-------------------------------------------| | `pcm_frame` | Audio frame input for voice processing | -### Audio Frame Out: +### Video Frame In + | **Name** | **Description** | |------------------|-------------------------------------------| -| `pcm_frame` | Audio frame output after voice processing | - - -### Azure Support +| `video_frame` | Video frame input for processing | -This extension also support Azure OpenAI Service, the propoerty settings are as follow: +### Audio Frame Out -``` json -{ - "base_uri": "wss://xxx.openai.azure.com", - "path": "/openai/realtime?api-version=xxx&deployment=xxx", - "api_key": "xxx", - "model": "gpt-4o-realtime-preview", - "vendor": "azure" -} -``` \ No newline at end of file +| **Name** | **Description** | +|------------------|-------------------------------------------| +| `pcm_frame` | Audio frame output after voice processing | diff --git a/agents/ten_packages/extension/gemini_v2v_python/extension.py b/agents/ten_packages/extension/gemini_v2v_python/extension.py index af6e846c..fe802c97 100644 --- a/agents/ten_packages/extension/gemini_v2v_python/extension.py +++ b/agents/ten_packages/extension/gemini_v2v_python/extension.py @@ -34,6 +34,9 @@ from ten_ai_base.types import LLMToolMetadata, LLMToolResult, LLMChatCompletionContentPartParam, TTSPcmOptions from google.genai.types import LiveServerMessage, LiveClientRealtimeInput, Blob, LiveConnectConfig, LiveConnectConfigDict, GenerationConfig, SpeechConfig, VoiceConfig, PrebuiltVoiceConfig, Content, Part, Tool, FunctionDeclaration, Schema, LiveClientToolResponse, FunctionCall, FunctionResponse from google.genai.live import AsyncSession +from PIL import Image +from io import BytesIO +from base64 import b64encode import urllib.parse import google.genai._api_client @@ -49,6 +52,62 @@ class Role(str, Enum): User = "user" Assistant = "assistant" + +def rgb2base64jpeg(rgb_data, width, height): + # Convert the RGB image to a PIL Image + pil_image = Image.frombytes("RGBA", (width, height), bytes(rgb_data)) + pil_image = pil_image.convert("RGB") + + # Resize the image while maintaining its aspect ratio + pil_image = resize_image_keep_aspect(pil_image, 512) + + # Save the image to a BytesIO object in JPEG format + buffered = BytesIO() + pil_image.save(buffered, format="JPEG") + pil_image.save("test.jpg", format="JPEG") + + # Get the byte data of the JPEG image + jpeg_image_data = buffered.getvalue() + + # Convert the JPEG byte data to a Base64 encoded string + base64_encoded_image = b64encode(jpeg_image_data).decode("utf-8") + + # Create the data URL + # mime_type = "image/jpeg" + return base64_encoded_image + +def resize_image_keep_aspect(image, max_size=512): + """ + Resize an image while maintaining its aspect ratio, ensuring the larger dimension is max_size. + If both dimensions are smaller than max_size, the image is not resized. + + :param image: A PIL Image object + :param max_size: The maximum size for the larger dimension (width or height) + :return: A PIL Image object (resized or original) + """ + # Get current width and height + width, height = image.size + + # If both dimensions are already smaller than max_size, return the original image + if width <= max_size and height <= max_size: + return image + + # Calculate the aspect ratio + aspect_ratio = width / height + + # Determine the new dimensions + if width > height: + new_width = max_size + new_height = int(max_size / aspect_ratio) + else: + new_height = max_size + new_width = int(max_size * aspect_ratio) + + # Resize the image with the new dimensions + resized_image = image.resize((new_width, new_height)) + + return resized_image + @dataclass class GeminiRealtimeConfig(BaseConfig): base_uri: str = "generativelanguage.googleapis.com" @@ -101,6 +160,9 @@ def __init__(self, name): self.client = None self.session:AsyncSession = None self.leftover_bytes = b'' + self.video_task = None + self.image_queue = asyncio.Queue() + self.video_buff: str = "" async def on_init(self, ten_env: AsyncTenEnv) -> None: await super().on_init(ten_env) @@ -133,7 +195,7 @@ async def on_start(self, ten_env: AsyncTenEnv) -> None: ) self.loop.create_task(self._loop(ten_env)) - + self.loop.create_task(self._on_video(ten_env)) # self.loop.create_task(self._loop()) except Exception as e: @@ -231,6 +293,7 @@ async def on_stop(self, ten_env: AsyncTenEnv) -> None: await self.session.close() async def on_audio_frame(self, ten_env: AsyncTenEnv, audio_frame: AudioFrame) -> None: + await super().on_audio_frame(ten_env, audio_frame) try: stream_id = audio_frame.get_property_int("stream_id") if self.channel_name == "": @@ -281,6 +344,37 @@ async def on_cmd(self, ten_env: AsyncTenEnv, cmd: Cmd) -> None: async def on_data(self, ten_env: AsyncTenEnv, data: Data) -> None: pass + async def on_video_frame(self, async_ten_env, video_frame): + await super().on_video_frame(async_ten_env, video_frame) + image_data = video_frame.get_buf() + image_width = video_frame.get_width() + image_height = video_frame.get_height() + await self.image_queue.put([image_data, image_width, image_height]) + + + async def _on_video(self, ten_env:AsyncTenEnv): + while True: + + # Process the first frame from the queue + [image_data, image_width, image_height] = await self.image_queue.get() + self.video_buff = rgb2base64jpeg(image_data, image_width, image_height) + media_chunks = [{ + "data": self.video_buff, + "mime_type": "image/jpeg", + }] + try: + if self.connected: + await self.session.send(media_chunks) + except Exception as e: + self.ten_env.log_error(f"Failed to send image {e}") + + # Skip remaining frames for the second + while not self.image_queue.empty(): + await self.image_queue.get() + + # Wait for 1 second before processing the next frame + await asyncio.sleep(1) + # Direction: IN async def _on_audio(self, buff: bytearray): self.buff += buff @@ -288,11 +382,16 @@ async def _on_audio(self, buff: bytearray): if self.connected and len(self.buff) >= self.audio_len_threshold: # await self.conn.send_audio_data(self.buff) try: - await self.session.send(LiveClientRealtimeInput(media_chunks=[Blob(data=self.buff, mime_type="audio/pcm")])) + media_chunks = [{ + "data": base64.b64encode(self.buff).decode(), + "mime_type": "audio/pcm", + }] + # await self.session.send(LiveClientRealtimeInput(media_chunks=media_chunks)) + await self.session.send(media_chunks) self.buff = b'' except Exception as e: - pass - # self.ten_env.log_error(f"Failed to send audio {e}") + # pass + self.ten_env.log_error(f"Failed to send audio {e}") def _get_session_config(self) -> LiveConnectConfigDict: def tool_dict(tool: LLMToolMetadata): @@ -553,3 +652,4 @@ async def _update_usage(self, usage: dict) -> None: "first_token_latency_99": np.percentile(self.first_token_times, 99) })) self.ten_env.send_data(data) + diff --git a/agents/ten_packages/extension/gemini_v2v_python/manifest.json b/agents/ten_packages/extension/gemini_v2v_python/manifest.json index a1bb9ea5..4c25224e 100644 --- a/agents/ten_packages/extension/gemini_v2v_python/manifest.json +++ b/agents/ten_packages/extension/gemini_v2v_python/manifest.json @@ -82,6 +82,12 @@ } } ], + "video_frame_in": [ + { + "name": "video_frame", + "property": {} + } + ], "data_out": [ { "name": "text_data", diff --git a/agents/ten_packages/extension/gemini_v2v_python/requirements.txt b/agents/ten_packages/extension/gemini_v2v_python/requirements.txt index 6f9ab33f..5370cc22 100644 --- a/agents/ten_packages/extension/gemini_v2v_python/requirements.txt +++ b/agents/ten_packages/extension/gemini_v2v_python/requirements.txt @@ -1,2 +1,2 @@ asyncio -google-genai \ No newline at end of file +google-genai==0.2.1 \ No newline at end of file diff --git a/playground/src/common/graph.ts b/playground/src/common/graph.ts index 7d1cc42a..3239aecf 100644 --- a/playground/src/common/graph.ts +++ b/playground/src/common/graph.ts @@ -420,6 +420,9 @@ class GraphEditor { // If no protocolLabel is provided, remove the entire connection graph.connections.splice(connectionIndex, 1) } + + // Clean up empty connections + GraphEditor.removeEmptyConnections(graph); } static findNode(graph: Graph, nodeName: string): Node | null { diff --git a/playground/src/common/hooks.ts b/playground/src/common/hooks.ts index 405eec32..dad89b89 100644 --- a/playground/src/common/hooks.ts +++ b/playground/src/common/hooks.ts @@ -153,7 +153,7 @@ const useGraphs = () => { } const update = async (graphId: string, updates: Partial) => { - await dispatch(updateGraph({ graphId, updates })) + await dispatch(updateGraph({ graphId, updates })).unwrap() } const getGraphNodeAddonByName = useCallback( diff --git a/playground/src/components/Chat/ChatCfgModuleSelect.tsx b/playground/src/components/Chat/ChatCfgModuleSelect.tsx index dcae871c..fa9842a8 100644 --- a/playground/src/components/Chat/ChatCfgModuleSelect.tsx +++ b/playground/src/components/Chat/ChatCfgModuleSelect.tsx @@ -31,13 +31,13 @@ import { import { Button } from "@/components/ui/button" import { cn } from "@/lib/utils" import { useAppSelector, useGraphs, } from "@/common/hooks" -import { AddonDef, Graph, Destination, GraphEditor, ProtocolLabel as GraphConnProtocol } from "@/common/graph" +import { AddonDef, Graph, Destination, GraphEditor, ProtocolLabel as GraphConnProtocol, ProtocolLabel } from "@/common/graph" import { toast } from "sonner" import { BoxesIcon, ChevronRightIcon, LoaderCircleIcon, SettingsIcon, Trash2Icon, WrenchIcon } from "lucide-react" import { useForm } from "react-hook-form" import { zodResolver } from "@hookform/resolvers/zod" import { z } from "zod" -import { DropdownMenu, DropdownMenuContent, DropdownMenuItem, DropdownMenuPortal, DropdownMenuSub, DropdownMenuSubContent, DropdownMenuSubTrigger, DropdownMenuTrigger } from "../ui/dropdown" +import { DropdownMenu, DropdownMenuCheckboxItem, DropdownMenuContent, DropdownMenuItem, DropdownMenuPortal, DropdownMenuSub, DropdownMenuSubContent, DropdownMenuSubTrigger, DropdownMenuTrigger } from "../ui/dropdown" import { isLLM } from "@/common" import { compatibleTools, ModuleRegistry, ModuleTypeLabels } from "@/common/moduleConfig" @@ -122,25 +122,46 @@ export function RemoteModuleCfgSheet() { const selectedGraphCopy: Graph = JSON.parse(JSON.stringify(selectedGraph)); const nodes = selectedGraphCopy.nodes; let needUpdate = false; + let enableRTCVideoSubscribe = false; + // Retrieve the agora_rtc node + const agoraRtcNode = GraphEditor.findNode(selectedGraphCopy, "agora_rtc"); + if (!agoraRtcNode) { + toast.error("agora_rtc node not found in the graph"); + return; + } + // Update graph nodes with selected modules Object.entries(data).forEach(([key, value]) => { const node = nodes.find((n) => n.name === key); if (node && value && node.addon !== value) { node.addon = value; node.property = addonModules.find((module) => module.name === value)?.defaultProperty; + + if(node.addon === "gemini_v2v_python") { + GraphEditor.addOrUpdateConnection( + selectedGraphCopy, + `${agoraRtcNode.extensionGroup}.${agoraRtcNode.name}`, + `${node.extensionGroup}.${node.name}`, + ProtocolLabel.VIDEO_FRAME, + "video_frame" + ); + enableRTCVideoSubscribe = true; + } else { + GraphEditor.removeConnection( + selectedGraphCopy, + `${agoraRtcNode.extensionGroup}.${agoraRtcNode.name}`, + `${node.extensionGroup}.${node.name}`, + ProtocolLabel.VIDEO_FRAME, + "video_frame" + ); + } + needUpdate = true; } }); - // Retrieve the agora_rtc node - const agoraRtcNode = GraphEditor.findNode(selectedGraphCopy, "agora_rtc"); - if (!agoraRtcNode) { - toast.error("agora_rtc node not found in the graph"); - return; - } - // Identify removed tools and process them const currentToolsInGraph = nodes .filter((node) => installedAndRegisteredToolModules.map((module) => module.name).includes(node.addon)) @@ -154,8 +175,9 @@ export function RemoteModuleCfgSheet() { // Process tool modules if (tools.length > 0) { - GraphEditor.enableRTCVideoSubscribe(selectedGraphCopy, tools.some((tool) => tool.includes("vision"))); - + if(!enableRTCVideoSubscribe) { + enableRTCVideoSubscribe = tools.some((tool) => tool.includes("vision")) + } tools.forEach((tool) => { if (!currentToolsInGraph.includes(tool)) { const toolModule = addonModules.find((module) => module.name === tool); @@ -177,6 +199,8 @@ export function RemoteModuleCfgSheet() { needUpdate = true; } + GraphEditor.enableRTCVideoSubscribe(selectedGraphCopy, enableRTCVideoSubscribe); + // Perform the update if changes are detected if (needUpdate) { try { @@ -184,8 +208,8 @@ export function RemoteModuleCfgSheet() { toast.success("Modules updated", { description: `Graph: ${selectedGraphCopy.id}`, }); - } catch (e) { - toast.error("Failed to update modules"); + } catch (e:any) { + toast.error(`Failed to update modules: ${e}`); } } }} diff --git a/playground/src/components/ui/dropdown.tsx b/playground/src/components/ui/dropdown.tsx index b3586b08..b07bf9a6 100644 --- a/playground/src/components/ui/dropdown.tsx +++ b/playground/src/components/ui/dropdown.tsx @@ -4,7 +4,7 @@ import * as React from "react"; import * as DropdownMenuPrimitive from "@radix-ui/react-dropdown-menu"; import { cn } from "@/lib/utils"; -import { ChevronRightIcon } from "lucide-react"; +import { CheckIcon, ChevronRightIcon } from "lucide-react"; const DropdownMenu = DropdownMenuPrimitive.Root; @@ -99,6 +99,29 @@ DropdownMenuSubContent.displayName = const DropdownMenuPortal = DropdownMenuPrimitive.Portal; + +const DropdownMenuCheckboxItem = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, checked, children, ...props }, ref) => ( + + {children} + + {checked && } + + +)); +DropdownMenuCheckboxItem.displayName = + DropdownMenuPrimitive.CheckboxItem.displayName; + export { DropdownMenu, DropdownMenuTrigger, @@ -108,4 +131,5 @@ export { DropdownMenuSubTrigger, DropdownMenuSubContent, DropdownMenuPortal, + DropdownMenuCheckboxItem, }; diff --git a/playground/src/store/reducers/global.ts b/playground/src/store/reducers/global.ts index f13dce2b..9ab2f66a 100644 --- a/playground/src/store/reducers/global.ts +++ b/playground/src/store/reducers/global.ts @@ -199,12 +199,19 @@ export const updateGraph = createAsyncThunk( "global/updateGraph", async ( { graphId, updates }: { graphId: string; updates: Partial }, - { dispatch } + { dispatch, rejectWithValue } ) => { - await apiUpdateGraph(graphId, updates); - await apiSaveProperty(); - const updatedGraph = await apiFetchGraphDetails(graphId); - dispatch(setGraph(updatedGraph)); + try { + await apiUpdateGraph(graphId, updates); + await apiSaveProperty(); + const updatedGraph = await apiFetchGraphDetails(graphId); + dispatch(setGraph(updatedGraph)); + return updatedGraph; // Optionally return the updated graph + } catch (error: any) { + // Handle error gracefully + console.error("Error updating graph:", error); + return rejectWithValue(error.response?.data || error.message); + } } ); From 881d47f25e18f75cabed9cc30f106439f69771c8 Mon Sep 17 00:00:00 2001 From: zhangqianze Date: Fri, 13 Dec 2024 03:19:06 +0800 Subject: [PATCH 3/3] feat: upgrade playground image --- agents/ten_packages/extension/gemini_v2v_python/extension.py | 2 +- docker-compose.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/agents/ten_packages/extension/gemini_v2v_python/extension.py b/agents/ten_packages/extension/gemini_v2v_python/extension.py index fe802c97..9c007a1c 100644 --- a/agents/ten_packages/extension/gemini_v2v_python/extension.py +++ b/agents/ten_packages/extension/gemini_v2v_python/extension.py @@ -64,7 +64,7 @@ def rgb2base64jpeg(rgb_data, width, height): # Save the image to a BytesIO object in JPEG format buffered = BytesIO() pil_image.save(buffered, format="JPEG") - pil_image.save("test.jpg", format="JPEG") + # pil_image.save("test.jpg", format="JPEG") # Get the byte data of the JPEG image jpeg_image_data = buffered.getvalue() diff --git a/docker-compose.yml b/docker-compose.yml index 64472fe2..b2ef8968 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -18,7 +18,7 @@ services: networks: - ten_agent_network ten_agent_playground: - image: ghcr.io/ten-framework/ten_agent_playground:0.6.1-39-gcda3b08 + image: ghcr.io/ten-framework/ten_agent_playground:0.6.2-9-gdaec880 container_name: ten_agent_playground restart: always ports: