From daec880d8d10ae2baae87d32caa2f9d173ba7161 Mon Sep 17 00:00:00 2001 From: zhangqianze Date: Fri, 13 Dec 2024 02:39:18 +0800 Subject: [PATCH] feat: update readme --- .../extension/gemini_v2v_python/README.md | 72 ++++++------ .../extension/gemini_v2v_python/extension.py | 108 +++++++++++++++++- .../extension/gemini_v2v_python/manifest.json | 6 + .../gemini_v2v_python/requirements.txt | 2 +- playground/src/common/graph.ts | 3 + playground/src/common/hooks.ts | 2 +- .../components/Chat/ChatCfgModuleSelect.tsx | 50 +++++--- playground/src/components/ui/dropdown.tsx | 26 ++++- playground/src/store/reducers/global.ts | 17 ++- 9 files changed, 224 insertions(+), 62 deletions(-) diff --git a/agents/ten_packages/extension/gemini_v2v_python/README.md b/agents/ten_packages/extension/gemini_v2v_python/README.md index 3cd294f3..e43f7041 100644 --- a/agents/ten_packages/extension/gemini_v2v_python/README.md +++ b/agents/ten_packages/extension/gemini_v2v_python/README.md @@ -1,65 +1,63 @@ -# openai_v2v_python +# gemini_v2v_python -An extension for integrating OpenAI's Next Generation of **Multimodal** AI into your application, providing configurable AI-driven features such as conversational agents, task automation, and tool integration. +An extension for integrating Gemini's Next Generation of **Multimodal** AI into your application, providing configurable AI-driven features such as conversational agents, task automation, and tool integration. ## Features - - -- OpenAI **Multimodal** Integration: Leverage GPT **Multimodal** models for voice to voice as well as text processing. +- Gemini **Multimodal** Integration: Leverage Gemini **Multimodal** models for voice-to-voice as well as text processing. - Configurable: Easily customize API keys, model settings, prompts, temperature, etc. - Async Queue Processing: Supports real-time message processing with task cancellation and prioritization. - ## API -Refer to `api` definition in [manifest.json] and default values in [property.json](property.json). - - +Refer to the `api` definition in [manifest.json] and default values in [property.json](property.json). | **Property** | **Type** | **Description** | |----------------------------|------------|-------------------------------------------| -| `api_key` | `string` | API key for authenticating with OpenAI | -| `temperature` | `float64` | Sampling temperature, higher values mean more randomness | -| `model` | `string` | Model identifier (e.g., GPT-3.5, GPT-4) | -| `max_tokens` | `int64` | Maximum number of tokens to generate | -| `system_message` | `string` | Default system message to send to the model | -| `voice` | `string` | Voice that OpenAI model speeches, such as `alloy`, `echo`, `shimmer`, etc | -| `server_vad` | `bool` | Flag to enable or disable server vad of OpenAI | -| `language` | `string` | Language that OpenAO model reponds, such as `en-US`, `zh-CN`, etc | -| `dump` | `bool` | Flag to enable or disable audio dump for debugging purpose | - -### Data Out: +| `api_key` | `string` | API key for authenticating with Gemini | +| `temperature` | `float32` | Sampling temperature, higher values mean more randomness | +| `model` | `string` | Model identifier (e.g., GPT-4, Gemini-1) | +| `max_tokens` | `int32` | Maximum number of tokens to generate | +| `system_message` | `string` | Default system message to send to the model | +| `voice` | `string` | Voice that Gemini model uses, such as `alloy`, `echo`, `shimmer`, etc. | +| `server_vad` | `bool` | Flag to enable or disable server VAD for Gemini | +| `language` | `string` | Language that Gemini model responds in, such as `en-US`, `zh-CN`, etc. | +| `dump` | `bool` | Flag to enable or disable audio dump for debugging purposes | +| `base_uri` | `string` | Base URI for connecting to the Gemini service | +| `audio_out` | `bool` | Flag to enable or disable audio output | +| `input_transcript` | `bool` | Flag to enable input transcript processing | +| `sample_rate` | `int32` | Sample rate for audio processing | +| `stream_id` | `int32` | Stream ID for identifying audio streams | +| `greeting` | `string` | Greeting message for initial interaction | + +### Data Out + | **Name** | **Property** | **Type** | **Description** | |----------------|--------------|------------|-------------------------------| | `text_data` | `text` | `string` | Outgoing text data | +| `append` | `text` | `string` | Additional text appended to the output | + +### Command Out -### Command Out: | **Name** | **Description** | |----------------|---------------------------------------------| | `flush` | Response after flushing the current state | +| `tool_call` | Invokes a tool with specific arguments | + +### Audio Frame In -### Audio Frame In: | **Name** | **Description** | |------------------|-------------------------------------------| | `pcm_frame` | Audio frame input for voice processing | -### Audio Frame Out: +### Video Frame In + | **Name** | **Description** | |------------------|-------------------------------------------| -| `pcm_frame` | Audio frame output after voice processing | - - -### Azure Support +| `video_frame` | Video frame input for processing | -This extension also support Azure OpenAI Service, the propoerty settings are as follow: +### Audio Frame Out -``` json -{ - "base_uri": "wss://xxx.openai.azure.com", - "path": "/openai/realtime?api-version=xxx&deployment=xxx", - "api_key": "xxx", - "model": "gpt-4o-realtime-preview", - "vendor": "azure" -} -``` \ No newline at end of file +| **Name** | **Description** | +|------------------|-------------------------------------------| +| `pcm_frame` | Audio frame output after voice processing | diff --git a/agents/ten_packages/extension/gemini_v2v_python/extension.py b/agents/ten_packages/extension/gemini_v2v_python/extension.py index af6e846c..fe802c97 100644 --- a/agents/ten_packages/extension/gemini_v2v_python/extension.py +++ b/agents/ten_packages/extension/gemini_v2v_python/extension.py @@ -34,6 +34,9 @@ from ten_ai_base.types import LLMToolMetadata, LLMToolResult, LLMChatCompletionContentPartParam, TTSPcmOptions from google.genai.types import LiveServerMessage, LiveClientRealtimeInput, Blob, LiveConnectConfig, LiveConnectConfigDict, GenerationConfig, SpeechConfig, VoiceConfig, PrebuiltVoiceConfig, Content, Part, Tool, FunctionDeclaration, Schema, LiveClientToolResponse, FunctionCall, FunctionResponse from google.genai.live import AsyncSession +from PIL import Image +from io import BytesIO +from base64 import b64encode import urllib.parse import google.genai._api_client @@ -49,6 +52,62 @@ class Role(str, Enum): User = "user" Assistant = "assistant" + +def rgb2base64jpeg(rgb_data, width, height): + # Convert the RGB image to a PIL Image + pil_image = Image.frombytes("RGBA", (width, height), bytes(rgb_data)) + pil_image = pil_image.convert("RGB") + + # Resize the image while maintaining its aspect ratio + pil_image = resize_image_keep_aspect(pil_image, 512) + + # Save the image to a BytesIO object in JPEG format + buffered = BytesIO() + pil_image.save(buffered, format="JPEG") + pil_image.save("test.jpg", format="JPEG") + + # Get the byte data of the JPEG image + jpeg_image_data = buffered.getvalue() + + # Convert the JPEG byte data to a Base64 encoded string + base64_encoded_image = b64encode(jpeg_image_data).decode("utf-8") + + # Create the data URL + # mime_type = "image/jpeg" + return base64_encoded_image + +def resize_image_keep_aspect(image, max_size=512): + """ + Resize an image while maintaining its aspect ratio, ensuring the larger dimension is max_size. + If both dimensions are smaller than max_size, the image is not resized. + + :param image: A PIL Image object + :param max_size: The maximum size for the larger dimension (width or height) + :return: A PIL Image object (resized or original) + """ + # Get current width and height + width, height = image.size + + # If both dimensions are already smaller than max_size, return the original image + if width <= max_size and height <= max_size: + return image + + # Calculate the aspect ratio + aspect_ratio = width / height + + # Determine the new dimensions + if width > height: + new_width = max_size + new_height = int(max_size / aspect_ratio) + else: + new_height = max_size + new_width = int(max_size * aspect_ratio) + + # Resize the image with the new dimensions + resized_image = image.resize((new_width, new_height)) + + return resized_image + @dataclass class GeminiRealtimeConfig(BaseConfig): base_uri: str = "generativelanguage.googleapis.com" @@ -101,6 +160,9 @@ def __init__(self, name): self.client = None self.session:AsyncSession = None self.leftover_bytes = b'' + self.video_task = None + self.image_queue = asyncio.Queue() + self.video_buff: str = "" async def on_init(self, ten_env: AsyncTenEnv) -> None: await super().on_init(ten_env) @@ -133,7 +195,7 @@ async def on_start(self, ten_env: AsyncTenEnv) -> None: ) self.loop.create_task(self._loop(ten_env)) - + self.loop.create_task(self._on_video(ten_env)) # self.loop.create_task(self._loop()) except Exception as e: @@ -231,6 +293,7 @@ async def on_stop(self, ten_env: AsyncTenEnv) -> None: await self.session.close() async def on_audio_frame(self, ten_env: AsyncTenEnv, audio_frame: AudioFrame) -> None: + await super().on_audio_frame(ten_env, audio_frame) try: stream_id = audio_frame.get_property_int("stream_id") if self.channel_name == "": @@ -281,6 +344,37 @@ async def on_cmd(self, ten_env: AsyncTenEnv, cmd: Cmd) -> None: async def on_data(self, ten_env: AsyncTenEnv, data: Data) -> None: pass + async def on_video_frame(self, async_ten_env, video_frame): + await super().on_video_frame(async_ten_env, video_frame) + image_data = video_frame.get_buf() + image_width = video_frame.get_width() + image_height = video_frame.get_height() + await self.image_queue.put([image_data, image_width, image_height]) + + + async def _on_video(self, ten_env:AsyncTenEnv): + while True: + + # Process the first frame from the queue + [image_data, image_width, image_height] = await self.image_queue.get() + self.video_buff = rgb2base64jpeg(image_data, image_width, image_height) + media_chunks = [{ + "data": self.video_buff, + "mime_type": "image/jpeg", + }] + try: + if self.connected: + await self.session.send(media_chunks) + except Exception as e: + self.ten_env.log_error(f"Failed to send image {e}") + + # Skip remaining frames for the second + while not self.image_queue.empty(): + await self.image_queue.get() + + # Wait for 1 second before processing the next frame + await asyncio.sleep(1) + # Direction: IN async def _on_audio(self, buff: bytearray): self.buff += buff @@ -288,11 +382,16 @@ async def _on_audio(self, buff: bytearray): if self.connected and len(self.buff) >= self.audio_len_threshold: # await self.conn.send_audio_data(self.buff) try: - await self.session.send(LiveClientRealtimeInput(media_chunks=[Blob(data=self.buff, mime_type="audio/pcm")])) + media_chunks = [{ + "data": base64.b64encode(self.buff).decode(), + "mime_type": "audio/pcm", + }] + # await self.session.send(LiveClientRealtimeInput(media_chunks=media_chunks)) + await self.session.send(media_chunks) self.buff = b'' except Exception as e: - pass - # self.ten_env.log_error(f"Failed to send audio {e}") + # pass + self.ten_env.log_error(f"Failed to send audio {e}") def _get_session_config(self) -> LiveConnectConfigDict: def tool_dict(tool: LLMToolMetadata): @@ -553,3 +652,4 @@ async def _update_usage(self, usage: dict) -> None: "first_token_latency_99": np.percentile(self.first_token_times, 99) })) self.ten_env.send_data(data) + diff --git a/agents/ten_packages/extension/gemini_v2v_python/manifest.json b/agents/ten_packages/extension/gemini_v2v_python/manifest.json index a1bb9ea5..4c25224e 100644 --- a/agents/ten_packages/extension/gemini_v2v_python/manifest.json +++ b/agents/ten_packages/extension/gemini_v2v_python/manifest.json @@ -82,6 +82,12 @@ } } ], + "video_frame_in": [ + { + "name": "video_frame", + "property": {} + } + ], "data_out": [ { "name": "text_data", diff --git a/agents/ten_packages/extension/gemini_v2v_python/requirements.txt b/agents/ten_packages/extension/gemini_v2v_python/requirements.txt index 6f9ab33f..5370cc22 100644 --- a/agents/ten_packages/extension/gemini_v2v_python/requirements.txt +++ b/agents/ten_packages/extension/gemini_v2v_python/requirements.txt @@ -1,2 +1,2 @@ asyncio -google-genai \ No newline at end of file +google-genai==0.2.1 \ No newline at end of file diff --git a/playground/src/common/graph.ts b/playground/src/common/graph.ts index 7d1cc42a..3239aecf 100644 --- a/playground/src/common/graph.ts +++ b/playground/src/common/graph.ts @@ -420,6 +420,9 @@ class GraphEditor { // If no protocolLabel is provided, remove the entire connection graph.connections.splice(connectionIndex, 1) } + + // Clean up empty connections + GraphEditor.removeEmptyConnections(graph); } static findNode(graph: Graph, nodeName: string): Node | null { diff --git a/playground/src/common/hooks.ts b/playground/src/common/hooks.ts index 405eec32..dad89b89 100644 --- a/playground/src/common/hooks.ts +++ b/playground/src/common/hooks.ts @@ -153,7 +153,7 @@ const useGraphs = () => { } const update = async (graphId: string, updates: Partial) => { - await dispatch(updateGraph({ graphId, updates })) + await dispatch(updateGraph({ graphId, updates })).unwrap() } const getGraphNodeAddonByName = useCallback( diff --git a/playground/src/components/Chat/ChatCfgModuleSelect.tsx b/playground/src/components/Chat/ChatCfgModuleSelect.tsx index dcae871c..fa9842a8 100644 --- a/playground/src/components/Chat/ChatCfgModuleSelect.tsx +++ b/playground/src/components/Chat/ChatCfgModuleSelect.tsx @@ -31,13 +31,13 @@ import { import { Button } from "@/components/ui/button" import { cn } from "@/lib/utils" import { useAppSelector, useGraphs, } from "@/common/hooks" -import { AddonDef, Graph, Destination, GraphEditor, ProtocolLabel as GraphConnProtocol } from "@/common/graph" +import { AddonDef, Graph, Destination, GraphEditor, ProtocolLabel as GraphConnProtocol, ProtocolLabel } from "@/common/graph" import { toast } from "sonner" import { BoxesIcon, ChevronRightIcon, LoaderCircleIcon, SettingsIcon, Trash2Icon, WrenchIcon } from "lucide-react" import { useForm } from "react-hook-form" import { zodResolver } from "@hookform/resolvers/zod" import { z } from "zod" -import { DropdownMenu, DropdownMenuContent, DropdownMenuItem, DropdownMenuPortal, DropdownMenuSub, DropdownMenuSubContent, DropdownMenuSubTrigger, DropdownMenuTrigger } from "../ui/dropdown" +import { DropdownMenu, DropdownMenuCheckboxItem, DropdownMenuContent, DropdownMenuItem, DropdownMenuPortal, DropdownMenuSub, DropdownMenuSubContent, DropdownMenuSubTrigger, DropdownMenuTrigger } from "../ui/dropdown" import { isLLM } from "@/common" import { compatibleTools, ModuleRegistry, ModuleTypeLabels } from "@/common/moduleConfig" @@ -122,25 +122,46 @@ export function RemoteModuleCfgSheet() { const selectedGraphCopy: Graph = JSON.parse(JSON.stringify(selectedGraph)); const nodes = selectedGraphCopy.nodes; let needUpdate = false; + let enableRTCVideoSubscribe = false; + // Retrieve the agora_rtc node + const agoraRtcNode = GraphEditor.findNode(selectedGraphCopy, "agora_rtc"); + if (!agoraRtcNode) { + toast.error("agora_rtc node not found in the graph"); + return; + } + // Update graph nodes with selected modules Object.entries(data).forEach(([key, value]) => { const node = nodes.find((n) => n.name === key); if (node && value && node.addon !== value) { node.addon = value; node.property = addonModules.find((module) => module.name === value)?.defaultProperty; + + if(node.addon === "gemini_v2v_python") { + GraphEditor.addOrUpdateConnection( + selectedGraphCopy, + `${agoraRtcNode.extensionGroup}.${agoraRtcNode.name}`, + `${node.extensionGroup}.${node.name}`, + ProtocolLabel.VIDEO_FRAME, + "video_frame" + ); + enableRTCVideoSubscribe = true; + } else { + GraphEditor.removeConnection( + selectedGraphCopy, + `${agoraRtcNode.extensionGroup}.${agoraRtcNode.name}`, + `${node.extensionGroup}.${node.name}`, + ProtocolLabel.VIDEO_FRAME, + "video_frame" + ); + } + needUpdate = true; } }); - // Retrieve the agora_rtc node - const agoraRtcNode = GraphEditor.findNode(selectedGraphCopy, "agora_rtc"); - if (!agoraRtcNode) { - toast.error("agora_rtc node not found in the graph"); - return; - } - // Identify removed tools and process them const currentToolsInGraph = nodes .filter((node) => installedAndRegisteredToolModules.map((module) => module.name).includes(node.addon)) @@ -154,8 +175,9 @@ export function RemoteModuleCfgSheet() { // Process tool modules if (tools.length > 0) { - GraphEditor.enableRTCVideoSubscribe(selectedGraphCopy, tools.some((tool) => tool.includes("vision"))); - + if(!enableRTCVideoSubscribe) { + enableRTCVideoSubscribe = tools.some((tool) => tool.includes("vision")) + } tools.forEach((tool) => { if (!currentToolsInGraph.includes(tool)) { const toolModule = addonModules.find((module) => module.name === tool); @@ -177,6 +199,8 @@ export function RemoteModuleCfgSheet() { needUpdate = true; } + GraphEditor.enableRTCVideoSubscribe(selectedGraphCopy, enableRTCVideoSubscribe); + // Perform the update if changes are detected if (needUpdate) { try { @@ -184,8 +208,8 @@ export function RemoteModuleCfgSheet() { toast.success("Modules updated", { description: `Graph: ${selectedGraphCopy.id}`, }); - } catch (e) { - toast.error("Failed to update modules"); + } catch (e:any) { + toast.error(`Failed to update modules: ${e}`); } } }} diff --git a/playground/src/components/ui/dropdown.tsx b/playground/src/components/ui/dropdown.tsx index b3586b08..b07bf9a6 100644 --- a/playground/src/components/ui/dropdown.tsx +++ b/playground/src/components/ui/dropdown.tsx @@ -4,7 +4,7 @@ import * as React from "react"; import * as DropdownMenuPrimitive from "@radix-ui/react-dropdown-menu"; import { cn } from "@/lib/utils"; -import { ChevronRightIcon } from "lucide-react"; +import { CheckIcon, ChevronRightIcon } from "lucide-react"; const DropdownMenu = DropdownMenuPrimitive.Root; @@ -99,6 +99,29 @@ DropdownMenuSubContent.displayName = const DropdownMenuPortal = DropdownMenuPrimitive.Portal; + +const DropdownMenuCheckboxItem = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, checked, children, ...props }, ref) => ( + + {children} + + {checked && } + + +)); +DropdownMenuCheckboxItem.displayName = + DropdownMenuPrimitive.CheckboxItem.displayName; + export { DropdownMenu, DropdownMenuTrigger, @@ -108,4 +131,5 @@ export { DropdownMenuSubTrigger, DropdownMenuSubContent, DropdownMenuPortal, + DropdownMenuCheckboxItem, }; diff --git a/playground/src/store/reducers/global.ts b/playground/src/store/reducers/global.ts index f13dce2b..9ab2f66a 100644 --- a/playground/src/store/reducers/global.ts +++ b/playground/src/store/reducers/global.ts @@ -199,12 +199,19 @@ export const updateGraph = createAsyncThunk( "global/updateGraph", async ( { graphId, updates }: { graphId: string; updates: Partial }, - { dispatch } + { dispatch, rejectWithValue } ) => { - await apiUpdateGraph(graphId, updates); - await apiSaveProperty(); - const updatedGraph = await apiFetchGraphDetails(graphId); - dispatch(setGraph(updatedGraph)); + try { + await apiUpdateGraph(graphId, updates); + await apiSaveProperty(); + const updatedGraph = await apiFetchGraphDetails(graphId); + dispatch(setGraph(updatedGraph)); + return updatedGraph; // Optionally return the updated graph + } catch (error: any) { + // Handle error gracefully + console.error("Error updating graph:", error); + return rejectWithValue(error.response?.data || error.message); + } } );