Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes for CI/CD compatibility #446

Merged
merged 7 commits into from
Oct 18, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 3 additions & 24 deletions examples/ML+DL-Examples/Spark-Rapids-ML/pca/notebooks/pca.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import time"
"import time\n",
"import os"
]
},
{
Expand Down Expand Up @@ -52,28 +53,6 @@
"from pyspark.sql import SparkSession\n",
"from pyspark import SparkConf\n",
"\n",
"def get_rapids_jar():\n",
" import os\n",
" import requests\n",
"\n",
" SPARK_RAPIDS_VERSION = \"24.08.1\"\n",
" rapids_jar = f\"rapids-4-spark_2.12-{SPARK_RAPIDS_VERSION}.jar\"\n",
"\n",
" if not os.path.exists(rapids_jar):\n",
" print(\"Downloading spark rapids jar\")\n",
" url = f\"https://repo1.maven.org/maven2/com/nvidia/rapids-4-spark_2.12/{SPARK_RAPIDS_VERSION}/{rapids_jar}\"\n",
" response = requests.get(url)\n",
" if response.status_code == 200:\n",
" with open(rapids_jar, \"wb\") as f:\n",
" f.write(response.content)\n",
" print(f\"File '{rapids_jar}' downloaded and saved successfully.\")\n",
" else:\n",
" print(f\"Failed to download the file. Status code: {response.status_code}\")\n",
" else:\n",
" print(\"File already exists. Skipping download.\")\n",
" \n",
" return rapids_jar\n",
"\n",
"def initialize_spark(rapids_jar: str):\n",
" '''\n",
" If no active Spark session is found, initialize and configure a new one. \n",
Expand Down Expand Up @@ -119,7 +98,7 @@
"# Check if Spark session is already active, if not, initialize it\n",
"if 'spark' not in globals():\n",
" print(\"No active Spark session found, initializing manually.\")\n",
" rapids_jar = get_rapids_jar()\n",
" rapids_jar = os.environ['RAPIDS_JAR']\n",
rishic3 marked this conversation as resolved.
Show resolved Hide resolved
" spark = initialize_spark(rapids_jar)\n",
"else:\n",
" print(\"Using existing Spark session.\")"
Expand Down