Skip to content

Commit

Permalink
feature complete gemini audio, transcription, and phrase endpointing …
Browse files Browse the repository at this point in the history
…demo
  • Loading branch information
kwindla committed Dec 22, 2024
1 parent f5f0de0 commit ab5df1a
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 77 deletions.
181 changes: 110 additions & 71 deletions examples/foundational/22d-natural-conversation-gemini-audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,14 @@
logger.remove(0)
logger.add(sys.stderr, level="DEBUG")

# TRANSCRIBER_MODEL = "gemini-1.5-flash-latest"
# CLASSIFIER_MODEL = "gemini-1.5-flash-latest"
# CONVERSATION_MODEL = "gemini-1.5-flash-latest"

TRANSCRIBER_MODEL = "gemini-2.0-flash-exp"
CLASSIFIER_MODEL = "gemini-2.0-flash-exp"
CONVERSATION_MODEL = "gemini-2.0-flash-exp"

transcriber_system_instruction = """You are an audio transcriber. You are receiving audio from a user. Your job is to
transcribe the input audio to text exactly as it was said by the user.
Expand Down Expand Up @@ -347,6 +355,11 @@


class AudioAccumulator(FrameProcessor):
"""Buffers user audio until the user stops speaking.
Always pushes a fresh context with a single audio message.
"""

def __init__(self, **kwargs):
super().__init__(**kwargs)
self._audio_frames = []
Expand Down Expand Up @@ -376,14 +389,6 @@ async def process_frame(self, frame: Frame, direction: FrameDirection):
self._user_speaking_utterance_state = True

elif isinstance(frame, UserStoppedSpeakingFrame):
if self._audio_frames[-1]:
fr = self._audio_frames[-1]
frame_duration = len(fr.audio) / 2 * fr.num_channels / fr.sample_rate

logger.debug(
f"!!! Frame duration: ({len(fr.audio)}) ({fr.num_channels}) ({fr.sample_rate}) {frame_duration}"
)

data = b"".join(frame.audio for frame in self._audio_frames)
logger.debug(
f"Processing audio buffer seconds: ({len(self._audio_frames)}) ({len(data)}) {len(data) / 2 / 16000}"
Expand Down Expand Up @@ -415,6 +420,12 @@ async def process_frame(self, frame: Frame, direction: FrameDirection):


class CompletenessCheck(FrameProcessor):
"""Checks the result of the classifier LLM to determine if the user has finished speaking.
Triggers the notifier if the user has finished speaking. Also triggers the notifier if an
idle timeout is reached.
"""

wait_time = 5.0

def __init__(self, notifier: BaseNotifier, audio_accumulator: AudioAccumulator, **kwargs):
Expand All @@ -427,12 +438,13 @@ def __init__(self, notifier: BaseNotifier, audio_accumulator: AudioAccumulator,
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)

if isinstance(frame, TextFrame) and frame.text.startswith("YES"):
if isinstance(frame, UserStartedSpeakingFrame):
if self._idle_task:
self._idle_task.cancel()
elif isinstance(frame, TextFrame) and frame.text.startswith("YES"):
logger.debug("Completeness check YES")
if self._idle_task:
logger.debug(f"CompletenessCheck idle wait CANCEL")
self._idle_task.cancel()
self._idle_task = None
await self.push_frame(UserStoppedSpeakingFrame())
await self._audio_accumulator.reset()
await self._notifier.notify()
Expand All @@ -443,24 +455,95 @@ async def process_frame(self, frame: Frame, direction: FrameDirection):
if self._wakeup_time:
self._wakeup_time = time.time() + self.wait_time
else:
logger.debug("CompletenessCheck idle wait START")
# logger.debug("!!! CompletenessCheck idle wait START")
self._wakeup_time = time.time() + self.wait_time
self._idle_task = self.get_event_loop().create_task(self._idle_task_handler())

async def _idle_task_handler(self):
try:
while time.time() < self._wakeup_time:
await asyncio.sleep(0.01)
logger.debug(f"CompletenessCheck idle wait OVER")
# logger.debug(f"!!! CompletenessCheck idle wait OVER")
await self._audio_accumulator.reset()
await self._notifier.notify()
except asyncio.CancelledError:
# logger.debug(f"!!! CompletenessCheck idle wait CANCEL")
pass
except Exception as e:
logger.error(f"CompletenessCheck idle wait error: {e}")
raise e
finally:
# logger.debug(f"!!! CompletenessCheck idle wait FINALLY")
self._wakeup_time = 0
self._idle_task = None


class UserAggregatorBuffer(LLMResponseAggregator):
"""Buffers the output of the transcription LLM. Used by the bot output gate."""

def __init__(self, **kwargs):
super().__init__(
messages=None,
role=None,
start_frame=LLMFullResponseStartFrame,
end_frame=LLMFullResponseEndFrame,
accumulator_frame=TextFrame,
handle_interruptions=True,
expect_stripped_words=False,
)
self._transcription = ""

async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
# parent method pushes frames
if isinstance(frame, UserStartedSpeakingFrame):
self._transcription = ""

async def _push_aggregation(self):
if self._aggregation:
self._transcription = self._aggregation
self._aggregation = ""

logger.debug(f"[Transcription] {self._transcription}")

async def wait_for_transcription(self):
while not self._transcription:
await asyncio.sleep(0.01)
tx = self._transcription
self._transcription = ""
return tx


class ConversationAudioContextAssembler(FrameProcessor):
"""Takes the single-message context generated by the AudioAccumulator and adds it to the conversation LLM's context."""

def __init__(self, context: OpenAILLMContext, **kwargs):
super().__init__(**kwargs)
self._context = context

async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)

# We must not block system frames.
if isinstance(frame, SystemFrame):
await self.push_frame(frame, direction)
return

if isinstance(frame, OpenAILLMContextFrame):
GoogleLLMContext.upgrade_to_google(self._context)
last_message = frame.context.messages[-1]
self._context._messages.append(last_message)
await self.push_frame(OpenAILLMContextFrame(context=self._context))


class OutputGate(FrameProcessor):
"""Buffers output frames until the notifier is triggered.
When the notifier fires, waits until a transcription is ready, then:
1. Replaces the last user audio message with the transcription.
2. Flushes the frames buffer.
"""

def __init__(
self,
notifier: BaseNotifier,
Expand Down Expand Up @@ -501,6 +584,13 @@ async def process_frame(self, frame: Frame, direction: FrameDirection):
await self.push_frame(frame, direction)
return

if isinstance(frame, LLMFullResponseStartFrame):
# Remove the audio message from the context. We will never need it again.
# If the completeness check fails, a new audio message will be appended to the context.
# If the completeness check succeeds, our notifier will fire and we will append the
# transcription to the context.
self._context._messages.pop()

if self._gate_open:
await self.push_frame(frame, direction)
return
Expand All @@ -517,16 +607,13 @@ async def _stop(self):

async def _gate_task_handler(self):
while True:
# logger.debug("!!! Waiting for notifier")
try:
await self._notifier.wait()

# logger.debug("!!! Notified")
transcription = await self._transcription_buffer.wait_for_transcription()

last_message = self._context.messages[-1]
if last_message.role == "user":
last_message.parts = [glm.Part(text=transcription)]
transcription = await self._transcription_buffer.wait_for_transcription() or "-"
self._context._messages.append(
glm.Content(role="user", parts=[glm.Part(text=transcription)])
)

self.open_gate()
for frame, direction in self._frames_buffer:
Expand All @@ -540,54 +627,6 @@ async def _gate_task_handler(self):
break


class ConversationAudioContextAssembler(FrameProcessor):
def __init__(self, context: OpenAILLMContext, **kwargs):
super().__init__(**kwargs)
self._context = context

async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)

# We must not block system frames.
if isinstance(frame, SystemFrame):
await self.push_frame(frame, direction)
return

if isinstance(frame, OpenAILLMContextFrame):
GoogleLLMContext.upgrade_to_google(self._context)
last_message = frame.context.messages[-1]
self._context._messages.append(last_message)
await self.push_frame(OpenAILLMContextFrame(context=self._context))


class UserAggregatorBuffer(LLMResponseAggregator):
def __init__(self, **kwargs):
super().__init__(
messages=None,
role=None,
start_frame=LLMFullResponseStartFrame,
end_frame=LLMFullResponseEndFrame,
accumulator_frame=TextFrame,
handle_interruptions=True,
expect_stripped_words=False,
)
self._transcription = ""

async def _push_aggregation(self):
if self._aggregation:
self._transcription = self._aggregation
self._aggregation = ""

logger.debug(f"[Transcription] {self._transcription}")

async def wait_for_transcription(self):
while not self._transcription:
await asyncio.sleep(0.01)
tx = self._transcription
self._transcription = ""
return tx


async def main():
async with aiohttp.ClientSession() as session:
(room_url, _) = await configure(session)
Expand All @@ -613,7 +652,7 @@ async def main():
# This is the LLM that will transcribe user speech.
tx_llm = GoogleLLMService(
name="Transcriber",
model="gemini-2.0-flash-exp",
model=TRANSCRIBER_MODEL,
api_key=os.getenv("GOOGLE_API_KEY"),
temperature=0.0,
system_instruction=transcriber_system_instruction,
Expand All @@ -622,7 +661,7 @@ async def main():
# This is the LLM that will classify user speech as complete or incomplete.
classifier_llm = GoogleLLMService(
name="Classifier",
model="gemini-2.0-flash-exp",
model=CLASSIFIER_MODEL,
api_key=os.getenv("GOOGLE_API_KEY"),
temperature=0.0,
system_instruction=classifier_system_instruction,
Expand All @@ -631,7 +670,7 @@ async def main():
# This is the regular LLM that responds conversationally.
conversation_llm = GoogleLLMService(
name="Conversation",
model="gemini-2.0-flash-exp",
model=CONVERSATION_MODEL,
api_key=os.getenv("GOOGLE_API_KEY"),
system_instruction=conversation_system_instruction,
)
Expand Down
13 changes: 7 additions & 6 deletions src/pipecat/services/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,7 @@ async def _process_context(self, context: OpenAILLMContext):

messages = context.messages
if context.system_message and self._system_instruction != context.system_message:
# logger.debug(f"System instruction changed: {context.system_message}")
logger.debug(f"System instruction changed: {context.system_message}")
self._system_instruction = context.system_message
self._create_client()

Expand Down Expand Up @@ -673,15 +673,16 @@ async def _process_context(self, context: OpenAILLMContext):
await self.stop_ttfb_metrics()

if response.usage_metadata:
# Use only the prompt token count from the response object
prompt_tokens = response.usage_metadata.prompt_token_count
completion_tokens = response.usage_metadata.candidates_token_count
total_tokens = response.usage_metadata.total_token_count
total_tokens = prompt_tokens

async for chunk in response:
if chunk.usage_metadata:
prompt_tokens += response.usage_metadata.prompt_token_count
completion_tokens += response.usage_metadata.candidates_token_count
total_tokens += response.usage_metadata.total_token_count
# Use only the completion_tokens from the chunks. Prompt tokens are already counted and
# are repeated here.
completion_tokens += chunk.usage_metadata.candidates_token_count
total_tokens += chunk.usage_metadata.candidates_token_count
try:
for c in chunk.parts:
if c.text:
Expand Down

0 comments on commit ab5df1a

Please sign in to comment.