Skip to content

Commit

Permalink
add timeout arg
Browse files Browse the repository at this point in the history
  • Loading branch information
Tianhao-Gu committed May 31, 2024
1 parent 529d5f8 commit ac70c22
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions src/spark/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
HADOOP_AWS_VER = os.getenv('HADOOP_AWS_VER')
DELTA_SPARK_VER = os.getenv('DELTA_SPARK_VER')
SCALA_VER = os.getenv('SCALA_VER')
SPARK_TIMEOUT_SECONDS = os.getenv('SPARK_TIMEOUT_SECONDS', 4 * 60 * 60)


def _get_jars(jar_names: list) -> str:
Expand Down Expand Up @@ -75,13 +74,15 @@ def get_base_spark_conf(app_name: str) -> SparkConf:
def get_spark_session(
app_name: str,
local: bool = False,
delta_lake: bool = False) -> SparkSession:
delta_lake: bool = False,
timeout_sec: int = 4 * 60 * 60) -> SparkSession:
"""
Helper to get and manage the SparkSession and keep all of our spark configuration params in one place.
:param app_name: The name of the application
:param local: Whether to run the spark session locally or not
:param delta_lake: Build the spark session with Delta Lake support
:param timeout_sec: The timeout in seconds to stop the Spark session forcefully
:return: A SparkSession object
"""
Expand All @@ -101,6 +102,7 @@ def get_spark_session(
spark_conf.set(key, value)

spark = SparkSession.builder.config(conf=spark_conf).getOrCreate()
Timer(int(SPARK_TIMEOUT_SECONDS), _stop_spark_session, [spark]).start()
timeout_sec = os.getenv('SPARK_TIMEOUT_SECONDS', timeout_sec)
Timer(int(timeout_sec), _stop_spark_session, [spark]).start()

return spark

0 comments on commit ac70c22

Please sign in to comment.