Skip to content

Commit

Permalink
fix: fix wrong xcom push key (#9)
Browse files Browse the repository at this point in the history
  • Loading branch information
zongsizhang authored Sep 5, 2024
1 parent 638aa71 commit 5e3a8e6
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 4 deletions.
5 changes: 3 additions & 2 deletions airflow_providers_wherobots/operators/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
Define the Operators for triggering and monitoring the execution of Wherobots Run
"""

from enum import auto, Enum
from enum import auto
from time import sleep
from typing import Optional, Sequence, Any

from airflow.models import BaseOperator
from strenum import StrEnum
from wherobots.db import Runtime

from airflow_providers_wherobots.hooks.base import DEFAULT_CONN_ID
Expand All @@ -21,7 +22,7 @@
)


class XComKey(str, Enum):
class XComKey(StrEnum):
run_id = auto()


Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "airflow-providers-wherobots"
version = "0.1.8"
version = "0.1.9"
description = "Airflow extension for communicating with Wherobots Cloud"
authors = ["zongsi.zhang <[email protected]>"]
readme = "README.md"
Expand Down
5 changes: 4 additions & 1 deletion tests/unit_tests/operator/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def test_execute_handle_states(
test_item: Tuple[list[Run], TaskInstanceState],
):
get_run_results, task_state = test_item
mocker.patch(
mocked_create_run = mocker.patch(
"airflow_providers_wherobots.hooks.rest_api.WherobotsRestAPIHook.create_run",
return_value=run_factory.build(status=RunStatus.PENDING),
)
Expand All @@ -171,6 +171,9 @@ def test_execute_handle_states(
except Exception as e:
assert isinstance(e, RuntimeError)
assert ti.state == task_state
# test xcom push
if task_state == TaskInstanceState.SUCCESS:
assert ti.xcom_pull(key="run_id") == mocked_create_run.return_value.ext_id

def test_on_kill(
self,
Expand Down

0 comments on commit 5e3a8e6

Please sign in to comment.