diff --git a/dbt/adapters/spark/impl.py b/dbt/adapters/spark/impl.py index 9652a24be..7600bc4eb 100644 --- a/dbt/adapters/spark/impl.py +++ b/dbt/adapters/spark/impl.py @@ -78,7 +78,7 @@ def _update(self, relation: _CachedRelation): self.relations[key].inner = relation.inner - def upsert_relation(self, relation): + def update_relation(self, relation): """Update the relation inner to the cache : param BaseRelation relation: The underlying relation. @@ -126,7 +126,10 @@ class SparkAdapter(SQLAdapter): Column: TypeAlias = SparkColumn ConnectionManager: TypeAlias = SparkConnectionManager AdapterSpecificConfigs: TypeAlias = SparkConfig - cache = SparkRelationsCache() + + def __init__(self, config): + super().__init__(config) + self.cache = SparkRelationsCache() @classmethod def date_function(cls) -> str: @@ -212,7 +215,6 @@ def get_relation( database = None cached = super().get_relation(database, schema, identifier) - logger.info(f">>> get_relation: {cached.render() if cached is not None else 'Empty'}") return self._set_relation_information(cached) if cached else None def parse_describe_extended( @@ -262,17 +264,16 @@ def get_columns_in_relation(self, relation: Relation) -> List[SparkColumn]: None, ) - logger.info(f">>> get_columns_in_relation: {relation.render() if relation is not None else 'Empty'}, " - f"{cached_relation.render() if cached_relation is not None else 'Empty'}") - if not cached_relation: - updated_relation = self.cache.add(self._get_updated_relation(relation)) + updated_relation = self._get_updated_relation(relation) + if updated_relation: + self.cache.add(updated_relation) else: updated_relation = self._set_relation_information(cached_relation) return self._get_spark_columns(updated_relation) - def _get_updated_relation(self, relation: BaseRelation) -> SparkRelation: + def _get_updated_relation(self, relation: BaseRelation) -> Optional[SparkRelation]: metadata = None columns = [] @@ -295,7 +296,10 @@ def _get_updated_relation(self, relation: BaseRelation) -> SparkRelation: columns = [x for x in columns if x.name not in self.HUDI_METADATA_COLUMNS] - provider = metadata[KEY_TABLE_PROVIDER] + if not metadata: + return None + + provider = metadata.get(KEY_TABLE_PROVIDER) return self.Relation.create( database=None, schema=relation.schema, @@ -303,8 +307,8 @@ def _get_updated_relation(self, relation: BaseRelation) -> SparkRelation: type=relation.type, is_delta=(provider == 'delta'), is_hudi=(provider == 'hudi'), - owner=metadata[KEY_TABLE_OWNER], - stats=metadata[KEY_TABLE_STATISTICS], + owner=metadata.get(KEY_TABLE_OWNER), + stats=metadata.get(KEY_TABLE_STATISTICS), columns={x.column: x.dtype for x in columns} ) @@ -315,13 +319,16 @@ def _set_relation_information(self, relation: SparkRelation) -> SparkRelation: updated_relation = self._get_updated_relation(relation) - self.cache.upsert_relation(updated_relation) + self.cache.update_relation(updated_relation) return updated_relation @staticmethod def _get_spark_columns( - relation: SparkRelation + relation: Optional[SparkRelation] ) -> List[SparkColumn]: + if not relation: + return [] + return [SparkColumn( table_database=None, table_schema=relation.schema, @@ -337,7 +344,6 @@ def _get_spark_columns( def _get_columns_for_catalog( self, relation: SparkRelation ) -> Iterable[Dict[str, Any]]: - logger.info(f">>> _get_columns_for_catalog: {relation.render() if relation is not None else 'Empty'}") updated_relation = self._set_relation_information(relation) for column in self._get_spark_columns(updated_relation):