Skip to content

Commit

Permalink
Refactor Dataset Pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
amarouane-ABDELHAK committed Dec 6, 2024
1 parent 0e5cc7c commit 53416d0
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 97 deletions.
50 changes: 37 additions & 13 deletions dags/veda_data_pipeline/groups/discover_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,33 +8,26 @@
from veda_data_pipeline.utils.s3_discovery import (
s3_discovery_handler, EmptyFileListError
)
from deprecated import deprecated

group_kwgs = {"group_id": "Discover", "tooltip": "Discover"}


@task(retries=1, retry_delay=timedelta(minutes=1))
def discover_from_s3_task(ti=None, event={}, alt_payload = None, **kwargs):
def discover_from_s3_task(ti=None, event={}, **kwargs):
"""Discover grouped assets/files from S3 in batches of 2800. Produce a list of such files stored on S3 to process.
This task is used as part of the discover_group subdag and outputs data to EVENT_BUCKET.
"""
if alt_payload:
config = {
**event,
**alt_payload
}
else:
config = {
**event,
**ti.dag_run.conf,
}
config = {
**event,
**ti.dag_run.conf,
}
# TODO test that this context var is available in taskflow
last_successful_execution = kwargs.get("prev_start_date_success")
if event.get("schedule") and last_successful_execution:
config["last_successful_execution"] = last_successful_execution.isoformat()
# (event, chunk_size=2800, role_arn=None, bucket_output=None):

if event.get("item_assets") and event.get("assets"):
config["assets"] = event.get("item_assets")
airflow_vars = Variable.get("aws_dags_variables")
airflow_vars_json = json.loads(airflow_vars)
event_bucket = airflow_vars_json.get("EVENT_BUCKET")
Expand All @@ -56,6 +49,36 @@ def discover_from_s3_task(ti=None, event={}, alt_payload = None, **kwargs):


@task
def get_files_task(payload, ti=None):
"""
Get files from S3 produced by discovery or dataset tasks.
Handles both single payload and multiple payload scenarios.
"""
dag_run_id = ti.dag_run.run_id
results = []

# Handle multiple payloads (dataset and items case)
payloads = payload if isinstance(payload, list) else [payload]

for item in payloads:
if isinstance(item, LazyXComAccess): # Dynamic task mapping case
payloads_xcom = item[0].pop("payload", [])
base_payload = item[0]
else:
payloads_xcom = item.pop("payload", [])
base_payload = item

for indx, payload_xcom in enumerate(payloads_xcom):
results.append({
"run_id": f"{dag_run_id}_{uuid.uuid4()}_{indx}",
**base_payload,
"payload": payload_xcom,
})

return results

@task
@deprecated(reason="Please use get_files_task function that hundles both files and dataset files use cases")
def get_files_to_process(payload, ti=None):
"""Get files from S3 produced by the discovery task.
Used as part of both the parallel_run_process_rasters and parallel_run_process_vectors tasks.
Expand All @@ -74,6 +97,7 @@ def get_files_to_process(payload, ti=None):


@task
@deprecated(reason="Please use get_files_task airflow task instead. This will be removed in the new release")
def get_dataset_files_to_process(payload, ti=None):
"""Get files from S3 produced by the dataset task.
This is different from the get_files_to_process task as it produces a combined structure from repeated mappings.
Expand Down
8 changes: 6 additions & 2 deletions dags/veda_data_pipeline/groups/processing_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ def submit_to_stac_ingestor_task(built_stac: dict):
)
return event



@task(max_active_tis_per_dag=5)
def build_stac_task(payload):
from veda_data_pipeline.utils.build_stac.handler import stac_handler
airflow_vars_json = Variable.get("aws_dags_variables", deserialize_json=True)
event_bucket = airflow_vars_json.get("EVENT_BUCKET")
return stac_handler(payload_src=payload, bucket_output=event_bucket)

121 changes: 49 additions & 72 deletions dags/veda_data_pipeline/veda_dataset_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,38 @@
import pendulum
from airflow import DAG
from airflow.models.param import Param
from airflow.decorators import task
from veda_data_pipeline.groups.discover_group import discover_from_s3_task, get_files_task
from airflow.operators.dummy_operator import DummyOperator as EmptyOperator
from airflow.models.variable import Variable
import json
from veda_data_pipeline.groups.collection_group import collection_task_group
from veda_data_pipeline.groups.discover_group import discover_from_s3_task, get_dataset_files_to_process
from veda_data_pipeline.groups.processing_tasks import submit_to_stac_ingestor_task
from veda_data_pipeline.groups.processing_tasks import submit_to_stac_ingestor_task, build_stac_task

dag_doc_md = """
template_dag_run_conf = {
"collection": "<collection-id>",
"data_type": "cog",
"description": "<collection-description>",
"discovery_items": [
{
"bucket": "<bucket-name>",
"datetime_range": "<range>",
"discovery": "s3",
"filename_regex": "<regex>",
"prefix": "<example-prefix/>",
}
],
"is_periodic": "<true|false>",
"license": "<collection-LICENSE>",
"time_density": "<time-density>",
"title": "<collection-title>",
}

dag_doc_md = f"""
### Dataset Pipeline
Generates a collection and triggers the file discovery process
#### Notes
- This DAG can run with the following configuration <br>
```json
{
"collection": "collection-id",
"data_type": "cog",
"description": "collection description",
"discovery_items":
[
{
"bucket": "veda-data-store-staging",
"datetime_range": "year",
"discovery": "s3",
"filename_regex": "^(.*).tif$",
"prefix": "example-prefix/"
}
],
"is_periodic": true,
"license": "collection-LICENSE",
"time_density": "year",
"title": "collection-title"
}
{template_dag_run_conf}
```
"""

Expand All @@ -44,40 +44,6 @@
"tags": ["collection", "discovery"],
}


@task
def extract_discovery_items(**kwargs):
ti = kwargs.get("ti")
discovery_items = ti.dag_run.conf.get("discovery_items")
print(discovery_items)
return discovery_items


@task(max_active_tis_per_dag=3)
def build_stac_task(payload):
from veda_data_pipeline.utils.build_stac.handler import stac_handler
airflow_vars = Variable.get("aws_dags_variables")
airflow_vars_json = json.loads(airflow_vars)
event_bucket = airflow_vars_json.get("EVENT_BUCKET")
return stac_handler(payload_src=payload, bucket_output=event_bucket)

@task()
def mutate_payload(**kwargs):
ti = kwargs.get("ti")
payload = ti.dag_run.conf
if assets := payload.get("assets"):
# remove thumbnail asset if provided in collection config
if "thumbnail" in assets.keys():
assets.pop("thumbnail")
# if thumbnail was only asset, delete assets
if not assets:
payload.pop("assets")
# finally put the mutated assets back in the payload
else:
payload["assets"] = assets
return payload


template_dag_run_conf = {
"collection": "<collection-id>",
"data_type": "cog",
Expand All @@ -99,21 +65,32 @@ def mutate_payload(**kwargs):
}

with DAG("veda_dataset_pipeline", params=template_dag_run_conf, **dag_args) as dag:
# ECS dependency variable
start = EmptyOperator(task_id="start")
end = EmptyOperator(task_id="end")

start = EmptyOperator(task_id="start", dag=dag)
end = EmptyOperator(task_id="end", dag=dag)

collection_grp = collection_task_group()
mutate_payload_task = mutate_payload()
discover = discover_from_s3_task.partial(alt_payload=mutate_payload_task).expand(event=extract_discovery_items())
discover.set_upstream(collection_grp) # do not discover until collection exists
get_files = get_dataset_files_to_process(payload=discover)
@task()
def remove_thumbnail_asset(ti):
payload = ti.dag_run.conf.copy()
payloads = list()
assets = payload.get("assets", {})
if assets.get("thumbnail"):
assets.pop("thumbnail")
# if thumbnail was only asset, delete assets
if not assets:
payload.pop("assets")
for item in payload.get("discovery_items"):
payloads.append({
**payload,
**item
}
)

return payloads

build_stac = build_stac_task.expand(payload=get_files)
# .output is needed coming from a non-taskflow operator
submit_stac = submit_to_stac_ingestor_task.expand(built_stac=build_stac)

collection_grp.set_upstream(start)
mutate_payload_task.set_upstream(start)
submit_stac.set_downstream(end)
mutated_payloads = start >> collection_task_group() >> remove_thumbnail_asset()
discover = discover_from_s3_task.expand(event=mutated_payloads)
get_files = get_files_task(payload=discover)
build_stac = build_stac_task.expand(payload=get_files)
submit_stac = submit_to_stac_ingestor_task.expand(built_stac=build_stac) >> end
14 changes: 4 additions & 10 deletions dags/veda_data_pipeline/veda_discover_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from airflow.decorators import task
from airflow.models.variable import Variable
import json
from veda_data_pipeline.groups.discover_group import discover_from_s3_task, get_files_to_process
from veda_data_pipeline.groups.processing_tasks import submit_to_stac_ingestor_task
from veda_data_pipeline.groups.discover_group import discover_from_s3_task, get_files_task
from veda_data_pipeline.groups.processing_tasks import submit_to_stac_ingestor_task, build_stac_task

dag_doc_md = """
### Discover files from S3
Expand Down Expand Up @@ -72,13 +72,7 @@
}


@task(max_active_tis_per_dag=5)
def build_stac_task(payload):
from veda_data_pipeline.utils.build_stac.handler import stac_handler
airflow_vars = Variable.get("aws_dags_variables")
airflow_vars_json = json.loads(airflow_vars)
event_bucket = airflow_vars_json.get("EVENT_BUCKET")
return stac_handler(payload_src=payload, bucket_output=event_bucket)



def get_discover_dag(id, event=None):
Expand All @@ -98,7 +92,7 @@ def get_discover_dag(id, event=None):
# define DAG using taskflow notation

discover = discover_from_s3_task(event=event)
get_files = get_files_to_process(payload=discover)
get_files = get_files_task(payload=discover)
build_stac = build_stac_task.expand(payload=get_files)
# .output is needed coming from a non-taskflow operator
submit_stac = submit_to_stac_ingestor_task.expand(built_stac=build_stac)
Expand Down

0 comments on commit 53416d0

Please sign in to comment.