Skip to content

Commit

Permalink
feat: lake formation tag support (#159)
Browse files Browse the repository at this point in the history
Signed-off-by: Henri Blancke <[email protected]>
Co-authored-by: nicor88 <[email protected]>
  • Loading branch information
henriblancke and nicor88 authored Mar 21, 2023
1 parent 8e42fda commit 514830d
Show file tree
Hide file tree
Showing 11 changed files with 102 additions and 6 deletions.
8 changes: 7 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ A dbt profile can be configured to run against AWS Athena using the following co
| aws_profile_name | Profile to use from your AWS shared credentials file. | Optional | `my-profile` |
| work_group | Identifier of Athena workgroup | Optional | `my-custom-workgroup` |
| num_retries | Number of times to retry a failing query | Optional | `3` |

| lf_tags | Default lf tags to apply to any database created by dbt | Optional | `{"origin": "dbt", "team": "analytics"}`|

**Example profiles.yml entry:**
```yaml
Expand All @@ -89,6 +89,9 @@ athena:
database: awsdatacatalog
aws_profile_name: my-profile
work_group: my-workgroup
lf_tags:
origin: dbt
team: analytics
```
_Additional information_
Expand Down Expand Up @@ -119,6 +122,9 @@ _Additional information_
* `field_delimiter` (`default=none`)
* Custom field delimiter, for when format is set to `TEXTFILE`
* `table_properties`: table properties to add to the table, valid for Iceberg only
* `lf_tags` (`default=none`)
* lf tags to associate with the table
* format: `{"tag1": "value1", "tag2": "value2"}`

#### Table location

Expand Down
2 changes: 2 additions & 0 deletions dbt/adapters/athena/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class AthenaCredentials(Credentials):
num_retries: Optional[int] = 5
s3_data_dir: Optional[str] = None
s3_data_naming: Optional[str] = "schema_table_unique"
lf_tags: Optional[Dict[str, str]] = None

@property
def type(self) -> str:
Expand All @@ -68,6 +69,7 @@ def _connection_keys(self) -> Tuple[str, ...]:
"endpoint_url",
"s3_data_dir",
"s3_data_naming",
"lf_tags",
)


Expand Down
58 changes: 56 additions & 2 deletions dbt/adapters/athena/impl.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import posixpath as path
from itertools import chain
from threading import Lock
from typing import Any, Dict, Iterator, List, Optional, Set, Tuple
from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Union
from urllib.parse import urlparse
from uuid import uuid4

Expand Down Expand Up @@ -56,6 +56,60 @@ def convert_number_type(cls, agate_table: agate.Table, col_idx: int) -> str:
def convert_datetime_type(cls, agate_table: agate.Table, col_idx: int) -> str:
return "timestamp"

# TODO: Add more lf-tag unit tests when moto supports lakeformation
# moto issue: https://github.com/getmoto/moto/issues/5964
@available
def add_lf_tags(self, database: str, table: str = None, lf_tags: Dict[str, str] = None):
conn = self.connections.get_thread_connection()
client = conn.handle

lf_tags = lf_tags or conn.credentials.lf_tags
if not lf_tags:
logger.debug("No LF tags configured")
return

resource = {
"Database": {"Name": database},
}

if table:
resource = {
"Table": {
"DatabaseName": database,
"Name": table,
}
}

with boto3_client_lock:
lf_client = client.session.client(
"lakeformation", region_name=client.region_name, config=get_boto3_config()
)

response = lf_client.add_lf_tags_to_resource(
Resource=resource,
LFTags=[
{
"TagKey": key,
"TagValues": [
value,
],
}
for key, value in lf_tags.items()
],
)

failures = response.get("Failures", [])
tbl_appendix = f".{table}" if table else ""
if failures:
base_msg = f"Failed to add LF tags: {lf_tags} to {database}" + tbl_appendix
for failure in failures:
tag = failure.get("LFTag", {}).get("TagKey")
error = failure.get("Error", {}).get("ErrorMessage")
logger.error(f"Failed to set {tag} for {database}" + tbl_appendix + f" - {error}")
raise DbtRuntimeError(base_msg)
else:
logger.debug(f"Added LF tags: {lf_tags} to {database}" + tbl_appendix)

@available
def get_work_group_output_location(self) -> Optional[str]:
conn = self.connections.get_thread_connection()
Expand Down Expand Up @@ -120,7 +174,7 @@ def s3_table_location(
return table_location

@available
def get_table_location(self, database_name: str, table_name: str) -> [str, None]:
def get_table_location(self, database_name: str, table_name: str) -> Union[str, None]:
"""
Helper function to S3 get table location
"""
Expand Down
2 changes: 2 additions & 0 deletions dbt/include/athena/macros/adapters/schema.sql
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
{%- call statement('create_schema') -%}
create schema if not exists {{ relation.without_identifier().render_hive() }}
{% endcall %}

{{ adapter.add_lf_tags(relation.schema) }}
{% endmacro %}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
{% set strategy = validate_get_incremental_strategy(raw_strategy, table_type) %}
{% set on_schema_change = incremental_validate_on_schema_change(config.get('on_schema_change'), default='ignore') %}

{% set lf_tags = config.get('lf_tags', default=none) %}
{% set partitioned_by = config.get('partitioned_by', default=none) %}
{% set target_relation = this.incorporate(type='table') %}
{% set existing_relation = load_relation(this) %}
Expand Down Expand Up @@ -83,6 +84,10 @@

{{ run_hooks(post_hooks, inside_transaction=False) }}

{% if lf_tags is not none %}
{{ adapter.add_lf_tags(target_relation.schema, target_relation.identifier, lf_tags) }}
{% endif %}

{{ return({'relations': [target_relation]}) }}

{%- endmaterialization %}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{% materialization table, adapter='athena' -%}
{%- set identifier = model['alias'] -%}

{%- set lf_tags = config.get('lf_tags', default=none) -%}
{%- set table_type = config.get('table_type', default='hive') | lower -%}
{%- set old_relation = adapter.get_relation(database=database, schema=schema, identifier=identifier) -%}
{%- set target_relation = api.Relation.create(identifier=identifier,
Expand All @@ -26,6 +27,10 @@

{{ run_hooks(post_hooks) }}

{% if lf_tags is not none %}
{{ adapter.add_lf_tags(target_relation.schema, identifier, lf_tags) }}
{% endif %}

{% do persist_docs(target_relation, model) %}

{{ return({'relations': [target_relation]}) }}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
{% macro create_or_replace_view(run_outside_transaction_hooks=True) %}
{%- set identifier = model['alias'] -%}

{%- set lf_tags = config.get('lf_tags', default=none) -%}
{%- set old_relation = adapter.get_relation(database=database, schema=schema, identifier=identifier) -%}
{%- set exists_as_view = (old_relation is not none and old_relation.is_view) -%}
{%- set target_relation = api.Relation.create(
identifier=identifier, schema=schema, database=database,
type='view') -%}
identifier=identifier,
schema=schema,
database=database,
type='view',
) -%}

{% if run_outside_transaction_hooks %}
-- no transactions on BigQuery
{{ run_hooks(pre_hooks, inside_transaction=False) }}
Expand All @@ -23,6 +29,10 @@
{{ create_view_as(target_relation, sql) }}
{%- endcall %}

{% if lf_tags is not none %}
{{ adapter.add_lf_tags(target_relation.schema, identifier, lf_tags) }}
{% endif %}

{{ run_hooks(post_hooks, inside_transaction=True) }}

{{ adapter.commit() }}
Expand Down
7 changes: 7 additions & 0 deletions dbt/include/athena/macros/materializations/seeds/helpers.sql
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
{% endmacro %}

{% macro athena__create_csv_table(model, agate_table) %}
{%- set identifier = model['alias'] -%}

{%- set lf_tags = config.get('lf_tags', default=none) -%}
{%- set column_override = config.get('column_types', {}) -%}
{%- set quote_seed_column = config.get('quote_columns', None) -%}
{%- set s3_data_dir = config.get('s3_data_dir', default=target.s3_data_dir) -%}
Expand All @@ -32,5 +35,9 @@
{{ sql }}
{%- endcall %}

{% if lf_tags is not none %}
{{ adapter.add_lf_tags(model.schema, identifier, lf_tags) }}
{% endif %}

{{ return(sql) }}
{% endmacro %}
3 changes: 3 additions & 0 deletions dbt/include/athena/sample_profiles.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ default:
region_name: [region_name]
database: [database name]
schema: [dev_schema]
lf_tags:
origin: dbt
team: analytics

prod:
type: athena
Expand Down
1 change: 1 addition & 0 deletions tests/unit/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
DATABASE_NAME = "test_dbt_athena"
BUCKET = "test-dbt-athena"
AWS_REGION = "eu-west-1"
S3_STAGING_DIR = "s3://my-bucket/test-dbt/"
ATHENA_WORKGROUP = "dbt-athena-adapter"
3 changes: 2 additions & 1 deletion tests/unit/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
BUCKET,
DATA_CATALOG_NAME,
DATABASE_NAME,
S3_STAGING_DIR,
)
from .utils import (
MockAWSService,
Expand All @@ -50,7 +51,7 @@ def setup_method(self, _):
"outputs": {
"test": {
"type": "athena",
"s3_staging_dir": "s3://my-bucket/test-dbt/",
"s3_staging_dir": S3_STAGING_DIR,
"region_name": AWS_REGION,
"database": DATA_CATALOG_NAME,
"work_group": ATHENA_WORKGROUP,
Expand Down

0 comments on commit 514830d

Please sign in to comment.