From ea77380f5613a69d6a0e8157987df7ba222766fb Mon Sep 17 00:00:00 2001 From: Tianhao-Gu Date: Tue, 11 Jun 2024 17:08:17 -0500 Subject: [PATCH] ability to add user defined spark config --- src/spark/utils.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/spark/utils.py b/src/spark/utils.py index bb6b979..60af4d5 100644 --- a/src/spark/utils.py +++ b/src/spark/utils.py @@ -99,7 +99,9 @@ def get_spark_session( local: bool = False, delta_lake: bool = True, timeout_sec: int = 4 * 60 * 60, - executor_cores: int = DEFAULT_EXECUTOR_CORES) -> SparkSession: + executor_cores: int = DEFAULT_EXECUTOR_CORES, + additional_conf: dict = None +) -> SparkSession: """ Helper to get and manage the SparkSession and keep all of our spark configuration params in one place. @@ -108,6 +110,7 @@ def get_spark_session( :param delta_lake: Build the spark session with Delta Lake support. Default is True. :param timeout_sec: The timeout in seconds to stop the Spark session forcefully. Default is 4 hours. :param executor_cores: The number of CPU cores that each Spark executor will use. Default is 1. + :param additional_conf: Additional user supplied configuration to pass to the Spark session. e.g. {"spark.executor.memory": "2g"} :return: A SparkSession object """ @@ -129,6 +132,9 @@ def get_spark_session( for key, value in delta_conf.items(): spark_conf.set(key, value) + if additional_conf: + spark_conf.setAll(additional_conf.items()) + spark = SparkSession.builder.config(conf=spark_conf).getOrCreate() timeout_sec = os.getenv('SPARK_TIMEOUT_SECONDS', timeout_sec) Timer(int(timeout_sec), _stop_spark_session, [spark]).start()