Skip to content

Commit

Permalink
Merge pull request #268 from NASA-IMPACT/refactor-dataset-pipeline
Browse files Browse the repository at this point in the history
Refactor Dataset Pipeline
  • Loading branch information
amarouane-ABDELHAK authored Dec 11, 2024
2 parents 0e5cc7c + 74573b7 commit e6ca7d4
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 116 deletions.
51 changes: 38 additions & 13 deletions dags/veda_data_pipeline/groups/discover_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,33 +8,27 @@
from veda_data_pipeline.utils.s3_discovery import (
s3_discovery_handler, EmptyFileListError
)
from deprecated import deprecated

group_kwgs = {"group_id": "Discover", "tooltip": "Discover"}


@task(retries=1, retry_delay=timedelta(minutes=1))
def discover_from_s3_task(ti=None, event={}, alt_payload = None, **kwargs):
def discover_from_s3_task(ti=None, event={}, **kwargs):
"""Discover grouped assets/files from S3 in batches of 2800. Produce a list of such files stored on S3 to process.
This task is used as part of the discover_group subdag and outputs data to EVENT_BUCKET.
"""
if alt_payload:
config = {
**event,
**alt_payload
}
else:
config = {
**event,
**ti.dag_run.conf,
}
payload = kwargs.get("payload", ti.dag_run.conf)
config = {
**event,
**payload,
}
# TODO test that this context var is available in taskflow
last_successful_execution = kwargs.get("prev_start_date_success")
if event.get("schedule") and last_successful_execution:
config["last_successful_execution"] = last_successful_execution.isoformat()
# (event, chunk_size=2800, role_arn=None, bucket_output=None):

if event.get("item_assets") and event.get("assets"):
config["assets"] = event.get("item_assets")
airflow_vars = Variable.get("aws_dags_variables")
airflow_vars_json = json.loads(airflow_vars)
event_bucket = airflow_vars_json.get("EVENT_BUCKET")
Expand All @@ -56,6 +50,36 @@ def discover_from_s3_task(ti=None, event={}, alt_payload = None, **kwargs):


@task
def get_files_task(payload, ti=None):
"""
Get files from S3 produced by discovery or dataset tasks.
Handles both single payload and multiple payload scenarios.
"""
dag_run_id = ti.dag_run.run_id
results = []

# Handle multiple payloads (dataset and items case)
payloads = payload if isinstance(payload, list) else [payload]

for item in payloads:
if isinstance(item, LazyXComAccess): # Dynamic task mapping case
payloads_xcom = item[0].pop("payload", [])
base_payload = item[0]
else:
payloads_xcom = item.pop("payload", [])
base_payload = item

for indx, payload_xcom in enumerate(payloads_xcom):
results.append({
"run_id": f"{dag_run_id}_{uuid.uuid4()}_{indx}",
**base_payload,
"payload": payload_xcom,
})

return results

@task
@deprecated(reason="Please use get_files_task function that handles both files and dataset files use cases")
def get_files_to_process(payload, ti=None):
"""Get files from S3 produced by the discovery task.
Used as part of both the parallel_run_process_rasters and parallel_run_process_vectors tasks.
Expand All @@ -74,6 +98,7 @@ def get_files_to_process(payload, ti=None):


@task
@deprecated(reason="Please use get_files_task airflow task instead. This will be removed in the new release")
def get_dataset_files_to_process(payload, ti=None):
"""Get files from S3 produced by the dataset task.
This is different from the get_files_to_process task as it produces a combined structure from repeated mappings.
Expand Down
8 changes: 6 additions & 2 deletions dags/veda_data_pipeline/groups/processing_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,9 @@ def submit_to_stac_ingestor_task(built_stac: dict):
return event




@task(max_active_tis_per_dag=5)
def build_stac_task(payload):
from veda_data_pipeline.utils.build_stac.handler import stac_handler
airflow_vars_json = Variable.get("aws_dags_variables", deserialize_json=True)
event_bucket = airflow_vars_json.get("EVENT_BUCKET")
return stac_handler(payload_src=payload, bucket_output=event_bucket)
4 changes: 2 additions & 2 deletions dags/veda_data_pipeline/utils/submit_stac.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def submission_handler(
cognito_app_secret=None,
stac_ingestor_api_url=None,
context=None,
) -> None:
) -> None | dict:
if context is None:
context = {}

Expand All @@ -121,7 +121,7 @@ def submission_handler(
secret_id=cognito_app_secret,
base_url=stac_ingestor_api_url,
)
ingestor.submit(event=stac_item, endpoint=endpoint)
return ingestor.submit(event=stac_item, endpoint=endpoint)


if __name__ == "__main__":
Expand Down
131 changes: 45 additions & 86 deletions dags/veda_data_pipeline/veda_dataset_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,40 @@
import pendulum
from airflow import DAG
from copy import deepcopy
from airflow.models.param import Param
from airflow.decorators import task
from veda_data_pipeline.groups.discover_group import discover_from_s3_task, get_files_task
from airflow.operators.dummy_operator import DummyOperator as EmptyOperator
from airflow.models.variable import Variable
import json
from veda_data_pipeline.groups.collection_group import collection_task_group
from veda_data_pipeline.groups.discover_group import discover_from_s3_task, get_dataset_files_to_process
from veda_data_pipeline.groups.processing_tasks import submit_to_stac_ingestor_task
from veda_data_pipeline.groups.processing_tasks import submit_to_stac_ingestor_task, build_stac_task

dag_doc_md = """
template_dag_run_conf = {
"collection": "<collection-id>",
"data_type": "cog",
"description": "<collection-description>",
"discovery_items":
[
{
"bucket": "<bucket-name>",
"datetime_range": "<range>",
"discovery": "s3",
"filename_regex": "<regex>",
"prefix": "<example-prefix/>"
}
],
"is_periodic": Param(True, type="boolean"),
"license": "<collection-LICENSE>",
"time_density": "<time-density>",
"title": "<collection-title>"
}

dag_doc_md = f"""
### Dataset Pipeline
Generates a collection and triggers the file discovery process
#### Notes
- This DAG can run with the following configuration <br>
```json
{
"collection": "collection-id",
"data_type": "cog",
"description": "collection description",
"discovery_items":
[
{
"bucket": "veda-data-store-staging",
"datetime_range": "year",
"discovery": "s3",
"filename_regex": "^(.*).tif$",
"prefix": "example-prefix/"
}
],
"is_periodic": true,
"license": "collection-LICENSE",
"time_density": "year",
"title": "collection-title"
}
{template_dag_run_conf}
```
"""

Expand All @@ -44,76 +46,33 @@
"tags": ["collection", "discovery"],
}

with DAG("veda_dataset_pipeline", params=template_dag_run_conf, **dag_args) as dag:
start = EmptyOperator(task_id="start")
end = EmptyOperator(task_id="end")

@task
def extract_discovery_items(**kwargs):
ti = kwargs.get("ti")
discovery_items = ti.dag_run.conf.get("discovery_items")
print(discovery_items)
return discovery_items


@task(max_active_tis_per_dag=3)
def build_stac_task(payload):
from veda_data_pipeline.utils.build_stac.handler import stac_handler
airflow_vars = Variable.get("aws_dags_variables")
airflow_vars_json = json.loads(airflow_vars)
event_bucket = airflow_vars_json.get("EVENT_BUCKET")
return stac_handler(payload_src=payload, bucket_output=event_bucket)

@task()
def mutate_payload(**kwargs):
ti = kwargs.get("ti")
payload = ti.dag_run.conf
if assets := payload.get("assets"):
# remove thumbnail asset if provided in collection config
if "thumbnail" in assets.keys():
@task()
def remove_thumbnail_asset(ti):
payload = deepcopy(ti.dag_run.conf)
payloads = list()
assets = payload.get("assets", {})
if assets.get("thumbnail"):
assets.pop("thumbnail")
# if thumbnail was only asset, delete assets
if not assets:
payload.pop("assets")
# finally put the mutated assets back in the payload
else:
payload["assets"] = assets
return payload


template_dag_run_conf = {
"collection": "<collection-id>",
"data_type": "cog",
"description": "<collection-description>",
"discovery_items":
[
{
"bucket": "<bucket-name>",
"datetime_range": "<range>",
"discovery": "s3",
"filename_regex": "<regex>",
"prefix": "<example-prefix/>"
for item in payload.get("discovery_items"):
payloads.append({
**payload,
**item
}
],
"is_periodic": "<true|false>",
"license": "<collection-LICENSE>",
"time_density": "<time-density>",
"title": "<collection-title>"
}
)

with DAG("veda_dataset_pipeline", params=template_dag_run_conf, **dag_args) as dag:
# ECS dependency variable

start = EmptyOperator(task_id="start", dag=dag)
end = EmptyOperator(task_id="end", dag=dag)
return payloads

collection_grp = collection_task_group()
mutate_payload_task = mutate_payload()
discover = discover_from_s3_task.partial(alt_payload=mutate_payload_task).expand(event=extract_discovery_items())
discover.set_upstream(collection_grp) # do not discover until collection exists
get_files = get_dataset_files_to_process(payload=discover)

mutated_payloads = start >> collection_task_group() >> remove_thumbnail_asset()
discover = discover_from_s3_task.expand(payload=mutated_payloads)
get_files = get_files_task(payload=discover)
build_stac = build_stac_task.expand(payload=get_files)
# .output is needed coming from a non-taskflow operator
submit_stac = submit_to_stac_ingestor_task.expand(built_stac=build_stac)

collection_grp.set_upstream(start)
mutate_payload_task.set_upstream(start)
submit_stac.set_downstream(end)
submit_stac = submit_to_stac_ingestor_task.expand(built_stac=build_stac) >> end
18 changes: 5 additions & 13 deletions dags/veda_data_pipeline/veda_discover_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import pendulum
from airflow import DAG
from airflow.operators.dummy_operator import DummyOperator
from airflow.decorators import task
from airflow.models.variable import Variable
import json
from veda_data_pipeline.groups.discover_group import discover_from_s3_task, get_files_to_process
from veda_data_pipeline.groups.processing_tasks import submit_to_stac_ingestor_task
from veda_data_pipeline.groups.discover_group import discover_from_s3_task, get_files_task

from veda_data_pipeline.groups.processing_tasks import submit_to_stac_ingestor_task, build_stac_task

dag_doc_md = """
### Discover files from S3
Expand Down Expand Up @@ -72,13 +70,7 @@
}


@task(max_active_tis_per_dag=5)
def build_stac_task(payload):
from veda_data_pipeline.utils.build_stac.handler import stac_handler
airflow_vars = Variable.get("aws_dags_variables")
airflow_vars_json = json.loads(airflow_vars)
event_bucket = airflow_vars_json.get("EVENT_BUCKET")
return stac_handler(payload_src=payload, bucket_output=event_bucket)



def get_discover_dag(id, event=None):
Expand All @@ -98,7 +90,7 @@ def get_discover_dag(id, event=None):
# define DAG using taskflow notation

discover = discover_from_s3_task(event=event)
get_files = get_files_to_process(payload=discover)
get_files = get_files_task(payload=discover)
build_stac = build_stac_task.expand(payload=get_files)
# .output is needed coming from a non-taskflow operator
submit_stac = submit_to_stac_ingestor_task.expand(built_stac=build_stac)
Expand Down

0 comments on commit e6ca7d4

Please sign in to comment.