diff --git a/.changeset/real-squids-warn.md b/.changeset/real-squids-warn.md new file mode 100644 index 000000000..43c5d096d --- /dev/null +++ b/.changeset/real-squids-warn.md @@ -0,0 +1,5 @@ +--- +"livekit-plugins-openai": patch +--- + +add session_updated event for RealtimeSession diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/__init__.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/__init__.py index ac9b866d6..471deef37 100644 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/__init__.py +++ b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/__init__.py @@ -11,6 +11,7 @@ RealtimeOutput, RealtimeResponse, RealtimeSession, + RealtimeSessionOptions, RealtimeToolCall, ServerVadOptions, ) @@ -25,6 +26,7 @@ "RealtimeSession", "RealtimeModel", "RealtimeError", + "RealtimeSessionOptions", "ServerVadOptions", "InputTranscriptionOptions", "ConversationItemCreated", diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/realtime_model.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/realtime_model.py index 83b2cbfa6..46794c85a 100644 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/realtime_model.py +++ b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/realtime_model.py @@ -21,6 +21,7 @@ EventTypes = Literal[ "start_session", + "session_updated", "error", "input_speech_started", "input_speech_stopped", @@ -151,18 +152,22 @@ class RealtimeError: @dataclass -class _ModelOptions: +class RealtimeSessionOptions: model: api_proto.OpenAIModel | str modalities: list[api_proto.Modality] instructions: str voice: api_proto.Voice input_audio_format: api_proto.AudioFormat output_audio_format: api_proto.AudioFormat - input_audio_transcription: InputTranscriptionOptions - turn_detection: ServerVadOptions + input_audio_transcription: InputTranscriptionOptions | None + turn_detection: ServerVadOptions | None tool_choice: api_proto.ToolChoice temperature: float max_response_output_tokens: int | Literal["inf"] + + +@dataclass +class _ModelOptions(RealtimeSessionOptions): api_key: str | None base_url: str entra_token: str | None @@ -901,12 +906,19 @@ def session_update( function_data["type"] = "function" tools.append(function_data) - server_vad_opts: api_proto.ServerVad = { - "type": "server_vad", - "threshold": self._opts.turn_detection.threshold, - "prefix_padding_ms": self._opts.turn_detection.prefix_padding_ms, - "silence_duration_ms": self._opts.turn_detection.silence_duration_ms, - } + server_vad_opts: api_proto.ServerVad | None = None + if self._opts.turn_detection is not None: + server_vad_opts = { + "type": "server_vad", + "threshold": self._opts.turn_detection.threshold, + "prefix_padding_ms": self._opts.turn_detection.prefix_padding_ms, + "silence_duration_ms": self._opts.turn_detection.silence_duration_ms, + } + input_audio_transcription_opts: api_proto.InputAudioTranscription | None = None + if self._opts.input_audio_transcription is not None: + input_audio_transcription_opts = { + "model": self._opts.input_audio_transcription.model, + } session_data: api_proto.ClientEvent.SessionUpdateData = { "modalities": self._opts.modalities, @@ -914,9 +926,7 @@ def session_update( "voice": self._opts.voice, "input_audio_format": self._opts.input_audio_format, "output_audio_format": self._opts.output_audio_format, - "input_audio_transcription": { - "model": self._opts.input_audio_transcription.model, - }, + "input_audio_transcription": input_audio_transcription_opts, "turn_detection": server_vad_opts, "tools": tools, "tool_choice": self._opts.tool_choice, @@ -1103,6 +1113,8 @@ async def _recv_task(): if event == "session.created": self._handle_session_created(data) + if event == "session.updated": + self._handle_session_updated(data) elif event == "error": self._handle_error(data) elif event == "input_audio_buffer.speech_started": @@ -1171,6 +1183,42 @@ def _handle_session_created( ): self._session_id = session_created["session"]["id"] + def _handle_session_updated( + self, session_updated: api_proto.ServerEvent.SessionUpdated + ): + session = session_updated["session"] + if session["turn_detection"] is None: + turn_detection = None + else: + turn_detection = ServerVadOptions( + threshold=session["turn_detection"]["threshold"], + prefix_padding_ms=session["turn_detection"]["prefix_padding_ms"], + silence_duration_ms=session["turn_detection"]["silence_duration_ms"], + ) + if session["input_audio_transcription"] is None: + input_audio_transcription = None + else: + input_audio_transcription = InputTranscriptionOptions( + model=session["input_audio_transcription"]["model"], + ) + + self.emit( + "session_updated", + RealtimeSessionOptions( + model=session["model"], + modalities=session["modalities"], + instructions=session["instructions"], + voice=session["voice"], + input_audio_format=session["input_audio_format"], + output_audio_format=session["output_audio_format"], + input_audio_transcription=input_audio_transcription, + turn_detection=turn_detection, + tool_choice=session["tool_choice"], + temperature=session["temperature"], + max_response_output_tokens=session["max_response_output_tokens"], + ), + ) + def _handle_error(self, error: api_proto.ServerEvent.Error): logger.error( "OpenAI S2S error %s",