Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Spark-RAPIDS-ML PCA #440

Merged
merged 8 commits into from
Oct 8, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 96 additions & 4 deletions examples/ML+DL-Examples/Spark-Rapids-ML/pca/notebooks/pca.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,105 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"No active Spark session found, initializing manually.\n",
"File already exists. Skipping download.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"24/10/04 18:04:27 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)\n",
"24/10/04 18:04:27 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\n",
"24/10/04 18:04:27 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n",
"Setting default log level to \"WARN\".\n",
"To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n",
"24/10/04 18:04:27 WARN RapidsPluginUtils: RAPIDS Accelerator 24.08.1 using cudf 24.08.0, private revision 9fac64da220ddd6bf5626bd7bd1dd74c08603eac\n",
"24/10/04 18:04:27 WARN RapidsPluginUtils: RAPIDS Accelerator is enabled, to disable GPU support set `spark.rapids.sql.enabled` to false.\n",
"24/10/04 18:04:31 WARN GpuDeviceManager: RMM pool is disabled since spark.rapids.memory.gpu.pooling.enabled is set to false; however, this configuration is deprecated and the behavior may change in a future release.\n"
]
}
],
"source": [
"from pyspark.sql import SparkSession\n",
"from pyspark import SparkConf\n",
"\n",
"def get_rapids_jar():\n",
" import os\n",
" import requests\n",
"\n",
" SPARK_RAPIDS_VERSION = \"24.08.1\"\n",
" rapids_jar = f\"rapids-4-spark_2.12-{SPARK_RAPIDS_VERSION}.jar\"\n",
"\n",
" if not os.path.exists(rapids_jar):\n",
" print(\"Downloading spark rapids jar\")\n",
" url = f\"https://repo1.maven.org/maven2/com/nvidia/rapids-4-spark_2.12/{SPARK_RAPIDS_VERSION}/{rapids_jar}\"\n",
" response = requests.get(url)\n",
" if response.status_code == 200:\n",
" with open(rapids_jar, \"wb\") as f:\n",
" f.write(response.content)\n",
" print(f\"File '{rapids_jar}' downloaded and saved successfully.\")\n",
" else:\n",
" print(f\"Failed to download the file. Status code: {response.status_code}\")\n",
" else:\n",
" print(\"File already exists. Skipping download.\")\n",
" \n",
" return rapids_jar\n",
"\n",
"def initialize_spark(rapids_jar: str):\n",
" '''\n",
" If no active Spark session is found, initialize and configure a new one. \n",
" '''\n",
" conf = SparkConf()\n",
" conf.set(\"spark.task.maxFailures\", \"1\")\n",
" conf.set(\"spark.driver.memory\", \"10g\")\n",
" conf.set(\"spark.executor.memory\", \"8g\")\n",
" conf.set(\"spark.rpc.message.maxSize\", \"1024\")\n",
" conf.set(\"spark.sql.pyspark.jvmStacktrace.enabled\", \"true\")\n",
" conf.set(\"spark.sql.execution.pyspark.udf.simplifiedTraceback.enabled\", \"false\")\n",
" conf.set(\"spark.sql.execution.arrow.pyspark.enabled\", \"true\")\n",
" conf.set(\"spark.python.worker.reuse\", \"true\")\n",
" conf.set(\"spark.rapids.ml.uvm.enabled\", \"true\")\n",
" conf.set(\"spark.jars\", rapids_jar)\n",
" conf.set(\"spark.executorEnv.PYTHONPATH\", rapids_jar)\n",
" conf.set(\"spark.rapids.memory.gpu.minAllocFraction\", \"0.0001\")\n",
" conf.set(\"spark.plugins\", \"com.nvidia.spark.SQLPlugin\")\n",
" conf.set(\"spark.locality.wait\", \"0s\")\n",
" conf.set(\"spark.sql.cache.serializer\", \"com.nvidia.spark.ParquetCachedBatchSerializer\")\n",
" conf.set(\"spark.rapids.memory.gpu.pooling.enabled\", \"false\")\n",
" conf.set(\"spark.sql.execution.sortBeforeRepartition\", \"false\")\n",
" conf.set(\"spark.rapids.sql.format.parquet.reader.type\", \"MULTITHREADED\")\n",
" conf.set(\"spark.rapids.sql.format.parquet.multiThreadedRead.maxNumFilesParallel\", \"20\")\n",
" conf.set(\"spark.rapids.sql.multiThreadedRead.numThreads\", \"20\")\n",
" conf.set(\"spark.rapids.sql.python.gpu.enabled\", \"true\")\n",
" conf.set(\"spark.rapids.memory.pinnedPool.size\", \"2G\")\n",
" conf.set(\"spark.python.daemon.module\", \"rapids.daemon\")\n",
" conf.set(\"spark.rapids.sql.batchSizeBytes\", \"512m\")\n",
" conf.set(\"spark.sql.adaptive.enabled\", \"false\")\n",
" conf.set(\"spark.sql.files.maxPartitionBytes\", \"512m\")\n",
" conf.set(\"spark.rapids.sql.concurrentGpuTasks\", \"1\")\n",
" conf.set(\"spark.sql.execution.arrow.maxRecordsPerBatch\", \"20000\")\n",
" conf.set(\"spark.rapids.sql.explain\", \"NONE\")\n",
" \n",
" spark = SparkSession.builder.appName(\"spark-rapids-ml-pca\").config(conf=conf).getOrCreate()\n",
" return spark\n",
"\n",
"# Check if Spark session is already active, if not, initialize it\n",
"if 'spark' not in globals():\n",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine, but it probably is ok to just run the above even if spark is initialized (e.g. if following README instructions).

" print(\"No active Spark session found, initializing manually.\")\n",
" rapids_jar = get_rapids_jar()\n",
" spark = initialize_spark(rapids_jar)\n",
"else:\n",
" print(\"Using existing Spark session.\")\n",
"\n",
"spark = SparkSession.builder.getOrCreate()"
"sc = spark.sparkContext"
]
},
{
Expand Down Expand Up @@ -439,7 +531,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "rapids-24.08",
"language": "python",
"name": "python3"
},
Expand Down