Skip to content

Commit

Permalink
Avoid show table extended command. (#231)
Browse files Browse the repository at this point in the history
### Description

Avoids show table extended command.

This is based on dbt-labs/dbt-spark#433.

1. Create a table/view list with `show tables in {{ relation }}` and `show views in {{ relation }}` commands, or `get_tables` API when `catalog` is provided.
2. Retrieve additional information by `describe extended {{ relation }}` command.
    1. `get_relation` with `needs_information=True`
    2. `get_columns_in_relation`
  • Loading branch information
ueshin committed Dec 5, 2022
1 parent 8e420cc commit 02d50f3
Show file tree
Hide file tree
Showing 9 changed files with 264 additions and 324 deletions.
13 changes: 13 additions & 0 deletions dbt/adapters/databricks/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,11 @@ def description(
def schemas(self, catalog_name: str, schema_name: Optional[str] = None) -> None:
self._cursor.schemas(catalog_name=catalog_name, schema_name=schema_name)

def tables(self, catalog_name: str, schema_name: str, table_name: Optional[str] = None) -> None:
self._cursor.tables(
catalog_name=catalog_name, schema_name=schema_name, table_name=table_name
)

def __del__(self) -> None:
if self._cursor.open:
# This should not happen. The cursor should explicitly be closed.
Expand Down Expand Up @@ -528,6 +533,14 @@ def list_schemas(self, database: str, schema: Optional[str] = None) -> Table:
lambda cursor: cursor.schemas(catalog_name=database, schema_name=schema),
)

def list_tables(self, database: str, schema: str, identifier: Optional[str] = None) -> Table:
return self._execute_cursor(
f"GetTables(database={database}, schema={schema}, identifier={identifier})",
lambda cursor: cursor.tables(
catalog_name=database, schema_name=schema, table_name=identifier
),
)

@classmethod
def open(cls, connection: Connection) -> Connection:
if connection.state == ConnectionState.OPEN:
Expand Down
180 changes: 111 additions & 69 deletions dbt/adapters/databricks/impl.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from concurrent.futures import Future
from contextlib import contextmanager
from dataclasses import dataclass
import re
from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Type, Union, cast
from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Type, Union

from agate import Row, Table
from agate import Row, Table, Text

from dbt.adapters.base import AdapterConfig, PythonJobHelper
from dbt.adapters.base.impl import catch_as_completed
Expand All @@ -18,6 +17,7 @@
LIST_RELATIONS_MACRO_NAME,
LIST_SCHEMAS_MACRO_NAME,
)
from dbt.clients.agate_helper import empty_table
from dbt.contracts.connection import AdapterResponse, Connection
from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.relation import RelationType
Expand All @@ -40,6 +40,15 @@
CURRENT_CATALOG_MACRO_NAME = "current_catalog"
USE_CATALOG_MACRO_NAME = "use_catalog"

SHOW_TABLES_MACRO_NAME = "show_tables"
SHOW_VIEWS_MACRO_NAME = "show_views"

TABLE_OR_VIEW_NOT_FOUND_MESSAGES = (
"[TABLE_OR_VIEW_NOT_FOUND]",
"Table or view not found",
"NoSuchTableException",
)


@dataclass
class DatabricksConfig(AdapterConfig):
Expand Down Expand Up @@ -113,9 +122,7 @@ def list_relations_without_caching( # type: ignore[override]
) -> List[DatabricksRelation]:
kwargs = {"schema_relation": schema_relation}
try:
# The catalog for `show table extended` needs to match the current catalog.
with self._catalog(schema_relation.database):
results = self.execute_macro(LIST_RELATIONS_MACRO_NAME, kwargs=kwargs)
results = self.execute_macro(LIST_RELATIONS_MACRO_NAME, kwargs=kwargs)
except dbt.exceptions.RuntimeException as e:
errmsg = getattr(e, "msg", "")
if f"Database '{schema_relation}' not found" in errmsg:
Expand All @@ -125,35 +132,79 @@ def list_relations_without_caching( # type: ignore[override]
logger.debug(f"{description} {schema_relation}: {e.msg}")
return []

relations = []
for row in results:
if len(row) != 4:
raise dbt.exceptions.RuntimeException(
f'Invalid value from "show table extended ...", '
f"got {len(row)} values, expected 4"
)
_schema, name, _, information = row
rel_type = RelationType.View if "Type: VIEW" in information else RelationType.Table
is_delta = "Provider: delta" in information
is_hudi = "Provider: hudi" in information
relation = self.Relation.create(
database=schema_relation.database,
# Use `_schema` retrieved from the cluster to avoid mismatched case
# between the profile and the cluster.
schema=_schema,
return [
self.Relation.create(
database=database,
schema=schema,
identifier=name,
type=rel_type,
information=information,
is_delta=is_delta,
is_hudi=is_hudi,
type=self.Relation.get_relation_type(kind),
)
relations.append(relation)
for database, schema, name, kind in results.select(
["database_name", "schema_name", "name", "kind"]
)
]

@available.parse(lambda *a, **k: empty_table())
def get_relations_without_caching(self, relation: DatabricksRelation) -> Table:
kwargs = {"relation": relation}

new_rows: List[Tuple]
if relation.database is not None:
assert relation.schema is not None
tables = self.connections.list_tables(
database=relation.database, schema=relation.schema
)
new_rows = [
(row["TABLE_CAT"], row["TABLE_SCHEM"], row["TABLE_NAME"], row["TABLE_TYPE"].lower())
for row in tables
]
else:
tables = self.execute_macro(SHOW_TABLES_MACRO_NAME, kwargs=kwargs)
new_rows = [
(relation.database, row["database"], row["tableName"], "") for row in tables
]

if any(not row[3] for row in new_rows):
with self._catalog(relation.database):
views = self.execute_macro(SHOW_VIEWS_MACRO_NAME, kwargs=kwargs)

view_names = set(views.columns["viewName"].values())
new_rows = [
(
row[0],
row[1],
row[2],
str(RelationType.View if row[2] in view_names else RelationType.Table),
)
for row in new_rows
]

return Table(
new_rows,
column_names=["database_name", "schema_name", "name", "kind"],
column_types=[Text(), Text(), Text(), Text()],
)

def get_relation(
self,
database: Optional[str],
schema: str,
identifier: str,
*,
needs_information: bool = False,
) -> Optional[DatabricksRelation]:
cached: Optional[DatabricksRelation] = super(SparkAdapter, self).get_relation(
database=database, schema=schema, identifier=identifier
)

if not needs_information:
return cached

return relations
return self._set_relation_information(cached) if cached else None

def parse_describe_extended( # type: ignore[override]
self, relation: DatabricksRelation, raw_rows: List[Row]
) -> List[DatabricksColumn]:
) -> Tuple[Dict[str, Any], List[DatabricksColumn]]:
# Convert the Row to a dict
dict_rows = [dict(zip(row._keys, row._values)) for row in raw_rows]
# Find the separator between the rows and the metadata provided
Expand All @@ -166,7 +217,7 @@ def parse_describe_extended( # type: ignore[override]

raw_table_stats = metadata.get(KEY_TABLE_STATISTICS)
table_stats = DatabricksColumn.convert_table_stats(raw_table_stats)
return [
return metadata, [
DatabricksColumn(
table_database=relation.database,
table_schema=relation.schema,
Expand All @@ -184,56 +235,47 @@ def parse_describe_extended( # type: ignore[override]
def get_columns_in_relation( # type: ignore[override]
self, relation: DatabricksRelation
) -> List[DatabricksColumn]:
columns = []
return self._get_updated_relation(relation)[1]

def _get_updated_relation(
self, relation: DatabricksRelation
) -> Tuple[DatabricksRelation, List[DatabricksColumn]]:
try:
rows: List[Row] = self.execute_macro(
rows = self.execute_macro(
GET_COLUMNS_IN_RELATION_RAW_MACRO_NAME, kwargs={"relation": relation}
)
columns = self.parse_describe_extended(relation, rows)
metadata, columns = self.parse_describe_extended(relation, rows)
except dbt.exceptions.RuntimeException as e:
# spark would throw error when table doesn't exist, where other
# CDW would just return and empty list, normalizing the behavior here
errmsg = getattr(e, "msg", "")
if any(
msg in errmsg
for msg in (
"[TABLE_OR_VIEW_NOT_FOUND]",
"Table or view not found",
"NoSuchTableException",
)
):
pass
found_msgs = (msg in errmsg for msg in TABLE_OR_VIEW_NOT_FOUND_MESSAGES)
if any(found_msgs):
metadata = None
columns = []
else:
raise e

# strip hudi metadata columns.
return [x for x in columns if x.name not in self.HUDI_METADATA_COLUMNS]
columns = [x for x in columns if x.name not in self.HUDI_METADATA_COLUMNS]

return (
self.Relation.create(
database=relation.database,
schema=relation.schema,
identifier=relation.identifier,
type=relation.type,
metadata=metadata,
),
columns,
)

def parse_columns_from_information( # type: ignore[override]
self, relation: DatabricksRelation
) -> List[DatabricksColumn]:
owner_match = re.findall(self.INFORMATION_OWNER_REGEX, cast(str, relation.information))
owner = owner_match[0] if owner_match else None
matches = re.finditer(self.INFORMATION_COLUMNS_REGEX, cast(str, relation.information))
columns = []
stats_match = re.findall(self.INFORMATION_STATISTICS_REGEX, cast(str, relation.information))
raw_table_stats = stats_match[0] if stats_match else None
table_stats = DatabricksColumn.convert_table_stats(raw_table_stats)
for match_num, match in enumerate(matches):
column_name, column_type, nullable = match.groups()
column = DatabricksColumn(
table_database=relation.database,
table_schema=relation.schema,
table_name=relation.table,
table_type=relation.type,
column_index=(match_num + 1),
table_owner=owner,
column=column_name,
dtype=column_type,
table_stats=table_stats,
)
columns.append(column)
return columns
def _set_relation_information(self, relation: DatabricksRelation) -> DatabricksRelation:
"""Update the information of the relation, or return it if it already exists."""
if relation.has_information():
return relation

return self._get_updated_relation(relation)[0]

def get_catalog(self, manifest: Manifest) -> Tuple[Table, List[Exception]]:
schema_map = self._get_catalog_schemas(manifest)
Expand All @@ -253,7 +295,7 @@ def get_catalog(self, manifest: Manifest) -> Tuple[Table, List[Exception]]:
def _get_columns_for_catalog( # type: ignore[override]
self, relation: DatabricksRelation
) -> Iterable[Dict[str, Any]]:
columns = self.parse_columns_from_information(relation)
columns = self.get_columns_in_relation(relation)

for column in columns:
# convert DatabricksRelation into catalog dicts
Expand Down
47 changes: 38 additions & 9 deletions dbt/adapters/databricks/relation.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
from dataclasses import dataclass
from typing import Any, Dict
from typing import Any, Dict, Optional

from dbt.adapters.base.relation import Policy
from dbt.adapters.spark.relation import SparkRelation
from dbt.adapters.base.relation import BaseRelation, Policy
from dbt.adapters.spark.impl import KEY_TABLE_OWNER, KEY_TABLE_STATISTICS

from dbt.adapters.databricks.utils import remove_undefined


KEY_TABLE_PROVIDER = "Provider"


@dataclass
class DatabricksQuotePolicy(Policy):
database: bool = False
schema: bool = False
identifier: bool = False


@dataclass
class DatabricksIncludePolicy(Policy):
database: bool = True
Expand All @@ -15,8 +25,12 @@ class DatabricksIncludePolicy(Policy):


@dataclass(frozen=True, eq=False, repr=False)
class DatabricksRelation(SparkRelation):
include_policy: DatabricksIncludePolicy = DatabricksIncludePolicy() # type: ignore[assignment]
class DatabricksRelation(BaseRelation):
quote_policy = DatabricksQuotePolicy()
include_policy = DatabricksIncludePolicy()
quote_character: str = "`"

metadata: Optional[Dict[str, Any]] = None

@classmethod
def __pre_deserialize__(cls, data: Dict[Any, Any]) -> Dict[Any, Any]:
Expand All @@ -27,8 +41,23 @@ def __pre_deserialize__(cls, data: Dict[Any, Any]) -> Dict[Any, Any]:
data["path"]["database"] = remove_undefined(data["path"]["database"])
return data

def __post_init__(self) -> None:
return
def has_information(self) -> bool:
return self.metadata is not None

@property
def is_delta(self) -> bool:
assert self.metadata is not None
return self.metadata.get(KEY_TABLE_PROVIDER) == "delta"

@property
def is_hudi(self) -> bool:
assert self.metadata is not None
return self.metadata.get(KEY_TABLE_PROVIDER) == "hudi"

@property
def owner(self) -> Optional[str]:
return self.metadata.get(KEY_TABLE_OWNER) if self.metadata is not None else None

def render(self) -> str:
return super(SparkRelation, self).render()
@property
def stats(self) -> Optional[str]:
return self.metadata.get(KEY_TABLE_STATISTICS) if self.metadata is not None else None
Loading

0 comments on commit 02d50f3

Please sign in to comment.