Skip to content

Commit

Permalink
ability to add user defined spark config
Browse files Browse the repository at this point in the history
  • Loading branch information
Tianhao-Gu committed Jun 11, 2024
1 parent 45cef54 commit ea77380
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion src/spark/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ def get_spark_session(
local: bool = False,
delta_lake: bool = True,
timeout_sec: int = 4 * 60 * 60,
executor_cores: int = DEFAULT_EXECUTOR_CORES) -> SparkSession:
executor_cores: int = DEFAULT_EXECUTOR_CORES,
additional_conf: dict = None
) -> SparkSession:
"""
Helper to get and manage the SparkSession and keep all of our spark configuration params in one place.
Expand All @@ -108,6 +110,7 @@ def get_spark_session(
:param delta_lake: Build the spark session with Delta Lake support. Default is True.
:param timeout_sec: The timeout in seconds to stop the Spark session forcefully. Default is 4 hours.
:param executor_cores: The number of CPU cores that each Spark executor will use. Default is 1.
:param additional_conf: Additional user supplied configuration to pass to the Spark session. e.g. {"spark.executor.memory": "2g"}
:return: A SparkSession object
"""
Expand All @@ -129,6 +132,9 @@ def get_spark_session(
for key, value in delta_conf.items():
spark_conf.set(key, value)

if additional_conf:
spark_conf.setAll(additional_conf.items())

Check warning on line 136 in src/spark/utils.py

View check run for this annotation

Codecov / codecov/patch

src/spark/utils.py#L136

Added line #L136 was not covered by tests

spark = SparkSession.builder.config(conf=spark_conf).getOrCreate()
timeout_sec = os.getenv('SPARK_TIMEOUT_SECONDS', timeout_sec)
Timer(int(timeout_sec), _stop_spark_session, [spark]).start()
Expand Down

0 comments on commit ea77380

Please sign in to comment.