Skip to content

Commit

Permalink
Fixes updating signal engagements (#5292)
Browse files Browse the repository at this point in the history
  • Loading branch information
mvilanova authored Oct 4, 2024
1 parent 77eb5e5 commit b30180e
Show file tree
Hide file tree
Showing 8 changed files with 33 additions and 29 deletions.
5 changes: 3 additions & 2 deletions src/dispatch/data/alert/service.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Optional

from pydantic.error_wrappers import ErrorWrapper, ValidationError

from dispatch.exceptions import NotFoundError

from .models import Alert, AlertCreate, AlertUpdate, AlertRead
from .models import Alert, AlertCreate, AlertRead, AlertUpdate


def get(*, db_session, alert_id: int) -> Optional[Alert]:
Expand All @@ -16,7 +17,7 @@ def get_by_name(*, db_session, name: str) -> Optional[Alert]:
return db_session.query(Alert).filter(Alert.name == name).one_or_none()


def get_by_name_or_raise(*, db_session, alert_in=AlertRead) -> AlertRead:
def get_by_name_or_raise(*, db_session, alert_in: AlertRead) -> AlertRead:
"""Returns the alert specified or raises ValidationError."""
alert = get_by_name(db_session=db_session, name=alert_in.name)

Expand Down
6 changes: 3 additions & 3 deletions src/dispatch/organization/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def get_by_name(*, db_session, name: str) -> Optional[Organization]:
return db_session.query(Organization).filter(Organization.name == name).one_or_none()


def get_by_name_or_raise(*, db_session, organization_in=OrganizationRead) -> Organization:
def get_by_name_or_raise(*, db_session, organization_in: OrganizationRead) -> Organization:
"""Returns the organization specified or raises ValidationError."""
organization = get_by_name(db_session=db_session, name=organization_in.name)

Expand All @@ -67,7 +67,7 @@ def get_by_slug(*, db_session, slug: str) -> Optional[Organization]:
return db_session.query(Organization).filter(Organization.slug == slug).one_or_none()


def get_by_slug_or_raise(*, db_session, organization_in=OrganizationRead) -> Organization:
def get_by_slug_or_raise(*, db_session, organization_in: OrganizationRead) -> Organization:
"""Returns the organization specified or raises ValidationError."""
organization = get_by_slug(db_session=db_session, slug=organization_in.slug)

Expand All @@ -85,7 +85,7 @@ def get_by_slug_or_raise(*, db_session, organization_in=OrganizationRead) -> Org
return organization


def get_by_name_or_default(*, db_session, organization_in=OrganizationRead) -> Organization:
def get_by_name_or_default(*, db_session, organization_in: OrganizationRead) -> Organization:
"""Returns a organization based on a name or the default if not specified."""
if organization_in.name:
return get_by_name_or_raise(db_session=db_session, organization_in=organization_in)
Expand Down
10 changes: 5 additions & 5 deletions src/dispatch/project/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

from pydantic import ValidationError
from pydantic.error_wrappers import ErrorWrapper
from dispatch.exceptions import NotFoundError

from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import true

from .models import Project, ProjectCreate, ProjectUpdate, ProjectRead
from dispatch.exceptions import NotFoundError

from .models import Project, ProjectCreate, ProjectRead, ProjectUpdate


def get(*, db_session: Session, project_id: int) -> Project | None:
Expand Down Expand Up @@ -42,7 +42,7 @@ def get_by_name(*, db_session: Session, name: str) -> Optional[Project]:
return db_session.query(Project).filter(Project.name == name).one_or_none()


def get_by_name_or_raise(*, db_session: Session, project_in=ProjectRead) -> Project:
def get_by_name_or_raise(*, db_session: Session, project_in: ProjectRead) -> Project:
"""Returns the project specified or raises ValidationError."""
project = get_by_name(db_session=db_session, name=project_in.name)

Expand All @@ -60,7 +60,7 @@ def get_by_name_or_raise(*, db_session: Session, project_in=ProjectRead) -> Proj
return project


def get_by_name_or_default(*, db_session, project_in=ProjectRead) -> Project:
def get_by_name_or_default(*, db_session, project_in: ProjectRead) -> Project:
"""Returns a project based on a name or the default if not specified."""
if project_in:
if project_in.name:
Expand Down
4 changes: 2 additions & 2 deletions src/dispatch/service/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from dispatch.project.models import ProjectRead
from dispatch.search_filter import service as search_filter_service

from .models import Service, ServiceCreate, ServiceUpdate, ServiceRead
from .models import Service, ServiceCreate, ServiceRead, ServiceUpdate


def get(*, db_session, service_id: int) -> Optional[Service]:
Expand All @@ -31,7 +31,7 @@ def get_by_name(*, db_session, project_id: int, name: str) -> Optional[Service]:
)


def get_by_name_or_raise(*, db_session, project_id, service_in=ServiceRead) -> ServiceRead:
def get_by_name_or_raise(*, db_session, project_id, service_in: ServiceRead) -> ServiceRead:
"""Returns the service specified or raises ValidationError."""
source = get_by_name(db_session=db_session, project_id=project_id, name=service_in.name)

Expand Down
17 changes: 10 additions & 7 deletions src/dispatch/signal/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,9 @@ def get_signal_engagement_by_name(


def get_signal_engagement_by_name_or_raise(
*, db_session: Session, project_id: int, signal_engagement_in=SignalEngagementRead
*, db_session: Session, project_id: int, signal_engagement_in: SignalEngagementRead
) -> SignalEngagement:
"""Gets a signal engagement by its name or raises an error if not found."""
signal_engagement = get_signal_engagement_by_name(
db_session=db_session, project_id=project_id, name=signal_engagement_in.name
)
Expand All @@ -124,7 +125,7 @@ def get_signal_engagement_by_name_or_raise(
[
ErrorWrapper(
NotFoundError(
msg="Signal Engagement not found.",
msg="Signal engagement not found.",
signal_engagement=signal_engagement_in.name,
),
loc="signalEngagement",
Expand Down Expand Up @@ -226,7 +227,7 @@ def delete_signal_filter(*, db_session: Session, signal_filter_id: int) -> int:


def get_signal_filter_by_name_or_raise(
*, db_session: Session, project_id: int, signal_filter_in=SignalFilterRead
*, db_session: Session, project_id: int, signal_filter_in: SignalFilterRead
) -> SignalFilter:
signal_filter = get_signal_filter_by_name(
db_session=db_session, project_id=project_id, name=signal_filter_in.name
Expand Down Expand Up @@ -371,9 +372,9 @@ def create(*, db_session: Session, signal_in: SignalCreate) -> Signal:
signal.entity_types = entity_types

engagements = []
for eng in signal_in.engagements:
for signal_engagement_in in signal_in.engagements:
signal_engagement = get_signal_engagement_by_name(
db_session=db_session, project_id=project.id, signal_engagement_in=eng
db_session=db_session, project_id=project.id, name=signal_engagement_in.name
)
engagements.append(signal_engagement)

Expand Down Expand Up @@ -455,9 +456,11 @@ def update(*, db_session: Session, signal: Signal, signal_in: SignalUpdate) -> S

if signal_in.engagements:
engagements = []
for eng in signal_in.engagements:
for signal_engagement_in in signal_in.engagements:
signal_engagement = get_signal_engagement_by_name_or_raise(
db_session=db_session, project_id=signal.project.id, signal_engagement_in=eng
db_session=db_session,
project_id=signal.project.id,
signal_engagement_in=signal_engagement_in,
)
engagements.append(signal_engagement)

Expand Down
5 changes: 3 additions & 2 deletions src/dispatch/tag/service.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Optional

from pydantic.error_wrappers import ErrorWrapper, ValidationError

from dispatch.exceptions import NotFoundError
from dispatch.project import service as project_service
from dispatch.tag_type import service as tag_type_service

from .models import Tag, TagCreate, TagUpdate, TagRead
from .models import Tag, TagCreate, TagRead, TagUpdate


def get(*, db_session, tag_id: int) -> Optional[Tag]:
Expand All @@ -23,7 +24,7 @@ def get_by_name(*, db_session, project_id: int, name: str) -> Optional[Tag]:
)


def get_by_name_or_raise(*, db_session, project_id: int, tag_in=TagRead) -> TagRead:
def get_by_name_or_raise(*, db_session, project_id: int, tag_in: TagRead) -> TagRead:
"""Returns the tag specified or raises ValidationError."""
tag = get_by_name(db_session=db_session, project_id=project_id, name=tag_in.name)

Expand Down
4 changes: 2 additions & 2 deletions src/dispatch/tag_type/service.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Optional

from pydantic.error_wrappers import ErrorWrapper, ValidationError
from dispatch.exceptions import NotFoundError

from dispatch.exceptions import NotFoundError
from dispatch.project import service as project_service

from .models import TagType, TagTypeCreate, TagTypeRead, TagTypeUpdate
Expand Down Expand Up @@ -33,7 +33,7 @@ def get_storage_tag_type_for_project(*, db_session, project_id) -> TagType | Non
)


def get_by_name_or_raise(*, db_session, project_id: int, tag_type_in=TagTypeRead) -> TagType:
def get_by_name_or_raise(*, db_session, project_id: int, tag_type_in: TagTypeRead) -> TagType:
"""Returns the tag_type specified or raises ValidationError."""
tag_type = get_by_name(db_session=db_session, project_id=project_id, name=tag_type_in.name)

Expand Down
11 changes: 5 additions & 6 deletions src/dispatch/workflow/service.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from typing import List, Optional

from pydantic.error_wrappers import ErrorWrapper, ValidationError
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import true

from pydantic.error_wrappers import ErrorWrapper, ValidationError

from dispatch.case import service as case_service
from dispatch.config import DISPATCH_UI_URL
from dispatch.document import service as document_service
Expand All @@ -18,12 +17,12 @@

from .models import (
Workflow,
WorkflowInstance,
WorkflowCreate,
WorkflowRead,
WorkflowUpdate,
WorkflowInstance,
WorkflowInstanceCreate,
WorkflowInstanceUpdate,
WorkflowRead,
WorkflowUpdate,
)


Expand All @@ -37,7 +36,7 @@ def get_by_name(*, db_session, name: str) -> Optional[Workflow]:
return db_session.query(Workflow).filter(Workflow.name == name).one_or_none()


def get_by_name_or_raise(*, db_session: Session, workflow_in=WorkflowRead) -> Workflow:
def get_by_name_or_raise(*, db_session: Session, workflow_in: WorkflowRead) -> Workflow:
workflow = get_by_name(db_session=db_session, name=workflow_in.name)

if not workflow:
Expand Down

0 comments on commit b30180e

Please sign in to comment.