diff --git a/dags/veda_data_pipeline/groups/discover_group.py b/dags/veda_data_pipeline/groups/discover_group.py index 6c486f5..8cd2d20 100644 --- a/dags/veda_data_pipeline/groups/discover_group.py +++ b/dags/veda_data_pipeline/groups/discover_group.py @@ -8,33 +8,27 @@ from veda_data_pipeline.utils.s3_discovery import ( s3_discovery_handler, EmptyFileListError ) +from deprecated import deprecated group_kwgs = {"group_id": "Discover", "tooltip": "Discover"} @task(retries=1, retry_delay=timedelta(minutes=1)) -def discover_from_s3_task(ti=None, event={}, alt_payload = None, **kwargs): +def discover_from_s3_task(ti=None, event={}, **kwargs): """Discover grouped assets/files from S3 in batches of 2800. Produce a list of such files stored on S3 to process. This task is used as part of the discover_group subdag and outputs data to EVENT_BUCKET. """ - if alt_payload: - config = { - **event, - **alt_payload - } - else: - config = { - **event, - **ti.dag_run.conf, - } + payload = kwargs.get("payload", ti.dag_run.conf) + config = { + **event, + **payload, + } # TODO test that this context var is available in taskflow last_successful_execution = kwargs.get("prev_start_date_success") if event.get("schedule") and last_successful_execution: config["last_successful_execution"] = last_successful_execution.isoformat() # (event, chunk_size=2800, role_arn=None, bucket_output=None): - if event.get("item_assets") and event.get("assets"): - config["assets"] = event.get("item_assets") airflow_vars = Variable.get("aws_dags_variables") airflow_vars_json = json.loads(airflow_vars) event_bucket = airflow_vars_json.get("EVENT_BUCKET") @@ -56,6 +50,36 @@ def discover_from_s3_task(ti=None, event={}, alt_payload = None, **kwargs): @task +def get_files_task(payload, ti=None): + """ + Get files from S3 produced by discovery or dataset tasks. + Handles both single payload and multiple payload scenarios. + """ + dag_run_id = ti.dag_run.run_id + results = [] + + # Handle multiple payloads (dataset and items case) + payloads = payload if isinstance(payload, list) else [payload] + + for item in payloads: + if isinstance(item, LazyXComAccess): # Dynamic task mapping case + payloads_xcom = item[0].pop("payload", []) + base_payload = item[0] + else: + payloads_xcom = item.pop("payload", []) + base_payload = item + + for indx, payload_xcom in enumerate(payloads_xcom): + results.append({ + "run_id": f"{dag_run_id}_{uuid.uuid4()}_{indx}", + **base_payload, + "payload": payload_xcom, + }) + + return results + +@task +@deprecated(reason="Please use get_files_task function that handles both files and dataset files use cases") def get_files_to_process(payload, ti=None): """Get files from S3 produced by the discovery task. Used as part of both the parallel_run_process_rasters and parallel_run_process_vectors tasks. @@ -74,6 +98,7 @@ def get_files_to_process(payload, ti=None): @task +@deprecated(reason="Please use get_files_task airflow task instead. This will be removed in the new release") def get_dataset_files_to_process(payload, ti=None): """Get files from S3 produced by the dataset task. This is different from the get_files_to_process task as it produces a combined structure from repeated mappings. diff --git a/dags/veda_data_pipeline/groups/processing_tasks.py b/dags/veda_data_pipeline/groups/processing_tasks.py index 48758fc..2dc48a4 100644 --- a/dags/veda_data_pipeline/groups/processing_tasks.py +++ b/dags/veda_data_pipeline/groups/processing_tasks.py @@ -36,5 +36,9 @@ def submit_to_stac_ingestor_task(built_stac: dict): return event - - +@task(max_active_tis_per_dag=5) +def build_stac_task(payload): + from veda_data_pipeline.utils.build_stac.handler import stac_handler + airflow_vars_json = Variable.get("aws_dags_variables", deserialize_json=True) + event_bucket = airflow_vars_json.get("EVENT_BUCKET") + return stac_handler(payload_src=payload, bucket_output=event_bucket) diff --git a/dags/veda_data_pipeline/utils/submit_stac.py b/dags/veda_data_pipeline/utils/submit_stac.py index d57b57f..22542ab 100644 --- a/dags/veda_data_pipeline/utils/submit_stac.py +++ b/dags/veda_data_pipeline/utils/submit_stac.py @@ -103,7 +103,7 @@ def submission_handler( cognito_app_secret=None, stac_ingestor_api_url=None, context=None, -) -> None: +) -> None | dict: if context is None: context = {} @@ -121,7 +121,7 @@ def submission_handler( secret_id=cognito_app_secret, base_url=stac_ingestor_api_url, ) - ingestor.submit(event=stac_item, endpoint=endpoint) + return ingestor.submit(event=stac_item, endpoint=endpoint) if __name__ == "__main__": diff --git a/dags/veda_data_pipeline/veda_dataset_pipeline.py b/dags/veda_data_pipeline/veda_dataset_pipeline.py index c735697..0acd6b3 100644 --- a/dags/veda_data_pipeline/veda_dataset_pipeline.py +++ b/dags/veda_data_pipeline/veda_dataset_pipeline.py @@ -1,38 +1,40 @@ import pendulum from airflow import DAG +from copy import deepcopy +from airflow.models.param import Param from airflow.decorators import task +from veda_data_pipeline.groups.discover_group import discover_from_s3_task, get_files_task from airflow.operators.dummy_operator import DummyOperator as EmptyOperator -from airflow.models.variable import Variable -import json from veda_data_pipeline.groups.collection_group import collection_task_group -from veda_data_pipeline.groups.discover_group import discover_from_s3_task, get_dataset_files_to_process -from veda_data_pipeline.groups.processing_tasks import submit_to_stac_ingestor_task +from veda_data_pipeline.groups.processing_tasks import submit_to_stac_ingestor_task, build_stac_task -dag_doc_md = """ +template_dag_run_conf = { + "collection": "", + "data_type": "cog", + "description": "", + "discovery_items": + [ + { + "bucket": "", + "datetime_range": "", + "discovery": "s3", + "filename_regex": "", + "prefix": "" + } + ], + "is_periodic": Param(True, type="boolean"), + "license": "", + "time_density": "", + "title": "" +} + +dag_doc_md = f""" ### Dataset Pipeline Generates a collection and triggers the file discovery process #### Notes - This DAG can run with the following configuration
```json -{ - "collection": "collection-id", - "data_type": "cog", - "description": "collection description", - "discovery_items": - [ - { - "bucket": "veda-data-store-staging", - "datetime_range": "year", - "discovery": "s3", - "filename_regex": "^(.*).tif$", - "prefix": "example-prefix/" - } - ], - "is_periodic": true, - "license": "collection-LICENSE", - "time_density": "year", - "title": "collection-title" -} +{template_dag_run_conf} ``` """ @@ -44,76 +46,33 @@ "tags": ["collection", "discovery"], } +with DAG("veda_dataset_pipeline", params=template_dag_run_conf, **dag_args) as dag: + start = EmptyOperator(task_id="start") + end = EmptyOperator(task_id="end") -@task -def extract_discovery_items(**kwargs): - ti = kwargs.get("ti") - discovery_items = ti.dag_run.conf.get("discovery_items") - print(discovery_items) - return discovery_items - - -@task(max_active_tis_per_dag=3) -def build_stac_task(payload): - from veda_data_pipeline.utils.build_stac.handler import stac_handler - airflow_vars = Variable.get("aws_dags_variables") - airflow_vars_json = json.loads(airflow_vars) - event_bucket = airflow_vars_json.get("EVENT_BUCKET") - return stac_handler(payload_src=payload, bucket_output=event_bucket) -@task() -def mutate_payload(**kwargs): - ti = kwargs.get("ti") - payload = ti.dag_run.conf - if assets := payload.get("assets"): - # remove thumbnail asset if provided in collection config - if "thumbnail" in assets.keys(): + @task() + def remove_thumbnail_asset(ti): + payload = deepcopy(ti.dag_run.conf) + payloads = list() + assets = payload.get("assets", {}) + if assets.get("thumbnail"): assets.pop("thumbnail") # if thumbnail was only asset, delete assets if not assets: payload.pop("assets") - # finally put the mutated assets back in the payload - else: - payload["assets"] = assets - return payload - - -template_dag_run_conf = { - "collection": "", - "data_type": "cog", - "description": "", - "discovery_items": - [ - { - "bucket": "", - "datetime_range": "", - "discovery": "s3", - "filename_regex": "", - "prefix": "" + for item in payload.get("discovery_items"): + payloads.append({ + **payload, + **item } - ], - "is_periodic": "", - "license": "", - "time_density": "", - "title": "" -} + ) -with DAG("veda_dataset_pipeline", params=template_dag_run_conf, **dag_args) as dag: - # ECS dependency variable - - start = EmptyOperator(task_id="start", dag=dag) - end = EmptyOperator(task_id="end", dag=dag) + return payloads - collection_grp = collection_task_group() - mutate_payload_task = mutate_payload() - discover = discover_from_s3_task.partial(alt_payload=mutate_payload_task).expand(event=extract_discovery_items()) - discover.set_upstream(collection_grp) # do not discover until collection exists - get_files = get_dataset_files_to_process(payload=discover) + mutated_payloads = start >> collection_task_group() >> remove_thumbnail_asset() + discover = discover_from_s3_task.expand(payload=mutated_payloads) + get_files = get_files_task(payload=discover) build_stac = build_stac_task.expand(payload=get_files) - # .output is needed coming from a non-taskflow operator - submit_stac = submit_to_stac_ingestor_task.expand(built_stac=build_stac) - - collection_grp.set_upstream(start) - mutate_payload_task.set_upstream(start) - submit_stac.set_downstream(end) + submit_stac = submit_to_stac_ingestor_task.expand(built_stac=build_stac) >> end diff --git a/dags/veda_data_pipeline/veda_discover_pipeline.py b/dags/veda_data_pipeline/veda_discover_pipeline.py index 49d790b..0b0f8d1 100644 --- a/dags/veda_data_pipeline/veda_discover_pipeline.py +++ b/dags/veda_data_pipeline/veda_discover_pipeline.py @@ -1,11 +1,9 @@ import pendulum from airflow import DAG from airflow.operators.dummy_operator import DummyOperator -from airflow.decorators import task -from airflow.models.variable import Variable -import json -from veda_data_pipeline.groups.discover_group import discover_from_s3_task, get_files_to_process -from veda_data_pipeline.groups.processing_tasks import submit_to_stac_ingestor_task +from veda_data_pipeline.groups.discover_group import discover_from_s3_task, get_files_task + +from veda_data_pipeline.groups.processing_tasks import submit_to_stac_ingestor_task, build_stac_task dag_doc_md = """ ### Discover files from S3 @@ -72,13 +70,7 @@ } -@task(max_active_tis_per_dag=5) -def build_stac_task(payload): - from veda_data_pipeline.utils.build_stac.handler import stac_handler - airflow_vars = Variable.get("aws_dags_variables") - airflow_vars_json = json.loads(airflow_vars) - event_bucket = airflow_vars_json.get("EVENT_BUCKET") - return stac_handler(payload_src=payload, bucket_output=event_bucket) + def get_discover_dag(id, event=None): @@ -98,7 +90,7 @@ def get_discover_dag(id, event=None): # define DAG using taskflow notation discover = discover_from_s3_task(event=event) - get_files = get_files_to_process(payload=discover) + get_files = get_files_task(payload=discover) build_stac = build_stac_task.expand(payload=get_files) # .output is needed coming from a non-taskflow operator submit_stac = submit_to_stac_ingestor_task.expand(built_stac=build_stac)