From 514830d22be5e3efda06731161751fddb21c9e34 Mon Sep 17 00:00:00 2001 From: Henri Blancke Date: Tue, 21 Mar 2023 05:04:14 -0400 Subject: [PATCH] feat: lake formation tag support (#159) Signed-off-by: Henri Blancke Co-authored-by: nicor88 <6278547+nicor88@users.noreply.github.com> --- README.md | 8 ++- dbt/adapters/athena/connections.py | 2 + dbt/adapters/athena/impl.py | 58 ++++++++++++++++++- dbt/include/athena/macros/adapters/schema.sql | 2 + .../models/incremental/incremental.sql | 5 ++ .../materializations/models/table/table.sql | 5 ++ .../models/view/create_or_replace_view.sql | 14 ++++- .../macros/materializations/seeds/helpers.sql | 7 +++ dbt/include/athena/sample_profiles.yml | 3 + tests/unit/constants.py | 1 + tests/unit/test_adapter.py | 3 +- 11 files changed, 102 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 99a60eba..53a4ba9a 100644 --- a/README.md +++ b/README.md @@ -72,7 +72,7 @@ A dbt profile can be configured to run against AWS Athena using the following co | aws_profile_name | Profile to use from your AWS shared credentials file. | Optional | `my-profile` | | work_group | Identifier of Athena workgroup | Optional | `my-custom-workgroup` | | num_retries | Number of times to retry a failing query | Optional | `3` | - +| lf_tags | Default lf tags to apply to any database created by dbt | Optional | `{"origin": "dbt", "team": "analytics"}`| **Example profiles.yml entry:** ```yaml @@ -89,6 +89,9 @@ athena: database: awsdatacatalog aws_profile_name: my-profile work_group: my-workgroup + lf_tags: + origin: dbt + team: analytics ``` _Additional information_ @@ -119,6 +122,9 @@ _Additional information_ * `field_delimiter` (`default=none`) * Custom field delimiter, for when format is set to `TEXTFILE` * `table_properties`: table properties to add to the table, valid for Iceberg only +* `lf_tags` (`default=none`) + * lf tags to associate with the table + * format: `{"tag1": "value1", "tag2": "value2"}` #### Table location diff --git a/dbt/adapters/athena/connections.py b/dbt/adapters/athena/connections.py index 750f5204..fb7d5f1b 100644 --- a/dbt/adapters/athena/connections.py +++ b/dbt/adapters/athena/connections.py @@ -47,6 +47,7 @@ class AthenaCredentials(Credentials): num_retries: Optional[int] = 5 s3_data_dir: Optional[str] = None s3_data_naming: Optional[str] = "schema_table_unique" + lf_tags: Optional[Dict[str, str]] = None @property def type(self) -> str: @@ -68,6 +69,7 @@ def _connection_keys(self) -> Tuple[str, ...]: "endpoint_url", "s3_data_dir", "s3_data_naming", + "lf_tags", ) diff --git a/dbt/adapters/athena/impl.py b/dbt/adapters/athena/impl.py index 86abed7d..b3ba3203 100755 --- a/dbt/adapters/athena/impl.py +++ b/dbt/adapters/athena/impl.py @@ -1,7 +1,7 @@ import posixpath as path from itertools import chain from threading import Lock -from typing import Any, Dict, Iterator, List, Optional, Set, Tuple +from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Union from urllib.parse import urlparse from uuid import uuid4 @@ -56,6 +56,60 @@ def convert_number_type(cls, agate_table: agate.Table, col_idx: int) -> str: def convert_datetime_type(cls, agate_table: agate.Table, col_idx: int) -> str: return "timestamp" + # TODO: Add more lf-tag unit tests when moto supports lakeformation + # moto issue: https://github.com/getmoto/moto/issues/5964 + @available + def add_lf_tags(self, database: str, table: str = None, lf_tags: Dict[str, str] = None): + conn = self.connections.get_thread_connection() + client = conn.handle + + lf_tags = lf_tags or conn.credentials.lf_tags + if not lf_tags: + logger.debug("No LF tags configured") + return + + resource = { + "Database": {"Name": database}, + } + + if table: + resource = { + "Table": { + "DatabaseName": database, + "Name": table, + } + } + + with boto3_client_lock: + lf_client = client.session.client( + "lakeformation", region_name=client.region_name, config=get_boto3_config() + ) + + response = lf_client.add_lf_tags_to_resource( + Resource=resource, + LFTags=[ + { + "TagKey": key, + "TagValues": [ + value, + ], + } + for key, value in lf_tags.items() + ], + ) + + failures = response.get("Failures", []) + tbl_appendix = f".{table}" if table else "" + if failures: + base_msg = f"Failed to add LF tags: {lf_tags} to {database}" + tbl_appendix + for failure in failures: + tag = failure.get("LFTag", {}).get("TagKey") + error = failure.get("Error", {}).get("ErrorMessage") + logger.error(f"Failed to set {tag} for {database}" + tbl_appendix + f" - {error}") + raise DbtRuntimeError(base_msg) + else: + logger.debug(f"Added LF tags: {lf_tags} to {database}" + tbl_appendix) + @available def get_work_group_output_location(self) -> Optional[str]: conn = self.connections.get_thread_connection() @@ -120,7 +174,7 @@ def s3_table_location( return table_location @available - def get_table_location(self, database_name: str, table_name: str) -> [str, None]: + def get_table_location(self, database_name: str, table_name: str) -> Union[str, None]: """ Helper function to S3 get table location """ diff --git a/dbt/include/athena/macros/adapters/schema.sql b/dbt/include/athena/macros/adapters/schema.sql index 777a3690..312f2bf6 100644 --- a/dbt/include/athena/macros/adapters/schema.sql +++ b/dbt/include/athena/macros/adapters/schema.sql @@ -2,6 +2,8 @@ {%- call statement('create_schema') -%} create schema if not exists {{ relation.without_identifier().render_hive() }} {% endcall %} + + {{ adapter.add_lf_tags(relation.schema) }} {% endmacro %} diff --git a/dbt/include/athena/macros/materializations/models/incremental/incremental.sql b/dbt/include/athena/macros/materializations/models/incremental/incremental.sql index 404bbec9..966d098b 100644 --- a/dbt/include/athena/macros/materializations/models/incremental/incremental.sql +++ b/dbt/include/athena/macros/materializations/models/incremental/incremental.sql @@ -5,6 +5,7 @@ {% set strategy = validate_get_incremental_strategy(raw_strategy, table_type) %} {% set on_schema_change = incremental_validate_on_schema_change(config.get('on_schema_change'), default='ignore') %} + {% set lf_tags = config.get('lf_tags', default=none) %} {% set partitioned_by = config.get('partitioned_by', default=none) %} {% set target_relation = this.incorporate(type='table') %} {% set existing_relation = load_relation(this) %} @@ -83,6 +84,10 @@ {{ run_hooks(post_hooks, inside_transaction=False) }} + {% if lf_tags is not none %} + {{ adapter.add_lf_tags(target_relation.schema, target_relation.identifier, lf_tags) }} + {% endif %} + {{ return({'relations': [target_relation]}) }} {%- endmaterialization %} diff --git a/dbt/include/athena/macros/materializations/models/table/table.sql b/dbt/include/athena/macros/materializations/models/table/table.sql index 088f67bf..d43c9c1f 100644 --- a/dbt/include/athena/macros/materializations/models/table/table.sql +++ b/dbt/include/athena/macros/materializations/models/table/table.sql @@ -1,6 +1,7 @@ {% materialization table, adapter='athena' -%} {%- set identifier = model['alias'] -%} + {%- set lf_tags = config.get('lf_tags', default=none) -%} {%- set table_type = config.get('table_type', default='hive') | lower -%} {%- set old_relation = adapter.get_relation(database=database, schema=schema, identifier=identifier) -%} {%- set target_relation = api.Relation.create(identifier=identifier, @@ -26,6 +27,10 @@ {{ run_hooks(post_hooks) }} + {% if lf_tags is not none %} + {{ adapter.add_lf_tags(target_relation.schema, identifier, lf_tags) }} + {% endif %} + {% do persist_docs(target_relation, model) %} {{ return({'relations': [target_relation]}) }} diff --git a/dbt/include/athena/macros/materializations/models/view/create_or_replace_view.sql b/dbt/include/athena/macros/materializations/models/view/create_or_replace_view.sql index f019c8b3..583dfeae 100644 --- a/dbt/include/athena/macros/materializations/models/view/create_or_replace_view.sql +++ b/dbt/include/athena/macros/materializations/models/view/create_or_replace_view.sql @@ -1,10 +1,16 @@ {% macro create_or_replace_view(run_outside_transaction_hooks=True) %} {%- set identifier = model['alias'] -%} + + {%- set lf_tags = config.get('lf_tags', default=none) -%} {%- set old_relation = adapter.get_relation(database=database, schema=schema, identifier=identifier) -%} {%- set exists_as_view = (old_relation is not none and old_relation.is_view) -%} {%- set target_relation = api.Relation.create( - identifier=identifier, schema=schema, database=database, - type='view') -%} + identifier=identifier, + schema=schema, + database=database, + type='view', + ) -%} + {% if run_outside_transaction_hooks %} -- no transactions on BigQuery {{ run_hooks(pre_hooks, inside_transaction=False) }} @@ -23,6 +29,10 @@ {{ create_view_as(target_relation, sql) }} {%- endcall %} + {% if lf_tags is not none %} + {{ adapter.add_lf_tags(target_relation.schema, identifier, lf_tags) }} + {% endif %} + {{ run_hooks(post_hooks, inside_transaction=True) }} {{ adapter.commit() }} diff --git a/dbt/include/athena/macros/materializations/seeds/helpers.sql b/dbt/include/athena/macros/materializations/seeds/helpers.sql index 1815c429..c2b36cdb 100644 --- a/dbt/include/athena/macros/materializations/seeds/helpers.sql +++ b/dbt/include/athena/macros/materializations/seeds/helpers.sql @@ -8,6 +8,9 @@ {% endmacro %} {% macro athena__create_csv_table(model, agate_table) %} + {%- set identifier = model['alias'] -%} + + {%- set lf_tags = config.get('lf_tags', default=none) -%} {%- set column_override = config.get('column_types', {}) -%} {%- set quote_seed_column = config.get('quote_columns', None) -%} {%- set s3_data_dir = config.get('s3_data_dir', default=target.s3_data_dir) -%} @@ -32,5 +35,9 @@ {{ sql }} {%- endcall %} + {% if lf_tags is not none %} + {{ adapter.add_lf_tags(model.schema, identifier, lf_tags) }} + {% endif %} + {{ return(sql) }} {% endmacro %} diff --git a/dbt/include/athena/sample_profiles.yml b/dbt/include/athena/sample_profiles.yml index f9cbd0cf..54068c2a 100644 --- a/dbt/include/athena/sample_profiles.yml +++ b/dbt/include/athena/sample_profiles.yml @@ -7,6 +7,9 @@ default: region_name: [region_name] database: [database name] schema: [dev_schema] + lf_tags: + origin: dbt + team: analytics prod: type: athena diff --git a/tests/unit/constants.py b/tests/unit/constants.py index 57f9e3b8..4178f9be 100644 --- a/tests/unit/constants.py +++ b/tests/unit/constants.py @@ -3,4 +3,5 @@ DATABASE_NAME = "test_dbt_athena" BUCKET = "test-dbt-athena" AWS_REGION = "eu-west-1" +S3_STAGING_DIR = "s3://my-bucket/test-dbt/" ATHENA_WORKGROUP = "dbt-athena-adapter" diff --git a/tests/unit/test_adapter.py b/tests/unit/test_adapter.py index 9435a79e..525c314d 100644 --- a/tests/unit/test_adapter.py +++ b/tests/unit/test_adapter.py @@ -26,6 +26,7 @@ BUCKET, DATA_CATALOG_NAME, DATABASE_NAME, + S3_STAGING_DIR, ) from .utils import ( MockAWSService, @@ -50,7 +51,7 @@ def setup_method(self, _): "outputs": { "test": { "type": "athena", - "s3_staging_dir": "s3://my-bucket/test-dbt/", + "s3_staging_dir": S3_STAGING_DIR, "region_name": AWS_REGION, "database": DATA_CATALOG_NAME, "work_group": ATHENA_WORKGROUP,