diff --git a/src/dispatch/plugins/dispatch_slack/case/messages.py b/src/dispatch/plugins/dispatch_slack/case/messages.py index 35c663c3b7ef..1a20a6f17913 100644 --- a/src/dispatch/plugins/dispatch_slack/case/messages.py +++ b/src/dispatch/plugins/dispatch_slack/case/messages.py @@ -14,7 +14,7 @@ from slack_sdk.web.client import WebClient from sqlalchemy.orm import Session -from dispatch.case.enums import CaseStatus +from dispatch.case.enums import CaseResolutionReason, CaseStatus from dispatch.case.models import Case from dispatch.config import DISPATCH_UI_URL from dispatch.messaging.strings import CASE_STATUS_DESCRIPTIONS, CASE_VISIBILITY_DESCRIPTIONS @@ -320,13 +320,17 @@ def create_genai_signal_analysis_message( return signal_metadata_blocks # Fetch related cases - related_cases = ( - signal_service.get_cases_for_signal( - db_session=db_session, signal_id=first_instance_signal.id + related_cases = [] + for resolution_reason in CaseResolutionReason: + related_cases.extend( + signal_service.get_cases_for_signal_by_resolution_reason( + db_session=db_session, + signal_id=first_instance_signal.id, + resolution_reason=resolution_reason, + ) + .from_self() # NOTE: function deprecated in SQLAlchemy 1.4 and removed in 2.0 + .filter(Case.id != case.id) ) - .from_self() # NOTE: function deprecated in SQLAlchemy 1.4 and removed in 2.0 - .filter(Case.id != case.id) - ) # Prepare historical context historical_context = [] diff --git a/src/dispatch/signal/service.py b/src/dispatch/signal/service.py index 1123bbc07e23..e66d6a27cf44 100644 --- a/src/dispatch/signal/service.py +++ b/src/dispatch/signal/service.py @@ -761,6 +761,16 @@ def get_unprocessed_signal_instance_ids(session: Session) -> list[int]: def get_instances_in_case(db_session: Session, case_id: int) -> Query: + """ + Retrieves signal instances associated with a given case. + + Args: + db_session (Session): The database session. + case_id (int): The ID of the case. + + Returns: + Query: A SQLAlchemy query object for the signal instances associated with the case. + """ return ( db_session.query(SignalInstance, Signal) .join(Signal) @@ -771,10 +781,46 @@ def get_instances_in_case(db_session: Session, case_id: int) -> Query: def get_cases_for_signal(db_session: Session, signal_id: int, limit: int = 10) -> Query: + """ + Retrieves cases associated with a given signal. + + Args: + db_session (Session): The database session. + signal_id (int): The ID of the signal. + limit (int, optional): The maximum number of cases to retrieve. Defaults to 10. + + Returns: + Query: A SQLAlchemy query object for the cases associated with the signal. + """ + return ( + db_session.query(Case) + .join(SignalInstance) + .filter(SignalInstance.signal_id == signal_id) + .order_by(desc(Case.created_at)) + .limit(limit) + ) + + +def get_cases_for_signal_by_resolution_reason( + db_session: Session, signal_id: int, resolution_reason: str, limit: int = 10 +) -> Query: + """ + Retrieves cases associated with a given signal and resolution reason. + + Args: + db_session (Session): The database session. + signal_id (int): The ID of the signal. + resolution_reason (str): The resolution reason to filter cases by. + limit (int, optional): The maximum number of cases to retrieve. Defaults to 10. + + Returns: + Query: A SQLAlchemy query object for the cases associated with the signal and resolution reason. + """ return ( db_session.query(Case) .join(SignalInstance) .filter(SignalInstance.signal_id == signal_id) + .filter(Case.resolution_reason == resolution_reason) .order_by(desc(Case.created_at)) .limit(limit) )