diff --git a/user_tools/src/spark_rapids_pytools/common/utilities.py b/user_tools/src/spark_rapids_pytools/common/utilities.py index 5ac5cdd97..74c6b2e57 100644 --- a/user_tools/src/spark_rapids_pytools/common/utilities.py +++ b/user_tools/src/spark_rapids_pytools/common/utilities.py @@ -17,31 +17,24 @@ import datetime import logging.config import os -import re import secrets -import ssl import string import subprocess import sys import threading import time -import urllib -from shutil import make_archive, which from dataclasses import dataclass, field from logging import Logger +from shutil import make_archive, which from typing import Callable, Any -import certifi import chevron -from bs4 import BeautifulSoup from packaging.version import Version from progress.spinner import PixelSpinner from pygments import highlight from pygments.formatters import get_formatter_by_name from pygments.lexers import get_lexer_by_name -from spark_rapids_pytools import get_version - class Utils: """Utility class used to enclose common helpers and utilities.""" @@ -86,50 +79,6 @@ def reformat_release_version(cls, defined_version: Version) -> str: res = f'{version_comp[0]}.{version_comp[1]:02}.{version_comp[2]}' return res - @classmethod - def get_latest_available_jar_version(cls, url_base: str, loaded_version: str) -> str: - """ - Given the defined version in the python tools build, we want to be able to get the highest - version number of the jar available for download from the mvn repo. - The returned version is guaranteed to be LEQ to the defined version. For example, it is not - allowed to use jar version higher than the python tool itself. - :param url_base: the base url from which the jar file is downloaded. It can be mvn repo. - :param loaded_version: the version from the python tools in string format - :return: the string value of the jar that should be downloaded. - """ - context = ssl.create_default_context(cafile=certifi.where()) - defined_version = Version(loaded_version) - jar_version = Version(loaded_version) - version_regex = r'\d{2}\.\d{2}\.\d+' - version_pattern = re.compile(version_regex) - with urllib.request.urlopen(url_base, context=context) as resp: - html_content = resp.read() - # Parse the HTML content using BeautifulSoup - soup = BeautifulSoup(html_content, 'html.parser') - # Find all the links with title in the format of "xx.xx.xx" - links = soup.find_all('a', {'title': version_pattern}) - # Get the link with the highest value - for link in links: - curr_title = re.search(version_regex, link.get('title')) - if curr_title: - curr_version = Version(curr_title.group()) - if curr_version <= defined_version: - jar_version = curr_version - # get formatted string - return cls.reformat_release_version(jar_version) - - @classmethod - def get_base_release(cls) -> str: - """ - For now the tools_jar is always with major.minor.0. - this method makes sure that even if the package version is incremented, we will still - get the correct url. - :return: a string containing the release number 22.12.0, 23.02.0, amd 23.04.0..etc - """ - defined_version = Version(get_version(main=None)) - # get the release from version - return cls.reformat_release_version(defined_version) - @classmethod def is_system_tool(cls, tool_name: str) -> bool: """ diff --git a/user_tools/src/spark_rapids_pytools/rapids/rapids_tool.py b/user_tools/src/spark_rapids_pytools/rapids/rapids_tool.py index 41214f65c..cda9bc6fb 100644 --- a/user_tools/src/spark_rapids_pytools/rapids/rapids_tool.py +++ b/user_tools/src/spark_rapids_pytools/rapids/rapids_tool.py @@ -393,7 +393,7 @@ def _process_jar_arg(self): jar_file_name = FSUtil.get_resource_name(jar_path) version_match = re.search(r'\d{2}\.\d{2}\.\d+', jar_file_name) jar_version = version_match.group() if version_match else 'Unknown' - self.logger.info('Using Spark RAPIDS accelerator jar version %s', jar_version) + self.logger.info('Using Spark RAPIDS Accelerator Tools jar version %s', jar_version) # add jar file name to the tool args self.ctxt.add_rapids_args('jarFileName', jar_file_name) self.ctxt.add_rapids_args('jarFilePath', jar_path) diff --git a/user_tools/src/spark_rapids_pytools/rapids/tool_ctxt.py b/user_tools/src/spark_rapids_pytools/rapids/tool_ctxt.py index 20e79a3f0..9f1f98003 100644 --- a/user_tools/src/spark_rapids_pytools/rapids/tool_ctxt.py +++ b/user_tools/src/spark_rapids_pytools/rapids/tool_ctxt.py @@ -16,16 +16,17 @@ import os import tarfile -from glob import glob from dataclasses import dataclass, field +from glob import glob from logging import Logger from typing import Type, Any, ClassVar, List -from spark_rapids_tools import CspEnv from spark_rapids_pytools.cloud_api.sp_types import PlatformBase from spark_rapids_pytools.common.prop_manager import YAMLPropertiesContainer from spark_rapids_pytools.common.sys_storage import FSUtil from spark_rapids_pytools.common.utilities import ToolLogging, Utils +from spark_rapids_tools import CspEnv +from spark_rapids_tools.utils import Utilities @dataclass @@ -177,7 +178,7 @@ def get_rapids_jar_url(self) -> str: raise FileNotFoundError('In Fat Mode. No matching JAR files found.') return matching_files[0] mvn_base_url = self.get_value('sparkRapids', 'mvnUrl') - jar_version = Utils.get_latest_available_jar_version(mvn_base_url, Utils.get_base_release()) + jar_version = Utilities.get_latest_mvn_jar_from_metadata(mvn_base_url) rapids_url = self.get_value('sparkRapids', 'repoUrl').format(mvn_base_url, jar_version, jar_version) return rapids_url diff --git a/user_tools/src/spark_rapids_pytools/resources/dev/prepackage_mgr.py b/user_tools/src/spark_rapids_pytools/resources/dev/prepackage_mgr.py index 472260aee..59028689c 100644 --- a/user_tools/src/spark_rapids_pytools/resources/dev/prepackage_mgr.py +++ b/user_tools/src/spark_rapids_pytools/resources/dev/prepackage_mgr.py @@ -25,11 +25,10 @@ import fire -from spark_rapids_tools import CspEnv from spark_rapids_pytools.common.prop_manager import JSONPropertiesContainer from spark_rapids_pytools.common.sys_storage import FSUtil -from spark_rapids_pytools.common.utilities import Utils - +from spark_rapids_tools import CspEnv +from spark_rapids_tools.utils import Utilities # Defines the constants and static configurations prepackage_conf = { @@ -87,8 +86,7 @@ def __init__(self, self.dest_dir = FSUtil.get_abs_path(self.dest_dir) def _get_spark_rapids_jar_url(self) -> str: - jar_version = Utils.get_latest_available_jar_version(self._mvn_base_url, # pylint: disable=no-member - Utils.get_base_release()) + jar_version = Utilities.get_latest_mvn_jar_from_metadata(self._mvn_base_url) # pylint: disable=no-member return (f'{self._mvn_base_url}/' # pylint: disable=no-member f'{jar_version}/rapids-4-spark-tools_2.12-{jar_version}.jar') diff --git a/user_tools/src/spark_rapids_tools/utils/__init__.py b/user_tools/src/spark_rapids_tools/utils/__init__.py index 2f6637b99..61e226417 100644 --- a/user_tools/src/spark_rapids_tools/utils/__init__.py +++ b/user_tools/src/spark_rapids_tools/utils/__init__.py @@ -15,7 +15,7 @@ """init file of the utils package for the Accelerated Spark tools""" from .util import ( - get_elem_from_dict, get_elem_non_safe, is_http_file + get_elem_from_dict, get_elem_non_safe, is_http_file, Utilities ) from .propmanager import ( @@ -28,5 +28,6 @@ 'get_elem_non_safe', 'AbstractPropContainer', 'PropValidatorSchema', - 'is_http_file' + 'is_http_file', + 'Utilities' ] diff --git a/user_tools/src/spark_rapids_tools/utils/util.py b/user_tools/src/spark_rapids_tools/utils/util.py index 8ff5a1975..9d38b1e90 100644 --- a/user_tools/src/spark_rapids_tools/utils/util.py +++ b/user_tools/src/spark_rapids_tools/utils/util.py @@ -17,18 +17,24 @@ import os import pathlib import re +import ssl import sys +import urllib +import xml.etree.ElementTree as elem_tree from functools import reduce from operator import getitem from typing import Any, Optional +import certifi import fire +from packaging.version import Version from pydantic import ValidationError, AnyHttpUrl, TypeAdapter import spark_rapids_pytools -from spark_rapids_tools.exceptions import CspPathAttributeError -from spark_rapids_pytools.common.utilities import Utils +from spark_rapids_pytools import get_version from spark_rapids_pytools.common.sys_storage import FSUtil +from spark_rapids_pytools.common.utilities import Utils +from spark_rapids_tools.exceptions import CspPathAttributeError def get_elem_from_dict(data, keys): @@ -160,3 +166,62 @@ def init_environment(short_name: str): print(Utils.gen_report_sec_header('Application Logs')) print(f'Location: {log_file}') print('In case of any errors, please share the log file with the Spark RAPIDS team.\n') + + +class Utilities: + """Utility class used to enclose common helpers and utilities.""" + + @classmethod + def get_latest_mvn_jar_from_metadata(cls, url_base: str, + loaded_version: str = None) -> str: + """ + Given the defined version in the python tools build, we want to be able to get the highest + version number of the jar available for download from the mvn repo. + The returned version is guaranteed to be LEQ to the defined version. For example, it is not + allowed to use jar version higher than the python tool itself. + + The implementation relies on parsing the "$MVN_REPO/maven-metadata.xml" which guarantees + that any delays in updating the directory list won't block the python module + from pulling the latest jar. + + :param url_base: the base url from which the jar file is downloaded. It can be mvn repo. + :param loaded_version: the version from the python tools in string format + :return: the string value of the jar that should be downloaded. + """ + + if loaded_version is None: + loaded_version = cls.get_base_release() + context = ssl.create_default_context(cafile=certifi.where()) + defined_version = Version(loaded_version) + jar_version = Version(loaded_version) + xml_path = f'{url_base}/maven-metadata.xml' + with urllib.request.urlopen(xml_path, context=context) as resp: + xml_content = resp.read() + xml_root = elem_tree.fromstring(xml_content) + for version_elem in xml_root.iter('version'): + curr_version = Version(version_elem.text) + if curr_version <= defined_version: + jar_version = curr_version + # get formatted string + return cls.reformat_release_version(jar_version) + + @classmethod + def reformat_release_version(cls, defined_version: Version) -> str: + # get the release from version + version_tuple = defined_version.release + version_comp = list(version_tuple) + # release format is under url YY.MM.MICRO where MM is 02, 04, 06, 08, 10, and 12 + res = f'{version_comp[0]}.{version_comp[1]:02}.{version_comp[2]}' + return res + + @classmethod + def get_base_release(cls) -> str: + """ + For now the tools_jar is always with major.minor.0. + this method makes sure that even if the package version is incremented, we will still + get the correct url. + :return: a string containing the release number 22.12.0, 23.02.0, amd 23.04.0..etc + """ + defined_version = Version(get_version(main=None)) + # get the release from version + return cls.reformat_release_version(defined_version)