diff --git a/src/dispatch/database/revisions/tenant/versions/2024-12-05_575ca7d954a8.py b/src/dispatch/database/revisions/tenant/versions/2024-12-05_575ca7d954a8.py new file mode 100644 index 000000000000..8365c4d919cc --- /dev/null +++ b/src/dispatch/database/revisions/tenant/versions/2024-12-05_575ca7d954a8.py @@ -0,0 +1,29 @@ +"""Adds incident summary to the incident table. + +Revision ID: 575ca7d954a8 +Revises: 928b725d64f6 +Create Date: 2024-12-05 15:05:46.932404 + +""" + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "575ca7d954a8" +down_revision = "928b725d64f6" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("incident", sa.Column("summary", sa.String(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("incident", "summary") + # ### end Alembic commands ### diff --git a/src/dispatch/incident/flows.py b/src/dispatch/incident/flows.py index f5bc36acf805..e07f5e2efec1 100644 --- a/src/dispatch/incident/flows.py +++ b/src/dispatch/incident/flows.py @@ -552,6 +552,9 @@ def incident_closed_status_flow(incident: Incident, db_session=None): # to rate and provide feedback about the incident send_incident_rating_feedback_message(incident, db_session) + # if an AI plugin is enabled, we send the incident review doc for summary + incident_service.generate_incident_summary(incident, db_session) + def conversation_topic_dispatcher( user_email: str, diff --git a/src/dispatch/incident/models.py b/src/dispatch/incident/models.py index 5d98f383768e..d5e727d6ecd6 100644 --- a/src/dispatch/incident/models.py +++ b/src/dispatch/incident/models.py @@ -220,6 +220,8 @@ def last_executive_report(self): notifications_group_id = Column(Integer, ForeignKey("group.id")) notifications_group = relationship("Group", foreign_keys=[notifications_group_id]) + summary = Column(String, nullable=True) + @hybrid_property def total_cost(self): total_cost = 0 @@ -323,6 +325,7 @@ class IncidentReadMinimal(IncidentBase): reporters_location: Optional[str] stable_at: Optional[datetime] = None storage: Optional[StorageRead] = None + summary: Optional[str] = None tags: Optional[List[TagRead]] = [] tasks: Optional[List[TaskReadMinimal]] = [] total_cost: Optional[float] @@ -344,6 +347,7 @@ class IncidentUpdate(IncidentBase): reported_at: Optional[datetime] = None reporter: Optional[ParticipantUpdate] stable_at: Optional[datetime] = None + summary: Optional[str] = None tags: Optional[List[TagRead]] = [] terms: Optional[List[TermRead]] = [] @@ -393,6 +397,7 @@ class IncidentRead(IncidentBase): reporters_location: Optional[str] stable_at: Optional[datetime] = None storage: Optional[StorageRead] = None + summary: Optional[str] = None tags: Optional[List[TagRead]] = [] tasks: Optional[List[TaskRead]] = [] terms: Optional[List[TermRead]] = [] diff --git a/src/dispatch/incident/scheduled.py b/src/dispatch/incident/scheduled.py index 301bd1448189..bb13356b3057 100644 --- a/src/dispatch/incident/scheduled.py +++ b/src/dispatch/incident/scheduled.py @@ -286,25 +286,13 @@ def incident_report_weekly(db_session: Session, project: Project): if incident.visibility == Visibility.restricted: continue try: - pir_doc = storage_plugin.instance.get( - file_id=incident.incident_review_document.resource_id, - mime_type="text/plain", - ) - prompt = f""" - Given the text of the security post-incident review document below, - provide answers to the following questions in a paragraph format. - Do not include the questions in your response. - 1. What is the summary of what happened? - 2. What were the overall risk(s)? - 3. How were the risk(s) mitigated? - 4. How was the incident resolved? - 5. What are the follow-up tasks? - - {pir_doc} - """ - - response = ai_plugin.instance.chat_completion(prompt=prompt) - summary = response["choices"][0]["message"]["content"] + # if already summary generated, use that instead + if incident.summary: + summary = incident.summary + else: + summary = incident_service.generate_incident_summary( + db_session=db_session, incident=incident + ) item = { "commander_fullname": incident.commander.individual.name, diff --git a/src/dispatch/incident/service.py b/src/dispatch/incident/service.py index b4ce96f0d88f..6667c3e107ac 100644 --- a/src/dispatch/incident/service.py +++ b/src/dispatch/incident/service.py @@ -11,9 +11,11 @@ from typing import List, Optional from pydantic.error_wrappers import ErrorWrapper, ValidationError +from sqlalchemy.orm import Session + from dispatch.decorators import timer from dispatch.case import service as case_service -from dispatch.database.core import SessionLocal +from dispatch.enums import Visibility from dispatch.event import service as event_service from dispatch.exceptions import NotFoundError from dispatch.incident.priority import service as incident_priority_service @@ -27,6 +29,7 @@ from dispatch.project import service as project_service from dispatch.tag import service as tag_service from dispatch.term import service as term_service +from dispatch.ticket import flows as ticket_flows from .enums import IncidentStatus from .models import Incident, IncidentCreate, IncidentRead, IncidentUpdate @@ -35,9 +38,7 @@ log = logging.getLogger(__name__) -def resolve_and_associate_role( - db_session: SessionLocal, incident: Incident, role: ParticipantRoleType -): +def resolve_and_associate_role(db_session: Session, incident: Incident, role: ParticipantRoleType): """For a given role type resolve which individual email should be assigned that role.""" email_address = None service_id = None @@ -65,12 +66,12 @@ def resolve_and_associate_role( @timer -def get(*, db_session, incident_id: int) -> Optional[Incident]: +def get(*, db_session: Session, incident_id: int) -> Optional[Incident]: """Returns an incident based on the given id.""" return db_session.query(Incident).filter(Incident.id == incident_id).first() -def get_by_name(*, db_session, project_id: int, name: str) -> Optional[Incident]: +def get_by_name(*, db_session: Session, project_id: int, name: str) -> Optional[Incident]: """Returns an incident based on the given name.""" return ( db_session.query(Incident) @@ -80,7 +81,9 @@ def get_by_name(*, db_session, project_id: int, name: str) -> Optional[Incident] ) -def get_all_open_by_incident_type(*, db_session, incident_type_id: int) -> List[Optional[Incident]]: +def get_all_open_by_incident_type( + *, db_session: Session, incident_type_id: int +) -> List[Optional[Incident]]: """Returns all non-closed incidents based on the given incident type.""" return ( db_session.query(Incident) @@ -90,7 +93,9 @@ def get_all_open_by_incident_type(*, db_session, incident_type_id: int) -> List[ ) -def get_by_name_or_raise(*, db_session, project_id: int, incident_in: IncidentRead) -> Incident: +def get_by_name_or_raise( + *, db_session: Session, project_id: int, incident_in: IncidentRead +) -> Incident: """Returns an incident based on a given name or raises ValidationError""" incident = get_by_name(db_session=db_session, project_id=project_id, name=incident_in.name) @@ -110,12 +115,14 @@ def get_by_name_or_raise(*, db_session, project_id: int, incident_in: IncidentRe return incident -def get_all(*, db_session, project_id: int) -> List[Optional[Incident]]: +def get_all(*, db_session: Session, project_id: int) -> List[Optional[Incident]]: """Returns all incidents.""" return db_session.query(Incident).filter(Incident.project_id == project_id) -def get_all_by_status(*, db_session, status: str, project_id: int) -> List[Optional[Incident]]: +def get_all_by_status( + *, db_session: Session, status: str, project_id: int +) -> List[Optional[Incident]]: """Returns all incidents based on the given status.""" return ( db_session.query(Incident) @@ -125,7 +132,7 @@ def get_all_by_status(*, db_session, status: str, project_id: int) -> List[Optio ) -def get_all_last_x_hours(*, db_session, hours: int) -> List[Optional[Incident]]: +def get_all_last_x_hours(*, db_session: Session, hours: int) -> List[Optional[Incident]]: """Returns all incidents in the last x hours.""" now = datetime.utcnow() return ( @@ -134,7 +141,7 @@ def get_all_last_x_hours(*, db_session, hours: int) -> List[Optional[Incident]]: def get_all_last_x_hours_by_status( - *, db_session, status: str, hours: int, project_id: int + *, db_session: Session, status: str, hours: int, project_id: int ) -> List[Optional[Incident]]: """Returns all incidents of a given status in the last x hours.""" now = datetime.utcnow() @@ -167,7 +174,7 @@ def get_all_last_x_hours_by_status( ) -def create(*, db_session, incident_in: IncidentCreate) -> Incident: +def create(*, db_session: Session, incident_in: IncidentCreate) -> Incident: """Creates a new incident.""" project = project_service.get_by_name_or_default( db_session=db_session, project_in=incident_in.project @@ -326,7 +333,7 @@ def create(*, db_session, incident_in: IncidentCreate) -> Incident: return incident -def update(*, db_session, incident: Incident, incident_in: IncidentUpdate) -> Incident: +def update(*, db_session: Session, incident: Incident, incident_in: IncidentUpdate) -> Incident: """Updates an existing incident.""" incident_type = incident_type_service.get_by_name_or_default( db_session=db_session, @@ -378,6 +385,16 @@ def update(*, db_session, incident: Incident, incident_in: IncidentUpdate) -> In incident_cost_service.update_incident_response_cost( incident_id=incident.id, db_session=db_session ) + # if the new incident type has plugin metadata and the + # project key of the ticket is the same, also update the ticket with the new metadata + if incident_type.plugin_metadata: + ticket_flows.update_incident_ticket_metadata( + db_session=db_session, + ticket_id=incident.ticket.resource_id, + project_id=incident.project.id, + incident_id=incident.id, + incident_type=incident_type, + ) update_data = incident_in.dict( skip_defaults=True, @@ -417,7 +434,72 @@ def update(*, db_session, incident: Incident, incident_in: IncidentUpdate) -> In return incident -def delete(*, db_session, incident_id: int): +def delete(*, db_session: Session, incident_id: int): """Deletes an existing incident.""" db_session.query(Incident).filter(Incident.id == incident_id).delete() db_session.commit() + + +def generate_incident_summary(*, db_session: Session, incident: Incident) -> str: + """Generates a summary of the incident.""" + # Skip summary for restricted incidents + if incident.visibility == Visibility.restricted: + return "Incident summary not generated for restricted incident." + + # Skip if no incident review document + if not incident.incident_review_document or not incident.incident_review_document.resource_id: + log.info( + f"Incident summary not generated for incident {incident.id}. No review document found." + ) + return "Incident summary not generated. No review document found." + + # Don't generate if no enabled ai plugin or storage plugin + ai_plugin = plugin_service.get_active_instance( + db_session=db_session, plugin_type="artificial-intelligence", project_id=incident.project.id + ) + if not ai_plugin: + log.info( + f"Incident summary not generated for incident {incident.id}. No AI plugin enabled." + ) + return "Incident summary not generated. No AI plugin enabled." + + storage_plugin = plugin_service.get_active_instance( + db_session=db_session, plugin_type="storage", project_id=incident.project.id + ) + + if not storage_plugin: + log.info( + f"Incident summary not generated for incident {incident.id}. No storage plugin enabled." + ) + return "Incident summary not generated. No storage plugin enabled." + + try: + pir_doc = storage_plugin.instance.get( + file_id=incident.incident_review_document.resource_id, + mime_type="text/plain", + ) + prompt = f""" + Given the text of the security post-incident review document below, + provide answers to the following questions in a paragraph format. + Do not include the questions in your response. + 1. What is the summary of what happened? + 2. What were the overall risk(s)? + 3. How were the risk(s) mitigated? + 4. How was the incident resolved? + 5. What are the follow-up tasks? + + {pir_doc} + """ + + response = ai_plugin.instance.chat_completion(prompt=prompt) + summary = response["choices"][0]["message"]["content"] + + incident.summary = summary + db_session.add(incident) + db_session.commit() + + return summary + + except Exception as e: + log.exception(f"Error trying to generate summary for incident {incident.id}: {e}") + return "Incident summary not generated. An error occurred." diff --git a/src/dispatch/incident/views.py b/src/dispatch/incident/views.py index cdeaf24bbadc..8bab7317001c 100644 --- a/src/dispatch/incident/views.py +++ b/src/dispatch/incident/views.py @@ -47,7 +47,7 @@ IncidentRead, IncidentUpdate, ) -from .service import create, delete, get, update +from .service import create, delete, get, update, generate_incident_summary log = logging.getLogger(__name__) @@ -497,3 +497,18 @@ def get_incident_forecast( {"name": "Actual", "data": actual[1:]}, ], } + + +@router.get( + "/{incident_id}/regenerate", + summary="Regenerates incident sumamary", + dependencies=[Depends(PermissionsDependency([IncidentEventPermission]))], +) +def generate_summary( + db_session: DbSession, + current_incident: CurrentIncident, +): + return generate_incident_summary( + db_session=db_session, + incident=current_incident, + ) diff --git a/src/dispatch/messaging/strings.py b/src/dispatch/messaging/strings.py index dfe6eda5678c..e8240612a6ef 100644 --- a/src/dispatch/messaging/strings.py +++ b/src/dispatch/messaging/strings.py @@ -100,7 +100,7 @@ class MessageType(DispatchEnum): ).strip() INCIDENT_WEEKLY_REPORT_NO_INCIDENTS_DESCRIPTION = """ -No open incidents have been closed in the last week.""".replace( +No open visibility incidents have been closed in the last week.""".replace( "\n", " " ).strip() diff --git a/src/dispatch/plugins/bases/signal_consumer.py b/src/dispatch/plugins/bases/signal_consumer.py index 4a2d516e107e..d3c1a7edfbde 100644 --- a/src/dispatch/plugins/bases/signal_consumer.py +++ b/src/dispatch/plugins/bases/signal_consumer.py @@ -14,6 +14,3 @@ class SignalConsumerPlugin(Plugin): def consume(self, **kwargs): raise NotImplementedError - - def delete(self, **kwargs): - raise NotImplementedError diff --git a/src/dispatch/plugins/dispatch_aws/plugin.py b/src/dispatch/plugins/dispatch_aws/plugin.py index 18c0d4fa8462..9c120b01896d 100644 --- a/src/dispatch/plugins/dispatch_aws/plugin.py +++ b/src/dispatch/plugins/dispatch_aws/plugin.py @@ -6,13 +6,15 @@ .. moduleauthor:: Kevin Glisson """ +import base64 import json import logging +import zlib from typing import TypedDict import boto3 -from pydantic import ValidationError from psycopg2.errors import UniqueViolation +from pydantic import ValidationError from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session @@ -28,6 +30,13 @@ log = logging.getLogger(__name__) +def decompress_json(compressed_str: str) -> str: + """Decompress a base64 encoded zlibed JSON string.""" + decoded = base64.b64decode(compressed_str) + decompressed = zlib.decompress(decoded) + return decompressed.decode("utf-8") + + class SqsEntries(TypedDict): Id: str ReceiptHandle: str @@ -36,7 +45,7 @@ class SqsEntries(TypedDict): class AWSSQSSignalConsumerPlugin(SignalConsumerPlugin): title = "AWS SQS - Signal Consumer" slug = "aws-sqs-signal-consumer" - description = "Uses sqs to consume signals" + description = "Uses SQS to consume signals." version = __version__ author = "Netflix" @@ -60,20 +69,32 @@ def consume(self, db_session: Session, project: Project) -> None: WaitTimeSeconds=20, ) if not response.get("Messages") or len(response["Messages"]) == 0: - log.info("No messages received from SQS") + log.info("No messages received from SQS.") continue entries: list[SqsEntries] = [] for message in response["Messages"]: - body = json.loads(message["Body"]) - signal_data = json.loads(body["Message"]) + try: + message_body = json.loads(message["Body"]) + message_body_message = message_body.get("Message") + message_attributes = message_body.get("MessageAttributes", {}) + + if message_attributes.get("compressed", {}).get("Value") == "zlib": + # Message is compressed, decompress it + message_body_message = decompress_json(message_body_message) + + signal_data = json.loads(message_body_message) + except Exception as e: + log.exception(f"Unable to extract signal data from SQS message: {e}") + continue + try: signal_instance_in = SignalInstanceCreate( project=project, raw=signal_data, **signal_data ) except ValidationError as e: log.warning( - f"Received signal instance that does not conform to `SignalInstanceCreate` structure, skipping creation: {e}" + f"Received a signal instance that does not conform to the SignalInstanceCreate pydantic model. Skipping creation: {e}" ) continue @@ -83,7 +104,7 @@ def consume(self, db_session: Session, project: Project) -> None: db_session=db_session, signal_instance_id=signal_instance_in.raw["id"] ): log.info( - f"Received signal instance that already exists in the database, skipping creation: {signal_instance_in.raw['id']}" + f"Received a signal that already exists in the database. Skipping signal instance creation: {signal_instance_in.raw['id']}" ) continue @@ -96,13 +117,17 @@ def consume(self, db_session: Session, project: Project) -> None: except IntegrityError as e: if isinstance(e.orig, UniqueViolation): log.info( - f"Received signal instance that already exists in the database, skipping creation: {e}" + f"Received a signal that already exists in the database. Skipping signal instance creation: {e}" ) else: - log.exception(f"Integrity error when creating signal instance: {e}") + log.exception( + f"Encountered an integrity error when trying to create a signal instance: {e}" + ) continue except Exception as e: - log.exception(f"Unable to create signal instance: {e}") + log.exception( + f"Unable to create signal instance. Signal name/variant: {signal_instance_in.raw['name'] if signal_instance_in.raw and signal_instance_in.raw['name'] else signal_instance_in.raw['variant']}. Error: {e}" + ) db_session.rollback() continue else: @@ -114,7 +139,7 @@ def consume(self, db_session: Session, project: Project) -> None: }, ) log.debug( - f"Received signal: name: {signal_instance.signal.name} id: {signal_instance.signal.id}" + f"Received a signal with name {signal_instance.signal.name} and id {signal_instance.signal.id}" ) entries.append( {"Id": message["MessageId"], "ReceiptHandle": message["ReceiptHandle"]} diff --git a/src/dispatch/plugins/dispatch_core/plugin.py b/src/dispatch/plugins/dispatch_core/plugin.py index ec3db480c076..f6ca14145a2b 100644 --- a/src/dispatch/plugins/dispatch_core/plugin.py +++ b/src/dispatch/plugins/dispatch_core/plugin.py @@ -249,6 +249,14 @@ def create_case_ticket( "resource_type": "dispatch-internal-ticket", } + def update_metadata( + self, + ticket_id: str, + metadata: dict, + ): + """Updates the metadata of a Dispatch ticket.""" + return + def update_case_ticket( self, ticket_id: str, diff --git a/src/dispatch/plugins/dispatch_jira/plugin.py b/src/dispatch/plugins/dispatch_jira/plugin.py index d66d1b848f66..222dbda7d774 100644 --- a/src/dispatch/plugins/dispatch_jira/plugin.py +++ b/src/dispatch/plugins/dispatch_jira/plugin.py @@ -315,6 +315,7 @@ def create( reporter = get_user_field(client, self.configuration, reporter_email) project_id, issue_type_name = process_plugin_metadata(incident_type_plugin_metadata) + other_fields = create_dict_from_plugin_metadata(incident_type_plugin_metadata) if not project_id: project_id = self.configuration.default_project_id @@ -335,6 +336,7 @@ def create( "assignee": assignee, "reporter": reporter, "summary": title, + **other_fields, } ticket = create(self.configuration, client, issue_fields) @@ -401,6 +403,31 @@ def update( return update(self.configuration, client, issue, issue_fields, status) + def update_metadata( + self, + ticket_id: str, + metadata: dict, + ): + """Updates the metadata of a Jira issue.""" + client = create_client(self.configuration) + issue = client.issue(ticket_id) + + # check to make sure project id matches metadata + project_id, issue_type_name = process_plugin_metadata(metadata) + if project_id and issue.fields.project.key != project_id: + log.warning( + f"Project key mismatch between Jira issue {issue.fields.project.key} and metadata {project_id} for ticket {ticket_id}" + ) + return + other_fields = create_dict_from_plugin_metadata(metadata) + issue_fields = { + **other_fields, + } + if issue_type_name: + issue_fields["issuetype"] = {"name": issue_type_name} + + issue.update(fields=issue_fields) + def create_case_ticket( self, case_id: int, diff --git a/src/dispatch/signal/service.py b/src/dispatch/signal/service.py index 751201594093..ba158d9e7b0e 100644 --- a/src/dispatch/signal/service.py +++ b/src/dispatch/signal/service.py @@ -544,9 +544,18 @@ def delete(*, db_session: Session, signal_id: int): return signal_id -def is_valid_uuid(val): +def is_valid_uuid(value) -> bool: + """ + Checks if the provided value is a valid UUID. + + Args: + val: The value to be checked. + + Returns: + bool: True if the value is a valid UUID, False otherwise. + """ try: - uuid.UUID(str(val), version=4) + uuid.UUID(str(value), version=4) return True except ValueError: return False @@ -587,7 +596,7 @@ def create_instance( signal_instance.id = signal_instance_in.raw["id"] if signal_instance.id and not is_valid_uuid(signal_instance.id): - msg = f"Invalid signal id format. Expecting UUID format. Received {signal_instance.id}." + msg = f"Invalid signal id format. Expecting UUIDv4 format. Signal id: {signal_instance.id}. Signal name/variant: {signal_instance.raw['name'] if signal_instance and signal_instance.raw and signal_instance.raw.get('name') else signal_instance.raw['variant']}" log.warn(msg) raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, diff --git a/src/dispatch/static/dispatch/src/incident/TimelineReportTab.vue b/src/dispatch/static/dispatch/src/incident/TimelineReportTab.vue index adae8685783a..1eac8f4a3554 100644 --- a/src/dispatch/static/dispatch/src/incident/TimelineReportTab.vue +++ b/src/dispatch/static/dispatch/src/incident/TimelineReportTab.vue @@ -1,6 +1,33 @@