Skip to content

Commit

Permalink
auto assigning spark app name
Browse files Browse the repository at this point in the history
  • Loading branch information
Tianhao-Gu committed Jun 4, 2024
1 parent d9bc820 commit 6e5e82d
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 7 deletions.
16 changes: 10 additions & 6 deletions src/spark/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from datetime import datetime
from threading import Timer

from pyspark.conf import SparkConf
Expand Down Expand Up @@ -72,20 +73,23 @@ def get_base_spark_conf(app_name: str) -> SparkConf:


def get_spark_session(
app_name: str,
app_name: str = None,
local: bool = False,
delta_lake: bool = False,
delta_lake: bool = True,
timeout_sec: int = 4 * 60 * 60) -> SparkSession:
"""
Helper to get and manage the SparkSession and keep all of our spark configuration params in one place.
:param app_name: The name of the application
:param local: Whether to run the spark session locally or not
:param delta_lake: Build the spark session with Delta Lake support
:param timeout_sec: The timeout in seconds to stop the Spark session forcefully
:param app_name: The name of the application. If not provided, a default name will be generated.
:param local: Whether to run the spark session locally or not. Default is False.
:param delta_lake: Build the spark session with Delta Lake support. Default is True.
:param timeout_sec: The timeout in seconds to stop the Spark session forcefully. Default is 4 hours.
:return: A SparkSession object
"""
if not app_name:
app_name = f"kbase_spark_session_{datetime.now().strftime('%Y%m%d%H%M%S')}"

Check warning on line 91 in src/spark/utils.py

View check run for this annotation

Codecov / codecov/patch

src/spark/utils.py#L91

Added line #L91 was not covered by tests

if local:
return SparkSession.builder.appName(app_name).getOrCreate()

Expand Down
2 changes: 1 addition & 1 deletion test/src/spark/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def spark_session_non_local(mock_spark_master):

with mock.patch.dict('os.environ', {"SPARK_MASTER_URL": spark_master_url,
"SPARK_TIMEOUT_SECONDS": "2"}):
spark_session = get_spark_session("TestApp", local=False)
spark_session = get_spark_session("TestApp", local=False, delta_lake=False)
print("Created non-local Spark session.")
try:
yield spark_session, port
Expand Down

0 comments on commit 6e5e82d

Please sign in to comment.