Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce support for extracting and processing grounding metadata from GoogleLLMService. #1030

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 130 additions & 0 deletions examples/foundational/31-gemini-grounding-metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
#
# Copyright (c) 2024, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#

import asyncio
import os
import sys
from pathlib import Path

import aiohttp
from dotenv import load_dotenv
from loguru import logger

from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.frames.frames import Frame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.services.cartesia import CartesiaTTSService
from pipecat.services.deepgram import DeepgramSTTService
from pipecat.services.google import GoogleLLMService, LLMSearchResponseFrame
from pipecat.transports.services.daily import DailyParams, DailyTransport

sys.path.append(str(Path(__file__).parent.parent))
from runner import configure

load_dotenv(override=True)

logger.remove(0)
logger.add(sys.stderr, level="DEBUG")

# Function handlers for the LLM
search_tool = {"google_search_retrieval": {}}
tools = [search_tool]

system_instruction = """
You are an expert at providing the most recent news from any place. Your responses will be converted to audio, so avoid using special characters or overly complex formatting.

Always use the google search API to retrieve the latest news. You must also use it to check which day is today.

You can:
- Use the Google search API to check the current date.
- Provide the most recent and relevant news from any place by using the google search API.
- Answer any questions the user may have, ensuring your responses are accurate and concise.

Start each interaction by asking the user about which place they would like to know the information.
"""


class LLMSearchLoggerProcessor(FrameProcessor):
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)

if isinstance(frame, LLMSearchResponseFrame):
print(f"LLMSearchLoggerProcessor: {frame}")

await self.push_frame(frame)


async def main():
async with aiohttp.ClientSession() as session:
(room_url, token) = await configure(session)

transport = DailyTransport(
room_url,
token,
"Latest news!",
DailyParams(
audio_out_enabled=True,
vad_enabled=True,
vad_analyzer=SileroVADAnalyzer(),
vad_audio_passthrough=True,
),
)

stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))

tts = CartesiaTTSService(
api_key=os.getenv("CARTESIA_API_KEY"),
voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady
)

# Initialize the Gemini Multimodal Live model
llm = GoogleLLMService(
api_key=os.getenv("GOOGLE_API_KEY"),
system_instruction=system_instruction,
tools=tools,
)

context = OpenAILLMContext(
[
{
"role": "user",
"content": "Start by greeting the user warmly, introducing yourself, and mentioning the current day. Be friendly and engaging to set a positive tone for the interaction.",
}
],
)
context_aggregator = llm.create_context_aggregator(context)

llm_search_logger = LLMSearchLoggerProcessor()

pipeline = Pipeline(
[
transport.input(),
stt,
context_aggregator.user(),
llm,
llm_search_logger,
tts,
transport.output(),
context_aggregator.assistant(),
]
)

task = PipelineTask(pipeline, PipelineParams(allow_interruptions=True))

@transport.event_handler("on_first_participant_joined")
async def on_first_participant_joined(transport, participant):
await task.queue_frames([context_aggregator.user().get_context_frame()])

runner = PipelineRunner()
await runner.run(task)


if __name__ == "__main__":
asyncio.run(main())
2 changes: 2 additions & 0 deletions src/pipecat/services/google/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .frames import LLMSearchResponseFrame
from .google import *
33 changes: 33 additions & 0 deletions src/pipecat/services/google/frames.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#
# Copyright (c) 2024–2025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#

from dataclasses import dataclass, field
from typing import List, Optional

from pipecat.frames.frames import DataFrame


@dataclass
class LLMSearchResult:
text: str
confidence: Optional[float] = None


@dataclass
class LLMSearchOrigin:
site_uri: Optional[str] = None
site_title: Optional[str] = None
results: List[LLMSearchResult] = field(default_factory=list)


@dataclass
class LLMSearchResponseFrame(DataFrame):
search_result: Optional[str] = None
rendered_content: Optional[str] = None
origins: List[LLMSearchOrigin] = field(default_factory=list)

def __str__(self):
return f"LLMSearchResponseFrame(search_result={self.search_result}, origins={self.origins})"
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.ai_services import LLMService, TTSService
from pipecat.services.google.frames import LLMSearchResponseFrame
from pipecat.services.openai import (
OpenAIAssistantContextAggregator,
OpenAIUserContextAggregator,
Expand Down Expand Up @@ -639,6 +640,9 @@ async def _process_context(self, context: OpenAILLMContext):
completion_tokens = 0
total_tokens = 0

grounding_metadata = None
search_result = ""

try:
logger.debug(
# f"Generating chat: {self._system_instruction} | {context.get_messages_for_logging()}"
Expand Down Expand Up @@ -698,6 +702,7 @@ async def _process_context(self, context: OpenAILLMContext):
try:
for c in chunk.parts:
if c.text:
search_result += c.text
await self.push_frame(LLMTextFrame(c.text))
elif c.function_call:
logger.debug(f"!!! Function call: {c.function_call}")
Expand All @@ -708,6 +713,63 @@ async def _process_context(self, context: OpenAILLMContext):
function_name=c.function_call.name,
arguments=args,
)
# Handle grounding metadata
# It seems only the last chunk that we receive may contain this information
# If the response doesn't include groundingMetadata, this means the response wasn't grounded.
if chunk.candidates:
for candidate in chunk.candidates:
# logger.debug(f"candidate received: {candidate}")
# Extract grounding metadata
grounding_metadata = (
{
"rendered_content": getattr(
getattr(candidate, "grounding_metadata", None),
"search_entry_point",
None,
).rendered_content
if hasattr(
getattr(candidate, "grounding_metadata", None),
"search_entry_point",
)
else None,
"origins": [
{
"site_uri": getattr(grounding_chunk.web, "uri", None),
"site_title": getattr(
grounding_chunk.web, "title", None
),
"results": [
{
"text": getattr(
grounding_support.segment, "text", ""
),
"confidence": getattr(
grounding_support, "confidence_scores", None
),
}
for grounding_support in getattr(
getattr(candidate, "grounding_metadata", None),
"grounding_supports",
[],
)
if index
in getattr(
grounding_support, "grounding_chunk_indices", []
)
],
}
for index, grounding_chunk in enumerate(
getattr(
getattr(candidate, "grounding_metadata", None),
"grounding_chunks",
[],
)
)
],
}
if getattr(candidate, "grounding_metadata", None)
else None
)
except Exception as e:
# Google LLMs seem to flag safety issues a lot!
if chunk.candidates[0].finish_reason == 3:
Expand All @@ -720,6 +782,14 @@ async def _process_context(self, context: OpenAILLMContext):
except Exception as e:
logger.exception(f"{self} exception: {e}")
finally:
if grounding_metadata is not None and isinstance(grounding_metadata, dict):
llm_search_frame = LLMSearchResponseFrame(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One nice improvement would be to create an RTVI message so we can share this response with the RTVI clients. What do you think? I can create a follow-up PR to implement the message.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe something like: RTVIBotLLMSearchResponse.

What do you think ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have included it in this PR: #1051

search_result=search_result,
origins=grounding_metadata["origins"],
rendered_content=grounding_metadata["rendered_content"],
)
await self.push_frame(llm_search_frame)

await self.start_llm_usage_metrics(
LLMTokenUsage(
prompt_tokens=prompt_tokens,
Expand Down
Loading