diff --git a/dbt/adapters/spark/connections.py b/dbt/adapters/spark/connections.py index 608ab2b45..d63d7a5f8 100644 --- a/dbt/adapters/spark/connections.py +++ b/dbt/adapters/spark/connections.py @@ -42,6 +42,21 @@ import base64 import time + +try: + from pyspark.rdd import _load_from_socket + import pyspark.sql.functions as F + from pyspark.sql import SparkSession +except ImportError: + SparkSession = None + _load_from_socket = None + F = None + +import importlib +import sqlalchemy +import re + + logger = AdapterLogger("Spark") NUMBERS = DECIMALS + (int, float) @@ -55,7 +70,7 @@ class SparkConnectionMethod(StrEnum): THRIFT = 'thrift' HTTP = 'http' ODBC = 'odbc' - + PYSPARK = 'pyspark' @dataclass class SparkCredentials(Credentials): @@ -76,6 +91,7 @@ class SparkCredentials(Credentials): use_ssl: bool = False server_side_parameters: Dict[str, Any] = field(default_factory=dict) retry_all: bool = False + python_module: Optional[str] = None @classmethod def __pre_deserialize__(cls, data): @@ -98,6 +114,18 @@ def __post_init__(self): ) self.database = None + if ( + self.method == SparkConnectionMethod.PYSPARK + ) and not ( + _load_from_socket and SparkSession and F + ): + raise dbt.exceptions.RuntimeException( + f"{self.method} connection method requires " + "additional dependencies. \n" + "Install the additional required dependencies with " + "`pip install pyspark`" + ) + if self.method == SparkConnectionMethod.ODBC: try: import pyodbc # noqa: F401 @@ -145,6 +173,76 @@ def _connection_keys(self): return ('host', 'port', 'cluster', 'endpoint', 'schema', 'organization') +class PysparkConnectionWrapper(object): + """Wrap a Spark context""" + + def __init__(self, python_module): + self.result = None + if python_module: + logger.debug(f"Loading spark context from python module {python_module}") + module = importlib.import_module(python_module) + create_spark_context = getattr(module, "create_spark_context") + self.spark = create_spark_context() + else: + # Create a default pyspark context + self.spark = SparkSession.builder.getOrCreate() + + def cursor(self): + return self + + def rollback(self, *args, **kwargs): + logger.debug("NotImplemented: rollback") + + def fetchall(self): + try: + rows = self.result.collect() + logger.debug(rows) + except Exception as e: + logger.debug(f"raising error {e}") + dbt.exceptions.raise_database_error(e) + return rows + + def execute(self, sql, bindings=None): + if sql.strip().endswith(";"): + sql = sql.strip()[:-1] + + if bindings is not None: + bindings = [self._fix_binding(binding) for binding in bindings] + sql = sql % tuple(bindings) + logger.debug(f"execute sql:{sql}") + try: + self.result = self.spark.sql(sql) + logger.debug("Executed with no errors") + if "show tables" in sql: + self.result = self.result.withColumn("description", F.lit("")) + except Exception as e: + logger.debug(f"raising error {e}") + dbt.exceptions.raise_database_error(e) + + @classmethod + def _fix_binding(cls, value): + """Convert complex datatypes to primitives that can be loaded by + the Spark driver""" + if isinstance(value, NUMBERS): + return float(value) + elif isinstance(value, datetime): + return "'" + value.strftime('%Y-%m-%d %H:%M:%S.%f')[:-3] + "'" + elif isinstance(value, str): + return "'" + value + "'" + else: + logger.debug(type(value)) + return "'" + str(value) + "'" + + @property + def description(self): + logger.debug(f"Description called returning list of columns: {self.result.columns}") + ret = [] + # Not sure the type is ever used by specifying it anyways + string_type = sqlalchemy.types.String + for column_name in self.result.columns: + ret.append((column_name, string_type)) + return ret + class PyhiveConnectionWrapper(object): """Wrap a Spark connection in a way that no-ops transactions""" @@ -346,7 +444,9 @@ def open(cls, connection): for i in range(1 + creds.connect_retries): try: - if creds.method == SparkConnectionMethod.HTTP: + if creds.method == SparkConnectionMethod.PYSPARK: + handle = PysparkConnectionWrapper(self.python_module) + elif creds.method == SparkConnectionMethod.HTTP: cls.validate_creds(creds, ['token', 'host', 'port', 'cluster', 'organization']) diff --git a/dbt/adapters/spark/relation.py b/dbt/adapters/spark/relation.py index 043cabfa0..2af69e0cb 100644 --- a/dbt/adapters/spark/relation.py +++ b/dbt/adapters/spark/relation.py @@ -5,6 +5,17 @@ from dbt.adapters.base.relation import BaseRelation, Policy from dbt.exceptions import RuntimeException +from typing import Optional, TypeVar, Any, Type, Dict, Union, Iterator, Tuple, Set + +Self = TypeVar("Self", bound="BaseRelation") +from dbt.contracts.graph.parsed import ParsedSourceDefinition, ParsedNode +from dbt.utils import filter_null_values, deep_merge, classproperty + +import importlib + + +from datetime import timezone, datetime + @dataclass class SparkQuotePolicy(Policy): @@ -28,6 +39,8 @@ class SparkRelation(BaseRelation): is_delta: Optional[bool] = None is_hudi: Optional[bool] = None information: str = None + source_meta: Dict[str, Any] = None + meta: Dict[str, Any] = None def __post_init__(self): if self.database != self.schema and self.database: @@ -40,3 +53,42 @@ def render(self): 'include, but only one can be set' ) return super().render() + + @classmethod + def create_from_source(cls: Type[Self], source: ParsedSourceDefinition, **kwargs: Any) -> Self: + source_quoting = source.quoting.to_dict(omit_none=True) + source_quoting.pop("column", None) + quote_policy = deep_merge( + cls.get_default_quote_policy().to_dict(omit_none=True), + source_quoting, + kwargs.get("quote_policy", {}), + ) + + return cls.create( + database=source.database, + schema=source.schema, + identifier=source.identifier, + quote_policy=quote_policy, + source_meta=source.source_meta, + meta=source.meta, + **kwargs, + ) + + def load_python_module(self, start_time, end_time): + logger.debug(f"Creating pyspark view for {self.identifier}") + from pyspark.sql import SparkSession + spark = SparkSession._instantiatedSession + if self.meta and self.meta.get('python_module'): + path = self.meta.get('python_module') + logger.debug(f"Loading python module {path}") + module = importlib.import_module(path) + create_dataframe = getattr(module, "create_dataframe") + df = create_dataframe(spark, start_time, end_time) + df.createOrReplaceTempView(self.identifier) + elif self.source_meta and self.source_meta.get('python_module'): + path = self.source_meta.get('python_module') + logger.debug(f"Loading python module {path}") + module = importlib.import_module(path) + create_dataframe_for = getattr(module, "create_dataframe_for") + df = create_dataframe_for(spark, self.identifier, start_time, end_time) + df.createOrReplaceTempView(self.identifier) \ No newline at end of file diff --git a/dbt/include/spark/macros/source.sql b/dbt/include/spark/macros/source.sql new file mode 100644 index 000000000..722b831c7 --- /dev/null +++ b/dbt/include/spark/macros/source.sql @@ -0,0 +1,11 @@ +{% macro source(source_name, identifier, start_dt = None, end_dt = None) %} + {%- set relation = builtins.source(source_name, identifier) -%} + + {%- if execute and (relation.source_meta.python_module or relation.meta.python_module) -%} + {%- do relation.load_python_module(start_dt, end_dt) -%} + {# Return the view name only. Spark view do not support schema and catalog names #} + {%- do return(relation.identifier) -%} + {% else -%} + {%- do return(relation) -%} + {% endif -%} +{% endmacro %}