Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Polling maven-metadata.xml to pull the latest tools jar #703

Merged
merged 2 commits into from
Dec 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 1 addition & 52 deletions user_tools/src/spark_rapids_pytools/common/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,31 +17,24 @@
import datetime
import logging.config
import os
import re
import secrets
import ssl
import string
import subprocess
import sys
import threading
import time
import urllib
from shutil import make_archive, which
from dataclasses import dataclass, field
from logging import Logger
from shutil import make_archive, which
from typing import Callable, Any

import certifi
import chevron
from bs4 import BeautifulSoup
from packaging.version import Version
from progress.spinner import PixelSpinner
from pygments import highlight
from pygments.formatters import get_formatter_by_name
from pygments.lexers import get_lexer_by_name

from spark_rapids_pytools import get_version


class Utils:
"""Utility class used to enclose common helpers and utilities."""
Expand Down Expand Up @@ -86,50 +79,6 @@ def reformat_release_version(cls, defined_version: Version) -> str:
res = f'{version_comp[0]}.{version_comp[1]:02}.{version_comp[2]}'
return res

@classmethod
def get_latest_available_jar_version(cls, url_base: str, loaded_version: str) -> str:
"""
Given the defined version in the python tools build, we want to be able to get the highest
version number of the jar available for download from the mvn repo.
The returned version is guaranteed to be LEQ to the defined version. For example, it is not
allowed to use jar version higher than the python tool itself.
:param url_base: the base url from which the jar file is downloaded. It can be mvn repo.
:param loaded_version: the version from the python tools in string format
:return: the string value of the jar that should be downloaded.
"""
context = ssl.create_default_context(cafile=certifi.where())
defined_version = Version(loaded_version)
jar_version = Version(loaded_version)
version_regex = r'\d{2}\.\d{2}\.\d+'
version_pattern = re.compile(version_regex)
with urllib.request.urlopen(url_base, context=context) as resp:
html_content = resp.read()
# Parse the HTML content using BeautifulSoup
soup = BeautifulSoup(html_content, 'html.parser')
# Find all the links with title in the format of "xx.xx.xx"
links = soup.find_all('a', {'title': version_pattern})
# Get the link with the highest value
for link in links:
curr_title = re.search(version_regex, link.get('title'))
if curr_title:
curr_version = Version(curr_title.group())
if curr_version <= defined_version:
jar_version = curr_version
# get formatted string
return cls.reformat_release_version(jar_version)

@classmethod
def get_base_release(cls) -> str:
"""
For now the tools_jar is always with major.minor.0.
this method makes sure that even if the package version is incremented, we will still
get the correct url.
:return: a string containing the release number 22.12.0, 23.02.0, amd 23.04.0..etc
"""
defined_version = Version(get_version(main=None))
# get the release from version
return cls.reformat_release_version(defined_version)

@classmethod
def is_system_tool(cls, tool_name: str) -> bool:
"""
Expand Down
2 changes: 1 addition & 1 deletion user_tools/src/spark_rapids_pytools/rapids/rapids_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def _process_jar_arg(self):
jar_file_name = FSUtil.get_resource_name(jar_path)
version_match = re.search(r'\d{2}\.\d{2}\.\d+', jar_file_name)
jar_version = version_match.group() if version_match else 'Unknown'
self.logger.info('Using Spark RAPIDS accelerator jar version %s', jar_version)
self.logger.info('Using Spark RAPIDS Accelerator Tools jar version %s', jar_version)
# add jar file name to the tool args
self.ctxt.add_rapids_args('jarFileName', jar_file_name)
self.ctxt.add_rapids_args('jarFilePath', jar_path)
Expand Down
7 changes: 4 additions & 3 deletions user_tools/src/spark_rapids_pytools/rapids/tool_ctxt.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,17 @@

import os
import tarfile
from glob import glob
from dataclasses import dataclass, field
from glob import glob
from logging import Logger
from typing import Type, Any, ClassVar, List

from spark_rapids_tools import CspEnv
from spark_rapids_pytools.cloud_api.sp_types import PlatformBase
from spark_rapids_pytools.common.prop_manager import YAMLPropertiesContainer
from spark_rapids_pytools.common.sys_storage import FSUtil
from spark_rapids_pytools.common.utilities import ToolLogging, Utils
from spark_rapids_tools import CspEnv
from spark_rapids_tools.utils import Utilities


@dataclass
Expand Down Expand Up @@ -177,7 +178,7 @@ def get_rapids_jar_url(self) -> str:
raise FileNotFoundError('In Fat Mode. No matching JAR files found.')
return matching_files[0]
mvn_base_url = self.get_value('sparkRapids', 'mvnUrl')
jar_version = Utils.get_latest_available_jar_version(mvn_base_url, Utils.get_base_release())
jar_version = Utilities.get_latest_mvn_jar_from_metadata(mvn_base_url)
rapids_url = self.get_value('sparkRapids', 'repoUrl').format(mvn_base_url, jar_version, jar_version)
return rapids_url

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,10 @@

import fire

from spark_rapids_tools import CspEnv
from spark_rapids_pytools.common.prop_manager import JSONPropertiesContainer
from spark_rapids_pytools.common.sys_storage import FSUtil
from spark_rapids_pytools.common.utilities import Utils

from spark_rapids_tools import CspEnv
from spark_rapids_tools.utils import Utilities

# Defines the constants and static configurations
prepackage_conf = {
Expand Down Expand Up @@ -87,8 +86,7 @@ def __init__(self,
self.dest_dir = FSUtil.get_abs_path(self.dest_dir)

def _get_spark_rapids_jar_url(self) -> str:
jar_version = Utils.get_latest_available_jar_version(self._mvn_base_url, # pylint: disable=no-member
Utils.get_base_release())
jar_version = Utilities.get_latest_mvn_jar_from_metadata(self._mvn_base_url) # pylint: disable=no-member
amahussein marked this conversation as resolved.
Show resolved Hide resolved
return (f'{self._mvn_base_url}/' # pylint: disable=no-member
f'{jar_version}/rapids-4-spark-tools_2.12-{jar_version}.jar')

Expand Down
5 changes: 3 additions & 2 deletions user_tools/src/spark_rapids_tools/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""init file of the utils package for the Accelerated Spark tools"""

from .util import (
get_elem_from_dict, get_elem_non_safe, is_http_file
get_elem_from_dict, get_elem_non_safe, is_http_file, Utilities
)

from .propmanager import (
Expand All @@ -28,5 +28,6 @@
'get_elem_non_safe',
'AbstractPropContainer',
'PropValidatorSchema',
'is_http_file'
'is_http_file',
'Utilities'
]
69 changes: 67 additions & 2 deletions user_tools/src/spark_rapids_tools/utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,24 @@
import os
import pathlib
import re
import ssl
import sys
import urllib
import xml.etree.ElementTree as elem_tree
from functools import reduce
from operator import getitem
from typing import Any, Optional

import certifi
import fire
from packaging.version import Version
from pydantic import ValidationError, AnyHttpUrl, TypeAdapter

import spark_rapids_pytools
from spark_rapids_tools.exceptions import CspPathAttributeError
from spark_rapids_pytools.common.utilities import Utils
from spark_rapids_pytools import get_version
from spark_rapids_pytools.common.sys_storage import FSUtil
from spark_rapids_pytools.common.utilities import Utils
from spark_rapids_tools.exceptions import CspPathAttributeError


def get_elem_from_dict(data, keys):
Expand Down Expand Up @@ -160,3 +166,62 @@ def init_environment(short_name: str):
print(Utils.gen_report_sec_header('Application Logs'))
print(f'Location: {log_file}')
print('In case of any errors, please share the log file with the Spark RAPIDS team.\n')


class Utilities:
"""Utility class used to enclose common helpers and utilities."""

@classmethod
def get_latest_mvn_jar_from_metadata(cls, url_base: str,
loaded_version: str = None) -> str:
"""
Given the defined version in the python tools build, we want to be able to get the highest
version number of the jar available for download from the mvn repo.
The returned version is guaranteed to be LEQ to the defined version. For example, it is not
allowed to use jar version higher than the python tool itself.

The implementation relies on parsing the "$MVN_REPO/maven-metadata.xml" which guarantees
that any delays in updating the directory list won't block the python module
from pulling the latest jar.

:param url_base: the base url from which the jar file is downloaded. It can be mvn repo.
:param loaded_version: the version from the python tools in string format
:return: the string value of the jar that should be downloaded.
"""

if loaded_version is None:
loaded_version = cls.get_base_release()
context = ssl.create_default_context(cafile=certifi.where())
defined_version = Version(loaded_version)
jar_version = Version(loaded_version)
xml_path = f'{url_base}/maven-metadata.xml'
with urllib.request.urlopen(xml_path, context=context) as resp:
xml_content = resp.read()
xml_root = elem_tree.fromstring(xml_content)
for version_elem in xml_root.iter('version'):
curr_version = Version(version_elem.text)
if curr_version <= defined_version:
jar_version = curr_version
# get formatted string
return cls.reformat_release_version(jar_version)

@classmethod
def reformat_release_version(cls, defined_version: Version) -> str:
# get the release from version
version_tuple = defined_version.release
version_comp = list(version_tuple)
# release format is under url YY.MM.MICRO where MM is 02, 04, 06, 08, 10, and 12
res = f'{version_comp[0]}.{version_comp[1]:02}.{version_comp[2]}'
return res

@classmethod
def get_base_release(cls) -> str:
"""
For now the tools_jar is always with major.minor.0.
this method makes sure that even if the package version is incremented, we will still
get the correct url.
:return: a string containing the release number 22.12.0, 23.02.0, amd 23.04.0..etc
"""
defined_version = Version(get_version(main=None))
# get the release from version
return cls.reformat_release_version(defined_version)
Loading