diff --git a/CHANGELOG.md b/CHANGELOG.md index 75caf04bb..2b56fde91 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ ### Features - Avoid show table extended command. ([#231](https://github.com/databricks/dbt-databricks/pull/231)) +- Use show table extended with table name list for get_catalog. ([#237](https://github.com/databricks/dbt-databricks/pull/237)) ## dbt-databricks 1.3.2 (November 9, 2022) diff --git a/dbt/adapters/databricks/column.py b/dbt/adapters/databricks/column.py index 9f6de4f78..a3e088335 100644 --- a/dbt/adapters/databricks/column.py +++ b/dbt/adapters/databricks/column.py @@ -1,9 +1,22 @@ from dataclasses import dataclass +from typing import ClassVar, Dict from dbt.adapters.spark.column import SparkColumn @dataclass class DatabricksColumn(SparkColumn): + TYPE_LABELS: ClassVar[Dict[str, str]] = { + "LONG": "BIGINT", + } + + @classmethod + def translate_type(cls, dtype: str) -> str: + return super(SparkColumn, cls).translate_type(dtype).lower() + + @property + def data_type(self) -> str: + return self.translate_type(self.dtype) + def __repr__(self) -> str: return "".format(self.name, self.data_type) diff --git a/dbt/adapters/databricks/impl.py b/dbt/adapters/databricks/impl.py index b4bb13bea..13731e2f1 100644 --- a/dbt/adapters/databricks/impl.py +++ b/dbt/adapters/databricks/impl.py @@ -1,14 +1,16 @@ from concurrent.futures import Future from contextlib import contextmanager +from itertools import chain from dataclasses import dataclass -from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Type, Union +import re +from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Type, Union from agate import Row, Table, Text from dbt.adapters.base import AdapterConfig, PythonJobHelper from dbt.adapters.base.impl import catch_as_completed from dbt.adapters.base.meta import available -from dbt.adapters.base.relation import BaseRelation +from dbt.adapters.base.relation import BaseRelation, InformationSchema from dbt.adapters.spark.impl import ( SparkAdapter, GET_COLUMNS_IN_RELATION_RAW_MACRO_NAME, @@ -18,9 +20,10 @@ LIST_SCHEMAS_MACRO_NAME, TABLE_OR_VIEW_NOT_FOUND_MESSAGES, ) -from dbt.clients.agate_helper import empty_table +from dbt.clients.agate_helper import DEFAULT_TYPE_TESTER, empty_table from dbt.contracts.connection import AdapterResponse, Connection from dbt.contracts.graph.manifest import Manifest +from dbt.contracts.graph.nodes import ResultNode from dbt.contracts.relation import RelationType import dbt.exceptions from dbt.events import AdapterLogger @@ -41,6 +44,7 @@ CURRENT_CATALOG_MACRO_NAME = "current_catalog" USE_CATALOG_MACRO_NAME = "use_catalog" +SHOW_TABLE_EXTENDED_MACRO_NAME = "show_table_extended" SHOW_TABLES_MACRO_NAME = "show_tables" SHOW_VIEWS_MACRO_NAME = "show_views" @@ -120,7 +124,10 @@ def list_relations_without_caching( # type: ignore[override] results = self.execute_macro(LIST_RELATIONS_MACRO_NAME, kwargs=kwargs) except dbt.exceptions.RuntimeException as e: errmsg = getattr(e, "msg", "") - if f"Database '{schema_relation}' not found" in errmsg: + if ( + "[SCHEMA_NOT_FOUND]" in errmsg + or f"Database '{schema_relation}' not found" in errmsg + ): return [] else: description = "Error while retrieving information about" @@ -139,6 +146,47 @@ def list_relations_without_caching( # type: ignore[override] ) ] + def _list_relations_with_information( + self, schema_relation: DatabricksRelation + ) -> List[Tuple[DatabricksRelation, str]]: + kwargs = {"schema_relation": schema_relation} + try: + # The catalog for `show table extended` needs to match the current catalog. + with self._catalog(schema_relation.database): + results = self.execute_macro(SHOW_TABLE_EXTENDED_MACRO_NAME, kwargs=kwargs) + except dbt.exceptions.RuntimeException as e: + errmsg = getattr(e, "msg", "") + if ( + "[SCHEMA_NOT_FOUND]" in errmsg + or f"Database '{schema_relation.without_identifier()}' not found" in errmsg + ): + results = [] + else: + description = "Error while retrieving information about" + logger.debug(f"{description} {schema_relation.without_identifier()}: {e.msg}") + results = [] + + relations: List[Tuple[DatabricksRelation, str]] = [] + for row in results: + if len(row) != 4: + raise dbt.exceptions.RuntimeException( + f'Invalid value from "show table extended ...", ' + f"got {len(row)} values, expected 4" + ) + _schema, name, _, information = row + rel_type = RelationType.View if "Type: VIEW" in information else RelationType.Table + relation = self.Relation.create( + database=schema_relation.database, + # Use `_schema` retrieved from the cluster to avoid mismatched case + # between the profile and the cluster. + schema=_schema, + identifier=name, + type=rel_type, + ) + relations.append((relation, information)) + + return relations + @available.parse(lambda *a, **k: empty_table()) def get_relations_without_caching(self, relation: DatabricksRelation) -> Table: kwargs = {"relation": relation} @@ -272,6 +320,32 @@ def _set_relation_information(self, relation: DatabricksRelation) -> DatabricksR return self._get_updated_relation(relation)[0] + def parse_columns_from_information( # type: ignore[override] + self, relation: DatabricksRelation, information: str + ) -> List[DatabricksColumn]: + owner_match = re.findall(self.INFORMATION_OWNER_REGEX, information) + owner = owner_match[0] if owner_match else None + matches = re.finditer(self.INFORMATION_COLUMNS_REGEX, information) + columns = [] + stats_match = re.findall(self.INFORMATION_STATISTICS_REGEX, information) + raw_table_stats = stats_match[0] if stats_match else None + table_stats = DatabricksColumn.convert_table_stats(raw_table_stats) + for match_num, match in enumerate(matches): + column_name, column_type, nullable = match.groups() + column = DatabricksColumn( + table_database=relation.database, + table_schema=relation.schema, + table_name=relation.table, + table_type=relation.type, + column_index=(match_num + 1), + table_owner=owner, + column=column_name, + dtype=DatabricksColumn.translate_type(column_type), + table_stats=table_stats, + ) + columns.append(column) + return columns + def get_catalog(self, manifest: Manifest) -> Tuple[Table, List[Exception]]: schema_map = self._get_catalog_schemas(manifest) @@ -287,10 +361,53 @@ def get_catalog(self, manifest: Manifest) -> Tuple[Table, List[Exception]]: catalogs, exceptions = catch_as_completed(futures) return catalogs, exceptions + def _get_one_catalog( + self, + information_schema: InformationSchema, + schemas: Set[str], + manifest: Manifest, + ) -> Table: + if len(schemas) != 1: + dbt.exceptions.raise_compiler_error( + f"Expected only one schema in spark _get_one_catalog, found " f"{schemas}" + ) + + database = information_schema.database + schema = list(schemas)[0] + + nodes: Iterator[ResultNode] = chain( + ( + node + for node in manifest.nodes.values() + if (node.is_relational and not node.is_ephemeral_model) + ), + manifest.sources.values(), + ) + + table_names: Set[str] = set() + for node in nodes: + if node.database == database and node.schema == schema: + relation = self.Relation.create_from(self.config, node) + if relation.identifier: + table_names.add(relation.identifier) + + columns: List[Dict[str, Any]] = [] + if len(table_names) > 0: + schema_relation = self.Relation.create( + database=database, + schema=schema, + identifier="|".join(table_names), + quote_policy=self.config.quoting, + ) + for relation, information in self._list_relations_with_information(schema_relation): + logger.debug("Getting table schema for relation {}", relation) + columns.extend(self._get_columns_for_catalog(relation, information)) + return Table.from_object(columns, column_types=DEFAULT_TYPE_TESTER) + def _get_columns_for_catalog( # type: ignore[override] - self, relation: DatabricksRelation + self, relation: DatabricksRelation, information: str ) -> Iterable[Dict[str, Any]]: - columns = self.get_columns_in_relation(relation) + columns = self.parse_columns_from_information(relation, information) for column in columns: # convert DatabricksRelation into catalog dicts diff --git a/dbt/include/databricks/macros/adapters.sql b/dbt/include/databricks/macros/adapters.sql index 392c84e89..8d4af7552 100644 --- a/dbt/include/databricks/macros/adapters.sql +++ b/dbt/include/databricks/macros/adapters.sql @@ -159,6 +159,18 @@ {{ return(adapter.get_relations_without_caching(schema_relation)) }} {% endmacro %} +{% macro show_table_extended(schema_relation) %} + {{ return(adapter.dispatch('show_table_extended', 'dbt')(schema_relation)) }} +{% endmacro %} + +{% macro databricks__show_table_extended(schema_relation) %} + {% call statement('show_table_extended', fetch_result=True) -%} + show table extended in {{ schema_relation.without_identifier() }} like '{{ schema_relation.identifier }}' + {% endcall %} + + {% do return(load_result('show_table_extended').table) %} +{% endmacro %} + {% macro show_tables(relation) %} {{ return(adapter.dispatch('show_tables', 'dbt')(relation)) }} {% endmacro %} diff --git a/tests/unit/test_adapter.py b/tests/unit/test_adapter.py index 6707bb3f6..9480a33cb 100644 --- a/tests/unit/test_adapter.py +++ b/tests/unit/test_adapter.py @@ -624,3 +624,239 @@ def test_relation_with_database(self): assert r1.database is None r2 = adapter.Relation.create(database="something", schema="different", identifier="table") assert r2.database == "something" + + def test_parse_columns_from_information_with_table_type_and_delta_provider(self): + self.maxDiff = None + rel_type = DatabricksRelation.get_relation_type.Table + + # Mimics the output of Spark in the information column + information = ( + "Database: default_schema\n" + "Table: mytable\n" + "Owner: root\n" + "Created Time: Wed Feb 04 18:15:00 UTC 1815\n" + "Last Access: Wed May 20 19:25:00 UTC 1925\n" + "Created By: Spark 3.0.1\n" + "Type: MANAGED\n" + "Provider: delta\n" + "Statistics: 123456789 bytes\n" + "Location: /mnt/vo\n" + "Serde Library: org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe\n" + "InputFormat: org.apache.hadoop.mapred.SequenceFileInputFormat\n" + "OutputFormat: org.apache.hadoop.hive.ql.io.HiveSequenceFileOutputFormat\n" + "Partition Provider: Catalog\n" + "Partition Columns: [`dt`]\n" + "Schema: root\n" + " |-- col1: decimal(22,0) (nullable = true)\n" + " |-- col2: string (nullable = true)\n" + " |-- dt: date (nullable = true)\n" + " |-- struct_col: struct (nullable = true)\n" + " | |-- struct_inner_col: string (nullable = true)\n" + ) + relation = DatabricksRelation.create( + schema="default_schema", identifier="mytable", type=rel_type + ) + + config = self._get_target_databricks_sql_connector(self.project_cfg) + columns = DatabricksAdapter(config).parse_columns_from_information(relation, information) + self.assertEqual(len(columns), 4) + self.assertEqual( + columns[0].to_column_dict(omit_none=False), + { + "table_database": None, + "table_schema": relation.schema, + "table_name": relation.name, + "table_type": rel_type, + "table_owner": "root", + "column": "col1", + "column_index": 1, + "dtype": "decimal(22,0)", + "numeric_scale": None, + "numeric_precision": None, + "char_size": None, + "stats:bytes:description": "", + "stats:bytes:include": True, + "stats:bytes:label": "bytes", + "stats:bytes:value": 123456789, + }, + ) + + self.assertEqual( + columns[3].to_column_dict(omit_none=False), + { + "table_database": None, + "table_schema": relation.schema, + "table_name": relation.name, + "table_type": rel_type, + "table_owner": "root", + "column": "struct_col", + "column_index": 4, + "dtype": "struct", + "numeric_scale": None, + "numeric_precision": None, + "char_size": None, + "stats:bytes:description": "", + "stats:bytes:include": True, + "stats:bytes:label": "bytes", + "stats:bytes:value": 123456789, + }, + ) + + def test_parse_columns_from_information_with_view_type(self): + self.maxDiff = None + rel_type = DatabricksRelation.get_relation_type.View + information = ( + "Database: default_schema\n" + "Table: myview\n" + "Owner: root\n" + "Created Time: Wed Feb 04 18:15:00 UTC 1815\n" + "Last Access: UNKNOWN\n" + "Created By: Spark 3.0.1\n" + "Type: VIEW\n" + "View Text: WITH base (\n" + " SELECT * FROM source_table\n" + ")\n" + "SELECT col1, col2, dt FROM base\n" + "View Original Text: WITH base (\n" + " SELECT * FROM source_table\n" + ")\n" + "SELECT col1, col2, dt FROM base\n" + "View Catalog and Namespace: spark_catalog.default\n" + "View Query Output Columns: [col1, col2, dt]\n" + "Table Properties: [view.query.out.col.1=col1, view.query.out.col.2=col2, " + "transient_lastDdlTime=1618324324, view.query.out.col.3=dt, " + "view.catalogAndNamespace.part.0=spark_catalog, " + "view.catalogAndNamespace.part.1=default]\n" + "Serde Library: org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe\n" + "InputFormat: org.apache.hadoop.mapred.SequenceFileInputFormat\n" + "OutputFormat: org.apache.hadoop.hive.ql.io.HiveSequenceFileOutputFormat\n" + "Storage Properties: [serialization.format=1]\n" + "Schema: root\n" + " |-- col1: decimal(22,0) (nullable = true)\n" + " |-- col2: string (nullable = true)\n" + " |-- dt: date (nullable = true)\n" + " |-- struct_col: struct (nullable = true)\n" + " | |-- struct_inner_col: string (nullable = true)\n" + ) + relation = DatabricksRelation.create( + schema="default_schema", identifier="myview", type=rel_type + ) + + config = self._get_target_databricks_sql_connector(self.project_cfg) + columns = DatabricksAdapter(config).parse_columns_from_information(relation, information) + self.assertEqual(len(columns), 4) + self.assertEqual( + columns[1].to_column_dict(omit_none=False), + { + "table_database": None, + "table_schema": relation.schema, + "table_name": relation.name, + "table_type": rel_type, + "table_owner": "root", + "column": "col2", + "column_index": 2, + "dtype": "string", + "numeric_scale": None, + "numeric_precision": None, + "char_size": None, + }, + ) + + self.assertEqual( + columns[3].to_column_dict(omit_none=False), + { + "table_database": None, + "table_schema": relation.schema, + "table_name": relation.name, + "table_type": rel_type, + "table_owner": "root", + "column": "struct_col", + "column_index": 4, + "dtype": "struct", + "numeric_scale": None, + "numeric_precision": None, + "char_size": None, + }, + ) + + def test_parse_columns_from_information_with_table_type_and_parquet_provider(self): + self.maxDiff = None + rel_type = DatabricksRelation.get_relation_type.Table + + information = ( + "Database: default_schema\n" + "Table: mytable\n" + "Owner: root\n" + "Created Time: Wed Feb 04 18:15:00 UTC 1815\n" + "Last Access: Wed May 20 19:25:00 UTC 1925\n" + "Created By: Spark 3.0.1\n" + "Type: MANAGED\n" + "Provider: parquet\n" + "Statistics: 1234567890 bytes, 12345678 rows\n" + "Location: /mnt/vo\n" + "Serde Library: org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe\n" + "InputFormat: org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat\n" + "OutputFormat: org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat\n" + "Schema: root\n" + " |-- col1: decimal(22,0) (nullable = true)\n" + " |-- col2: string (nullable = true)\n" + " |-- dt: date (nullable = true)\n" + " |-- struct_col: struct (nullable = true)\n" + " | |-- struct_inner_col: string (nullable = true)\n" + ) + relation = DatabricksRelation.create( + schema="default_schema", identifier="mytable", type=rel_type + ) + + config = self._get_target_databricks_sql_connector(self.project_cfg) + columns = DatabricksAdapter(config).parse_columns_from_information(relation, information) + self.assertEqual(len(columns), 4) + self.assertEqual( + columns[2].to_column_dict(omit_none=False), + { + "table_database": None, + "table_schema": relation.schema, + "table_name": relation.name, + "table_type": rel_type, + "table_owner": "root", + "column": "dt", + "column_index": 3, + "dtype": "date", + "numeric_scale": None, + "numeric_precision": None, + "char_size": None, + "stats:bytes:description": "", + "stats:bytes:include": True, + "stats:bytes:label": "bytes", + "stats:bytes:value": 1234567890, + "stats:rows:description": "", + "stats:rows:include": True, + "stats:rows:label": "rows", + "stats:rows:value": 12345678, + }, + ) + + self.assertEqual( + columns[3].to_column_dict(omit_none=False), + { + "table_database": None, + "table_schema": relation.schema, + "table_name": relation.name, + "table_type": rel_type, + "table_owner": "root", + "column": "struct_col", + "column_index": 4, + "dtype": "struct", + "numeric_scale": None, + "numeric_precision": None, + "char_size": None, + "stats:bytes:description": "", + "stats:bytes:include": True, + "stats:bytes:label": "bytes", + "stats:bytes:value": 1234567890, + "stats:rows:description": "", + "stats:rows:include": True, + "stats:rows:label": "rows", + "stats:rows:value": 12345678, + }, + )