Skip to content

Commit

Permalink
fix NPE while get_columns_in_relation
Browse files Browse the repository at this point in the history
  • Loading branch information
TalkWIthKeyboard committed May 15, 2022
1 parent 6cd6f6d commit 345239b
Showing 1 changed file with 20 additions and 14 deletions.
34 changes: 20 additions & 14 deletions dbt/adapters/spark/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def _update(self, relation: _CachedRelation):

self.relations[key].inner = relation.inner

def upsert_relation(self, relation):
def update_relation(self, relation):
"""Update the relation inner to the cache
: param BaseRelation relation: The underlying relation.
Expand Down Expand Up @@ -126,7 +126,10 @@ class SparkAdapter(SQLAdapter):
Column: TypeAlias = SparkColumn
ConnectionManager: TypeAlias = SparkConnectionManager
AdapterSpecificConfigs: TypeAlias = SparkConfig
cache = SparkRelationsCache()

def __init__(self, config):
super().__init__(config)
self.cache = SparkRelationsCache()

@classmethod
def date_function(cls) -> str:
Expand Down Expand Up @@ -212,7 +215,6 @@ def get_relation(
database = None

cached = super().get_relation(database, schema, identifier)
logger.info(f">>> get_relation: {cached.render() if cached is not None else 'Empty'}")
return self._set_relation_information(cached) if cached else None

def parse_describe_extended(
Expand Down Expand Up @@ -262,17 +264,16 @@ def get_columns_in_relation(self, relation: Relation) -> List[SparkColumn]:
None,
)

logger.info(f">>> get_columns_in_relation: {relation.render() if relation is not None else 'Empty'}, "
f"{cached_relation.render() if cached_relation is not None else 'Empty'}")

if not cached_relation:
updated_relation = self.cache.add(self._get_updated_relation(relation))
updated_relation = self._get_updated_relation(relation)
if updated_relation:
self.cache.add(updated_relation)
else:
updated_relation = self._set_relation_information(cached_relation)

return self._get_spark_columns(updated_relation)

def _get_updated_relation(self, relation: BaseRelation) -> SparkRelation:
def _get_updated_relation(self, relation: BaseRelation) -> Optional[SparkRelation]:
metadata = None
columns = []

Expand All @@ -295,16 +296,19 @@ def _get_updated_relation(self, relation: BaseRelation) -> SparkRelation:
columns = [x for x in columns
if x.name not in self.HUDI_METADATA_COLUMNS]

provider = metadata[KEY_TABLE_PROVIDER]
if not metadata:
return None

provider = metadata.get(KEY_TABLE_PROVIDER)
return self.Relation.create(
database=None,
schema=relation.schema,
identifier=relation.identifier,
type=relation.type,
is_delta=(provider == 'delta'),
is_hudi=(provider == 'hudi'),
owner=metadata[KEY_TABLE_OWNER],
stats=metadata[KEY_TABLE_STATISTICS],
owner=metadata.get(KEY_TABLE_OWNER),
stats=metadata.get(KEY_TABLE_STATISTICS),
columns={x.column: x.dtype for x in columns}
)

Expand All @@ -315,13 +319,16 @@ def _set_relation_information(self, relation: SparkRelation) -> SparkRelation:

updated_relation = self._get_updated_relation(relation)

self.cache.upsert_relation(updated_relation)
self.cache.update_relation(updated_relation)
return updated_relation

@staticmethod
def _get_spark_columns(
relation: SparkRelation
relation: Optional[SparkRelation]
) -> List[SparkColumn]:
if not relation:
return []

return [SparkColumn(
table_database=None,
table_schema=relation.schema,
Expand All @@ -337,7 +344,6 @@ def _get_spark_columns(
def _get_columns_for_catalog(
self, relation: SparkRelation
) -> Iterable[Dict[str, Any]]:
logger.info(f">>> _get_columns_for_catalog: {relation.render() if relation is not None else 'Empty'}")
updated_relation = self._set_relation_information(relation)

for column in self._get_spark_columns(updated_relation):
Expand Down

0 comments on commit 345239b

Please sign in to comment.