Skip to content

Commit

Permalink
Polling maven-metadata.xml to pull the latest tools jar (#703)
Browse files Browse the repository at this point in the history
* Polling maven-metadata.xml to pull the latest tools jar

Fixes #702

This PR is an alternative implementation to pull the latest version of
the Tools jar.
The python module parses the maven-metadata.xml to get the available
releases. It was found that this file is usually up-to-date with the
releases.

* Move some utility functions to the new package
* Fix the logging message of the python module to avoid confusing RAPIDS
  jars with Tools jars



---------

Signed-off-by: Ahmed Hussein (amahussein) <[email protected]>
  • Loading branch information
amahussein authored Dec 26, 2023
1 parent 03effc2 commit 69875a1
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 65 deletions.
53 changes: 1 addition & 52 deletions user_tools/src/spark_rapids_pytools/common/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,31 +17,24 @@
import datetime
import logging.config
import os
import re
import secrets
import ssl
import string
import subprocess
import sys
import threading
import time
import urllib
from shutil import make_archive, which
from dataclasses import dataclass, field
from logging import Logger
from shutil import make_archive, which
from typing import Callable, Any

import certifi
import chevron
from bs4 import BeautifulSoup
from packaging.version import Version
from progress.spinner import PixelSpinner
from pygments import highlight
from pygments.formatters import get_formatter_by_name
from pygments.lexers import get_lexer_by_name

from spark_rapids_pytools import get_version


class Utils:
"""Utility class used to enclose common helpers and utilities."""
Expand Down Expand Up @@ -86,50 +79,6 @@ def reformat_release_version(cls, defined_version: Version) -> str:
res = f'{version_comp[0]}.{version_comp[1]:02}.{version_comp[2]}'
return res

@classmethod
def get_latest_available_jar_version(cls, url_base: str, loaded_version: str) -> str:
"""
Given the defined version in the python tools build, we want to be able to get the highest
version number of the jar available for download from the mvn repo.
The returned version is guaranteed to be LEQ to the defined version. For example, it is not
allowed to use jar version higher than the python tool itself.
:param url_base: the base url from which the jar file is downloaded. It can be mvn repo.
:param loaded_version: the version from the python tools in string format
:return: the string value of the jar that should be downloaded.
"""
context = ssl.create_default_context(cafile=certifi.where())
defined_version = Version(loaded_version)
jar_version = Version(loaded_version)
version_regex = r'\d{2}\.\d{2}\.\d+'
version_pattern = re.compile(version_regex)
with urllib.request.urlopen(url_base, context=context) as resp:
html_content = resp.read()
# Parse the HTML content using BeautifulSoup
soup = BeautifulSoup(html_content, 'html.parser')
# Find all the links with title in the format of "xx.xx.xx"
links = soup.find_all('a', {'title': version_pattern})
# Get the link with the highest value
for link in links:
curr_title = re.search(version_regex, link.get('title'))
if curr_title:
curr_version = Version(curr_title.group())
if curr_version <= defined_version:
jar_version = curr_version
# get formatted string
return cls.reformat_release_version(jar_version)

@classmethod
def get_base_release(cls) -> str:
"""
For now the tools_jar is always with major.minor.0.
this method makes sure that even if the package version is incremented, we will still
get the correct url.
:return: a string containing the release number 22.12.0, 23.02.0, amd 23.04.0..etc
"""
defined_version = Version(get_version(main=None))
# get the release from version
return cls.reformat_release_version(defined_version)

@classmethod
def is_system_tool(cls, tool_name: str) -> bool:
"""
Expand Down
2 changes: 1 addition & 1 deletion user_tools/src/spark_rapids_pytools/rapids/rapids_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def _process_jar_arg(self):
jar_file_name = FSUtil.get_resource_name(jar_path)
version_match = re.search(r'\d{2}\.\d{2}\.\d+', jar_file_name)
jar_version = version_match.group() if version_match else 'Unknown'
self.logger.info('Using Spark RAPIDS accelerator jar version %s', jar_version)
self.logger.info('Using Spark RAPIDS Accelerator Tools jar version %s', jar_version)
# add jar file name to the tool args
self.ctxt.add_rapids_args('jarFileName', jar_file_name)
self.ctxt.add_rapids_args('jarFilePath', jar_path)
Expand Down
7 changes: 4 additions & 3 deletions user_tools/src/spark_rapids_pytools/rapids/tool_ctxt.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,17 @@

import os
import tarfile
from glob import glob
from dataclasses import dataclass, field
from glob import glob
from logging import Logger
from typing import Type, Any, ClassVar, List

from spark_rapids_tools import CspEnv
from spark_rapids_pytools.cloud_api.sp_types import PlatformBase
from spark_rapids_pytools.common.prop_manager import YAMLPropertiesContainer
from spark_rapids_pytools.common.sys_storage import FSUtil
from spark_rapids_pytools.common.utilities import ToolLogging, Utils
from spark_rapids_tools import CspEnv
from spark_rapids_tools.utils import Utilities


@dataclass
Expand Down Expand Up @@ -177,7 +178,7 @@ def get_rapids_jar_url(self) -> str:
raise FileNotFoundError('In Fat Mode. No matching JAR files found.')
return matching_files[0]
mvn_base_url = self.get_value('sparkRapids', 'mvnUrl')
jar_version = Utils.get_latest_available_jar_version(mvn_base_url, Utils.get_base_release())
jar_version = Utilities.get_latest_mvn_jar_from_metadata(mvn_base_url)
rapids_url = self.get_value('sparkRapids', 'repoUrl').format(mvn_base_url, jar_version, jar_version)
return rapids_url

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,10 @@

import fire

from spark_rapids_tools import CspEnv
from spark_rapids_pytools.common.prop_manager import JSONPropertiesContainer
from spark_rapids_pytools.common.sys_storage import FSUtil
from spark_rapids_pytools.common.utilities import Utils

from spark_rapids_tools import CspEnv
from spark_rapids_tools.utils import Utilities

# Defines the constants and static configurations
prepackage_conf = {
Expand Down Expand Up @@ -87,8 +86,7 @@ def __init__(self,
self.dest_dir = FSUtil.get_abs_path(self.dest_dir)

def _get_spark_rapids_jar_url(self) -> str:
jar_version = Utils.get_latest_available_jar_version(self._mvn_base_url, # pylint: disable=no-member
Utils.get_base_release())
jar_version = Utilities.get_latest_mvn_jar_from_metadata(self._mvn_base_url) # pylint: disable=no-member
return (f'{self._mvn_base_url}/' # pylint: disable=no-member
f'{jar_version}/rapids-4-spark-tools_2.12-{jar_version}.jar')

Expand Down
5 changes: 3 additions & 2 deletions user_tools/src/spark_rapids_tools/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""init file of the utils package for the Accelerated Spark tools"""

from .util import (
get_elem_from_dict, get_elem_non_safe, is_http_file
get_elem_from_dict, get_elem_non_safe, is_http_file, Utilities
)

from .propmanager import (
Expand All @@ -28,5 +28,6 @@
'get_elem_non_safe',
'AbstractPropContainer',
'PropValidatorSchema',
'is_http_file'
'is_http_file',
'Utilities'
]
69 changes: 67 additions & 2 deletions user_tools/src/spark_rapids_tools/utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,24 @@
import os
import pathlib
import re
import ssl
import sys
import urllib
import xml.etree.ElementTree as elem_tree
from functools import reduce
from operator import getitem
from typing import Any, Optional

import certifi
import fire
from packaging.version import Version
from pydantic import ValidationError, AnyHttpUrl, TypeAdapter

import spark_rapids_pytools
from spark_rapids_tools.exceptions import CspPathAttributeError
from spark_rapids_pytools.common.utilities import Utils
from spark_rapids_pytools import get_version
from spark_rapids_pytools.common.sys_storage import FSUtil
from spark_rapids_pytools.common.utilities import Utils
from spark_rapids_tools.exceptions import CspPathAttributeError


def get_elem_from_dict(data, keys):
Expand Down Expand Up @@ -160,3 +166,62 @@ def init_environment(short_name: str):
print(Utils.gen_report_sec_header('Application Logs'))
print(f'Location: {log_file}')
print('In case of any errors, please share the log file with the Spark RAPIDS team.\n')


class Utilities:
"""Utility class used to enclose common helpers and utilities."""

@classmethod
def get_latest_mvn_jar_from_metadata(cls, url_base: str,
loaded_version: str = None) -> str:
"""
Given the defined version in the python tools build, we want to be able to get the highest
version number of the jar available for download from the mvn repo.
The returned version is guaranteed to be LEQ to the defined version. For example, it is not
allowed to use jar version higher than the python tool itself.
The implementation relies on parsing the "$MVN_REPO/maven-metadata.xml" which guarantees
that any delays in updating the directory list won't block the python module
from pulling the latest jar.
:param url_base: the base url from which the jar file is downloaded. It can be mvn repo.
:param loaded_version: the version from the python tools in string format
:return: the string value of the jar that should be downloaded.
"""

if loaded_version is None:
loaded_version = cls.get_base_release()
context = ssl.create_default_context(cafile=certifi.where())
defined_version = Version(loaded_version)
jar_version = Version(loaded_version)
xml_path = f'{url_base}/maven-metadata.xml'
with urllib.request.urlopen(xml_path, context=context) as resp:
xml_content = resp.read()
xml_root = elem_tree.fromstring(xml_content)
for version_elem in xml_root.iter('version'):
curr_version = Version(version_elem.text)
if curr_version <= defined_version:
jar_version = curr_version
# get formatted string
return cls.reformat_release_version(jar_version)

@classmethod
def reformat_release_version(cls, defined_version: Version) -> str:
# get the release from version
version_tuple = defined_version.release
version_comp = list(version_tuple)
# release format is under url YY.MM.MICRO where MM is 02, 04, 06, 08, 10, and 12
res = f'{version_comp[0]}.{version_comp[1]:02}.{version_comp[2]}'
return res

@classmethod
def get_base_release(cls) -> str:
"""
For now the tools_jar is always with major.minor.0.
this method makes sure that even if the package version is incremented, we will still
get the correct url.
:return: a string containing the release number 22.12.0, 23.02.0, amd 23.04.0..etc
"""
defined_version = Version(get_version(main=None))
# get the release from version
return cls.reformat_release_version(defined_version)

0 comments on commit 69875a1

Please sign in to comment.