Skip to content

Commit

Permalink
dag run state change endpoints notify listeners about state change (a…
Browse files Browse the repository at this point in the history
…pache#45652)

Signed-off-by: Maciej Obuchowski <[email protected]>
  • Loading branch information
mobuchowski authored Jan 15, 2025
1 parent 9c75dac commit 6afde78
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 2 deletions.
4 changes: 4 additions & 0 deletions airflow/api_fastapi/core_api/routes/public/dag_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
)
from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc
from airflow.exceptions import ParamValidationError
from airflow.listeners.listener import get_listener_manager
from airflow.models import DAG, DagModel, DagRun
from airflow.models.dag_version import DagVersion
from airflow.timetables.base import DataInterval
Expand Down Expand Up @@ -159,10 +160,13 @@ def patch_dag_run(
attr_value = getattr(patch_body, "state")
if attr_value == DAGRunPatchStates.SUCCESS:
set_dag_run_state_to_success(dag=dag, run_id=dag_run.run_id, commit=True, session=session)
get_listener_manager().hook.on_dag_run_success(dag_run=dag_run, msg="")
elif attr_value == DAGRunPatchStates.QUEUED:
set_dag_run_state_to_queued(dag=dag, run_id=dag_run.run_id, commit=True, session=session)
# Not notifying on queued - only notifying on RUNNING, this is happening in scheduler
elif attr_value == DAGRunPatchStates.FAILED:
set_dag_run_state_to_failed(dag=dag, run_id=dag_run.run_id, commit=True, session=session)
get_listener_manager().hook.on_dag_run_failed(dag_run=dag_run, msg="")
elif attr_name == "note":
# Once Authentication is implemented in this FastAPI app,
# user id will be added when updating dag run note
Expand Down
24 changes: 24 additions & 0 deletions tests/api_fastapi/core_api/routes/public/test_dag_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import time_machine
from sqlalchemy import select

from airflow.listeners.listener import get_listener_manager
from airflow.models import DagModel, DagRun
from airflow.models.asset import AssetEvent, AssetModel
from airflow.models.param import Param
Expand Down Expand Up @@ -943,6 +944,29 @@ def test_patch_dag_run_bad_request(self, test_client):
body = response.json()
assert body["detail"][0]["msg"] == "Input should be 'queued', 'success' or 'failed'"

@pytest.fixture(autouse=True)
def clean_listener_manager(self):
get_listener_manager().clear()
yield
get_listener_manager().clear()

@pytest.mark.parametrize(
"state, listener_state",
[
("queued", []),
("success", [DagRunState.SUCCESS]),
("failed", [DagRunState.FAILED]),
],
)
def test_patch_dag_run_notifies_listeners(self, test_client, state, listener_state):
from tests.listeners.class_listener import ClassBasedListener

listener = ClassBasedListener()
get_listener_manager().add_listener(listener)
response = test_client.patch(f"/public/dags/{DAG1_ID}/dagRuns/{DAG1_RUN1_ID}", json={"state": state})
assert response.status_code == 200
assert listener.state == listener_state


class TestDeleteDagRun:
def test_delete_dag_run(self, test_client):
Expand Down
47 changes: 45 additions & 2 deletions tests/listeners/class_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
from airflow.listeners import hookimpl
from airflow.utils.state import DagRunState, TaskInstanceState

from tests_common.test_utils.version_compat import AIRFLOW_V_2_10_PLUS
from tests_common.test_utils.version_compat import AIRFLOW_V_2_10_PLUS, AIRFLOW_V_3_0_PLUS

if AIRFLOW_V_2_10_PLUS:
if AIRFLOW_V_3_0_PLUS:

class ClassBasedListener:
def __init__(self):
Expand All @@ -41,6 +41,49 @@ def before_stopping(self, component):
stopped_component = component
self.state.append(DagRunState.SUCCESS)

@hookimpl
def on_task_instance_running(self, previous_state, task_instance):
self.state.append(TaskInstanceState.RUNNING)

@hookimpl
def on_task_instance_success(self, previous_state, task_instance):
self.state.append(TaskInstanceState.SUCCESS)

@hookimpl
def on_task_instance_failed(self, previous_state, task_instance, error: None | str | BaseException):
self.state.append(TaskInstanceState.FAILED)

@hookimpl
def on_dag_run_running(self, dag_run, msg: str):
self.state.append(DagRunState.RUNNING)

@hookimpl
def on_dag_run_success(self, dag_run, msg: str):
self.state.append(DagRunState.SUCCESS)

@hookimpl
def on_dag_run_failed(self, dag_run, msg: str):
self.state.append(DagRunState.FAILED)

elif AIRFLOW_V_2_10_PLUS:

class ClassBasedListener: # type: ignore[no-redef]
def __init__(self):
self.started_component = None
self.stopped_component = None
self.state = []

@hookimpl
def on_starting(self, component):
self.started_component = component
self.state.append(DagRunState.RUNNING)

@hookimpl
def before_stopping(self, component):
global stopped_component
stopped_component = component
self.state.append(DagRunState.SUCCESS)

@hookimpl
def on_task_instance_running(self, previous_state, task_instance, session):
self.state.append(TaskInstanceState.RUNNING)
Expand Down

0 comments on commit 6afde78

Please sign in to comment.