Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

config executor cores #23

Merged
merged 6 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/notebook/startup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
With the PYTHONPATH configured to /src by the Dockerfile, we can directly import modules from the src directory.
"""

from spark.utils import get_spark_session, get_base_spark_conf
from spark.utils import get_spark_session
20 changes: 16 additions & 4 deletions src/spark/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
HADOOP_AWS_VER = os.getenv('HADOOP_AWS_VER')
DELTA_SPARK_VER = os.getenv('DELTA_SPARK_VER')
SCALA_VER = os.getenv('SCALA_VER')
# the default number of CPU cores that each Spark executor will use
# If not specified, Spark will typically use all available cores on the worker nodes
DEFAULT_EXECUTOR_CORES = 1


def _get_jars(jar_names: list) -> str:
Expand All @@ -29,7 +32,9 @@ def _get_jars(jar_names: list) -> str:
return ", ".join(jars)


def _get_delta_lake_conf(jars_str: str) -> dict:
def _get_delta_lake_conf(
jars_str: str,
) -> dict:
"""
Helper function to get Delta Lake specific Spark configuration.

Expand Down Expand Up @@ -58,32 +63,39 @@ def _stop_spark_session(spark):
spark.stop()


def get_base_spark_conf(app_name: str) -> SparkConf:
def _get_base_spark_conf(
app_name: str,
executor_cores: int,
) -> SparkConf:
"""
Helper function to get the base Spark configuration.

:param app_name: The name of the application
:param executor_cores: The number of CPU cores that each Spark executor will use.

:return: A SparkConf object with the base configuration
"""
return SparkConf().setAll([
("spark.master", os.environ.get("SPARK_MASTER_URL", "spark://spark-master:7077")),
("spark.app.name", app_name),
("spark.executor.cores", executor_cores),
])


def get_spark_session(
app_name: str = None,
local: bool = False,
delta_lake: bool = True,
timeout_sec: int = 4 * 60 * 60) -> SparkSession:
timeout_sec: int = 4 * 60 * 60,
executor_cores: int = DEFAULT_EXECUTOR_CORES) -> SparkSession:
"""
Helper to get and manage the SparkSession and keep all of our spark configuration params in one place.

:param app_name: The name of the application. If not provided, a default name will be generated.
:param local: Whether to run the spark session locally or not. Default is False.
:param delta_lake: Build the spark session with Delta Lake support. Default is True.
:param timeout_sec: The timeout in seconds to stop the Spark session forcefully. Default is 4 hours.
:param executor_cores: The number of CPU cores that each Spark executor will use. Default is 1.

:return: A SparkSession object
"""
Expand All @@ -93,7 +105,7 @@ def get_spark_session(
if local:
return SparkSession.builder.appName(app_name).getOrCreate()

spark_conf = get_base_spark_conf(app_name)
spark_conf = _get_base_spark_conf(app_name, executor_cores)

if delta_lake:

Expand Down
10 changes: 7 additions & 3 deletions test/src/spark/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pyspark import SparkConf
from pyspark.sql import SparkSession

from src.spark.utils import get_spark_session, _get_jars, get_base_spark_conf, JAR_DIR
from src.spark.utils import get_spark_session, _get_jars, _get_base_spark_conf, JAR_DIR


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -95,20 +95,24 @@ def test_get_base_spark_conf():
app_name = "test_app"
expected_master_url = "spark://spark-master:7077"
expected_app_name = app_name
executor_cores = 3

with mock.patch.dict('os.environ', {}):
result = get_base_spark_conf(app_name)
result = _get_base_spark_conf(app_name, executor_cores)
assert isinstance(result, SparkConf)
assert result.get("spark.master") == expected_master_url
assert result.get("spark.app.name") == expected_app_name
assert result.get("spark.executor.cores") == str(executor_cores)


def test_get_base_spark_conf_with_env():
app_name = "test_app"
custom_master_url = "spark://custom-master:7077"
executor_cores = 3

with mock.patch.dict('os.environ', {"SPARK_MASTER_URL": custom_master_url}):
result = get_base_spark_conf(app_name)
result = _get_base_spark_conf(app_name, executor_cores)
assert isinstance(result, SparkConf)
assert result.get("spark.master") == custom_master_url
assert result.get("spark.app.name") == app_name
assert result.get("spark.executor.cores") == str(executor_cores)
Loading