Skip to content

Commit

Permalink
AIP-72: Allow pushing and pulling XCom from Task Context (apache#45075)
Browse files Browse the repository at this point in the history
Part of apache#44481

There is a lot of cleanup to do but I wanted to get a basic DAG that uses XCom working first.

Example DAG used: `tutorial_dag`

```py

from __future__ import annotations

# [START tutorial]
# [START import_module]
import json
import textwrap

import pendulum

# The DAG object; we'll need this to instantiate a DAG
from airflow.models.dag import DAG

# Operators; we need this to operate!
from airflow.providers.standard.operators.python import PythonOperator

# [END import_module]

# [START instantiate_dag]
with DAG(
    "tutorial_dag",
    # [START default_args]
    # These args will get passed on to each operator
    # You can override them on a per-task basis during operator initialization
    default_args={"retries": 2},
    # [END default_args]
    description="DAG tutorial",
    schedule=None,
    start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
    catchup=False,
    tags=["example"],
) as dag:
    # [END instantiate_dag]
    # [START documentation]
    dag.doc_md = __doc__
    # [END documentation]

    # [START extract_function]
    def extract(**kwargs):
        ti = kwargs["ti"]
        data_string = '{"1001": 301.27, "1002": 433.21, "1003": 502.22}'
        ti.xcom_push("order_data", data_string)

    # [END extract_function]

    # [START transform_function]
    def transform(**kwargs):
        ti = kwargs["ti"]
        extract_data_string = ti.xcom_pull(task_ids="extract", key="order_data")
        order_data = json.loads(extract_data_string)

        total_order_value = 0
        for value in order_data.values():
            total_order_value += value

        total_value = {"total_order_value": total_order_value}
        total_value_json_string = json.dumps(total_value)
        ti.xcom_push("total_order_value", total_value_json_string)

    # [END transform_function]

    # [START load_function]
    def load(**kwargs):
        ti = kwargs["ti"]
        total_value_string = ti.xcom_pull(task_ids="transform", key="total_order_value")
        total_order_value = json.loads(total_value_string)

        print(total_order_value)

    # [END load_function]

    # [START main_flow]
    extract_task = PythonOperator(
        task_id="extract",
        python_callable=extract,
    )
    extract_task.doc_md = textwrap.dedent(
        """\
    #### Extract task
    A simple Extract task to get data ready for the rest of the data pipeline.
    In this case, getting data is simulated by reading from a hardcoded JSON string.
    This data is then put into xcom, so that it can be processed by the next task.
    """
    )

    transform_task = PythonOperator(
        task_id="transform",
        python_callable=transform,
    )
    transform_task.doc_md = textwrap.dedent(
        """\
    #### Transform task
    A simple Transform task which takes in the collection of order data from xcom
    and computes the total order value.
    This computed value is then put into xcom, so that it can be processed by the next task.
    """
    )

    load_task = PythonOperator(
        task_id="load",
        python_callable=load,
    )
    load_task.doc_md = textwrap.dedent(
        """\
    #### Load task
    A simple Load task which takes in the result of the Transform task, by reading it
    from xcom and instead of saving it to end user review, just prints it out.
    """
    )

    extract_task >> transform_task >> load_task
```

---
<img width="1703" alt="image" src="https://github.com/user-attachments/assets/10025ef4-0410-4c2a-9bb6-1e68f51a8805" />

<img width="1710" alt="image" src="https://github.com/user-attachments/assets/201b61c0-3998-4b06-b0d4-2145120321f8" />

---
<img width="1721" alt="image" src="https://github.com/user-attachments/assets/dd9c50e3-20c5-4762-99f9-c02a8c16732e" />
  • Loading branch information
kaxil authored Dec 24, 2024
1 parent 56a75c4 commit 0917498
Show file tree
Hide file tree
Showing 11 changed files with 236 additions and 35 deletions.
16 changes: 13 additions & 3 deletions airflow/api_fastapi/execution_api/routes/xcoms.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from typing import Annotated

from fastapi import Body, HTTPException, Query, status
from pydantic import Json

from airflow.api_fastapi.common.db.common import SessionDep
from airflow.api_fastapi.common.router import AirflowRouter
Expand Down Expand Up @@ -92,7 +91,7 @@ def get_xcom(
)

try:
xcom_value = BaseXCom.deserialize_value(result)
xcom_value = BaseXCom.orm_deserialize_value(result)
except json.JSONDecodeError:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
Expand All @@ -118,7 +117,7 @@ def set_xcom(
task_id: str,
key: str,
value: Annotated[
Json,
str,
Body(
description="A JSON-formatted string representing the value to set for the XCom.",
openapi_examples={
Expand All @@ -142,6 +141,17 @@ def set_xcom(
map_index: Annotated[int, Query()] = -1,
):
"""Set an Airflow XCom."""
try:
json.loads(value)
except json.JSONDecodeError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"reason": "invalid_format",
"message": "XCom value is not a valid JSON-formatted string",
},
)

if not has_xcom_access(key, token):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
Expand Down
2 changes: 1 addition & 1 deletion airflow/utils/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ def lazy_mapping_from_context(source: Context) -> Mapping[str, Any]:
"""
if not isinstance(source, Context):
# Sometimes we are passed a plain dict (usually in tests, or in User's
# custom operators) -- be lienent about what we accept so we don't
# custom operators) -- be lenient about what we accept so we don't
# break anything for users.
return source

Expand Down
9 changes: 7 additions & 2 deletions task_sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,11 +207,16 @@ class XComOperations:
def __init__(self, client: Client):
self.client = client

def get(self, dag_id: str, run_id: str, task_id: str, key: str, map_index: int = -1) -> XComResponse:
def get(
self, dag_id: str, run_id: str, task_id: str, key: str, map_index: int | None = None
) -> XComResponse:
"""Get a XCom value from the API server."""
# TODO: check if we need to use map_index as params in the uri
# ref: https://github.com/apache/airflow/blob/v2-10-stable/airflow/api_connexion/openapi/v1.yaml#L1785C1-L1785C81
resp = self.client.get(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}", params={"map_index": map_index})
params = {}
if map_index is not None:
params.update({"map_index": map_index})
resp = self.client.get(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}", params=params)
return XComResponse.model_validate_json(resp.read())

def set(
Expand Down
2 changes: 1 addition & 1 deletion task_sdk/src/airflow/sdk/execution_time/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ class GetXCom(BaseModel):
dag_id: str
run_id: str
task_id: str
map_index: int = -1
map_index: int | None = None
type: Literal["GetXCom"] = "GetXCom"


Expand Down
3 changes: 1 addition & 2 deletions task_sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@
from airflow.sdk.execution_time.comms import (
ConnectionResult,
DeferTask,
ErrorResponse,
GetConnection,
GetVariable,
GetXCom,
Expand Down Expand Up @@ -719,7 +718,7 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger):
if isinstance(conn, ConnectionResponse):
conn_result = ConnectionResult.from_conn_response(conn)
resp = conn_result.model_dump_json(exclude_unset=True).encode()
elif isinstance(conn, ErrorResponse):
else:
resp = conn.model_dump_json().encode()
elif isinstance(msg, GetVariable):
var = self.client.variables.get(msg.key)
Expand Down
114 changes: 111 additions & 3 deletions task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import os
import sys
from collections.abc import Iterable
from datetime import datetime, timezone
from io import FileIO
from typing import TYPE_CHECKING, Annotated, Any, Generic, TextIO, TypeVar
Expand All @@ -33,12 +34,15 @@
from airflow.sdk.definitions.baseoperator import BaseOperator
from airflow.sdk.execution_time.comms import (
DeferTask,
GetXCom,
RescheduleTask,
SetRenderedFields,
SetXCom,
StartupDetails,
TaskState,
ToSupervisor,
ToTask,
XComResult,
)
from airflow.sdk.execution_time.context import ConnectionAccessor

Expand Down Expand Up @@ -111,11 +115,109 @@ def get_template_context(self):
"ts_nodash_with_tz": ts_nodash_with_tz,
}
context.update(context_from_server)
# TODO: We should use/move TypeDict from airflow.utils.context.Context
return context

def xcom_pull(self, *args, **kwargs): ...
def xcom_pull(
self,
task_ids: str | Iterable[str] | None = None, # TODO: Simplify to a single task_id? (breaking change)
dag_id: str | None = None,
key: str = "return_value", # TODO: Make this a constant (``XCOM_RETURN_KEY``)
include_prior_dates: bool = False, # TODO: Add support for this
*,
map_indexes: int | Iterable[int] | None = None,
default: Any = None,
run_id: str | None = None,
) -> Any:
"""
Pull XComs that optionally meet certain criteria.
:param key: A key for the XCom. If provided, only XComs with matching
keys will be returned. The default key is ``'return_value'``, also
available as constant ``XCOM_RETURN_KEY``. This key is automatically
given to XComs returned by tasks (as opposed to being pushed
manually). To remove the filter, pass *None*.
:param task_ids: Only XComs from tasks with matching ids will be
pulled. Pass *None* to remove the filter.
:param dag_id: If provided, only pulls XComs from this DAG. If *None*
(default), the DAG of the calling task is used.
:param map_indexes: If provided, only pull XComs with matching indexes.
If *None* (default), this is inferred from the task(s) being pulled
(see below for details).
:param include_prior_dates: If False, only XComs from the current
logical_date are returned. If *True*, XComs from previous dates
are returned as well.
:param run_id: If provided, only pulls XComs from a DagRun w/a matching run_id.
If *None* (default), the run_id of the calling task is used.
When pulling one single task (``task_id`` is *None* or a str) without
specifying ``map_indexes``, the return value is inferred from whether
the specified task is mapped. If not, value from the one single task
instance is returned. If the task to pull is mapped, an iterator (not a
list) yielding XComs from mapped task instances is returned. In either
case, ``default`` (*None* if not specified) is returned if no matching
XComs are found.
When pulling multiple tasks (i.e. either ``task_id`` or ``map_index`` is
a non-str iterable), a list of matching XComs is returned. Elements in
the list is ordered by item ordering in ``task_id`` and ``map_index``.
"""
if dag_id is None:
dag_id = self.dag_id
if run_id is None:
run_id = self.run_id

if task_ids is None:
task_ids = self.task_id
elif not isinstance(task_ids, str) and isinstance(task_ids, Iterable):
# TODO: Handle multiple task_ids or remove support
raise NotImplementedError("Multiple task_ids are not supported yet")

if map_indexes is None:
map_indexes = self.map_index
elif isinstance(map_indexes, Iterable):
# TODO: Handle multiple map_indexes or remove support
raise NotImplementedError("Multiple map_indexes are not supported yet")

log = structlog.get_logger(logger_name="task")
SUPERVISOR_COMMS.send_request(
log=log,
msg=GetXCom(
key=key,
dag_id=dag_id,
task_id=task_ids,
run_id=run_id,
map_index=map_indexes,
),
)

msg = SUPERVISOR_COMMS.get_message()
if TYPE_CHECKING:
assert isinstance(msg, XComResult)

def xcom_push(self, *args, **kwargs): ...
value = msg.value
if value is not None:
return value
return default

def xcom_push(self, key: str, value: Any):
"""
Make an XCom available for tasks to pull.
:param key: Key to store the value under.
:param value: Value to store. Only be JSON-serializable may be used otherwise.
"""
log = structlog.get_logger(logger_name="task")
SUPERVISOR_COMMS.send_request(
log=log,
msg=SetXCom(
key=key,
value=value,
dag_id=self.dag_id,
task_id=self.task_id,
run_id=self.run_id,
),
)


def parse(what: StartupDetails) -> RuntimeTaskInstance:
Expand Down Expand Up @@ -269,10 +371,16 @@ def run(ti: RuntimeTaskInstance, log: Logger):
msg: ToSupervisor | None = None
try:
# TODO: pre execute etc.
# TODO next_method to support resuming from deferred
# TODO: Get a real context object
ti.task = ti.task.prepare_for_execution()
context = ti.get_template_context()
# TODO: Get things from _execute_task_with_callbacks
# - Clearing XCom
# - Setting Current Context (set_current_context)
# - Render Templates
# - Update RTIF
# - Pre Execute
# etc
ti.task.execute(context) # type: ignore[attr-defined]
msg = TaskState(state=TerminalTIState.SUCCESS, end_date=datetime.now(tz=timezone.utc))
except TaskDeferred as defer:
Expand Down
16 changes: 13 additions & 3 deletions task_sdk/tests/api/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,13 +313,23 @@ class TestXCOMOperations:
response parsing.
"""

def test_xcom_get_success(self):
@pytest.mark.parametrize(
"value",
[
pytest.param("value1", id="string-value"),
pytest.param({"key1": "value1"}, id="dict-value"),
pytest.param('{"key1": "value1"}', id="dict-str-value"),
pytest.param(["value1", "value2"], id="list-value"),
pytest.param({"key": "test_key", "value": {"key2": "value2"}}, id="nested-dict-value"),
],
)
def test_xcom_get_success(self, value):
# Simulate a successful response from the server when getting an xcom
def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path == "/xcoms/dag_id/run_id/task_id/key":
return httpx.Response(
status_code=201,
json={"key": "test_key", "value": "test_value"},
json={"key": "test_key", "value": value},
)
return httpx.Response(status_code=400, json={"detail": "Bad Request"})

Expand All @@ -332,7 +342,7 @@ def handle_request(request: httpx.Request) -> httpx.Response:
)
assert isinstance(result, XComResponse)
assert result.key == "test_key"
assert result.value == "test_value"
assert result.value == value

def test_xcom_get_success_with_map_index(self):
# Simulate a successful response from the server when getting an xcom with map_index passed
Expand Down
2 changes: 1 addition & 1 deletion task_sdk/tests/execution_time/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def test_getattr_connection(self):
) as mock_supervisor_comms:
mock_supervisor_comms.get_message.return_value = conn_result

# Fetch the connection; Triggers __getattr__
# Fetch the connection; triggers __getattr__
conn = accessor.mysql_conn

expected_conn = Connection(conn_id="mysql_conn", conn_type="mysql", host="mysql", port=3306)
Expand Down
2 changes: 1 addition & 1 deletion task_sdk/tests/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,7 +817,7 @@ def watched_subprocess(self, mocker):
GetXCom(dag_id="test_dag", run_id="test_run", task_id="test_task", key="test_key"),
b'{"key":"test_key","value":"test_value","type":"XComResult"}\n',
"xcoms.get",
("test_dag", "test_run", "test_task", "test_key", -1),
("test_dag", "test_run", "test_task", "test_key", None),
XComResult(key="test_key", value="test_value"),
id="get_xcom",
),
Expand Down
Loading

0 comments on commit 0917498

Please sign in to comment.