Skip to content

Commit

Permalink
fix(airflow): fix AthenaOperator extraction (datahub-project#11857)
Browse files Browse the repository at this point in the history
Co-authored-by: Harshal Sheth <[email protected]>
  • Loading branch information
steffengr and hsheth2 authored Dec 4, 2024
1 parent eef9759 commit 49b6284
Show file tree
Hide file tree
Showing 6 changed files with 1,440 additions and 2 deletions.
2 changes: 1 addition & 1 deletion metadata-ingestion-modules/airflow-plugin/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def get_long_description():
*plugins["datahub-kafka"],
f"acryl-datahub[testing-utils]{_self_pin}",
# Extra requirements for loading our test dags.
"apache-airflow[snowflake]>=2.0.2",
"apache-airflow[snowflake,amazon]>=2.0.2",
# A collection of issues we've encountered:
# - Connexion's new version breaks Airflow:
# See https://github.com/apache/airflow/issues/35234.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def __init__(self):
"BigQueryOperator",
"BigQueryExecuteQueryOperator",
# Athena also does something similar.
"AthenaOperator",
"AWSAthenaOperator",
# Additional types that OL doesn't support. This is only necessary because
# on older versions of Airflow, these operators don't inherit from SQLExecuteQueryOperator.
Expand All @@ -59,6 +58,8 @@ def __init__(self):
for operator in _sql_operator_overrides:
self.task_to_extractor.extractors[operator] = GenericSqlExtractor

self.task_to_extractor.extractors["AthenaOperator"] = AthenaOperatorExtractor

self.task_to_extractor.extractors[
"BigQueryInsertJobOperator"
] = BigQueryInsertJobOperatorExtractor
Expand Down Expand Up @@ -276,6 +277,27 @@ def extract(self) -> Optional[TaskMetadata]:
)


class AthenaOperatorExtractor(BaseExtractor):
def extract(self) -> Optional[TaskMetadata]:
from airflow.providers.amazon.aws.operators.athena import (
AthenaOperator, # type: ignore
)

operator: "AthenaOperator" = self.operator
sql = operator.query
if not sql:
self.log.warning("No query found in AthenaOperator")
return None

return _parse_sql_into_task_metadata(
self,
sql,
platform="athena",
default_database=None,
default_schema=self.operator.database,
)


def _snowflake_default_schema(self: "SnowflakeExtractor") -> Optional[str]:
if hasattr(self.operator, "schema") and self.operator.schema is not None:
return self.operator.schema
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from datetime import datetime

from airflow import DAG
from airflow.providers.amazon.aws.operators.athena import AthenaOperator

ATHENA_COST_TABLE = "costs"
ATHENA_PROCESSED_TABLE = "processed_costs"


def _fake_athena_execute(*args, **kwargs):
pass


with DAG(
"athena_operator",
start_date=datetime(2023, 1, 1),
schedule_interval=None,
catchup=False,
) as dag:
# HACK: We don't want to send real requests to Athena. As a workaround,
# we can simply monkey-patch the operator.
AthenaOperator.execute = _fake_athena_execute # type: ignore

transform_cost_table = AthenaOperator(
aws_conn_id="my_aws",
task_id="transform_cost_table",
database="athena_db",
query="""
CREATE OR REPLACE TABLE {{ params.out_table_name }} AS
SELECT
id,
month,
total_cost,
area,
total_cost / area as cost_per_area
FROM {{ params.in_table_name }}
""",
params={
"in_table_name": ATHENA_COST_TABLE,
"out_table_name": ATHENA_PROCESSED_TABLE,
},
output_location="s3://athena-results-bucket/",
)
Loading

0 comments on commit 49b6284

Please sign in to comment.