Skip to content

Commit

Permalink
Uses a representative sample of related cases for the GenAI analysis (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
mvilanova authored Oct 3, 2024
1 parent e8c2a7f commit 7268e33
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 7 deletions.
18 changes: 11 additions & 7 deletions src/dispatch/plugins/dispatch_slack/case/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from slack_sdk.web.client import WebClient
from sqlalchemy.orm import Session

from dispatch.case.enums import CaseStatus
from dispatch.case.enums import CaseResolutionReason, CaseStatus
from dispatch.case.models import Case
from dispatch.config import DISPATCH_UI_URL
from dispatch.messaging.strings import CASE_STATUS_DESCRIPTIONS, CASE_VISIBILITY_DESCRIPTIONS
Expand Down Expand Up @@ -320,13 +320,17 @@ def create_genai_signal_analysis_message(
return signal_metadata_blocks

# Fetch related cases
related_cases = (
signal_service.get_cases_for_signal(
db_session=db_session, signal_id=first_instance_signal.id
related_cases = []
for resolution_reason in CaseResolutionReason:
related_cases.extend(
signal_service.get_cases_for_signal_by_resolution_reason(
db_session=db_session,
signal_id=first_instance_signal.id,
resolution_reason=resolution_reason,
)
.from_self() # NOTE: function deprecated in SQLAlchemy 1.4 and removed in 2.0
.filter(Case.id != case.id)
)
.from_self() # NOTE: function deprecated in SQLAlchemy 1.4 and removed in 2.0
.filter(Case.id != case.id)
)

# Prepare historical context
historical_context = []
Expand Down
46 changes: 46 additions & 0 deletions src/dispatch/signal/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,16 @@ def get_unprocessed_signal_instance_ids(session: Session) -> list[int]:


def get_instances_in_case(db_session: Session, case_id: int) -> Query:
"""
Retrieves signal instances associated with a given case.
Args:
db_session (Session): The database session.
case_id (int): The ID of the case.
Returns:
Query: A SQLAlchemy query object for the signal instances associated with the case.
"""
return (
db_session.query(SignalInstance, Signal)
.join(Signal)
Expand All @@ -771,10 +781,46 @@ def get_instances_in_case(db_session: Session, case_id: int) -> Query:


def get_cases_for_signal(db_session: Session, signal_id: int, limit: int = 10) -> Query:
"""
Retrieves cases associated with a given signal.
Args:
db_session (Session): The database session.
signal_id (int): The ID of the signal.
limit (int, optional): The maximum number of cases to retrieve. Defaults to 10.
Returns:
Query: A SQLAlchemy query object for the cases associated with the signal.
"""
return (
db_session.query(Case)
.join(SignalInstance)
.filter(SignalInstance.signal_id == signal_id)
.order_by(desc(Case.created_at))
.limit(limit)
)


def get_cases_for_signal_by_resolution_reason(
db_session: Session, signal_id: int, resolution_reason: str, limit: int = 10
) -> Query:
"""
Retrieves cases associated with a given signal and resolution reason.
Args:
db_session (Session): The database session.
signal_id (int): The ID of the signal.
resolution_reason (str): The resolution reason to filter cases by.
limit (int, optional): The maximum number of cases to retrieve. Defaults to 10.
Returns:
Query: A SQLAlchemy query object for the cases associated with the signal and resolution reason.
"""
return (
db_session.query(Case)
.join(SignalInstance)
.filter(SignalInstance.signal_id == signal_id)
.filter(Case.resolution_reason == resolution_reason)
.order_by(desc(Case.created_at))
.limit(limit)
)

0 comments on commit 7268e33

Please sign in to comment.