Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support for pyspark connection method #308

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion dbt/adapters/spark/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,17 @@
import base64
import time


try:
from pyspark.rdd import _load_from_socket
import pyspark.sql.functions as F
from pyspark.sql import SparkSession
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Functions and Sparksessions are not used in this file

except ImportError:
SparkSession = None
_load_from_socket = None
F = None


logger = AdapterLogger("Spark")

NUMBERS = DECIMALS + (int, float)
Expand All @@ -56,7 +67,7 @@ class SparkConnectionMethod(StrEnum):
HTTP = 'http'
ODBC = 'odbc'
SESSION = 'session'

PYSPARK = 'pyspark'

@dataclass
class SparkCredentials(Credentials):
Expand All @@ -77,6 +88,7 @@ class SparkCredentials(Credentials):
use_ssl: bool = False
server_side_parameters: Dict[str, Any] = field(default_factory=dict)
retry_all: bool = False
python_module: Optional[str] = None

@classmethod
def __pre_deserialize__(cls, data):
Expand All @@ -99,6 +111,18 @@ def __post_init__(self):
)
self.database = None

if (
self.method == SparkConnectionMethod.PYSPARK
) and not (
_load_from_socket and SparkSession and F
):
raise dbt.exceptions.RuntimeException(
f"{self.method} connection method requires "
"additional dependencies. \n"
"Install the additional required dependencies with "
"`pip install pyspark`"
)

if self.method == SparkConnectionMethod.ODBC:
try:
import pyodbc # noqa: F401
Expand Down Expand Up @@ -462,6 +486,11 @@ def open(cls, connection):
SessionConnectionWrapper,
)
handle = SessionConnectionWrapper(Connection())
elif creds.method == SparkConnectionMethod.PYSPARK:
from .pysparkcon import ( # noqa: F401
PysparkConnectionWrapper,
)
handle = PysparkConnectionWrapper(self.python_module)
else:
raise dbt.exceptions.DbtProfileError(
f"invalid credential method: {creds.method}"
Expand Down
93 changes: 93 additions & 0 deletions dbt/adapters/spark/pysparkcon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@

from __future__ import annotations

import datetime as dt
from types import TracebackType
from typing import Any

from dbt.events import AdapterLogger
from dbt.utils import DECIMALS


from pyspark.rdd import _load_from_socket
import pyspark.sql.functions as F


import importlib
import sqlalchemy
import re

logger = AdapterLogger("Spark")
NUMBERS = DECIMALS + (int, float)


class PysparkConnectionWrapper(object):
"""Wrap a Spark context"""

def __init__(self, python_module):
self.result = None
if python_module:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer to avoid such a hook, it's very specific. The python_module is a unexpected parameter for PysparkConnectionWrapper, it's unclear why it is needed and how it works.

We could add docs about this, still it is confusing to write PysparkConnectionWrapper(python_module)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can change it, what do you propose ?

logger.debug(f"Loading spark context from python module {python_module}")
module = importlib.import_module(python_module)
create_spark_context = getattr(module, "create_spark_context")
self.spark = create_spark_context()
else:
# Create a default pyspark context
self.spark = SparkSession.builder.getOrCreate()

def cursor(self):
return self

def rollback(self, *args, **kwargs):
logger.debug("NotImplemented: rollback")

def fetchall(self):
try:
rows = self.result.collect()
logger.debug(rows)
except Exception as e:
logger.debug(f"raising error {e}")
dbt.exceptions.raise_database_error(e)
return rows

def execute(self, sql, bindings=None):
if sql.strip().endswith(";"):
sql = sql.strip()[:-1]

if bindings is not None:
bindings = [self._fix_binding(binding) for binding in bindings]
sql = sql % tuple(bindings)
logger.debug(f"execute sql:{sql}")
try:
self.result = self.spark.sql(sql)
logger.debug("Executed with no errors")
if "show tables" in sql:
self.result = self.result.withColumn("description", F.lit(""))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why add the description column?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is an iceberg specific issue. When using iceberg it's missing the column. I'll remove this from the PR

except Exception as e:
logger.debug(f"raising error {e}")
dbt.exceptions.raise_database_error(e)

@classmethod
def _fix_binding(cls, value):
"""Convert complex datatypes to primitives that can be loaded by
the Spark driver"""
if isinstance(value, NUMBERS):
return float(value)
elif isinstance(value, datetime):
return "'" + value.strftime('%Y-%m-%d %H:%M:%S.%f')[:-3] + "'"
elif isinstance(value, str):
return "'" + value + "'"
else:
logger.debug(type(value))
return "'" + str(value) + "'"

@property
def description(self):
logger.debug(f"Description called returning list of columns: {self.result.columns}")
ret = []
# Not sure the type is ever used by specifying it anyways
string_type = sqlalchemy.types.String
for column_name in self.result.columns:
ret.append((column_name, string_type))
return ret

52 changes: 52 additions & 0 deletions dbt/adapters/spark/relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,17 @@
from dbt.adapters.base.relation import BaseRelation, Policy
from dbt.exceptions import RuntimeException

from typing import Optional, TypeVar, Any, Type, Dict, Union, Iterator, Tuple, Set

Self = TypeVar("Self", bound="BaseRelation")
from dbt.contracts.graph.parsed import ParsedSourceDefinition, ParsedNode
from dbt.utils import filter_null_values, deep_merge, classproperty

import importlib


from datetime import timezone, datetime


@dataclass
class SparkQuotePolicy(Policy):
Expand All @@ -28,6 +39,8 @@ class SparkRelation(BaseRelation):
is_delta: Optional[bool] = None
is_hudi: Optional[bool] = None
information: str = None
source_meta: Dict[str, Any] = None
meta: Dict[str, Any] = None

def __post_init__(self):
if self.database != self.schema and self.database:
Expand All @@ -40,3 +53,42 @@ def render(self):
'include, but only one can be set'
)
return super().render()

@classmethod
def create_from_source(cls: Type[Self], source: ParsedSourceDefinition, **kwargs: Any) -> Self:
source_quoting = source.quoting.to_dict(omit_none=True)
source_quoting.pop("column", None)
quote_policy = deep_merge(
cls.get_default_quote_policy().to_dict(omit_none=True),
source_quoting,
kwargs.get("quote_policy", {}),
)

return cls.create(
database=source.database,
schema=source.schema,
identifier=source.identifier,
quote_policy=quote_policy,
source_meta=source.source_meta,
meta=source.meta,
**kwargs,
)

def load_python_module(self, start_time, end_time):
logger.debug(f"Creating pyspark view for {self.identifier}")
from pyspark.sql import SparkSession
spark = SparkSession._instantiatedSession
if self.meta and self.meta.get('python_module'):
path = self.meta.get('python_module')
logger.debug(f"Loading python module {path}")
module = importlib.import_module(path)
create_dataframe = getattr(module, "create_dataframe")
df = create_dataframe(spark, start_time, end_time)
df.createOrReplaceTempView(self.identifier)
elif self.source_meta and self.source_meta.get('python_module'):
path = self.source_meta.get('python_module')
logger.debug(f"Loading python module {path}")
module = importlib.import_module(path)
create_dataframe_for = getattr(module, "create_dataframe_for")
df = create_dataframe_for(spark, self.identifier, start_time, end_time)
df.createOrReplaceTempView(self.identifier)
11 changes: 11 additions & 0 deletions dbt/include/spark/macros/source.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
{% macro source(source_name, identifier, start_dt = None, end_dt = None) %}
{%- set relation = builtins.source(source_name, identifier) -%}

{%- if execute and (relation.source_meta.python_module or relation.meta.python_module) -%}
{%- do relation.load_python_module(start_dt, end_dt) -%}
{# Return the view name only. Spark view do not support schema and catalog names #}
{%- do return(relation.identifier) -%}
{% else -%}
{%- do return(relation) -%}
{% endif -%}
{% endmacro %}