Skip to content

Commit

Permalink
Adopt new fields in API, change constructor arguments format (#10)
Browse files Browse the repository at this point in the history
* feat: add new fields for run operator, change payloads to dict

* fix: make runtime field string
  • Loading branch information
zongsizhang authored Sep 9, 2024
1 parent 5e3a8e6 commit aee5671
Show file tree
Hide file tree
Showing 8 changed files with 52 additions and 168 deletions.
5 changes: 2 additions & 3 deletions airflow_providers_wherobots/hooks/rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from airflow_providers_wherobots.hooks.base import DEFAULT_CONN_ID
from airflow_providers_wherobots.wherobots.models import (
Run,
CreateRunPayload,
LogsResponse,
)

Expand Down Expand Up @@ -81,11 +80,11 @@ def get_run(self, run_id: str) -> Run:
resp_json = self._api_call("GET", f"/runs/{run_id}").json()
return Run.model_validate(resp_json)

def create_run(self, payload: CreateRunPayload) -> Run:
def create_run(self, payload: dict[str, Any]) -> Run:
resp_json = self._api_call(
"POST",
"/runs",
payload=payload.model_dump(mode="json"),
payload=payload,
).json()
return Run.model_validate(resp_json)

Expand Down
27 changes: 14 additions & 13 deletions airflow_providers_wherobots/operators/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,10 @@

from airflow.models import BaseOperator
from strenum import StrEnum
from wherobots.db import Runtime

from airflow_providers_wherobots.hooks.base import DEFAULT_CONN_ID
from airflow_providers_wherobots.hooks.rest_api import WherobotsRestAPIHook
from airflow_providers_wherobots.wherobots.models import (
PythonRunPayload,
JavaRunPayload,
CreateRunPayload,
RUN_NAME_ALPHABET,
RunStatus,
Run,
Expand All @@ -36,9 +32,10 @@ class WherobotsRunOperator(BaseOperator):
def __init__(
self,
name: Optional[str] = None,
runtime: Optional[Runtime] = Runtime.SEDONA,
python: Optional[PythonRunPayload] = None,
java: Optional[JavaRunPayload] = None,
runtime: str = "TINY",
run_python: Optional[dict[str, Any]] = None,
run_jar: Optional[dict[str, Any]] = None,
environment: Optional[dict[str, Any]] = None,
polling_interval: int = 20,
wherobots_conn_id: str = DEFAULT_CONN_ID,
poll_logs: bool = False,
Expand All @@ -47,12 +44,16 @@ def __init__(
):
super().__init__(**kwargs)
# If the user specifies the name, we will use it and rely on the server to validate the name
self.run_payload = CreateRunPayload(
runtime=runtime,
name=name or self.default_run_name,
python=python,
java=java,
)
self.run_payload: dict[str, Any] = {
"runtime": runtime,
"name": name or self.default_run_name,
}
if run_python:
self.run_payload["runPython"] = run_python
if run_jar:
self.run_payload["runJar"] = run_jar
if environment:
self.run_payload["environment"] = environment
self._polling_interval = polling_interval
self.wherobots_conn_id = wherobots_conn_id
self.xcom_push = xcom_push
Expand Down
80 changes: 2 additions & 78 deletions airflow_providers_wherobots/wherobots/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
import string
from datetime import datetime
from enum import auto
from typing import Optional, Sequence, List
from typing import Optional, List

from pydantic import BaseModel, Field, ConfigDict, computed_field
from pydantic import BaseModel, Field, ConfigDict
from strenum import StrEnum
from wherobots.db import Runtime

RUN_NAME_ALPHABET = string.ascii_letters + string.digits + "-_."

Expand Down Expand Up @@ -39,81 +38,6 @@ class Run(WherobotsModel):
end_time: Optional[datetime] = Field(default=None, alias="completeTime")


class PythonRunPayload(BaseModel):
"""
Model for the payload of Run with type == "python"
"""

# For airflow to render the template fields
template_fields: Sequence[str] = Field(
("uri", "args", "entrypoint"), exclude=True, init=False
)

uri: str
args: list[str] = []
entrypoint: Optional[str] = None

@classmethod
def create(cls, uri: str, args: list[str], entrypoint: Optional[str] = None):
return cls(uri=uri, args=args, entrypoint=entrypoint)


class JavaRunPayload(BaseModel):
"""
Model for the payload of Run with type == "python"
"""

# For airflow to render the template fields
template_fields: Sequence[str] = Field(("uri", "args", "main_class"), exclude=True)

uri: str
args: list[str] = []
main_class: Optional[str] = Field(None, alias="mainClass")

@classmethod
def create(cls, uri: str, args: list[str], main_class: Optional[str] = None):
return cls(uri=uri, args=args, mainClass=main_class)


class RunType(StrEnum):
python = auto()
java = auto()


class CreateRunPayload(BaseModel):
# For airflow to render the template fields
template_fields: Sequence[str] = Field(("name", "python", "java"), exclude=True)

runtime: Runtime
name: Optional[str] = None
python: Optional[PythonRunPayload] = None
java: Optional[JavaRunPayload] = None
timeout_seconds: int = Field(3600, alias="timeoutSeconds")

@computed_field
def type(self) -> RunType:
run_type = RunType.python if self.python else RunType.java
assert isinstance(run_type, RunType)
return run_type

@classmethod
def create(
cls,
runtime: Runtime,
name: str,
python: Optional[PythonRunPayload] = None,
java: Optional[JavaRunPayload] = None,
timeout_seconds: int = 3600,
):
return cls(
runtime=runtime,
name=name,
python=python,
java=java,
timeoutSeconds=timeout_seconds,
)


class LogItem(BaseModel):
timestamp: int
raw: str
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "airflow-providers-wherobots"
version = "0.1.9"
version = "0.1.10"
description = "Airflow extension for communicating with Wherobots Cloud"
authors = ["zongsi.zhang <[email protected]>"]
readme = "README.md"
Expand Down
9 changes: 3 additions & 6 deletions tests/integration_tests/operator/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@
from airflow.models import Connection

from airflow_providers_wherobots.operators.run import WherobotsRunOperator
from airflow_providers_wherobots.wherobots.models import (
PythonRunPayload,
)
from tests.unit_tests.operator.test_run import execute_dag

DEFAULT_START = pendulum.datetime(2021, 9, 13, tz="UTC")
Expand All @@ -28,9 +25,9 @@ def test_staging_run_success(staging_conn: Connection, dag: DAG) -> None:
wherobots_conn_id=staging_conn.conn_id,
task_id="test_run_smoke",
name="airflow_operator_test_run_{{ ts_nodash }}",
python=PythonRunPayload(
uri="s3://wbts-wbc-rcv7vl73oy/hao9o6y8ci/data/customer-z4asgjn7clrcbz/very_simple_job.py"
),
run_python={
"uri": "s3://wbts-wbc-rcv7vl73oy/hao9o6y8ci/data/customer-z4asgjn7clrcbz/very_simple_job.py"
},
dag=dag,
)
execute_dag(dag, task_id=operator.task_id)
24 changes: 10 additions & 14 deletions tests/unit_tests/hooks/test_rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
)
from airflow_providers_wherobots.wherobots.models import (
Run,
CreateRunPayload,
PythonRunPayload,
LogsResponse,
)
from tests.unit_tests import helpers
Expand Down Expand Up @@ -108,22 +106,20 @@ def test_create_run(self, test_default_conn) -> None:
"""
test_run: Run = helpers.run_factory.build()
url = f"https://{test_default_conn.host}/runs"
create_payload = CreateRunPayload.create(
name=test_run.name,
runtime=Runtime.SEDONA,
python=PythonRunPayload(
uri="s3://bucket/test.py",
args=["arg1", "arg2"],
entrypoint="src.main",
),
)
create_payload = {
"name": test_run.name,
"runtime": Runtime.SEDONA.value,
"python": {
"uri": "s3://bucket/test.py",
"args": ["arg1", "arg2"],
"entrypoint": "src.main",
},
}
responses.add(
responses.POST,
url,
json=test_run.model_dump(mode="json"),
match=[
matchers.json_params_matcher(create_payload.model_dump(mode="json"))
],
match=[matchers.json_params_matcher(create_payload)],
status=HTTPStatus.OK,
)
with WherobotsRestAPIHook() as hook:
Expand Down
42 changes: 20 additions & 22 deletions tests/unit_tests/operator/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@

from airflow_providers_wherobots.operators.run import WherobotsRunOperator
from airflow_providers_wherobots.wherobots.models import (
PythonRunPayload,
RunStatus,
CreateRunPayload,
LogsResponse,
Run,
LogItem,
Expand Down Expand Up @@ -74,22 +72,22 @@ def test_render_template(self, mocker: MockerFixture, dag: DAG):
operator = WherobotsRunOperator(
task_id="test_render_template_python",
name="test_run_{{ ds }}",
python=PythonRunPayload(
uri="s3://bucket/test-{{ ds }}.py",
args=["{{ ds }}"],
entrypoint="src.main_{{ ds }}",
),
run_python={
"uri": "s3://bucket/test-{{ ds }}.py",
"args": ["{{ ds }}"],
},
dag=dag,
)
execute_dag(dag, task_id=operator.task_id)
assert create_run.call_count == 1
rendered_payload = create_run.call_args.args[0]
assert isinstance(rendered_payload, CreateRunPayload)
assert isinstance(rendered_payload, dict)
expected_ds = data_interval_start.format("YYYY-MM-DD")
assert rendered_payload.name == f"test_run_{expected_ds}"
assert rendered_payload.python.uri == f"s3://bucket/test-{expected_ds}.py"
assert rendered_payload.python.args == [expected_ds]
assert rendered_payload.python.entrypoint == f"src.main_{expected_ds}"
assert rendered_payload["name"] == f"test_run_{expected_ds}"
assert (
rendered_payload["runPython"]["uri"] == f"s3://bucket/test-{expected_ds}.py"
)
assert rendered_payload["runPython"]["args"] == [expected_ds]

@pytest.mark.usefixtures("clean_airflow_db")
def test_default_name(self, mocker: MockerFixture, dag: DAG):
Expand All @@ -100,13 +98,13 @@ def test_default_name(self, mocker: MockerFixture, dag: DAG):
)
operator = WherobotsRunOperator(
task_id="test_default_name",
python=PythonRunPayload(uri=""),
run_python={"uri": ""},
dag=dag,
)
execute_dag(dag, task_id=operator.task_id)
rendered_payload = create_run.call_args.args[0]
assert isinstance(rendered_payload, CreateRunPayload)
assert rendered_payload.name == operator.default_run_name.replace(
assert isinstance(rendered_payload, dict)
assert rendered_payload["name"] == operator.default_run_name.replace(
"{{ ts_nodash }}", data_interval_start.strftime("%Y%m%dT%H%M%S")
)

Expand Down Expand Up @@ -160,7 +158,7 @@ def test_execute_handle_states(
)
operator = WherobotsRunOperator(
task_id=f"test_execute_{uuid.uuid4()}",
python=PythonRunPayload(uri=""),
run_python={"uri": ""},
dag=dag,
polling_interval=0,
poll_logs=poll_logs,
Expand All @@ -186,11 +184,11 @@ def test_on_kill(
operator = WherobotsRunOperator(
task_id="test_render_template_python",
name="test_run_{{ ds }}",
python=PythonRunPayload(
uri="s3://bucket/test-{{ ds }}.py",
args=["{{ ds }}"],
entrypoint="src.main_{{ ds }}",
),
run_python={
"uri": "s3://bucket/test-{{ ds }}.py",
"args": ["{{ ds }}"],
"entrypoint": "src.main_{{ ds }}",
},
dag=dag,
)
operator.on_kill()
Expand All @@ -209,7 +207,7 @@ def test_poll_and_display_logs(self, mocker: MockerFixture):
)
operator = WherobotsRunOperator(
task_id="test_poll_and_display_logs",
python=PythonRunPayload(uri=""),
run_python={"uri": ""},
dag=DAG("test_poll_and_display_logs"),
)
assert operator.poll_and_display_logs(hook, test_run, 0) == 2
Expand Down
31 changes: 0 additions & 31 deletions tests/unit_tests/wherobots/test_models.py

This file was deleted.

0 comments on commit aee5671

Please sign in to comment.