diff --git a/src/dispatch/cli.py b/src/dispatch/cli.py index 95ae7d5a3f96..de62735e8d09 100644 --- a/src/dispatch/cli.py +++ b/src/dispatch/cli.py @@ -3,13 +3,13 @@ import click import uvicorn + from dispatch import __version__, config from dispatch.enums import UserRoles from dispatch.plugin.models import PluginInstance -from .scheduler import scheduler from .extensions import configure_extensions - +from .scheduler import scheduler os.environ["OAUTHLIB_INSECURE_TRANSPORT"] = "1" @@ -80,10 +80,10 @@ def list_plugins(): ) def install_plugins(force): """Installs all plugins, or only one.""" + from dispatch.common.utils.cli import install_plugins from dispatch.database.core import SessionLocal from dispatch.plugin import service as plugin_service from dispatch.plugin.models import Plugin - from dispatch.common.utils.cli import install_plugins from dispatch.plugins.base import plugins install_plugins() @@ -162,9 +162,9 @@ def dispatch_user(): ) def register_user(email: str, role: str, password: str, organization: str): """Registers a new user.""" - from dispatch.database.core import refetch_db_session from dispatch.auth import service as user_service - from dispatch.auth.models import UserRegister, UserOrganization + from dispatch.auth.models import UserOrganization, UserRegister + from dispatch.database.core import refetch_db_session db_session = refetch_db_session(organization_slug=organization) user = user_service.get_by_email(email=email, db_session=db_session) @@ -198,9 +198,9 @@ def register_user(email: str, role: str, password: str, organization: str): ) def update_user(email: str, role: str, organization: str): """Updates a user's roles.""" - from dispatch.database.core import SessionLocal from dispatch.auth import service as user_service - from dispatch.auth.models import UserUpdate, UserOrganization + from dispatch.auth.models import UserOrganization, UserUpdate + from dispatch.database.core import SessionLocal db_session = SessionLocal() user = user_service.get_by_email(email=email, db_session=db_session) @@ -222,9 +222,9 @@ def update_user(email: str, role: str, organization: str): @click.password_option() def reset_user_password(email: str, password: str): """Resets a user's password.""" - from dispatch.database.core import SessionLocal from dispatch.auth import service as user_service from dispatch.auth.models import UserUpdate + from dispatch.database.core import SessionLocal db_session = SessionLocal() user = user_service.get_by_email(email=email, db_session=db_session) @@ -249,9 +249,7 @@ def database_init(): """Initializes a new database.""" click.echo("Initializing new database...") from .database.core import engine - from .database.manage import ( - init_database, - ) + from .database.manage import init_database init_database(engine) click.secho("Success.", fg="green") @@ -265,12 +263,13 @@ def database_init(): ) def restore_database(dump_file): """Restores the database via psql.""" - from sh import psql, createdb, ErrorReturnCode_1 + from sh import ErrorReturnCode_1, createdb, psql + from dispatch.config import ( + DATABASE_CREDENTIALS, DATABASE_HOSTNAME, DATABASE_NAME, DATABASE_PORT, - DATABASE_CREDENTIALS, ) username, password = str(DATABASE_CREDENTIALS).split(":") @@ -318,11 +317,12 @@ def restore_database(dump_file): def dump_database(dump_file): """Dumps the database via pg_dump.""" from sh import pg_dump + from dispatch.config import ( + DATABASE_CREDENTIALS, DATABASE_HOSTNAME, DATABASE_NAME, DATABASE_PORT, - DATABASE_CREDENTIALS, ) username, password = str(DATABASE_CREDENTIALS).split(":") @@ -345,7 +345,7 @@ def dump_database(dump_file): @click.option("--yes", is_flag=True, help="Silences all confirmation prompts.") def drop_database(yes): """Drops all data in database.""" - from sqlalchemy_utils import drop_database, database_exists + from sqlalchemy_utils import database_exists, drop_database if database_exists(str(config.SQLALCHEMY_DATABASE_URI)): if yes: @@ -378,10 +378,10 @@ def drop_database(yes): def upgrade_database(tag, sql, revision, revision_type): """Upgrades database schema to newest version.""" import sqlalchemy - from sqlalchemy import inspect - from sqlalchemy_utils import database_exists from alembic import command as alembic_command from alembic.config import Config as AlembicConfig + from sqlalchemy import inspect + from sqlalchemy_utils import database_exists from .database.core import engine from .database.manage import init_database @@ -570,6 +570,7 @@ def revision_database( ): """Create new database revision.""" import types + from alembic import command as alembic_command from alembic.config import Config as AlembicConfig @@ -623,20 +624,16 @@ def dispatch_scheduler(): from .evergreen.scheduled import create_evergreen_reminders # noqa from .feedback.incident.scheduled import feedback_report_daily # noqa from .feedback.service.scheduled import oncall_shift_feedback # noqa - from .incident_cost.scheduled import calculate_incidents_response_cost # noqa - from .incident.scheduled import ( # noqa - incident_auto_tagger, - incident_close_reminder, - incident_report_daily, + from .incident.scheduled import ( + incident_auto_tagger, # noqa ) + from .signal.scheduled import consume_signals # noqa + from .incident_cost.scheduled import calculate_incidents_response_cost # noqa from .monitor.scheduled import sync_active_stable_monitors # noqa from .report.scheduled import incident_report_reminders # noqa - from .signal.scheduled import consume_signals # noqa - from .tag.scheduled import sync_tags, build_tag_models # noqa - from .task.scheduled import ( # noqa - create_incident_tasks_reminders, - sync_incident_tasks_daily, - sync_active_stable_incident_tasks, + from .tag.scheduled import build_tag_models, sync_tags # noqa + from .task.scheduled import ( + create_incident_tasks_reminders, # noqa ) from .term.scheduled import sync_terms # noqa from .workflow.scheduled import sync_workflows # noqa @@ -661,6 +658,7 @@ def list_tasks(): def start_tasks(tasks, exclude, eager): """Starts the scheduler.""" import signal + from dispatch.common.utils.cli import install_plugins install_plugins() @@ -704,6 +702,7 @@ def dispatch_server(): def show_routes(): """Prints all available routes.""" from tabulate import tabulate + from dispatch.main import api_router table = [] @@ -716,9 +715,11 @@ def show_routes(): @dispatch_server.command("config") def show_config(): """Prints the current config as dispatch sees it.""" - import sys import inspect + import sys + from tabulate import tabulate + from dispatch import config func_members = inspect.getmembers(sys.modules[config.__name__]) @@ -771,15 +772,58 @@ def signals_group(): pass +@signals_group.command("consume") +@click.argument("organization") +@click.argument("project") +@click.argument("plugin") +def consume_signals(organization, project, plugin): # TODO support multiple from one command + """Runs a continuous process that consumes signals from the specified plugin.""" + from sqlalchemy import true + + from dispatch.common.utils.cli import install_plugins + from dispatch.database.core import refetch_db_session + from dispatch.project import service as project_service + from dispatch.project.models import ProjectRead + + install_plugins() + + session = refetch_db_session(organization) + + project = project_service.get_by_name_or_raise( + db_session=session, project_in=ProjectRead(name=project) + ) + + instances = ( + session.query(PluginInstance) + .filter(PluginInstance.enabled == true()) + .filter(PluginInstance.project_id == project.id) + .all() + ) + + instance = None + for i in instances: + if i.plugin.slug == "signal-consumer": + instance: PluginInstance = i + break + + if not instance: + click.secho( + f"No signal consumer plugin has been configured for this organization/plugin. Organization: {organization} Project: {project}", + fg="red", + ) + return + + @signals_group.command("process") def process_signals(): """Runs a continuous process that does additional processing on newly created signals.""" from sqlalchemy import asc - from dispatch.database.core import sessionmaker, engine, SessionLocal - from dispatch.signal.models import SignalInstance + + from dispatch.common.utils.cli import install_plugins + from dispatch.database.core import SessionLocal, engine, sessionmaker from dispatch.organization.service import get_all as get_all_organizations from dispatch.signal import flows as signal_flows - from dispatch.common.utils.cli import install_plugins + from dispatch.signal.models import SignalInstance install_plugins() @@ -818,20 +862,15 @@ def process_signals(): @click.argument("project") def run_slack_websocket(organization: str, project: str): """Runs the slack websocket process.""" - from sqlalchemy import true - from slack_bolt.adapter.socket_mode import SocketModeHandler + from sqlalchemy import true - from dispatch.database.core import refetch_db_session from dispatch.common.utils.cli import install_plugins + from dispatch.database.core import refetch_db_session from dispatch.plugins.dispatch_slack.bolt import app + from dispatch.plugins.dispatch_slack.case.interactive import configure as case_configure from dispatch.plugins.dispatch_slack.incident.interactive import configure as incident_configure - from dispatch.plugins.dispatch_slack.feedback.interactive import ( # noqa - configure as feedback_configure, - ) from dispatch.plugins.dispatch_slack.workflow import configure as workflow_configure - from dispatch.plugins.dispatch_slack.case.interactive import configure as case_configure - from dispatch.project import service as project_service from dispatch.project.models import ProjectRead @@ -883,6 +922,7 @@ def run_slack_websocket(organization: str, project: str): def shell(ipython_args): """Starts an ipython shell importing our app. Useful for debugging.""" import sys + import IPython from IPython.terminal.ipapp import load_default_config diff --git a/src/dispatch/plugins/dispatch_aws/_version.py b/src/dispatch/plugins/dispatch_aws/_version.py new file mode 100644 index 000000000000..3dc1f76bc69e --- /dev/null +++ b/src/dispatch/plugins/dispatch_aws/_version.py @@ -0,0 +1 @@ +__version__ = "0.1.0" diff --git a/src/dispatch/plugins/dispatch_aws/config.py b/src/dispatch/plugins/dispatch_aws/config.py new file mode 100644 index 000000000000..557886e6166b --- /dev/null +++ b/src/dispatch/plugins/dispatch_aws/config.py @@ -0,0 +1,29 @@ +from pydantic import Field +from dispatch.config import BaseConfigurationModel + + +class AWSSQSConfiguration(BaseConfigurationModel): + """SQS configuration description.""" + + queue_name: str = Field( + title="SQS Queue Name", + description="SQS Queue Name, not the ARN.", + ) + + queue_owner: str = Field( + title="SQS Queue Owner", + description="SQS Queue Owner Account ID.", + ) + + region: str = Field( + title="AWS Region", + description="AWS Region.", + default="us-east-1", + ) + + batch_size: int = Field( + title="SQS Batch Size", + description="SQS Batch Size.", + default=10, + le=10, + ) diff --git a/src/dispatch/plugins/dispatch_aws/plugin.py b/src/dispatch/plugins/dispatch_aws/plugin.py new file mode 100644 index 000000000000..0fe4c7490de6 --- /dev/null +++ b/src/dispatch/plugins/dispatch_aws/plugin.py @@ -0,0 +1,66 @@ +""" +.. module: dispatch.plugins.dispatchaws.plugin + :platform: Unix + :copyright: (c) 2023 by Netflix Inc., see AUTHORS for more + :license: Apache, see LICENSE for more details. +.. moduleauthor:: Kevin Glisson +""" +import boto3 +import json +import logging + +from dispatch.metrics import provider as metrics_provider +from dispatch.plugins.bases import SignalConsumerPlugin +from dispatch.signal import service as signal_service +from dispatch.signal.models import SignalInstance +from dispatch.plugins.dispatch_aws.config import AWSSQSConfigurationSchema + +from . import __version__ + +log = logging.getLogger(__name__) + + +class SQSSignalConsumerPlugin(SignalConsumerPlugin): + title = "AWS SQS - Signal Consumer" + slug = "aws-sqs-signal-consumer" + description = "Uses sqs to consume signals" + version = __version__ + + author = "Netflix" + author_url = "https://github.com/netflix/dispatch.git" + + def __init__(self): + self.configuration_schema = AWSSQSConfigurationSchema + + def consume( + self, + ): + client = boto3.client("sqs", region_name=self.configuration.region) + sqs_queue_url: str = client.get_queue_url( + QueueName=self.configuration.queue_name, QueueOwnerAWSAccountId=self.sqs_queue_owner + )["QueueUrl"] + + while True: + response = client.receive_message( + QueueUrl=sqs_queue_url, + MaxNumberOfMessages=self.configuration.batch_size, + VisibilityTimeout=2 * self.round_length, + WaitTimeSeconds=self.round_length, + ) + if response.get("Messages") and len(response.get("Messages")) > 0: + entries = [] + for message in response["Messages"]: + try: + body = json.loads(message["Body"]) + signal = signal_service.create_signal_instance(SignalInstance(**body)) + metrics_provider.counter( + "sqs.signal.received", tags={"signalName": signal.name} + ) + log.debug(f"Received signal: {signal}") + entries.append( + {"Id": message["MessageId"], "ReceiptHandle": message["ReceiptHandle"]} + ) + except Exception as e: + log.exception(e) + + client.delete_message_batch(QueueUrl=sqs_queue_url, Entries=entries) diff --git a/src/dispatch/signal/exceptions.py b/src/dispatch/signal/exceptions.py new file mode 100644 index 000000000000..867412acd0d4 --- /dev/null +++ b/src/dispatch/signal/exceptions.py @@ -0,0 +1,13 @@ +from dispatch.exceptions import DispatchException + + +class SignalNotIdentifiedException(DispatchException): + pass + + +class SignalNotDefinedException(DispatchException): + pass + + +class SignalNotEnabledException(DispatchException): + pass diff --git a/src/dispatch/signal/service.py b/src/dispatch/signal/service.py index f445763af842..8c87b3f1c4bc 100644 --- a/src/dispatch/signal/service.py +++ b/src/dispatch/signal/service.py @@ -18,6 +18,15 @@ from dispatch.tag import service as tag_service from dispatch.workflow import service as workflow_service from dispatch.entity.models import Entity +from sqlalchemy.exc import IntegrityError +from dispatch.entity_type.models import EntityScopeEnum +from dispatch.entity import service as entity_service + +from .exceptions import ( + SignalNotDefinedException, + SignalNotEnabledException, + SignalNotIdentifiableException, +) from .models import ( Signal, @@ -109,6 +118,66 @@ def get_signal_engagement_by_name_or_raise( return signal_engagement +def create_signal_instance(*, db_session: Session, signal_instance_in: SignalInstanceCreate): + project = project_service.get_by_name_or_default( + db_session=db_session, project_in=signal_instance_in.project + ) + + if not signal_instance_in.signal: + external_id = signal_instance_in.raw.get("externalId") + variant = signal_instance_in.raw.get("variant") + + if external_id or variant: + signal = get_by_variant_or_external_id( + db_session=db_session, + project_id=project.id, + external_id=external_id, + variant=variant, + ) + + signal_instance_in.signal = signal + else: + msg = "An external id or variant must be provided." + raise SignalNotIdentifiableException(msg) + + if not signal: + msg = f"No signal definition found. External Id: {external_id} Variant: {variant}" + raise SignalNotDefinedException(msg) + + if not signal.enabled: + msg = f"Signal definition not enabled. Signal Name: {signal.name}" + raise SignalNotEnabledException(msg) + + try: + signal_instance = create_instance( + db_session=db_session, signal_instance_in=signal_instance_in + ) + signal_instance.signal = signal + db_session.commit() + except IntegrityError: + db_session.rollback() + signal_instance = update_instance( + db_session=db_session, signal_instance_in=signal_instance_in + ) + # Note: we can do this because it's still relatively cheap, if we add more logic to the flow + # this will need to be moved to a background function (similar to case creation) + # fetch `all` entities that should be associated with all signal definitions + entity_types = entity_type_service.get_all( + db_session=db_session, scope=EntityScopeEnum.all + ).all() + entity_types = signal_instance.signal.entity_types + entity_types + + if entity_types: + entities = entity_service.find_entities( + db_session=db_session, + signal_instance=signal_instance, + entity_types=entity_types, + ) + signal_instance.entities = entities + db_session.commit() + return signal_instance + + def create_signal_filter( *, db_session: Session, creator: DispatchUser, signal_filter_in: SignalFilterCreate ) -> SignalFilter: