diff --git a/src/dispatch/signal/models.py b/src/dispatch/signal/models.py index dc7fdce4932e..a0fc51e5ce8d 100644 --- a/src/dispatch/signal/models.py +++ b/src/dispatch/signal/models.py @@ -1,15 +1,15 @@ import uuid from datetime import datetime -from typing import List, Optional, Any +from typing import Any, List, Optional from pydantic import Field from sqlalchemy import ( + JSON, Boolean, Column, DateTime, ForeignKey, Integer, - JSON, PrimaryKeyConstraint, String, Table, @@ -24,23 +24,21 @@ from dispatch.case.priority.models import CasePriority, CasePriorityRead from dispatch.case.type.models import CaseType, CaseTypeRead from dispatch.data.source.models import SourceBase -from dispatch.entity_type.models import EntityType -from dispatch.project.models import ProjectRead - from dispatch.database.core import Base from dispatch.entity.models import EntityRead -from dispatch.entity_type.models import EntityTypeRead -from dispatch.tag.models import TagRead +from dispatch.entity_type.models import EntityType, EntityTypeRead from dispatch.enums import DispatchEnum from dispatch.models import ( DispatchBase, EvergreenMixin, NameStr, + Pagination, PrimaryKey, ProjectMixin, TimeStampMixin, - Pagination, ) +from dispatch.project.models import ProjectRead +from dispatch.tag.models import TagRead from dispatch.workflow.models import WorkflowRead @@ -290,6 +288,10 @@ class SignalEngagementRead(SignalEngagementBase): id: PrimaryKey +class SignalEngagementUpdate(SignalEngagementBase): + id: PrimaryKey + + class SignalEngagementPagination(Pagination): items: List[SignalEngagementRead] diff --git a/src/dispatch/signal/service.py b/src/dispatch/signal/service.py index 02d3a26c5ae2..751201594093 100644 --- a/src/dispatch/signal/service.py +++ b/src/dispatch/signal/service.py @@ -36,6 +36,7 @@ SignalEngagement, SignalEngagementCreate, SignalEngagementRead, + SignalEngagementUpdate, SignalFilter, SignalFilterAction, SignalFilterCreate, @@ -51,32 +52,6 @@ log = logging.getLogger(__name__) -def create_signal_engagement( - *, db_session: Session, creator: DispatchUser, signal_engagement_in: SignalEngagementCreate -) -> SignalEngagement: - """Creates a new signal filter.""" - project = project_service.get_by_name_or_raise( - db_session=db_session, project_in=signal_engagement_in.project - ) - - entity_type = entity_type_service.get( - db_session=db_session, entity_type_id=signal_engagement_in.entity_type.id - ) - - signal_engagement = SignalEngagement( - name=signal_engagement_in.name, - description=signal_engagement_in.description, - message=signal_engagement_in.message, - require_mfa=signal_engagement_in.require_mfa, - entity_type=entity_type, - creator=creator, - project=project, - ) - db_session.add(signal_engagement) - db_session.commit() - return signal_engagement - - def get_signal_engagement( *, db_session: Session, signal_engagement_id: int ) -> Optional[SignalEngagement]: @@ -88,18 +63,6 @@ def get_signal_engagement( ) -def get_all_by_entity_type(*, db_session: Session, entity_type_id: int) -> list[SignalInstance]: - """Fetches all signal instances associated with a given entity type.""" - return ( - db_session.query(SignalInstance) - .join(SignalInstance.signal) - .join(assoc_signal_entity_types) - .join(EntityType) - .filter(assoc_signal_entity_types.c.entity_type_id == entity_type_id) - .all() - ) - - def get_signal_engagement_by_name( *, db_session, project_id: int, name: str ) -> Optional[SignalEngagement]: @@ -136,6 +99,66 @@ def get_signal_engagement_by_name_or_raise( return signal_engagement +def create_signal_engagement( + *, db_session: Session, creator: DispatchUser, signal_engagement_in: SignalEngagementCreate +) -> SignalEngagement: + """Creates a new signal engagement.""" + project = project_service.get_by_name_or_raise( + db_session=db_session, project_in=signal_engagement_in.project + ) + + entity_type = entity_type_service.get( + db_session=db_session, entity_type_id=signal_engagement_in.entity_type.id + ) + + signal_engagement = SignalEngagement( + name=signal_engagement_in.name, + description=signal_engagement_in.description, + message=signal_engagement_in.message, + require_mfa=signal_engagement_in.require_mfa, + entity_type=entity_type, + creator=creator, + project=project, + ) + db_session.add(signal_engagement) + db_session.commit() + return signal_engagement + + +def update_signal_engagement( + *, + db_session: Session, + signal_engagement: SignalEngagement, + signal_engagement_in: SignalEngagementUpdate, +) -> SignalEngagement: + """Updates an existing signal engagement.""" + signal_engagement_data = signal_engagement.dict() + update_data = signal_engagement_in.dict( + skip_defaults=True, + exclude={}, + ) + + for field in signal_engagement_data: + if field in update_data: + setattr(signal_engagement, field, update_data[field]) + + db_session.add(signal_engagement) + db_session.commit() + return signal_engagement + + +def get_all_by_entity_type(*, db_session: Session, entity_type_id: int) -> list[SignalInstance]: + """Fetches all signal instances associated with a given entity type.""" + return ( + db_session.query(SignalInstance) + .join(SignalInstance.signal) + .join(assoc_signal_entity_types) + .join(EntityType) + .filter(assoc_signal_entity_types.c.entity_type_id == entity_type_id) + .all() + ) + + def create_signal_instance(*, db_session: Session, signal_instance_in: SignalInstanceCreate): """Creates a new signal instance.""" project = project_service.get_by_name_or_default( @@ -347,6 +370,7 @@ def create(*, db_session: Session, signal_in: SignalCreate) -> Signal: exclude={ "case_priority", "case_type", + "engagements", "entity_types", "filters", "oncall_service", diff --git a/src/dispatch/signal/views.py b/src/dispatch/signal/views.py index 9ac621f04f29..6b38f20c5568 100644 --- a/src/dispatch/signal/views.py +++ b/src/dispatch/signal/views.py @@ -1,12 +1,11 @@ import logging from typing import Union -from fastapi import APIRouter, BackgroundTasks, HTTPException, Request, Response, status, Depends +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request, Response, status from pydantic.error_wrappers import ErrorWrapper, ValidationError - from sqlalchemy.exc import IntegrityError -from dispatch.auth.permissions import SensitiveProjectActionPermission, PermissionsDependency +from dispatch.auth.permissions import PermissionsDependency, SensitiveProjectActionPermission from dispatch.auth.service import CurrentUser from dispatch.database.core import DbSession from dispatch.database.service import CommonParameters, search_filter_sort_paginate @@ -21,6 +20,7 @@ SignalEngagementCreate, SignalEngagementPagination, SignalEngagementRead, + SignalEngagementUpdate, SignalFilterCreate, SignalFilterPagination, SignalFilterRead, @@ -40,8 +40,10 @@ delete_signal_filter, get, get_by_primary_or_external_id, + get_signal_engagement, get_signal_filter, update, + update_signal_engagement, update_signal_filter, ) @@ -127,12 +129,14 @@ def get_signal_engagements(common: CommonParameters): @router.get("/engagements/{engagement_id}", response_model=SignalEngagementRead) -def get_signal_engagement( +def get_engagement( db_session: DbSession, signal_engagement_id: PrimaryKey, ): """Gets a signal engagement by its id.""" - engagement = get(db_session=db_session, signal_engagement_id=signal_engagement_id) + engagement = get_signal_engagement( + db_session=db_session, signal_engagement_id=signal_engagement_id + ) if not engagement: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -164,6 +168,46 @@ def create_engagement( ) from None +@router.put( + "/engagements/{signal_engagement_id}", + response_model=SignalEngagementRead, + dependencies=[Depends(PermissionsDependency([SensitiveProjectActionPermission]))], +) +def update_engagement( + db_session: DbSession, + signal_engagement_id: PrimaryKey, + signal_engagement_in: SignalEngagementUpdate, +): + """Updates an existing signal engagement.""" + signal_engagement = get_signal_engagement( + db_session=db_session, signal_engagement_id=signal_engagement_id + ) + if not signal_engagement: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=[{"msg": "A signal engagement with this id does not exist."}], + ) + + try: + signal_engagement = update_signal_engagement( + db_session=db_session, + signal_engagement=signal_engagement, + signal_engagement_in=signal_engagement_in, + ) + except IntegrityError: + raise ValidationError( + [ + ErrorWrapper( + ExistsError(msg="A signal engagement with this name already exists."), + loc="name", + ) + ], + model=SignalEngagementUpdate, + ) from None + + return signal_engagement + + @router.post("/filters", response_model=SignalFilterRead) def create_filter( db_session: DbSession, @@ -188,7 +232,7 @@ def create_filter( @router.put( "/filters/{signal_filter_id}", - response_model=SignalRead, + response_model=SignalFilterRead, dependencies=[Depends(PermissionsDependency([SensitiveProjectActionPermission]))], ) def update_filter(