From 6e5e82d3dfdc2fcd768457b508379d93a1b5e3c5 Mon Sep 17 00:00:00 2001 From: Tianhao-Gu Date: Tue, 4 Jun 2024 12:25:30 -0500 Subject: [PATCH] auto assigning spark app name --- src/spark/utils.py | 16 ++++++++++------ test/src/spark/utils_test.py | 2 +- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/src/spark/utils.py b/src/spark/utils.py index 2a22b0e..288b5cb 100644 --- a/src/spark/utils.py +++ b/src/spark/utils.py @@ -1,4 +1,5 @@ import os +from datetime import datetime from threading import Timer from pyspark.conf import SparkConf @@ -72,20 +73,23 @@ def get_base_spark_conf(app_name: str) -> SparkConf: def get_spark_session( - app_name: str, + app_name: str = None, local: bool = False, - delta_lake: bool = False, + delta_lake: bool = True, timeout_sec: int = 4 * 60 * 60) -> SparkSession: """ Helper to get and manage the SparkSession and keep all of our spark configuration params in one place. - :param app_name: The name of the application - :param local: Whether to run the spark session locally or not - :param delta_lake: Build the spark session with Delta Lake support - :param timeout_sec: The timeout in seconds to stop the Spark session forcefully + :param app_name: The name of the application. If not provided, a default name will be generated. + :param local: Whether to run the spark session locally or not. Default is False. + :param delta_lake: Build the spark session with Delta Lake support. Default is True. + :param timeout_sec: The timeout in seconds to stop the Spark session forcefully. Default is 4 hours. :return: A SparkSession object """ + if not app_name: + app_name = f"kbase_spark_session_{datetime.now().strftime('%Y%m%d%H%M%S')}" + if local: return SparkSession.builder.appName(app_name).getOrCreate() diff --git a/test/src/spark/utils_test.py b/test/src/spark/utils_test.py index bdb4c3e..3813bfc 100644 --- a/test/src/spark/utils_test.py +++ b/test/src/spark/utils_test.py @@ -46,7 +46,7 @@ def spark_session_non_local(mock_spark_master): with mock.patch.dict('os.environ', {"SPARK_MASTER_URL": spark_master_url, "SPARK_TIMEOUT_SECONDS": "2"}): - spark_session = get_spark_session("TestApp", local=False) + spark_session = get_spark_session("TestApp", local=False, delta_lake=False) print("Created non-local Spark session.") try: yield spark_session, port