From 09174981d76a803f526667efd1539f40b18beac1 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Tue, 24 Dec 2024 17:07:42 +0530 Subject: [PATCH] AIP-72: Allow pushing and pulling XCom from Task Context (#45075) Part of https://github.com/apache/airflow/issues/44481 There is a lot of cleanup to do but I wanted to get a basic DAG that uses XCom working first. Example DAG used: `tutorial_dag` ```py from __future__ import annotations # [START tutorial] # [START import_module] import json import textwrap import pendulum # The DAG object; we'll need this to instantiate a DAG from airflow.models.dag import DAG # Operators; we need this to operate! from airflow.providers.standard.operators.python import PythonOperator # [END import_module] # [START instantiate_dag] with DAG( "tutorial_dag", # [START default_args] # These args will get passed on to each operator # You can override them on a per-task basis during operator initialization default_args={"retries": 2}, # [END default_args] description="DAG tutorial", schedule=None, start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), catchup=False, tags=["example"], ) as dag: # [END instantiate_dag] # [START documentation] dag.doc_md = __doc__ # [END documentation] # [START extract_function] def extract(**kwargs): ti = kwargs["ti"] data_string = '{"1001": 301.27, "1002": 433.21, "1003": 502.22}' ti.xcom_push("order_data", data_string) # [END extract_function] # [START transform_function] def transform(**kwargs): ti = kwargs["ti"] extract_data_string = ti.xcom_pull(task_ids="extract", key="order_data") order_data = json.loads(extract_data_string) total_order_value = 0 for value in order_data.values(): total_order_value += value total_value = {"total_order_value": total_order_value} total_value_json_string = json.dumps(total_value) ti.xcom_push("total_order_value", total_value_json_string) # [END transform_function] # [START load_function] def load(**kwargs): ti = kwargs["ti"] total_value_string = ti.xcom_pull(task_ids="transform", key="total_order_value") total_order_value = json.loads(total_value_string) print(total_order_value) # [END load_function] # [START main_flow] extract_task = PythonOperator( task_id="extract", python_callable=extract, ) extract_task.doc_md = textwrap.dedent( """\ #### Extract task A simple Extract task to get data ready for the rest of the data pipeline. In this case, getting data is simulated by reading from a hardcoded JSON string. This data is then put into xcom, so that it can be processed by the next task. """ ) transform_task = PythonOperator( task_id="transform", python_callable=transform, ) transform_task.doc_md = textwrap.dedent( """\ #### Transform task A simple Transform task which takes in the collection of order data from xcom and computes the total order value. This computed value is then put into xcom, so that it can be processed by the next task. """ ) load_task = PythonOperator( task_id="load", python_callable=load, ) load_task.doc_md = textwrap.dedent( """\ #### Load task A simple Load task which takes in the result of the Transform task, by reading it from xcom and instead of saving it to end user review, just prints it out. """ ) extract_task >> transform_task >> load_task ``` --- image image --- image --- .../api_fastapi/execution_api/routes/xcoms.py | 16 ++- airflow/utils/context.py | 2 +- task_sdk/src/airflow/sdk/api/client.py | 9 +- .../src/airflow/sdk/execution_time/comms.py | 2 +- .../airflow/sdk/execution_time/supervisor.py | 3 +- .../airflow/sdk/execution_time/task_runner.py | 114 +++++++++++++++++- task_sdk/tests/api/test_client.py | 16 ++- task_sdk/tests/execution_time/test_context.py | 2 +- .../tests/execution_time/test_supervisor.py | 2 +- .../execution_api/routes/test_xcoms.py | 83 +++++++++++-- tests/models/test_xcom.py | 22 +++- 11 files changed, 236 insertions(+), 35 deletions(-) diff --git a/airflow/api_fastapi/execution_api/routes/xcoms.py b/airflow/api_fastapi/execution_api/routes/xcoms.py index 93285eb3a74f4..faacd543fca2b 100644 --- a/airflow/api_fastapi/execution_api/routes/xcoms.py +++ b/airflow/api_fastapi/execution_api/routes/xcoms.py @@ -22,7 +22,6 @@ from typing import Annotated from fastapi import Body, HTTPException, Query, status -from pydantic import Json from airflow.api_fastapi.common.db.common import SessionDep from airflow.api_fastapi.common.router import AirflowRouter @@ -92,7 +91,7 @@ def get_xcom( ) try: - xcom_value = BaseXCom.deserialize_value(result) + xcom_value = BaseXCom.orm_deserialize_value(result) except json.JSONDecodeError: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, @@ -118,7 +117,7 @@ def set_xcom( task_id: str, key: str, value: Annotated[ - Json, + str, Body( description="A JSON-formatted string representing the value to set for the XCom.", openapi_examples={ @@ -142,6 +141,17 @@ def set_xcom( map_index: Annotated[int, Query()] = -1, ): """Set an Airflow XCom.""" + try: + json.loads(value) + except json.JSONDecodeError: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={ + "reason": "invalid_format", + "message": "XCom value is not a valid JSON-formatted string", + }, + ) + if not has_xcom_access(key, token): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, diff --git a/airflow/utils/context.py b/airflow/utils/context.py index c6cf2db498532..28bcd2fe6701d 100644 --- a/airflow/utils/context.py +++ b/airflow/utils/context.py @@ -477,7 +477,7 @@ def lazy_mapping_from_context(source: Context) -> Mapping[str, Any]: """ if not isinstance(source, Context): # Sometimes we are passed a plain dict (usually in tests, or in User's - # custom operators) -- be lienent about what we accept so we don't + # custom operators) -- be lenient about what we accept so we don't # break anything for users. return source diff --git a/task_sdk/src/airflow/sdk/api/client.py b/task_sdk/src/airflow/sdk/api/client.py index da91c2bd98dd2..7488ef3e88a81 100644 --- a/task_sdk/src/airflow/sdk/api/client.py +++ b/task_sdk/src/airflow/sdk/api/client.py @@ -207,11 +207,16 @@ class XComOperations: def __init__(self, client: Client): self.client = client - def get(self, dag_id: str, run_id: str, task_id: str, key: str, map_index: int = -1) -> XComResponse: + def get( + self, dag_id: str, run_id: str, task_id: str, key: str, map_index: int | None = None + ) -> XComResponse: """Get a XCom value from the API server.""" # TODO: check if we need to use map_index as params in the uri # ref: https://github.com/apache/airflow/blob/v2-10-stable/airflow/api_connexion/openapi/v1.yaml#L1785C1-L1785C81 - resp = self.client.get(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}", params={"map_index": map_index}) + params = {} + if map_index is not None: + params.update({"map_index": map_index}) + resp = self.client.get(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}", params=params) return XComResponse.model_validate_json(resp.read()) def set( diff --git a/task_sdk/src/airflow/sdk/execution_time/comms.py b/task_sdk/src/airflow/sdk/execution_time/comms.py index 97fadcafc409e..b90787ca4cfc9 100644 --- a/task_sdk/src/airflow/sdk/execution_time/comms.py +++ b/task_sdk/src/airflow/sdk/execution_time/comms.py @@ -167,7 +167,7 @@ class GetXCom(BaseModel): dag_id: str run_id: str task_id: str - map_index: int = -1 + map_index: int | None = None type: Literal["GetXCom"] = "GetXCom" diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py b/task_sdk/src/airflow/sdk/execution_time/supervisor.py index 73bc446a28df8..811d1ce86a60d 100644 --- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py @@ -62,7 +62,6 @@ from airflow.sdk.execution_time.comms import ( ConnectionResult, DeferTask, - ErrorResponse, GetConnection, GetVariable, GetXCom, @@ -719,7 +718,7 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger): if isinstance(conn, ConnectionResponse): conn_result = ConnectionResult.from_conn_response(conn) resp = conn_result.model_dump_json(exclude_unset=True).encode() - elif isinstance(conn, ErrorResponse): + else: resp = conn.model_dump_json().encode() elif isinstance(msg, GetVariable): var = self.client.variables.get(msg.key) diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py b/task_sdk/src/airflow/sdk/execution_time/task_runner.py index 48dd3ecbfcd67..d8d540318f175 100644 --- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py @@ -21,6 +21,7 @@ import os import sys +from collections.abc import Iterable from datetime import datetime, timezone from io import FileIO from typing import TYPE_CHECKING, Annotated, Any, Generic, TextIO, TypeVar @@ -33,12 +34,15 @@ from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.execution_time.comms import ( DeferTask, + GetXCom, RescheduleTask, SetRenderedFields, + SetXCom, StartupDetails, TaskState, ToSupervisor, ToTask, + XComResult, ) from airflow.sdk.execution_time.context import ConnectionAccessor @@ -111,11 +115,109 @@ def get_template_context(self): "ts_nodash_with_tz": ts_nodash_with_tz, } context.update(context_from_server) + # TODO: We should use/move TypeDict from airflow.utils.context.Context return context - def xcom_pull(self, *args, **kwargs): ... + def xcom_pull( + self, + task_ids: str | Iterable[str] | None = None, # TODO: Simplify to a single task_id? (breaking change) + dag_id: str | None = None, + key: str = "return_value", # TODO: Make this a constant (``XCOM_RETURN_KEY``) + include_prior_dates: bool = False, # TODO: Add support for this + *, + map_indexes: int | Iterable[int] | None = None, + default: Any = None, + run_id: str | None = None, + ) -> Any: + """ + Pull XComs that optionally meet certain criteria. + + :param key: A key for the XCom. If provided, only XComs with matching + keys will be returned. The default key is ``'return_value'``, also + available as constant ``XCOM_RETURN_KEY``. This key is automatically + given to XComs returned by tasks (as opposed to being pushed + manually). To remove the filter, pass *None*. + :param task_ids: Only XComs from tasks with matching ids will be + pulled. Pass *None* to remove the filter. + :param dag_id: If provided, only pulls XComs from this DAG. If *None* + (default), the DAG of the calling task is used. + :param map_indexes: If provided, only pull XComs with matching indexes. + If *None* (default), this is inferred from the task(s) being pulled + (see below for details). + :param include_prior_dates: If False, only XComs from the current + logical_date are returned. If *True*, XComs from previous dates + are returned as well. + :param run_id: If provided, only pulls XComs from a DagRun w/a matching run_id. + If *None* (default), the run_id of the calling task is used. + + When pulling one single task (``task_id`` is *None* or a str) without + specifying ``map_indexes``, the return value is inferred from whether + the specified task is mapped. If not, value from the one single task + instance is returned. If the task to pull is mapped, an iterator (not a + list) yielding XComs from mapped task instances is returned. In either + case, ``default`` (*None* if not specified) is returned if no matching + XComs are found. + + When pulling multiple tasks (i.e. either ``task_id`` or ``map_index`` is + a non-str iterable), a list of matching XComs is returned. Elements in + the list is ordered by item ordering in ``task_id`` and ``map_index``. + """ + if dag_id is None: + dag_id = self.dag_id + if run_id is None: + run_id = self.run_id + + if task_ids is None: + task_ids = self.task_id + elif not isinstance(task_ids, str) and isinstance(task_ids, Iterable): + # TODO: Handle multiple task_ids or remove support + raise NotImplementedError("Multiple task_ids are not supported yet") + + if map_indexes is None: + map_indexes = self.map_index + elif isinstance(map_indexes, Iterable): + # TODO: Handle multiple map_indexes or remove support + raise NotImplementedError("Multiple map_indexes are not supported yet") + + log = structlog.get_logger(logger_name="task") + SUPERVISOR_COMMS.send_request( + log=log, + msg=GetXCom( + key=key, + dag_id=dag_id, + task_id=task_ids, + run_id=run_id, + map_index=map_indexes, + ), + ) + + msg = SUPERVISOR_COMMS.get_message() + if TYPE_CHECKING: + assert isinstance(msg, XComResult) - def xcom_push(self, *args, **kwargs): ... + value = msg.value + if value is not None: + return value + return default + + def xcom_push(self, key: str, value: Any): + """ + Make an XCom available for tasks to pull. + + :param key: Key to store the value under. + :param value: Value to store. Only be JSON-serializable may be used otherwise. + """ + log = structlog.get_logger(logger_name="task") + SUPERVISOR_COMMS.send_request( + log=log, + msg=SetXCom( + key=key, + value=value, + dag_id=self.dag_id, + task_id=self.task_id, + run_id=self.run_id, + ), + ) def parse(what: StartupDetails) -> RuntimeTaskInstance: @@ -269,10 +371,16 @@ def run(ti: RuntimeTaskInstance, log: Logger): msg: ToSupervisor | None = None try: # TODO: pre execute etc. - # TODO next_method to support resuming from deferred # TODO: Get a real context object ti.task = ti.task.prepare_for_execution() context = ti.get_template_context() + # TODO: Get things from _execute_task_with_callbacks + # - Clearing XCom + # - Setting Current Context (set_current_context) + # - Render Templates + # - Update RTIF + # - Pre Execute + # etc ti.task.execute(context) # type: ignore[attr-defined] msg = TaskState(state=TerminalTIState.SUCCESS, end_date=datetime.now(tz=timezone.utc)) except TaskDeferred as defer: diff --git a/task_sdk/tests/api/test_client.py b/task_sdk/tests/api/test_client.py index 16b8d6c9bfe15..279502793ee23 100644 --- a/task_sdk/tests/api/test_client.py +++ b/task_sdk/tests/api/test_client.py @@ -313,13 +313,23 @@ class TestXCOMOperations: response parsing. """ - def test_xcom_get_success(self): + @pytest.mark.parametrize( + "value", + [ + pytest.param("value1", id="string-value"), + pytest.param({"key1": "value1"}, id="dict-value"), + pytest.param('{"key1": "value1"}', id="dict-str-value"), + pytest.param(["value1", "value2"], id="list-value"), + pytest.param({"key": "test_key", "value": {"key2": "value2"}}, id="nested-dict-value"), + ], + ) + def test_xcom_get_success(self, value): # Simulate a successful response from the server when getting an xcom def handle_request(request: httpx.Request) -> httpx.Response: if request.url.path == "/xcoms/dag_id/run_id/task_id/key": return httpx.Response( status_code=201, - json={"key": "test_key", "value": "test_value"}, + json={"key": "test_key", "value": value}, ) return httpx.Response(status_code=400, json={"detail": "Bad Request"}) @@ -332,7 +342,7 @@ def handle_request(request: httpx.Request) -> httpx.Response: ) assert isinstance(result, XComResponse) assert result.key == "test_key" - assert result.value == "test_value" + assert result.value == value def test_xcom_get_success_with_map_index(self): # Simulate a successful response from the server when getting an xcom with map_index passed diff --git a/task_sdk/tests/execution_time/test_context.py b/task_sdk/tests/execution_time/test_context.py index 65d2b50f8a17f..a3220c3bef1e3 100644 --- a/task_sdk/tests/execution_time/test_context.py +++ b/task_sdk/tests/execution_time/test_context.py @@ -67,7 +67,7 @@ def test_getattr_connection(self): ) as mock_supervisor_comms: mock_supervisor_comms.get_message.return_value = conn_result - # Fetch the connection; Triggers __getattr__ + # Fetch the connection; triggers __getattr__ conn = accessor.mysql_conn expected_conn = Connection(conn_id="mysql_conn", conn_type="mysql", host="mysql", port=3306) diff --git a/task_sdk/tests/execution_time/test_supervisor.py b/task_sdk/tests/execution_time/test_supervisor.py index 098e5914f2551..53da57cf178eb 100644 --- a/task_sdk/tests/execution_time/test_supervisor.py +++ b/task_sdk/tests/execution_time/test_supervisor.py @@ -817,7 +817,7 @@ def watched_subprocess(self, mocker): GetXCom(dag_id="test_dag", run_id="test_run", task_id="test_task", key="test_key"), b'{"key":"test_key","value":"test_value","type":"XComResult"}\n', "xcoms.get", - ("test_dag", "test_run", "test_task", "test_key", -1), + ("test_dag", "test_run", "test_task", "test_key", None), XComResult(key="test_key", value="test_value"), id="get_xcom", ), diff --git a/tests/api_fastapi/execution_api/routes/test_xcoms.py b/tests/api_fastapi/execution_api/routes/test_xcoms.py index d9d33f28d4425..6347db9b6db28 100644 --- a/tests/api_fastapi/execution_api/routes/test_xcoms.py +++ b/tests/api_fastapi/execution_api/routes/test_xcoms.py @@ -21,6 +21,7 @@ import pytest +from airflow.api_fastapi.execution_api.datamodels.xcom import XComResponse from airflow.models.dagrun import DagRun from airflow.models.xcom import XCom from airflow.utils.session import create_session @@ -40,19 +41,21 @@ class TestXComsGetEndpoint: @pytest.mark.parametrize( ("value", "expected_value"), [ - ("value1", "value1"), - ({"key2": "value2"}, {"key2": "value2"}), - ({"key2": "value2", "key3": ["value3"]}, {"key2": "value2", "key3": ["value3"]}), - (["value1"], ["value1"]), + ('"value1"', '"value1"'), + ('{"key2": "value2"}', '{"key2": "value2"}'), + ('{"key2": "value2", "key3": ["value3"]}', '{"key2": "value2", "key3": ["value3"]}'), + ('["value1"]', '["value1"]'), ], ) def test_xcom_get_from_db(self, client, create_task_instance, session, value, expected_value): """Test that XCom value is returned from the database in JSON-compatible format.""" ti = create_task_instance() ti.xcom_push(key="xcom_1", value=value, session=session) - session.commit() + xcom = session.query(XCom).filter_by(task_id=ti.task_id, dag_id=ti.dag_id, key="xcom_1").first() + assert xcom.value == expected_value + response = client.get(f"/execution/xcoms/{ti.dag_id}/{ti.run_id}/{ti.task_id}/xcom_1") assert response.status_code == 200 @@ -86,19 +89,17 @@ class TestXComsSetEndpoint: @pytest.mark.parametrize( ("value", "expected_value"), [ - ('"value1"', "value1"), - ('{"key2": "value2"}', {"key2": "value2"}), - ('{"key2": "value2", "key3": ["value3"]}', {"key2": "value2", "key3": ["value3"]}), - ('["value1"]', ["value1"]), + ('"value1"', '"value1"'), + ('{"key2": "value2"}', '{"key2": "value2"}'), + ('{"key2": "value2", "key3": ["value3"]}', '{"key2": "value2", "key3": ["value3"]}'), + ('["value1"]', '["value1"]'), ], ) def test_xcom_set(self, client, create_task_instance, session, value, expected_value): """ Test that XCom value is set correctly. The value is passed as a JSON string in the request body. - This is then validated via Pydantic.Json type in the request body and converted to - a Python object before being sent to XCom.set. XCom.set then uses json.dumps to - serialize it and store the value in the database. This is done so that Task SDK in multiple - languages can use the same API to set XCom values. + XCom.set then uses json.dumps to serialize it and store the value in the database. + This is done so that Task SDK in multiple languages can use the same API to set XCom values. """ ti = create_task_instance() session.commit() @@ -114,6 +115,24 @@ def test_xcom_set(self, client, create_task_instance, session, value, expected_v xcom = session.query(XCom).filter_by(task_id=ti.task_id, dag_id=ti.dag_id, key="xcom_1").first() assert xcom.value == expected_value + @pytest.mark.parametrize( + "value", + ["value1", {"key2": "value2"}, ["value1"]], + ) + def test_xcom_set_invalid_json(self, client, create_task_instance, value): + response = client.post( + "/execution/xcoms/dag/runid/task/xcom_1", + json="invalid_json", + ) + + assert response.status_code == 400 + assert response.json() == { + "detail": { + "reason": "invalid_format", + "message": "XCom value is not a valid JSON-formatted string", + } + } + def test_xcom_access_denied(self, client): with mock.patch("airflow.api_fastapi.execution_api.routes.xcoms.has_xcom_access", return_value=False): response = client.post( @@ -128,3 +147,41 @@ def test_xcom_access_denied(self, client): "message": "Task does not have access to set XCom key 'xcom_perms'", } } + + @pytest.mark.parametrize( + ("value", "expected_value"), + [ + ('"value1"', '"value1"'), + ('{"key2": "value2"}', '{"key2": "value2"}'), + ('{"key2": "value2", "key3": ["value3"]}', '{"key2": "value2", "key3": ["value3"]}'), + ('["value1"]', '["value1"]'), + ], + ) + def test_xcom_roundtrip(self, client, create_task_instance, session, value, expected_value): + """ + Test that XCom value is set and retrieved correctly using API. + + This test sets an XCom value using the API and then retrieves it using the API so we can + ensure client and server are working correctly together. The server expects a JSON string + and it will also return a JSON string. It is the client's responsibility to parse the JSON + string into a native object. This is useful for Task SDKs in other languages. + """ + ti = create_task_instance() + + session.commit() + client.post( + f"/execution/xcoms/{ti.dag_id}/{ti.run_id}/{ti.task_id}/test_xcom_roundtrip", + json=value, + ) + + xcom = ( + session.query(XCom) + .filter_by(task_id=ti.task_id, dag_id=ti.dag_id, key="test_xcom_roundtrip") + .first() + ) + assert xcom.value == expected_value + + response = client.get(f"/execution/xcoms/{ti.dag_id}/{ti.run_id}/{ti.task_id}/test_xcom_roundtrip") + + assert response.status_code == 200 + assert XComResponse.model_validate_json(response.read()).value == expected_value diff --git a/tests/models/test_xcom.py b/tests/models/test_xcom.py index 85e3249fe6cbb..1e62b414de2da 100644 --- a/tests/models/test_xcom.py +++ b/tests/models/test_xcom.py @@ -324,18 +324,30 @@ def test_xcom_get_many_from_prior_dates(self, session, tis_for_xcom_get_many_fro class TestXComSet: - def test_xcom_set(self, session, task_instance): + @pytest.mark.parametrize( + ("key", "value", "expected_value"), + [ + pytest.param("xcom_dict", {"key": "value"}, {"key": "value"}, id="dict"), + pytest.param("xcom_int", 123, 123, id="int"), + pytest.param("xcom_float", 45.67, 45.67, id="float"), + pytest.param("xcom_str", "hello", "hello", id="str"), + pytest.param("xcom_bool", True, True, id="bool"), + pytest.param("xcom_list", [1, 2, 3], [1, 2, 3], id="list"), + ], + ) + def test_xcom_set(self, session, task_instance, key, value, expected_value): XCom.set( - key="xcom_1", - value={"key": "value"}, + key=key, + value=value, dag_id=task_instance.dag_id, task_id=task_instance.task_id, run_id=task_instance.run_id, session=session, ) stored_xcoms = session.query(XCom).all() - assert stored_xcoms[0].key == "xcom_1" - assert stored_xcoms[0].value == {"key": "value"} + assert stored_xcoms[0].key == key + assert isinstance(stored_xcoms[0].value, type(expected_value)) + assert stored_xcoms[0].value == expected_value assert stored_xcoms[0].dag_id == "dag" assert stored_xcoms[0].task_id == "task_1" assert stored_xcoms[0].logical_date == task_instance.logical_date