From 8bc8f9e4d9d3098acbf524ae1f61837758f00982 Mon Sep 17 00:00:00 2001 From: Rishi <77904151+rishic3@users.noreply.github.com> Date: Tue, 8 Oct 2024 00:31:00 -0400 Subject: [PATCH] Update Spark-RAPIDS-ML PCA (#440) * Update Spark-RAPIDS-ML PCA Signed-off-by: Rishi Chandra * Reran with standalone * Fix typo * Delete Scala example, remove mean-centering * Update README * Update standalone setup script, README * SparkSession init for CI * remove sparkcontext --------- Signed-off-by: Rishi Chandra --- .../Spark-Rapids-ML/pca/README.md | 32 + .../Spark-Rapids-ML/pca/notebooks/pca.ipynb | 556 ++++++++++++++++++ .../Spark-Rapids-ML/pca/start-spark-rapids.sh | 80 +++ .../ML+DL-Examples/Spark-cuML/pca/Dockerfile | 80 --- .../ML+DL-Examples/Spark-cuML/pca/README.md | 102 ---- .../ML+DL-Examples/Spark-cuML/pca/main.scala | 53 -- .../pca/notebooks/Spark_PCA_End_to_End.ipynb | 544 ----------------- .../ML+DL-Examples/Spark-cuML/pca/pom.xml | 87 --- .../com/nvidia/spark/examples/pca/Main.scala | 83 --- .../Spark-cuML/pca/spark-env.sh | 19 - .../Spark-cuML/pca/spark-submit.sh | 43 -- .../Spark-cuML/pca/start-spark.sh | 19 - 12 files changed, 668 insertions(+), 1030 deletions(-) create mode 100644 examples/ML+DL-Examples/Spark-Rapids-ML/pca/README.md create mode 100644 examples/ML+DL-Examples/Spark-Rapids-ML/pca/notebooks/pca.ipynb create mode 100755 examples/ML+DL-Examples/Spark-Rapids-ML/pca/start-spark-rapids.sh delete mode 100644 examples/ML+DL-Examples/Spark-cuML/pca/Dockerfile delete mode 100644 examples/ML+DL-Examples/Spark-cuML/pca/README.md delete mode 100644 examples/ML+DL-Examples/Spark-cuML/pca/main.scala delete mode 100644 examples/ML+DL-Examples/Spark-cuML/pca/notebooks/Spark_PCA_End_to_End.ipynb delete mode 100644 examples/ML+DL-Examples/Spark-cuML/pca/pom.xml delete mode 100644 examples/ML+DL-Examples/Spark-cuML/pca/scala/src/com/nvidia/spark/examples/pca/Main.scala delete mode 100644 examples/ML+DL-Examples/Spark-cuML/pca/spark-env.sh delete mode 100755 examples/ML+DL-Examples/Spark-cuML/pca/spark-submit.sh delete mode 100755 examples/ML+DL-Examples/Spark-cuML/pca/start-spark.sh diff --git a/examples/ML+DL-Examples/Spark-Rapids-ML/pca/README.md b/examples/ML+DL-Examples/Spark-Rapids-ML/pca/README.md new file mode 100644 index 000000000..2879b94a8 --- /dev/null +++ b/examples/ML+DL-Examples/Spark-Rapids-ML/pca/README.md @@ -0,0 +1,32 @@ +# Spark-Rapids-ML PCA example + +This is an example of the GPU accelerated PCA algorithm from the [Spark-Rapids-ML](https://github.com/NVIDIA/spark-rapids-ml) library, which provides PySpark ML compatible algorithms powered by RAPIDS cuML. +The notebook uses PCA to reduce a random dataset with 2048 feature dimensions to 3 dimensions. We train both the GPU and CPU algorithms for comparison. + +## Build + +Please refer to the Spark-Rapids-ML [README](https://github.com/NVIDIA/spark-rapids-ml/blob/HEAD/python) for environment setup instructions and API usage. + +## Download RAPIDS Jar from Maven Central + +Download the RAPIDS jar from Maven Central: [rapids-4-spark_2.12-24.08.1.jar](https://repo1.maven.org/maven2/com/nvidia/rapids-4-spark_2.12/24.08.1/rapids-4-spark_2.12-24.08.1.jar) +Alternatively, see the Spark-Rapids [download page](https://nvidia.github.io/spark-rapids/docs/download.html#download-rapids-accelerator-for-apache-spark-v24081) for version selection. + +## Running the Notebooks + +Once you have built your environment, please follow these instructions to run the notebooks. Make sure `jupyterlab` is installed in the environment. + +**Note**: for demonstration purposes, these examples just use a local Spark Standalone cluster with a single executor, but you should be able to run them on any distributed Spark cluster. +``` +# setup environment variables +export SPARK_HOME=/path/to/spark +export RAPIDS_JAR=/path/to/rapids.jar + +# launches the standalone cluster and jupyter with pyspark +./start-spark-rapids.sh + +# BROWSE to localhost:8888 to view/run notebooks + +# stop spark standalone cluster +${SPARK_HOME}/sbin/stop-worker.sh; ${SPARK_HOME}/sbin/stop-master.sh +``` diff --git a/examples/ML+DL-Examples/Spark-Rapids-ML/pca/notebooks/pca.ipynb b/examples/ML+DL-Examples/Spark-Rapids-ML/pca/notebooks/pca.ipynb new file mode 100644 index 000000000..9f41e03eb --- /dev/null +++ b/examples/ML+DL-Examples/Spark-Rapids-ML/pca/notebooks/pca.ipynb @@ -0,0 +1,556 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Principal Component Analysis (PCA)\n", + "\n", + "In this notebook, we will demonstrate the end-to-end workflow of Spark RAPIDS accelerated PCA." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "import time" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "No active Spark session found, initializing manually.\n", + "File already exists. Skipping download.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "24/10/04 18:04:27 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)\n", + "24/10/04 18:04:27 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\n", + "24/10/04 18:04:27 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n", + "Setting default log level to \"WARN\".\n", + "To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n", + "24/10/04 18:04:27 WARN RapidsPluginUtils: RAPIDS Accelerator 24.08.1 using cudf 24.08.0, private revision 9fac64da220ddd6bf5626bd7bd1dd74c08603eac\n", + "24/10/04 18:04:27 WARN RapidsPluginUtils: RAPIDS Accelerator is enabled, to disable GPU support set `spark.rapids.sql.enabled` to false.\n", + "24/10/04 18:04:31 WARN GpuDeviceManager: RMM pool is disabled since spark.rapids.memory.gpu.pooling.enabled is set to false; however, this configuration is deprecated and the behavior may change in a future release.\n" + ] + } + ], + "source": [ + "from pyspark.sql import SparkSession\n", + "from pyspark import SparkConf\n", + "\n", + "def get_rapids_jar():\n", + " import os\n", + " import requests\n", + "\n", + " SPARK_RAPIDS_VERSION = \"24.08.1\"\n", + " rapids_jar = f\"rapids-4-spark_2.12-{SPARK_RAPIDS_VERSION}.jar\"\n", + "\n", + " if not os.path.exists(rapids_jar):\n", + " print(\"Downloading spark rapids jar\")\n", + " url = f\"https://repo1.maven.org/maven2/com/nvidia/rapids-4-spark_2.12/{SPARK_RAPIDS_VERSION}/{rapids_jar}\"\n", + " response = requests.get(url)\n", + " if response.status_code == 200:\n", + " with open(rapids_jar, \"wb\") as f:\n", + " f.write(response.content)\n", + " print(f\"File '{rapids_jar}' downloaded and saved successfully.\")\n", + " else:\n", + " print(f\"Failed to download the file. Status code: {response.status_code}\")\n", + " else:\n", + " print(\"File already exists. Skipping download.\")\n", + " \n", + " return rapids_jar\n", + "\n", + "def initialize_spark(rapids_jar: str):\n", + " '''\n", + " If no active Spark session is found, initialize and configure a new one. \n", + " '''\n", + " conf = SparkConf()\n", + " conf.set(\"spark.task.maxFailures\", \"1\")\n", + " conf.set(\"spark.driver.memory\", \"10g\")\n", + " conf.set(\"spark.executor.memory\", \"8g\")\n", + " conf.set(\"spark.rpc.message.maxSize\", \"1024\")\n", + " conf.set(\"spark.sql.pyspark.jvmStacktrace.enabled\", \"true\")\n", + " conf.set(\"spark.sql.execution.pyspark.udf.simplifiedTraceback.enabled\", \"false\")\n", + " conf.set(\"spark.sql.execution.arrow.pyspark.enabled\", \"true\")\n", + " conf.set(\"spark.python.worker.reuse\", \"true\")\n", + " conf.set(\"spark.rapids.ml.uvm.enabled\", \"true\")\n", + " conf.set(\"spark.jars\", rapids_jar)\n", + " conf.set(\"spark.executorEnv.PYTHONPATH\", rapids_jar)\n", + " conf.set(\"spark.rapids.memory.gpu.minAllocFraction\", \"0.0001\")\n", + " conf.set(\"spark.plugins\", \"com.nvidia.spark.SQLPlugin\")\n", + " conf.set(\"spark.locality.wait\", \"0s\")\n", + " conf.set(\"spark.sql.cache.serializer\", \"com.nvidia.spark.ParquetCachedBatchSerializer\")\n", + " conf.set(\"spark.rapids.memory.gpu.pooling.enabled\", \"false\")\n", + " conf.set(\"spark.sql.execution.sortBeforeRepartition\", \"false\")\n", + " conf.set(\"spark.rapids.sql.format.parquet.reader.type\", \"MULTITHREADED\")\n", + " conf.set(\"spark.rapids.sql.format.parquet.multiThreadedRead.maxNumFilesParallel\", \"20\")\n", + " conf.set(\"spark.rapids.sql.multiThreadedRead.numThreads\", \"20\")\n", + " conf.set(\"spark.rapids.sql.python.gpu.enabled\", \"true\")\n", + " conf.set(\"spark.rapids.memory.pinnedPool.size\", \"2G\")\n", + " conf.set(\"spark.python.daemon.module\", \"rapids.daemon\")\n", + " conf.set(\"spark.rapids.sql.batchSizeBytes\", \"512m\")\n", + " conf.set(\"spark.sql.adaptive.enabled\", \"false\")\n", + " conf.set(\"spark.sql.files.maxPartitionBytes\", \"512m\")\n", + " conf.set(\"spark.rapids.sql.concurrentGpuTasks\", \"1\")\n", + " conf.set(\"spark.sql.execution.arrow.maxRecordsPerBatch\", \"20000\")\n", + " conf.set(\"spark.rapids.sql.explain\", \"NONE\")\n", + " \n", + " spark = SparkSession.builder.appName(\"spark-rapids-ml-pca\").config(conf=conf).getOrCreate()\n", + " return spark\n", + "\n", + "# Check if Spark session is already active, if not, initialize it\n", + "if 'spark' not in globals():\n", + " print(\"No active Spark session found, initializing manually.\")\n", + " rapids_jar = get_rapids_jar()\n", + " spark = initialize_spark(rapids_jar)\n", + "else:\n", + " print(\"Using existing Spark session.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Generate synthetic dataset\n", + "\n", + "Here we generate a 100,000 x 2048 random dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "24/10/04 18:04:45 WARN TaskSetManager: Stage 0 contains a task of very large size (160085 KiB). The maximum recommended task size is 1000 KiB.\n", + " \r" + ] + } + ], + "source": [ + "rows = 100000\n", + "dim = 2048\n", + "dtype = 'float32'\n", + "np.random.seed(42)\n", + "\n", + "data = np.random.rand(rows, dim).astype(dtype)\n", + "pd_data = pd.DataFrame({\"features\": list(data)})\n", + "prepare_df = spark.createDataFrame(pd_data)\n", + "prepare_df.write.mode(\"overwrite\").parquet(\"PCA_data.parquet\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Spark-RAPIDS-ML accepts ArrayType input\n", + "\n", + "Note that in the original Spark-ML PCA, we must `Vectorize` the input column:\n", + "\n", + "```python\n", + "from pyspark.ml.linalg import Vectors\n", + "data = [(Vectors.sparse(5, [(1, 1.0), (3, 7.0)]),),\n", + " (Vectors.dense([2.0, 0.0, 3.0, 4.0, 5.0]),),\n", + " (Vectors.dense([4.0, 0.0, 0.0, 6.0, 7.0]),)]\n", + "df = spark.createDataFrame(data,[\"features\"])\n", + "df.show()\n", + "```\n", + "\n", + "...whereas the Spark-RAPIDS-ML version does not require extra Vectorization, and can accept an ArrayType column as the input column:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "root\n", + " |-- features: array (nullable = true)\n", + " | |-- element: float (containsNull = true)\n", + "\n" + ] + } + ], + "source": [ + "data_df = spark.read.parquet(\"PCA_data.parquet\")\n", + "data_df.printSchema()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Using Spark-RAPIDS-ML PCA (GPU)\n", + "\n", + "Compared to the Spark-ML PCA training API:\n", + "\n", + "```python\n", + "from pyspark.ml.feature import PCA\n", + "pca = PCA(k=3, inputCol=\"features\")\n", + "pca.setOutputCol(\"pca_features\")\n", + "```\n", + "\n", + "We use a customized class which requires **no code change** from the user to enjoy GPU acceleration:\n", + "\n", + "```python\n", + "from spark_rapids_ml.feature import PCA\n", + "pca = PCA(k=3, inputCol=\"features\")\n", + "pca.setOutputCol(\"pca_features\")\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "PCA_570681141389" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from spark_rapids_ml.feature import PCA\n", + "\n", + "gpu_pca = PCA(k=2, inputCol=\"features\")\n", + "gpu_pca.setOutputCol(\"pca_features\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The PCA estimator object can be persisted and reloaded." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "estimator_path = \"/tmp/pca_estimator\"\n", + "gpu_pca.write().overwrite().save(estimator_path)\n", + "gpu_pca_loaded = PCA.load(estimator_path)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Fit" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "24/10/04 18:04:58 WARN MultiFileReaderThreadPool: Configuring the file reader thread pool with a max of 32 threads instead of spark.rapids.sql.multiThreadedRead.numThreads = 20\n", + "2024-10-04 18:04:58,487 - spark_rapids_ml.feature.PCA - INFO - CUDA managed memory enabled.\n", + "2024-10-04 18:04:58,570 - spark_rapids_ml.feature.PCA - INFO - Training spark-rapids-ml with 1 worker(s) ...\n", + "INFO: Process 2762394 found CUDA visible device(s): 0\n", + "2024-10-04 18:05:01,613 - spark_rapids_ml.feature.PCA - INFO - Loading data into python worker memory\n", + "2024-10-04 18:05:02,551 - spark_rapids_ml.feature.PCA - INFO - Initializing cuml context\n", + "2024-10-04 18:05:03,795 - spark_rapids_ml.feature.PCA - INFO - Invoking cuml fit\n", + "2024-10-04 18:05:05,326 - spark_rapids_ml.feature.PCA - INFO - Cuml fit complete\n", + "2024-10-04 18:05:06,858 - spark_rapids_ml.feature.PCA - INFO - Finished training\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "GPU PCA fit took: 8.90433144569397 sec\n" + ] + } + ], + "source": [ + "start_time = time.time()\n", + "gpu_pca_model = gpu_pca_loaded.fit(data_df)\n", + "gpu_fit_time = time.time() - start_time\n", + "print(f\"GPU PCA fit took: {gpu_fit_time} sec\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Transform" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+---------------------------+\n", + "|pca_features |\n", + "+---------------------------+\n", + "|[0.062363233, 0.4037608] |\n", + "|[0.49734917, 0.703541] |\n", + "|[0.0035427138, 0.29358602] |\n", + "|[-0.06798951, 0.37400067] |\n", + "|[0.10075127, 0.34651726] |\n", + "|[-0.22320557, 0.6660976] |\n", + "|[0.49608234, 0.6761328] |\n", + "|[0.25515205, 0.20352581] |\n", + "|[-0.5102935, 0.319284] |\n", + "|[-0.5109488, 0.2756377] |\n", + "|[0.411546, -0.17954555] |\n", + "|[0.21616393, -0.46268395] |\n", + "|[-0.0924304, 0.65660465] |\n", + "|[0.12355948, 0.9478601] |\n", + "|[0.49234354, 0.63746333] |\n", + "|[-0.86077166, 0.0037032962]|\n", + "|[-0.013956882, 0.663955] |\n", + "|[-0.30510652, 0.02372247] |\n", + "|[-0.05999008, 0.28261736] |\n", + "|[0.36605445, 0.9674797] |\n", + "+---------------------------+\n", + "only showing top 20 rows\n", + "\n", + "GPU PCA transform took: 0.43911027908325195 sec\n" + ] + } + ], + "source": [ + "start_time = time.time()\n", + "embeddings = gpu_pca_model.transform(data_df).select(\"pca_features\").show(truncate=False)\n", + "gpu_transform_time = time.time() - start_time\n", + "print(f\"GPU PCA transform took: {gpu_transform_time} sec\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Using Spark-ML PCA (CPU)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "PCA_58add243f20d" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from pyspark.ml.feature import PCA\n", + "\n", + "cpu_pca = PCA(k=2, inputCol=\"features\")\n", + "cpu_pca.setOutputCol(\"pca_features\")" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "root\n", + " |-- features: vector (nullable = true)\n", + "\n" + ] + } + ], + "source": [ + "from pyspark.ml.functions import array_to_vector\n", + "\n", + "vector_df = data_df.select(array_to_vector(\"features\").alias(\"features\"))\n", + "vector_df.printSchema()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Fit" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "24/10/04 17:07:07 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.lapack.JNILAPACK\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU PCA fit took: 63.37388610839844 sec\n" + ] + } + ], + "source": [ + "start_time = time.time()\n", + "cpu_pca_model = cpu_pca.fit(vector_df)\n", + "pca_fit_time = time.time() - start_time\n", + "print(f\"CPU PCA fit took: {pca_fit_time} sec\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Transform" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+-------------------------------------------+\n", + "|pca_features |\n", + "+-------------------------------------------+\n", + "|[0.24926765828229927,0.3425432972889563] |\n", + "|[-0.5175207040808384,0.48893065865444574] |\n", + "|[-0.2505049373829902,0.381272141155778] |\n", + "|[-0.39046980420292005,0.4870705091697811] |\n", + "|[-0.4024088726395023,0.707133448810984] |\n", + "|[-0.3061227832285992,0.5363554872099332] |\n", + "|[-0.6065136982526093,0.5205197626985932] |\n", + "|[-0.21870566838630084,0.6516598402789231] |\n", + "|[0.1910036552854184,0.6336513389989592] |\n", + "|[0.6139537641786907,0.6055187085018856] |\n", + "|[-0.026502904776425647,-0.0366087508156753]|\n", + "|[-0.2989311781309336,-0.05136110567458389] |\n", + "|[-0.5474468086054212,-0.18779964958125014] |\n", + "|[-0.6644746232216499,0.10351178251944647] |\n", + "|[-0.12685301272617464,0.47394431583661295] |\n", + "|[-0.4355221246718862,-0.00346289187881239] |\n", + "|[0.6222719258951077,0.5488293416698503] |\n", + "|[0.04966907735703511,0.7138677407505005] |\n", + "|[0.6260486995906139,0.3553228450428632] |\n", + "|[0.16396683091519929,0.7382693234881972] |\n", + "+-------------------------------------------+\n", + "only showing top 20 rows\n", + "\n", + "CPU PCA transform took: 0.19607114791870117 sec\n" + ] + } + ], + "source": [ + "start_time = time.time()\n", + "embeddings = cpu_pca_model.transform(vector_df).select(\"pca_features\").show(truncate=False)\n", + "pca_transform_time = time.time() - start_time\n", + "print(f\"CPU PCA transform took: {pca_transform_time} sec\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Summary" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU runtime: (64.02s + 0.20s)\n", + "GPU runtime: (8.76s + 0.42s)\n", + "End-to-end speedup: CPU / GPU = 7.00x\n" + ] + } + ], + "source": [ + "speedup = (pca_fit_time + pca_transform_time) / (gpu_fit_time + gpu_transform_time)\n", + "print(f\"CPU runtime: ({pca_fit_time:.2f}s + {pca_transform_time:.2f}s)\")\n", + "print(f\"GPU runtime: ({gpu_fit_time:.2f}s + {gpu_transform_time:.2f}s)\")\n", + "print(f\"End-to-end speedup: CPU / GPU = {speedup:.2f}x\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "rapids-24.08", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.10" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/examples/ML+DL-Examples/Spark-Rapids-ML/pca/start-spark-rapids.sh b/examples/ML+DL-Examples/Spark-Rapids-ML/pca/start-spark-rapids.sh new file mode 100755 index 000000000..4e292bb4e --- /dev/null +++ b/examples/ML+DL-Examples/Spark-Rapids-ML/pca/start-spark-rapids.sh @@ -0,0 +1,80 @@ +#!/bin/bash +# +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Check if SPARK_HOME is set +if [ -z "$SPARK_HOME" ]; then + echo "Please set the SPARK_HOME environment variable before running this script." + exit 1 +fi + +# Check if RAPIDS_JAR is set +if [ -z "$RAPIDS_JAR" ]; then + echo "Please set the RAPIDS_JAR environment variable before running this script." + exit 1 +fi + +# Configuration +MASTER_HOSTNAME=$(hostname) +MASTER=spark://${MASTER_HOSTNAME}:7077 +CORES_PER_WORKER=8 +MEMORY_PER_WORKER=16G + +# Environment variables +export SPARK_HOME=${SPARK_HOME} +export MASTER=${MASTER} +export SPARK_WORKER_INSTANCES=1 +export CORES_PER_WORKER=${CORES_PER_WORKER} +export PYSPARK_DRIVER_PYTHON=jupyter +export PYSPARK_DRIVER_PYTHON_OPTS='lab' + +# Start standalone cluster +echo "Starting Spark standalone cluster..." +${SPARK_HOME}/sbin/start-master.sh +${SPARK_HOME}/sbin/start-worker.sh -c ${CORES_PER_WORKER} -m ${MEMORY_PER_WORKER} ${MASTER} + +# Start Jupyter with PySpark +echo "Launching PySpark with Jupyter..." +${SPARK_HOME}/bin/pyspark --master ${MASTER} \ +--driver-memory 10G \ +--executor-memory 8G \ +--conf spark.task.maxFailures=1 \ +--conf spark.rpc.message.maxSize=1024 \ +--conf spark.sql.pyspark.jvmStacktrace.enabled=true \ +--conf spark.sql.execution.pyspark.udf.simplifiedTraceback.enabled=false \ +--conf spark.sql.execution.arrow.pyspark.enabled=true \ +--conf spark.python.worker.reuse=true \ +--conf spark.rapids.ml.uvm.enabled=true \ +--conf spark.jars=${RAPIDS_JAR} \ +--conf spark.executorEnv.PYTHONPATH=${RAPIDS_JAR} \ +--conf spark.rapids.memory.gpu.minAllocFraction=0.0001 \ +--conf spark.plugins=com.nvidia.spark.SQLPlugin \ +--conf spark.locality.wait=0s \ +--conf spark.sql.cache.serializer=com.nvidia.spark.ParquetCachedBatchSerializer \ +--conf spark.rapids.memory.gpu.pooling.enabled=false \ +--conf spark.sql.execution.sortBeforeRepartition=false \ +--conf spark.rapids.sql.format.parquet.reader.type=MULTITHREADED \ +--conf spark.rapids.sql.format.parquet.multiThreadedRead.maxNumFilesParallel=20 \ +--conf spark.rapids.sql.multiThreadedRead.numThreads=20 \ +--conf spark.rapids.sql.python.gpu.enabled=true \ +--conf spark.rapids.memory.pinnedPool.size=2G \ +--conf spark.python.daemon.module=rapids.daemon \ +--conf spark.rapids.sql.batchSizeBytes=512m \ +--conf spark.sql.adaptive.enabled=false \ +--conf spark.sql.files.maxPartitionBytes=512m \ +--conf spark.rapids.sql.concurrentGpuTasks=1 \ +--conf spark.sql.execution.arrow.maxRecordsPerBatch=20000 \ +--conf spark.rapids.sql.explain=NONE \ No newline at end of file diff --git a/examples/ML+DL-Examples/Spark-cuML/pca/Dockerfile b/examples/ML+DL-Examples/Spark-cuML/pca/Dockerfile deleted file mode 100644 index deaef0ffd..000000000 --- a/examples/ML+DL-Examples/Spark-cuML/pca/Dockerfile +++ /dev/null @@ -1,80 +0,0 @@ -#!/bin/bash -# -# Copyright (c) 2021-2023, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -ARG CUDA_VER=11.8.0 -FROM nvidia/cuda:${CUDA_VER}-devel-ubuntu20.04 -# Please do not update the BRANCH_VER version -ARG BRANCH_VER=24.08 - -RUN apt-get update -RUN apt-get install -y wget ninja-build git - -ENV PATH="/root/miniconda3/bin:${PATH}" -ARG PATH="/root/miniconda3/bin:${PATH}" -RUN wget --quiet \ - https://repo.anaconda.com/miniconda/Miniconda3-py38_4.10.3-Linux-x86_64.sh \ - && mkdir /root/.conda \ - && bash Miniconda3-py38_4.10.3-Linux-x86_64.sh -b \ - && rm -f Miniconda3-py38_4.10.3-Linux-x86_64.sh \ - && conda init - -SHELL ["conda", "run", "--no-capture-output", "-n", "base", "/bin/bash", "-c"] -RUN echo $PATH -RUN echo $CONDA_PREFIX -RUN conda --version - -RUN conda install -c conda-forge openjdk=8 maven=3.8.1 -y - -# install cuDF dependency. -RUN conda install -c rapidsai-nightly -c nvidia -c conda-forge cudf=${BRANCH_VER} python=3.8 -y - -RUN wget --quiet \ - https://github.com/Kitware/CMake/releases/download/v3.26.4/cmake-3.26.4-linux-x86_64.tar.gz \ - && tar -xzf cmake-3.21.3-linux-x86_64.tar.gz \ - && rm -rf cmake-3.21.3-linux-x86_64.tar.gz - -ENV PATH="/cmake-3.21.3-linux-x86_64/bin:${PATH}" - -RUN git clone -b branch-${BRANCH_VER} https://github.com/rapidsai/raft.git - -ENV RAFT_PATH=/raft - -#use main branch to download release jars -RUN git clone -b main https://github.com/NVIDIA/spark-rapids-ml.git \ - && cd spark-rapids-ml \ - && mvn clean -RUN cd /spark-rapids-ml \ - && mvn install -DskipTests - -ADD scala /workspace/scala -ADD pom.xml /workspace/ - -RUN cd /workspace/ \ - && mvn clean package - -# install spark-3.1.2-bin-hadoop3.2 -RUN wget --quiet \ - https://archive.apache.org/dist/spark/spark-3.1.2/spark-3.1.2-bin-hadoop3.2.tgz \ - && tar -xzf spark-3.1.2-bin-hadoop3.2.tgz -C /opt/ \ - && rm spark-3.1.2-bin-hadoop3.2.tgz -ENV SPARK_HOME=/opt/spark-3.1.2-bin-hadoop3.2 -# add spark env to conf -ADD spark-env.sh /opt/spark-3.1.2-bin-hadoop3.2/conf/ -ADD start-spark.sh /workspace/ -ADD spark-submit.sh /workspace/ - -WORKDIR /workspace diff --git a/examples/ML+DL-Examples/Spark-cuML/pca/README.md b/examples/ML+DL-Examples/Spark-cuML/pca/README.md deleted file mode 100644 index 75ced6296..000000000 --- a/examples/ML+DL-Examples/Spark-cuML/pca/README.md +++ /dev/null @@ -1,102 +0,0 @@ -# PCA example - -This is an example of the GPU accelerated PCA algorithm running on Spark. - -## Build - -Please refer to [README](https://github.com/NVIDIA/spark-rapids-ml#readme) in the [spark-rapids-ml](https://github.com/NVIDIA/spark-rapids-ml) github repository for build instructions and API usage. - -## Get jars from Maven Central - -User can also download the release jar from Maven central: - -[rapids-4-spark-ml_2.12-22.02.0-cuda11.jar](https://repo1.maven.org/maven2/com/nvidia/rapids-4-spark-ml_2.12/22.02.0/rapids-4-spark-ml_2.12-22.02.0-cuda11.jar) - -[rapids-4-spark_2.12-23.04.0.jar](https://repo1.maven.org/maven2/com/nvidia/rapids-4-spark_2.12/23.04.0/rapids-4-spark_2.12-23.04.0.jar) - -Note: This demo could only work with v22.02.0 spark-ml version, and only compatible with spark-rapids versions prior to 23.04.0 . Please do not update the version in release. - -## Sample code - -User can find sample scala code in [`main.scala`](main.scala). In the sample code, we will generate random data with 2048 feature dimensions. Then we use PCA to reduce number of features to 3. - -Just copy the sample code into the spark-shell launched according to [this section](https://github.com/NVIDIA/spark-rapids-ml#how-to-use) and REPL will give out the algorithm results. - -## Notebook - -[Apache Toree](https://toree.apache.org/) is required to run PCA sample code in a Jupyter Notebook. - -It is assumed that a Standalone Spark cluster has been set up, the `SPARK_MASTER` and `SPARK_HOME` environment variables are defined and point to the master spark URL (e.g. `spark://localhost:7077`), and the home directory for Apache Spark respectively. - -1. Make sure you have jupyter notebook and [sbt](https://www.scala-sbt.org/1.x/docs/Installing-sbt-on-Linux.html) installed first. -2. Build the 'toree' locally to support scala 2.12, and install it. - - ``` bash - # Download toree - wget https://github.com/apache/incubator-toree/archive/refs/tags/v0.5.0-incubating-rc4.tar.gz - - tar -xvzf v0.5.0-incubating-rc4.tar.gz - - # Build the Toree pip package. - cd incubator-toree-0.5.0-incubating-rc4 - make pip-release - - # Install Toree - pip install dist/toree-pip/toree-0.5.0.tar.gz - ``` - -3. Install a new kernel with the jar(use $RAPIDS_ML_JAR for reference) built from section [Build](#build) and launch - - ``` bash - RAPIDS_ML_JAR=PATH_TO_rapids-4-spark-ml_2.12-22.02.0-cuda11.jar - PLUGIN_JAR=PATH_TO_rapids-4-spark_2.12-23.04.0.jar - - jupyter toree install \ - --spark_home=${SPARK_HOME} \ - --user \ - --toree_opts='--nosparkcontext' \ - --kernel_name="spark-rapids-ml-pca" \ - --spark_opts='--master ${SPARK_MASTER} \ - --jars ${RAPIDS_ML_JAR},${PLUGIN_JAR} \ - --conf spark.driver.memory=10G \ - --conf spark.executor.memory=10G \ - --conf spark.executor.heartbeatInterval=20s \ - --conf spark.rapids.sql.enabled=true \ - --conf spark.plugins=com.nvidia.spark.SQLPlugin \ - --conf spark.rapids.memory.gpu.allocFraction=0.35 \ - --conf spark.rapids.memory.gpu.maxAllocFraction=0.6 \ - --conf spark.executor.resource.gpu.amount=1 \ - --conf spark.task.resource.gpu.amount=1 \ - --conf spark.executor.resource.gpu.discoveryScript=./getGpusResources.sh \ - --files $SPARK_HOME/examples/src/main/scripts/getGpusResources.sh' - ``` - - Launch the notebook: - - ``` bash - jupyter notebook - ``` - - Please choose "spark-rapids-ml-pca" as your notebook kernel. - - - -## Submit app jar - -We also provide the spark-submit way to run the PCA example. We suggest using Dockerfile to get a clean runnning environment: - -```bash -docker build -f Dockerfile -t nvspark/pca:0.1 . -``` -Then get into the container of this image(`nvidia-docker` is required as we will use GPU then): -```bash -nvidia-docker run -it nvspark/pca:0.1 bash -``` - -In this docker image, we assume that user has 2 GPUs in his machine. If you are not the condition, please modify the `-Dspark.worker.resource.gpu.amount` in `spark-env.sh` according to your actual environment. - -Then just start the standalone Spark and submit the job: -``` bash -./start-spark.sh -./spark-submit.sh -``` diff --git a/examples/ML+DL-Examples/Spark-cuML/pca/main.scala b/examples/ML+DL-Examples/Spark-cuML/pca/main.scala deleted file mode 100644 index 408d6cb40..000000000 --- a/examples/ML+DL-Examples/Spark-cuML/pca/main.scala +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import com.nvidia.spark.ml.feature.PCA -import org.apache.spark.ml.linalg._ -import org.apache.spark.sql.functions._ -val dim = 2048 -val rows = 1000 -val r = new scala.util.Random(0) - -// generate dummy data -val dataDf = spark.createDataFrame( - (0 until rows).map(_ => Tuple1(Array.fill(dim)(r.nextDouble)))).withColumnRenamed("_1", "feature") - - -// use RAPIDS PCA class and enable cuBLAS gemm -val pcaGpu = new com.nvidia.spark.ml.feature.PCA().setInputCol("feature").setOutputCol("pca_features").setK(3) - -// train -val pcaModelGpu = spark.time(pcaGpu.fit(dataDf)) - -// transform -pcaModelGpu.transform(dataDf).select("pca_features").show(false) - -// use original Spark ML PCA class -val pcaCpu = new org.apache.spark.ml.feature.PCA().setInputCol("feature_vec").setOutputCol("pca_features").setK(3) - -// use udf to meet standard CPU ML algo input requirement: Vector input -val convertToVector = udf((array: Seq[Double]) => { - Vectors.dense(array.map(_.toDouble).toArray) -}) - -val vectorDf = dataDf.withColumn("feature_vec", convertToVector(col("feature"))) - -// train -val pcaModelCpu = spark.time(pcaCpu.fit(vectorDf)) - -// transform -pcaModelCpu.transform(vectorDf).select("pca_features").show(false) - diff --git a/examples/ML+DL-Examples/Spark-cuML/pca/notebooks/Spark_PCA_End_to_End.ipynb b/examples/ML+DL-Examples/Spark-cuML/pca/notebooks/Spark_PCA_End_to_End.ipynb deleted file mode 100644 index 277505bd0..000000000 --- a/examples/ML+DL-Examples/Spark-cuML/pca/notebooks/Spark_PCA_End_to_End.ipynb +++ /dev/null @@ -1,544 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "90826259", - "metadata": {}, - "source": [ - "## Introduction\n", - "\n", - "In this notebook, we will show the integrated workflow of Spark RAPIDS accelerated ETL and PCA train & transform." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "0f555c95", - "metadata": {}, - "outputs": [], - "source": [ - "import org.apache.spark.ml.linalg._\n", - "import org.apache.spark.sql.functions._" - ] - }, - { - "cell_type": "markdown", - "id": "6d47a7f0", - "metadata": {}, - "source": [ - "### Generate dummy data for PCA benchmark\n", - "\n", - "Generate the sample data of 2048 columns and 50,000 rows" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "18e00096", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Waiting for a Spark session to start..." - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "rows = 50000\n", - "dim = 2048\n", - "r = scala.util.Random@18a448de\n", - "prepareDf = [array_feature[0]: double, array_feature[1]: double ... 2046 more fields]\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "[array_feature[0]: double, array_feature[1]: double ... 2046 more fields]" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "val rows = 50000\n", - "val dim = 2048\n", - "val r = new scala.util.Random(0)\n", - "var prepareDf = spark.createDataFrame(\n", - " (0 until rows).map(_ => Tuple1(Array.fill(dim)(r.nextDouble))))\n", - " .withColumnRenamed(\"_1\", \"array_feature\")\n", - " .select((0 until dim).map(i => col(\"array_feature\").getItem(i)): _*)\n", - "prepareDf.write.mode(\"overwrite\").parquet(\"PCA_raw_parquet\")" - ] - }, - { - "cell_type": "markdown", - "id": "f2ed419b", - "metadata": {}, - "source": [ - "### Read raw parquet data\n", - "\n", - "The parquet file contains the raw data for PCA train and transform.\n", - "\n", - "There're 2048 columns in the table naming as \"array_feature[0], array_feature[1] ... array_feature[2047]\"." - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "0084a38f", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "df = [array_feature[0]: double, array_feature[1]: double ... 2046 more fields]\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "[array_feature[0]: double, array_feature[1]: double ... 2046 more fields]" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "val df = spark.read.parquet(\"PCA_raw_parquet\")" - ] - }, - { - "cell_type": "markdown", - "id": "58a33554", - "metadata": {}, - "source": [ - "### ETL: Calculate mean value for each column\n", - "\n", - "PCA algorithm is expecting mean centered data as input, so use a simple ETL process to do mean centering." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "a152bd3d", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "dim = 2048\n", - "avgValue = [0.5014784341440235,0.5007938298214618,0.4988382739107633,0.5004857021518329,0.4976086737881863,0.501459317390976,0.4998871629299758,0.5003749032337383,0.5004268051953419,0.4992212831312325,0.5002230208274252,0.49916485476370304,0.49928552249125024,0.5001192271170941,0.4974153011145406,0.500340861041902,0.500511698285404,0.5029175790341269,0.5000848064753295,0.49946358217105435,0.4991402970341374,0.4999057035861329,0.4993188619485362,0.49782509547668896,0.5001573241354326,0.4991954590903186,0.4988846878237177,0.5008673384728016,0.4982505290656533,0.5000069827383224,0.49830672380384944,0.49849188876978057,0.502253148518209,0.4995624384114367,0.5006052199700368,0.49922409882583835,0.4996825327694508,0.4983465266402566,0.5001149704952238...\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "[0.5014784341440235,0.5007938298214618,0.4988382739107633,0.5004857021518329,0.4976086737881863,0.501459317390976,0.4998871629299758,0.5003749032337383,0.5004268051953419,0.4992212831312325,0.5002230208274252,0.49916485476370304,0.49928552249125024,0.5001192271170941,0.4974153011145406,0.500340861041902,0.500511698285404,0.5029175790341269,0.5000848064753295,0.49946358217105435,0.4991402970341374,0.4999057035861329,0.4993188619485362,0.49782509547668896,0.5001573241354326,0.4991954590903186,0.4988846878237177,0.5008673384728016,0.4982505290656533,0.5000069827383224,0.49830672380384944,0.49849188876978057,0.502253148518209,0.4995624384114367,0.5006052199700368,0.49922409882583835,0.4996825327694508,0.4983465266402566,0.5001149704952238..." - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "val dim = 2048\n", - "val avgValue = df.select(\n", - " (0 until dim).map(\"array_feature[\" + _ + \"]\").map(col).map(avg): _*).first()\n", - "val inputCols = (0 until dim).map(i =>\n", - " (col(\"array_feature[\" + i + \"]\") - avgValue.getDouble(i)).alias(\"feature_\"+i)\n", - " )\n", - "val meanCenterDf = df.select(inputCols:_*)" - ] - }, - { - "cell_type": "markdown", - "id": "0bc302cd", - "metadata": {}, - "source": [ - "### Spark RAPIDS accelerated PCA can accept ArrayType column as the input column.\n", - "\n", - "Comparing to the original Spark PCA requirement, there's no need to do extra `Vectorize` work for the input column.\n", - "\n", - "For example, the following code is required when using standard Spark PCA:\n", - "\n", - "```scala\n", - "val convertToVector = udf((array: Seq[Double]) => {\n", - " Vectors.dense(array.map(_.toDouble).toArray)\n", - "})\n", - "val vectorDf = dataDf.withColumn(\"feature_vec\", convertToVector(col(\"feature\")))\n", - "```" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "ac57b415", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "dataDf = [feature_0: double, feature_1: double ... 2047 more fields]\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "[feature_0: double, feature_1: double ... 2047 more fields]" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "val dataDf = meanCenterDf.withColumn(\"feature\",array(meanCenterDf.columns.map(col):_*))" - ] - }, - { - "cell_type": "markdown", - "id": "ec1138d3", - "metadata": {}, - "source": [ - "### Use Spark RAPIDS accelerated PCA\n", - "\n", - "Comparing to the original PCA training API:\n", - "\n", - "```scala\n", - "val pca = new org.apache.spark.ml.feature.PCA()\n", - " .setInputCol(\"feature\")\n", - " .setOutputCol(\"pca_features\")\n", - " .setK(3)\n", - " .fit(vectorDf)\n", - "```\n", - "\n", - "We used a customized class and user will need to do `no code change` to enjoy the GPU acceleration:\n", - "\n", - "```scala\n", - "val pca = new com.nvidia.spark.ml.feature.PCA()\n", - "...\n", - "```" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "8d9fb9bd", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "pcaGpu = pca_6b8d054604e4\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "pca_6b8d054604e4" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "val pcaGpu = new com.nvidia.spark.ml.feature.PCA().setInputCol(\"feature\").setOutputCol(\"pca_features\").setK(3)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "d95ce4f9", - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/plain": [ - "pcaModelGpu = PCAModel: uid=pca_6b8d054604e4, k=3\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Time taken: 8280 ms\n" - ] - }, - { - "data": { - "text/plain": [ - "PCAModel: uid=pca_6b8d054604e4, k=3" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "val pcaModelGpu = spark.time(pcaGpu.fit(dataDf))" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "0263a1b2", - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "+-----------------------------------------+\n", - "|pca_features |\n", - "+-----------------------------------------+\n", - "|[0.568805417, 0.041445481, 0.621107902] |\n", - "|[0.378405859, -0.244389411, -0.358809445]|\n", - "|[0.421817533, -0.309621711, -0.159095405]|\n", - "|[0.424088954, 0.09907811, 0.252832213] |\n", - "|[0.481344556, 0.303004001, 0.06884068] |\n", - "|[0.837941281, 0.113648256, -0.319001501] |\n", - "|[0.093790516, -0.364140016, -0.33318393] |\n", - "|[0.103996026, -0.174265839, 0.226559042] |\n", - "|[-0.283206201, -0.487276589, 0.174362571]|\n", - "|[0.101710379, 0.569866637, 0.118964435] |\n", - "+-----------------------------------------+\n", - "only showing top 10 rows\n", - "\n", - "Time taken: 3284 ms\n" - ] - } - ], - "source": [ - "spark.time(pcaModelGpu.transform(dataDf).select(\"pca_features\").show(10, false))" - ] - }, - { - "cell_type": "markdown", - "id": "4e09e7fa", - "metadata": {}, - "source": [ - "### Use original Spark PCA" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "6db8b704", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "convertToVector = SparkUserDefinedFunction($Lambda$4927/1016137095@128d5536,org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7,List(Some(class[value[0]: array])),None,true,true)\n", - "vectorDf = [feature_0: double, feature_1: double ... 2048 more fields]\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "[feature_0: double, feature_1: double ... 2048 more fields]" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "val convertToVector = udf((array: Seq[Double]) => {\n", - " Vectors.dense(array.map(_.toDouble).toArray)\n", - "})\n", - "val vectorDf = dataDf.withColumn(\"feature_vec\", convertToVector(col(\"feature\")))" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "b870173b", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "pcaCpu = pca_3f8feb827742\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "pca_3f8feb827742" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "val pcaCpu = new org.apache.spark.ml.feature.PCA().setInputCol(\"feature_vec\").setOutputCol(\"pca_features\").setK(3)" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "e3e1a9ca", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "pcaModelCpu = PCAModel: uid=pca_3f8feb827742, k=3\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Time taken: 140539 ms\n" - ] - }, - { - "data": { - "text/plain": [ - "PCAModel: uid=pca_3f8feb827742, k=3" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "val pcaModelCpu = spark.time(pcaCpu.fit(vectorDf))" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "504f0bb7", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "+--------------------------------------------------------------+\n", - "|pca_features |\n", - "+--------------------------------------------------------------+\n", - "|[0.5688054172628585,-0.04144548077183109,-0.6211079018457807] |\n", - "|[0.37840585922945796,0.24438941118757604,0.3588094451238177] |\n", - "|[0.4218175332258925,0.3096217108376109,0.15909540520858537] |\n", - "|[0.4240889539599815,-0.09907811042793396,-0.2528322129752815] |\n", - "|[0.4813445560531313,-0.3030040008580291,-0.06884068037876276] |\n", - "|[0.8379412808563966,-0.11364825624115062,0.3190015014324452] |\n", - "|[0.09379051625949268,0.3641400160124998,0.33318393004824964] |\n", - "|[0.10399602625088979,0.17426583892592548,-0.2265590421381768] |\n", - "|[-0.2832062006131796,0.4872765894121887,-0.1743625713365004] |\n", - "|[0.10171037937872408,-0.5698666372294762,-0.11896443456600647]|\n", - "+--------------------------------------------------------------+\n", - "only showing top 10 rows\n", - "\n", - "Time taken: 12738 ms\n" - ] - } - ], - "source": [ - "spark.time(pcaModelCpu.transform(vectorDf).select(\"pca_features\").show(10, false))" - ] - }, - { - "cell_type": "markdown", - "id": "d486eb9a", - "metadata": {}, - "source": [ - "### Summary\n", - "\n", - "With the data of 50,000 rows, we achived:\n", - "\n", - "the speedup for training: 140539 / 8280 = `16.97`\n", - "\n", - "the speedup for transform: 12738 / 3284 = `3.87`" - ] - }, - { - "cell_type": "markdown", - "id": "671bc74d", - "metadata": {}, - "source": [ - "### Note\n", - "\n", - "Some columns in GPU output have different signs from that in CPU output, this is due to the calculation nature of SVD algorithm which doesn't impact the effectiveness of the SVD results. More details could be found in the [wiki](https://en.wikipedia.org/wiki/Singular_value_decomposition#Relation_to_eigenvalue_decomposition)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4d53fb76", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "spark-rapids-ml-pca - Scala", - "language": "scala", - "name": "spark-rapids-ml-pca_scala" - }, - "language_info": { - "codemirror_mode": "text/x-scala", - "file_extension": ".scala", - "mimetype": "text/x-scala", - "name": "scala", - "pygments_lexer": "scala", - "version": "2.12.12" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/examples/ML+DL-Examples/Spark-cuML/pca/pom.xml b/examples/ML+DL-Examples/Spark-cuML/pca/pom.xml deleted file mode 100644 index 3d7f28b48..000000000 --- a/examples/ML+DL-Examples/Spark-cuML/pca/pom.xml +++ /dev/null @@ -1,87 +0,0 @@ - - - - 4.0.0 - - com.nvidia - PCAExample - jar - 23.04.0-SNAPSHOT - - - 8 - 8 - - - - - org.apache.spark - spark-core_2.12 - 3.1.2 - - - - org.apache.spark - spark-sql_2.12 - 3.1.2 - provided - - - - org.apache.spark - spark-mllib_2.12 - 3.1.2 - provided - - - com.nvidia - rapids-4-spark-ml_2.12 - 22.02.0 - - - - - - scala/src - - - org.apache.maven.plugins - maven-compiler-plugin - 3.8.1 - - 1.8 - 1.8 - - - - org.scala-tools - maven-scala-plugin - 2.15.2 - - - - compile - testCompile - - - - - - - - diff --git a/examples/ML+DL-Examples/Spark-cuML/pca/scala/src/com/nvidia/spark/examples/pca/Main.scala b/examples/ML+DL-Examples/Spark-cuML/pca/scala/src/com/nvidia/spark/examples/pca/Main.scala deleted file mode 100644 index 46a3dab85..000000000 --- a/examples/ML+DL-Examples/Spark-cuML/pca/scala/src/com/nvidia/spark/examples/pca/Main.scala +++ /dev/null @@ -1,83 +0,0 @@ -/* - * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -package com.nvidia.spark.examples.pca - -import org.apache.spark.ml.linalg._ -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.functions._ - -object Main { - def main(args: Array[String]): Unit = { - val spark = SparkSession.builder().appName("PCA Example").getOrCreate() - val dim = 2048 - val rows = 50000 - val r = new scala.util.Random(0) - - // generate dummy data - var prepareDf = spark.createDataFrame( - (0 until rows).map(_ => Tuple1(Array.fill(dim)(r.nextDouble)))) - .withColumnRenamed("_1", "array_feature") - .select((0 until dim).map(i => col("array_feature").getItem(i)): _*) - // save to parquet files - prepareDf.write.mode("overwrite").parquet("PCA_raw_parquet") - - // load the parquet files - val df = spark.read.parquet("PCA_raw_parquet") - - // mean centering via ETL - val avgValue = df.select( - (0 until dim).map("array_feature[" + _ + "]").map(col).map(avg): _*).first() - val inputCols = (0 until dim).map(i => - (col("array_feature[" + i + "]") - avgValue.getDouble(i)).alias("feature_"+i) - ) - val meanCenterDf = df.select(inputCols:_*) - - val dataDf = meanCenterDf.withColumn("feature",array(meanCenterDf.columns.map(col):_*)) - - val pcaGpu = new com.nvidia.spark.ml.feature.PCA().setInputCol("feature").setOutputCol("pca_features").setK(3) - // GPU train - val gpuStart = System.currentTimeMillis() - val pcaModelGpu = pcaGpu.fit(dataDf) - val gpuEnd = System.currentTimeMillis() - - // use udf to meet standard CPU ML algo input requirement: Vector input - val convertToVector = udf((array: Seq[Double]) => { - Vectors.dense(array.map(_.toDouble).toArray) - }) - - val vectorDf = dataDf.withColumn("feature_vec", convertToVector(col("feature"))) - - // use original Spark ML PCA class - val pcaCpu = new org.apache.spark.ml.feature.PCA().setInputCol("feature_vec").setOutputCol("pca_features").setK(3) - - // CPU train - val cpuStart = System.currentTimeMillis() - val pcaModelCpu = pcaCpu.fit(vectorDf) - val cpuEnd = System.currentTimeMillis() - - - println("GPU training: ") - println( (gpuEnd - gpuStart) / 1000 + " seconds") - println("CPU training: ") - println( (cpuEnd - cpuStart) / 1000 + " seconds") - - // transform - pcaModelGpu.transform(vectorDf).select("pca_features").show(false) - pcaModelCpu.transform(vectorDf).select("pca_features").show(false) - } -} diff --git a/examples/ML+DL-Examples/Spark-cuML/pca/spark-env.sh b/examples/ML+DL-Examples/Spark-cuML/pca/spark-env.sh deleted file mode 100644 index 82ead37c5..000000000 --- a/examples/ML+DL-Examples/Spark-cuML/pca/spark-env.sh +++ /dev/null @@ -1,19 +0,0 @@ -#!/bin/bash -# -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -SPARK_MASTER_HOST=127.0.0.1 -SPARK_WORKER_OPTS="-Dspark.worker.resource.gpu.amount=2 -Dspark.worker.resource.gpu.discoveryScript=/opt/spark-3.1.2-bin-hadoop3.2/examples/src/main/scripts/getGpusResources.sh" diff --git a/examples/ML+DL-Examples/Spark-cuML/pca/spark-submit.sh b/examples/ML+DL-Examples/Spark-cuML/pca/spark-submit.sh deleted file mode 100755 index 55a5c5057..000000000 --- a/examples/ML+DL-Examples/Spark-cuML/pca/spark-submit.sh +++ /dev/null @@ -1,43 +0,0 @@ -#!/bin/bash -# -# Copyright (c) 2021-2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# Note that the last rapids-4-spark-ml release version is 22.02.0, snapshot version is 23.04.0-SNAPSHOT, please do not update the version in release -ML_JAR=/root/.m2/repository/com/nvidia/rapids-4-spark-ml_2.12/22.02.0/rapids-4-spark-ml_2.12-22.02.0.jar -PLUGIN_JAR=/root/.m2/repository/com/nvidia/rapids-4-spark_2.12/23.04.0/rapids-4-spark_2.12-23.04.0.jar -Note: The last rapids-4-spark-ml release version is 22.02.0, snapshot version is 23.04.0-SNAPSHOT. - -$SPARK_HOME/bin/spark-submit \ ---master spark://127.0.0.1:7077 \ ---conf spark.executor.cores=12 \ ---conf spark.executor.instances=2 \ ---driver-memory 30G \ ---executor-memory 30G \ ---conf spark.driver.maxResultSize=8G \ ---conf spark.rapids.sql.enabled=true \ ---conf spark.plugins=com.nvidia.spark.SQLPlugin \ ---conf spark.rapids.memory.gpu.allocFraction=0.35 \ ---conf spark.rapids.memory.gpu.maxAllocFraction=0.6 \ ---conf spark.task.resource.gpu.amount=0.08 \ ---conf spark.executor.extraClassPath=$ML_JAR:$PLUGIN_JAR \ ---conf spark.driver.extraClassPath=$ML_JAR:$PLUGIN_JAR \ ---conf spark.executor.resource.gpu.amount=1 \ ---conf spark.rpc.message.maxSize=2046 \ ---conf spark.executor.heartbeatInterval=500s \ ---conf spark.network.timeout=1000s \ ---jars $ML_JAR,$PLUGIN_JAR \ ---class com.nvidia.spark.examples.pca.Main \ -/workspace/target/PCAExample-23.04.0-SNAPSHOT.jar diff --git a/examples/ML+DL-Examples/Spark-cuML/pca/start-spark.sh b/examples/ML+DL-Examples/Spark-cuML/pca/start-spark.sh deleted file mode 100755 index c622e1cb4..000000000 --- a/examples/ML+DL-Examples/Spark-cuML/pca/start-spark.sh +++ /dev/null @@ -1,19 +0,0 @@ -#!/bin/bash -# -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -/opt/spark-3.1.2-bin-hadoop3.2/sbin/start-master.sh -/opt/spark-3.1.2-bin-hadoop3.2/sbin/start-slave.sh spark://127.0.0.1:7077