diff --git a/requirements-base.in b/requirements-base.in index e6834e4d73dd..95003484d6f4 100644 --- a/requirements-base.in +++ b/requirements-base.in @@ -6,6 +6,7 @@ atlassian-python-api==3.32.0 attrs==22.1.0 bcrypt blockkit +boto3 cachetools chardet click @@ -44,8 +45,8 @@ schemathesis sentry-asgi sentry-sdk sh -slack-bolt slack_sdk +slack-bolt slowapi spacy sqlalchemy-filters diff --git a/setup.py b/setup.py index 0314fe703c4e..ef838342801a 100644 --- a/setup.py +++ b/setup.py @@ -404,6 +404,7 @@ def run(self): "dispatch.plugins": [ "dispatch_atlassian_confluence = dispatch.plugins.dispatch_atlassian_confluence.plugin:ConfluencePagePlugin", "dispatch_atlassian_confluence_document = dispatch.plugins.dispatch_atlassian_confluence.docs.plugin:ConfluencePageDocPlugin", + "dispatch_aws_sqs = dispatch.plugins.dispatch_aws.plugin:AWSSQSSignalConsumerPlugin", "dispatch_basic_auth = dispatch.plugins.dispatch_core.plugin:BasicAuthProviderPlugin", "dispatch_contact = dispatch.plugins.dispatch_core.plugin:DispatchContactPlugin", "dispatch_document_resolver = dispatch.plugins.dispatch_core.plugin:DispatchDocumentResolverPlugin", diff --git a/src/dispatch/cli.py b/src/dispatch/cli.py index 689e3ec9c66b..0675e157cbec 100644 --- a/src/dispatch/cli.py +++ b/src/dispatch/cli.py @@ -3,13 +3,13 @@ import click import uvicorn + from dispatch import __version__, config from dispatch.enums import UserRoles from dispatch.plugin.models import PluginInstance -from .scheduler import scheduler from .extensions import configure_extensions - +from .scheduler import scheduler os.environ["OAUTHLIB_INSECURE_TRANSPORT"] = "1" @@ -80,10 +80,10 @@ def list_plugins(): ) def install_plugins(force): """Installs all plugins, or only one.""" + from dispatch.common.utils.cli import install_plugins from dispatch.database.core import SessionLocal from dispatch.plugin import service as plugin_service from dispatch.plugin.models import Plugin - from dispatch.common.utils.cli import install_plugins from dispatch.plugins.base import plugins install_plugins() @@ -162,9 +162,9 @@ def dispatch_user(): ) def register_user(email: str, role: str, password: str, organization: str): """Registers a new user.""" - from dispatch.database.core import refetch_db_session from dispatch.auth import service as user_service - from dispatch.auth.models import UserRegister, UserOrganization + from dispatch.auth.models import UserOrganization, UserRegister + from dispatch.database.core import refetch_db_session db_session = refetch_db_session(organization_slug=organization) user = user_service.get_by_email(email=email, db_session=db_session) @@ -198,9 +198,9 @@ def register_user(email: str, role: str, password: str, organization: str): ) def update_user(email: str, role: str, organization: str): """Updates a user's roles.""" - from dispatch.database.core import SessionLocal from dispatch.auth import service as user_service - from dispatch.auth.models import UserUpdate, UserOrganization + from dispatch.auth.models import UserOrganization, UserUpdate + from dispatch.database.core import SessionLocal db_session = SessionLocal() user = user_service.get_by_email(email=email, db_session=db_session) @@ -222,9 +222,9 @@ def update_user(email: str, role: str, organization: str): @click.password_option() def reset_user_password(email: str, password: str): """Resets a user's password.""" - from dispatch.database.core import SessionLocal from dispatch.auth import service as user_service from dispatch.auth.models import UserUpdate + from dispatch.database.core import SessionLocal db_session = SessionLocal() user = user_service.get_by_email(email=email, db_session=db_session) @@ -249,9 +249,7 @@ def database_init(): """Initializes a new database.""" click.echo("Initializing new database...") from .database.core import engine - from .database.manage import ( - init_database, - ) + from .database.manage import init_database init_database(engine) click.secho("Success.", fg="green") @@ -265,12 +263,13 @@ def database_init(): ) def restore_database(dump_file): """Restores the database via psql.""" - from sh import psql, createdb, ErrorReturnCode_1 + from sh import ErrorReturnCode_1, createdb, psql + from dispatch.config import ( + DATABASE_CREDENTIALS, DATABASE_HOSTNAME, DATABASE_NAME, DATABASE_PORT, - DATABASE_CREDENTIALS, ) username, password = str(DATABASE_CREDENTIALS).split(":") @@ -318,11 +317,12 @@ def restore_database(dump_file): def dump_database(dump_file): """Dumps the database via pg_dump.""" from sh import pg_dump + from dispatch.config import ( + DATABASE_CREDENTIALS, DATABASE_HOSTNAME, DATABASE_NAME, DATABASE_PORT, - DATABASE_CREDENTIALS, ) username, password = str(DATABASE_CREDENTIALS).split(":") @@ -345,7 +345,7 @@ def dump_database(dump_file): @click.option("--yes", is_flag=True, help="Silences all confirmation prompts.") def drop_database(yes): """Drops all data in database.""" - from sqlalchemy_utils import drop_database, database_exists + from sqlalchemy_utils import database_exists, drop_database if database_exists(str(config.SQLALCHEMY_DATABASE_URI)): if yes: @@ -378,10 +378,10 @@ def drop_database(yes): def upgrade_database(tag, sql, revision, revision_type): """Upgrades database schema to newest version.""" import sqlalchemy - from sqlalchemy import inspect - from sqlalchemy_utils import database_exists from alembic import command as alembic_command from alembic.config import Config as AlembicConfig + from sqlalchemy import inspect + from sqlalchemy_utils import database_exists from .database.core import engine from .database.manage import init_database @@ -570,6 +570,7 @@ def revision_database( ): """Create new database revision.""" import types + from alembic import command as alembic_command from alembic.config import Config as AlembicConfig @@ -623,20 +624,15 @@ def dispatch_scheduler(): from .evergreen.scheduled import create_evergreen_reminders # noqa from .feedback.incident.scheduled import feedback_report_daily # noqa from .feedback.service.scheduled import oncall_shift_feedback # noqa - from .incident_cost.scheduled import calculate_incidents_response_cost # noqa - from .incident.scheduled import ( # noqa - incident_auto_tagger, - incident_close_reminder, - incident_report_daily, + from .incident.scheduled import ( + incident_auto_tagger, # noqa ) + from .incident_cost.scheduled import calculate_incidents_response_cost # noqa from .monitor.scheduled import sync_active_stable_monitors # noqa from .report.scheduled import incident_report_reminders # noqa - from .signal.scheduled import consume_signals # noqa - from .tag.scheduled import sync_tags, build_tag_models # noqa - from .task.scheduled import ( # noqa - create_incident_tasks_reminders, - sync_incident_tasks_daily, - sync_active_stable_incident_tasks, + from .tag.scheduled import build_tag_models, sync_tags # noqa + from .task.scheduled import ( + create_incident_tasks_reminders, # noqa ) from .term.scheduled import sync_terms # noqa from .workflow.scheduled import sync_workflows # noqa @@ -662,6 +658,7 @@ def list_tasks(): def start_tasks(tasks, exclude, eager): """Starts the scheduler.""" import signal + from dispatch.common.utils.cli import install_plugins install_plugins() @@ -705,6 +702,7 @@ def dispatch_server(): def show_routes(): """Prints all available routes.""" from tabulate import tabulate + from dispatch.main import api_router table = [] @@ -717,9 +715,11 @@ def show_routes(): @dispatch_server.command("config") def show_config(): """Prints the current config as dispatch sees it.""" - import sys import inspect + import sys + from tabulate import tabulate + from dispatch import config func_members = inspect.getmembers(sys.modules[config.__name__]) @@ -772,15 +772,51 @@ def signals_group(): pass +@signals_group.command("consume") +@click.argument("organization") +@click.argument("project") +@click.argument("plugin_name") +def consume_signals(organization, project, plugin_name): # TODO support multiple from one command + """Runs a continuous process that consumes signals from the specified plugin.""" + from dispatch.common.utils.cli import install_plugins + from dispatch.database.core import refetch_db_session + from dispatch.project import service as project_service + from dispatch.project.models import ProjectRead + from dispatch.plugin import service as plugin_service + + install_plugins() + + session = refetch_db_session(organization) + + project = project_service.get_by_name_or_raise( + db_session=session, project_in=ProjectRead(name=project) + ) + + plugins = plugin_service.get_active_instances( + db_session=session, plugin_type="signal-consumer", project_id=project.id + ) + + if not plugins: + log.debug( + "No signals consumed. No signal-consumer plugins enabled. Project: {project.name}. Organization: {project.organization.name}" + ) + return + + for plugin in plugins: + if plugin.plugin.slug == plugin_name: + plugin.instance.consume(db_session=session, project=project) + + @signals_group.command("process") def process_signals(): """Runs a continuous process that does additional processing on newly created signals.""" from sqlalchemy import asc - from dispatch.database.core import sessionmaker, engine, SessionLocal - from dispatch.signal.models import SignalInstance + + from dispatch.common.utils.cli import install_plugins + from dispatch.database.core import SessionLocal, engine, sessionmaker from dispatch.organization.service import get_all as get_all_organizations from dispatch.signal import flows as signal_flows - from dispatch.common.utils.cli import install_plugins + from dispatch.signal.models import SignalInstance install_plugins() @@ -819,20 +855,15 @@ def process_signals(): @click.argument("project") def run_slack_websocket(organization: str, project: str): """Runs the slack websocket process.""" - from sqlalchemy import true - from slack_bolt.adapter.socket_mode import SocketModeHandler + from sqlalchemy import true - from dispatch.database.core import refetch_db_session from dispatch.common.utils.cli import install_plugins + from dispatch.database.core import refetch_db_session from dispatch.plugins.dispatch_slack.bolt import app + from dispatch.plugins.dispatch_slack.case.interactive import configure as case_configure from dispatch.plugins.dispatch_slack.incident.interactive import configure as incident_configure - from dispatch.plugins.dispatch_slack.feedback.interactive import ( # noqa - configure as feedback_configure, - ) from dispatch.plugins.dispatch_slack.workflow import configure as workflow_configure - from dispatch.plugins.dispatch_slack.case.interactive import configure as case_configure - from dispatch.project import service as project_service from dispatch.project.models import ProjectRead @@ -884,6 +915,7 @@ def run_slack_websocket(organization: str, project: str): def shell(ipython_args): """Starts an ipython shell importing our app. Useful for debugging.""" import sys + import IPython from IPython.terminal.ipapp import load_default_config diff --git a/src/dispatch/plugins/dispatch_aws/__init__.py b/src/dispatch/plugins/dispatch_aws/__init__.py new file mode 100644 index 000000000000..ad5cc752c07b --- /dev/null +++ b/src/dispatch/plugins/dispatch_aws/__init__.py @@ -0,0 +1 @@ +from ._version import __version__ # noqa diff --git a/src/dispatch/plugins/dispatch_aws/_version.py b/src/dispatch/plugins/dispatch_aws/_version.py new file mode 100644 index 000000000000..3dc1f76bc69e --- /dev/null +++ b/src/dispatch/plugins/dispatch_aws/_version.py @@ -0,0 +1 @@ +__version__ = "0.1.0" diff --git a/src/dispatch/plugins/dispatch_aws/config.py b/src/dispatch/plugins/dispatch_aws/config.py new file mode 100644 index 000000000000..d14e3f58ae2a --- /dev/null +++ b/src/dispatch/plugins/dispatch_aws/config.py @@ -0,0 +1,29 @@ +from pydantic import Field +from dispatch.config import BaseConfigurationModel + + +class AWSSQSConfiguration(BaseConfigurationModel): + """Signal SQS configuration""" + + queue_name: str = Field( + title="Queue Name", + description="Queue Name, not the ARN.", + ) + + queue_owner: str = Field( + title="Queue Owner", + description="Queue Owner Account ID.", + ) + + region: str = Field( + title="AWS Region", + description="AWS Region.", + default="us-east-1", + ) + + batch_size: int = Field( + title="Batch Size", + description="Number of messages to retrieve from SQS.", + default=10, + le=10, + ) diff --git a/src/dispatch/plugins/dispatch_aws/plugin.py b/src/dispatch/plugins/dispatch_aws/plugin.py new file mode 100644 index 000000000000..fb6902e83cc1 --- /dev/null +++ b/src/dispatch/plugins/dispatch_aws/plugin.py @@ -0,0 +1,78 @@ +""" +.. module: dispatch.plugins.dispatchaws.plugin + :platform: Unix + :copyright: (c) 2023 by Netflix Inc., see AUTHORS for more + :license: Apache, see LICENSE for more details. +.. moduleauthor:: Kevin Glisson +""" +import boto3 +import json +import logging + +from dispatch.metrics import provider as metrics_provider +from dispatch.plugins.bases import SignalConsumerPlugin +from dispatch.signal import service as signal_service +from dispatch.signal.models import SignalInstanceCreate +from dispatch.plugins.dispatch_aws.config import AWSSQSConfiguration + +from . import __version__ + +log = logging.getLogger(__name__) + + +class AWSSQSSignalConsumerPlugin(SignalConsumerPlugin): + title = "AWS SQS - Signal Consumer" + slug = "aws-sqs-signal-consumer" + description = "Uses sqs to consume signals" + version = __version__ + + author = "Netflix" + author_url = "https://github.com/netflix/dispatch.git" + + def __init__(self): + self.configuration_schema = AWSSQSConfiguration + + def consume(self, db_session, project): + client = boto3.client("sqs", region_name=self.configuration.region) + queue_url: str = client.get_queue_url( + QueueName=self.configuration.queue_name, + QueueOwnerAWSAccountId=self.configuration.queue_owner, + )["QueueUrl"] + + while True: + response = client.receive_message( + QueueUrl=queue_url, + MaxNumberOfMessages=self.configuration.batch_size, + VisibilityTimeout=40, + WaitTimeSeconds=20, + ) + if response.get("Messages") and len(response.get("Messages")) > 0: + entries = [] + for message in response["Messages"]: + try: + body = json.loads(message["Body"]) + signal_data = json.loads(body["Message"]) + + signal_instance = signal_service.create_signal_instance( + db_session=db_session, + signal_instance_in=SignalInstanceCreate( + project=project, raw=signal_data, **signal_data + ), + ) + metrics_provider.counter( + "aws-sqs-signal-consumer.signal.received", + tags={ + "signalName": signal_instance.signal.name, + "externalId": signal_instance.signal.external_id, + }, + ) + log.debug( + f"Received signal: SignalName: {signal_instance.signal.name} ExernalId: {signal_instance.signal.external_id}" + ) + entries.append( + {"Id": message["MessageId"], "ReceiptHandle": message["ReceiptHandle"]} + ) + except Exception as e: + log.exception(e) + + client.delete_message_batch(QueueUrl=queue_url, Entries=entries) diff --git a/src/dispatch/signal/exceptions.py b/src/dispatch/signal/exceptions.py new file mode 100644 index 000000000000..867412acd0d4 --- /dev/null +++ b/src/dispatch/signal/exceptions.py @@ -0,0 +1,13 @@ +from dispatch.exceptions import DispatchException + + +class SignalNotIdentifiedException(DispatchException): + pass + + +class SignalNotDefinedException(DispatchException): + pass + + +class SignalNotEnabledException(DispatchException): + pass diff --git a/src/dispatch/signal/models.py b/src/dispatch/signal/models.py index d35154a8c325..a295c02fb486 100644 --- a/src/dispatch/signal/models.py +++ b/src/dispatch/signal/models.py @@ -363,11 +363,12 @@ class AdditionalMetadata(DispatchBase): class SignalInstanceBase(DispatchBase): - project: ProjectRead + project: Optional[ProjectRead] case: Optional[CaseReadMinimal] canary: Optional[bool] = False entities: Optional[List[EntityRead]] = [] raw: dict[str, Any] + external_id: Optional[str] filter_action: SignalFilterAction = None created_at: Optional[datetime] = None diff --git a/src/dispatch/signal/scheduled.py b/src/dispatch/signal/scheduled.py deleted file mode 100644 index f09d5415f396..000000000000 --- a/src/dispatch/signal/scheduled.py +++ /dev/null @@ -1,50 +0,0 @@ -""" -.. module: dispatch.signal.scheduled - :platform: Unix - :copyright: (c) 2022 by Netflix Inc., see AUTHORS for more - :license: Apache, see LICENSE for more details. -""" -import logging - -from schedule import every - -from dispatch.database.core import SessionLocal -from dispatch.decorators import scheduled_project_task, timer -from dispatch.plugin import service as plugin_service -from dispatch.project.models import Project -from dispatch.scheduler import scheduler -from dispatch.signal import flows as signal_flows - -log = logging.getLogger(__name__) - - -# TODO do we want per signal source flexibility? -@scheduler.add(every(1).minutes, name="signal-consume") -@timer -@scheduled_project_task -def consume_signals(db_session: SessionLocal, project: Project): - """Consume signals from external sources.""" - plugins = plugin_service.get_active_instances( - db_session=db_session, plugin_type="signal-consumer", project_id=project.id - ) - - if not plugins: - log.debug( - "No signals consumed. No signal-consumer plugins enabled. Project: {project.name}. Organization: {project.organization.name}" - ) - return - - for plugin in plugins: - log.debug(f"Consuming signals using signal-consumer plugin: {plugin.plugin.slug}") - signal_instances = plugin.instance.consume() - for signal_instance_data in signal_instances: - log.info(f"Attempting to process the following signal: {signal_instance_data}") - try: - signal_flows.create_signal_instance( - db_session=db_session, - project=project, - signal_instance_data=signal_instance_data, - ) - except Exception as e: - log.debug(signal_instance_data) - log.exception(e) diff --git a/src/dispatch/signal/service.py b/src/dispatch/signal/service.py index f445763af842..eadc7e0924e1 100644 --- a/src/dispatch/signal/service.py +++ b/src/dispatch/signal/service.py @@ -18,6 +18,15 @@ from dispatch.tag import service as tag_service from dispatch.workflow import service as workflow_service from dispatch.entity.models import Entity +from sqlalchemy.exc import IntegrityError +from dispatch.entity_type.models import EntityScopeEnum +from dispatch.entity import service as entity_service + +from .exceptions import ( + SignalNotDefinedException, + SignalNotEnabledException, + SignalNotIdentifiedException, +) from .models import ( Signal, @@ -109,6 +118,58 @@ def get_signal_engagement_by_name_or_raise( return signal_engagement +def create_signal_instance(*, db_session: Session, signal_instance_in: SignalInstanceCreate): + if not signal_instance_in.signal: + external_id = signal_instance_in.external_id + + # this assumes the external_ids are uuids + if external_id: + signal = ( + db_session.query(Signal).filter(Signal.external_id == external_id).one_or_none() + ) + signal_instance_in.signal = signal + else: + msg = "An externalId must be provided." + raise SignalNotIdentifiedException(msg) + + if not signal: + msg = f"No signal definition found. ExternalId: {external_id}" + raise SignalNotDefinedException(msg) + + if not signal.enabled: + msg = f"Signal definition not enabled. SignalName: {signal.name} ExternalId: {signal.external_id}" + raise SignalNotEnabledException(msg) + + try: + signal_instance = create_instance( + db_session=db_session, signal_instance_in=signal_instance_in + ) + signal_instance.signal = signal + db_session.commit() + except IntegrityError: + db_session.rollback() + signal_instance = update_instance( + db_session=db_session, signal_instance_in=signal_instance_in + ) + # Note: we can do this because it's still relatively cheap, if we add more logic here + # this will need to be moved to a background function (similar to case creation) + # fetch `all` entities that should be associated with all signal definitions + entity_types = entity_type_service.get_all( + db_session=db_session, scope=EntityScopeEnum.all + ).all() + entity_types = signal_instance.signal.entity_types + entity_types + + if entity_types: + entities = entity_service.find_entities( + db_session=db_session, + signal_instance=signal_instance, + entity_types=entity_types, + ) + signal_instance.entities = entities + db_session.commit() + return signal_instance + + def create_signal_filter( *, db_session: Session, creator: DispatchUser, signal_filter_in: SignalFilterCreate ) -> SignalFilter: @@ -451,6 +512,7 @@ def create_instance( "project", "entities", "raw", + "external_id", } ), raw=json.loads(json.dumps(signal_instance_in.raw)),