Skip to content

Commit

Permalink
Ensure column name is backticked to fix 859
Browse files Browse the repository at this point in the history
  • Loading branch information
benc-db committed Dec 2, 2024
1 parent 6398033 commit ce169ca
Show file tree
Hide file tree
Showing 8 changed files with 85 additions and 25 deletions.
16 changes: 15 additions & 1 deletion dbt/adapters/databricks/column.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import dataclass
from typing import ClassVar, Optional
from typing import Any, ClassVar, Optional

from dbt.adapters.databricks.utils import quote
from dbt.adapters.spark.column import SparkColumn


Expand Down Expand Up @@ -28,3 +29,16 @@ def data_type(self) -> str:

def __repr__(self) -> str:
return "<DatabricksColumn {} ({})>".format(self.name, self.data_type)

@staticmethod
def get_name(column: dict[str, Any]) -> str:
name = column["name"]
return quote(name) if column.get("quote", False) else name

@staticmethod
def format_remove_column_list(columns: list["DatabricksColumn"]) -> str:
return ", ".join([quote(c.name) for c in columns])

@staticmethod
def format_add_column_list(columns: list["DatabricksColumn"]) -> str:
return ", ".join([f"{quote(c.name)} {c.data_type}" for c in columns])
4 changes: 4 additions & 0 deletions dbt/adapters/databricks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,7 @@ def handle_missing_objects(exec: Callable[[], T], default: T) -> T:
if check_not_found_error(errmsg):
return default
raise e


def quote(name: str) -> str:
return f"`{name}`"
17 changes: 17 additions & 0 deletions dbt/include/databricks/macros/adapters/columns.sql
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,20 @@

{% do return(load_result('get_columns_comments_via_information_schema').table) %}
{% endmacro %}

{% macro databricks__alter_relation_add_remove_columns(relation, add_columns, remove_columns) %}
{% if remove_columns %}
{% if not relation.is_delta %}
{{ exceptions.raise_compiler_error('Delta format required for dropping columns from tables') }}
{% endif %}
{%- call statement('alter_relation_remove_columns') -%}
ALTER TABLE {{ relation }} DROP COLUMNS ({{ api.Column.format_remove_column_list(remove_columns) }})
{%- endcall -%}
{% endif %}

{% if add_columns %}
{%- call statement('alter_relation_add_columns') -%}
ALTER TABLE {{ relation }} ADD COLUMNS ({{ api.Column.format_add_column_list(add_columns) }})
{%- endcall -%}
{% endif %}
{% endmacro %}
23 changes: 3 additions & 20 deletions dbt/include/databricks/macros/adapters/persist_docs.sql
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
{% macro databricks__alter_column_comment(relation, column_dict) %}
{% if config.get('file_format', default='delta') in ['delta', 'hudi'] %}
{% for column_name in column_dict %}
{% set comment = column_dict[column_name]['description'] %}
{% for column in column_dict.values() %}
{% set comment = column['description'] %}
{% set escaped_comment = comment | replace('\'', '\\\'') %}
{% set comment_query %}
alter table {{ relation }} change column
{{ adapter.quote(column_name) if column_dict[column_name]['quote'] else column_name }}
comment '{{ escaped_comment }}';
alter table {{ relation }} change column {{ api.Column.get_name(column) }} comment '{{ escaped_comment }}';
{% endset %}
{% do run_query(comment_query) %}
{% endfor %}
Expand All @@ -30,18 +28,3 @@
{% do alter_column_comment(relation, columns_to_persist_docs) %}
{% endif %}
{% endmacro %}

{% macro get_column_comment_sql(column_name, column_dict) -%}
{% if column_name in column_dict and column_dict[column_name]["description"] -%}
{% set escaped_description = column_dict[column_name]["description"] | replace("'", "\\'") %}
{% set column_comment_clause = "comment '" ~ escaped_description ~ "'" %}
{%- endif -%}
{{ adapter.quote(column_name) }} {{ column_comment_clause }}
{% endmacro %}

{% macro get_persist_docs_column_list(model_columns, query_columns) %}
{% for column_name in query_columns %}
{{ get_column_comment_sql(column_name, model_columns) }}
{{- ", " if not loop.last else "" }}
{% endfor %}
{% endmacro %}
6 changes: 3 additions & 3 deletions dbt/include/databricks/macros/relations/constraints.sql
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@
{% for column_name in column_names %}
{% set column = model.get('columns', {}).get(column_name) %}
{% if column %}
{% set quoted_name = adapter.quote(column['name']) if column['quote'] else column['name'] %}
{% set quoted_name = api.Column.get_name(column) %}
{% set stmt = "alter table " ~ relation ~ " change column " ~ quoted_name ~ " set not null " ~ (constraint.expression or "") ~ ";" %}
{% do statements.append(stmt) %}
{% else %}
Expand All @@ -154,7 +154,7 @@
{% if not column %}
{{ exceptions.warn('Invalid primary key column: ' ~ column_name) }}
{% else %}
{% set quoted_name = adapter.quote(column['name']) if column['quote'] else column['name'] %}
{% set quoted_name = api.Column.get_name(column) %}
{% do quoted_names.append(quoted_name) %}
{% endif %}
{% endfor %}
Expand Down Expand Up @@ -203,7 +203,7 @@
{% if not column %}
{{ exceptions.warn('Invalid foreign key column: ' ~ column_name) }}
{% else %}
{% set quoted_name = adapter.quote(column['name']) if column['quote'] else column['name'] %}
{% set quoted_name = api.Column.get_name(column) %}
{% do quoted_names.append(quoted_name) %}
{% endif %}
{% endfor %}
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/macros/relations/test_constraint_macros.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from unittest.mock import Mock
from dbt.adapters.databricks.column import DatabricksColumn
import pytest

from tests.unit.macros.base import MacroTestBase
Expand All @@ -16,6 +18,7 @@ def macro_folders_to_load(self) -> list:
def modify_context(self, default_context) -> None:
# Mock local_md5
default_context["local_md5"] = lambda s: f"hash({s})"
default_context["api"] = Mock(Column=DatabricksColumn)

def render_constraints(self, template, *args):
return self.run_macro(template, "databricks_constraints_to_dbt", *args)
Expand Down
36 changes: 36 additions & 0 deletions tests/unit/test_column.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pytest

from dbt.adapters.databricks.column import DatabricksColumn


Expand All @@ -24,3 +26,37 @@ def test_convert_table_stats_with_bytes_and_rows(self):
"stats:rows:label": "rows",
"stats:rows:value": 12345678,
}

class TestColumnStatics:
@pytest.mark.parametrize(
"column, expected",
[
({"name": "foo", "quote": True}, "`foo`"),
({"name": "foo", "quote": False}, "foo"),
({"name": "foo"}, "foo"),
],
)
def test_get_name(self, column, expected):
assert DatabricksColumn.get_name(column) == expected

@pytest.mark.parametrize(
"columns, expected",
[
([], ""),
([DatabricksColumn("foo", "string")], "`foo`"),
([DatabricksColumn("foo", "string"), DatabricksColumn("bar", "int")], "`foo`, `bar`"),
],
)
def test_format_remove_column_list(self, columns, expected):
assert DatabricksColumn.format_remove_column_list(columns) == expected

@pytest.mark.parametrize(
"columns, expected",
[
([], ""),
([DatabricksColumn("foo", "string")], "`foo` string"),
([DatabricksColumn("foo", "string"), DatabricksColumn("bar", "int")], "`foo` string, `bar` int"),
],
)
def test_format_add_column_list(self, columns, expected):
assert DatabricksColumn.format_add_column_list(columns) == expected
5 changes: 4 additions & 1 deletion tests/unit/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dbt.adapters.databricks.utils import redact_credentials, remove_ansi
from dbt.adapters.databricks.utils import quote, redact_credentials, remove_ansi


class TestDatabricksUtils:
Expand Down Expand Up @@ -64,3 +64,6 @@ def test_remove_ansi(self):
72 # how to execute python model in notebook
"""
assert remove_ansi(test_string) == expected_string

def test_quote(self):
assert quote("table") == '`table`'

0 comments on commit ce169ca

Please sign in to comment.