From b778edbfc22259d94f8e51464f6a3cc3a903adec Mon Sep 17 00:00:00 2001 From: Ahmed Hussein <50450311+amahussein@users.noreply.github.com> Date: Mon, 21 Oct 2024 08:57:59 -0500 Subject: [PATCH] Use StorageLib to download dependencies (#1383) * Use StorageLib to download dependencies Signed-off-by: Ahmed Hussein Fixes #1364, Contributes to #1359 - Use the CspPath and CspFs to manage dependencies - This allows more flexibility in specifying custom dependencies including local disk storage. - Remove Pricing catalog from python package * use requests library as a default download utility Signed-off-by: Ahmed Hussein * address code reviews Signed-off-by: Ahmed Hussein --------- Signed-off-by: Ahmed Hussein --- user_tools/pyproject.toml | 12 +- .../spark_rapids_pytools/rapids/rapids_job.py | 2 +- .../rapids/rapids_tool.py | 73 ++-- .../resources/databricks_aws-configs.json | 98 +++-- .../resources/databricks_azure-configs.json | 62 +-- .../resources/dataproc-configs.json | 68 ++-- .../resources/dataproc_gke-configs.json | 68 ++-- .../resources/dev/prepackage_mgr.py | 37 +- .../resources/emr-configs.json | 98 +++-- .../resources/onprem-configs.json | 55 ++- user_tools/src/spark_rapids_tools/enums.py | 36 +- .../spark_rapids_tools/storagelib/cspfs.py | 22 +- .../spark_rapids_tools/storagelib/csppath.py | 54 ++- .../storagelib/tools/fs_utils.py | 186 +++++++++ .../tools/qualx/qualx_main.py | 2 +- .../src/spark_rapids_tools/utils/net_utils.py | 366 ++++++++++++++++++ 16 files changed, 930 insertions(+), 309 deletions(-) create mode 100644 user_tools/src/spark_rapids_tools/storagelib/tools/fs_utils.py create mode 100644 user_tools/src/spark_rapids_tools/utils/net_utils.py diff --git a/user_tools/pyproject.toml b/user_tools/pyproject.toml index 39d46ebfe..8decf7a82 100644 --- a/user_tools/pyproject.toml +++ b/user_tools/pyproject.toml @@ -24,7 +24,7 @@ dependencies = [ "numpy<=1.24.4", "chevron==0.14.0", "fastprogress==1.0.3", - "fastcore==1.5.29", + "fastcore==1.7.10", "fire>=0.5.0", "pandas==1.4.3", "pyYAML>=6.0", @@ -37,8 +37,8 @@ dependencies = [ "urllib3==1.26.19", "beautifulsoup4==4.11.2", "pygments==2.15.0", - # used to apply validator on objects and models - "pydantic==2.1.1", + # used to apply validator on objects and models. "2.9.2" contains from_json method. + "pydantic==2.9.2", # used to help pylint understand pydantic "pylint-pydantic==0.3.0", # used for common API to access remote filesystems like local/s3/gcs/hdfs @@ -76,7 +76,11 @@ version = {attr = "spark_rapids_pytools.__version__"} repository = "https://github.com/NVIDIA/spark-rapids-tools/tree/main" [project.optional-dependencies] test = [ - "tox", 'pytest', 'cli_test_helpers', 'behave' + "tox", 'pytest', 'cli_test_helpers', 'behave', + # use flak-8 plugin for pydantic + 'flake8-pydantic', + # use pylint specific version + 'pylint==3.2.7', ] qualx = [ "holoviews", diff --git a/user_tools/src/spark_rapids_pytools/rapids/rapids_job.py b/user_tools/src/spark_rapids_pytools/rapids/rapids_job.py index f095daa3f..a77a94b00 100644 --- a/user_tools/src/spark_rapids_pytools/rapids/rapids_job.py +++ b/user_tools/src/spark_rapids_pytools/rapids/rapids_job.py @@ -166,7 +166,7 @@ def _get_hadoop_classpath(self) -> Optional[str]: conf_dir_path = LocalPath(conf_dir) if conf_dir_path.is_dir() and conf_dir_path.exists(): # return the first valid directory found without the URI prefix - return conf_dir_path.no_prefix + return conf_dir_path.no_scheme except Exception as e: # pylint: disable=broad-except self.logger.debug( 'Could not build hadoop classpath from %s. Reason: %s', dir_key, e) diff --git a/user_tools/src/spark_rapids_pytools/rapids/rapids_tool.py b/user_tools/src/spark_rapids_pytools/rapids/rapids_tool.py index e2d4c226e..55b91d81c 100644 --- a/user_tools/src/spark_rapids_pytools/rapids/rapids_tool.py +++ b/user_tools/src/spark_rapids_pytools/rapids/rapids_tool.py @@ -20,7 +20,6 @@ import os import re import sys -import tarfile import time from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field @@ -34,13 +33,16 @@ from spark_rapids_pytools.cloud_api.sp_types import get_platform, \ ClusterBase, DeployMode, NodeHWInfo from spark_rapids_pytools.common.prop_manager import YAMLPropertiesContainer, AbstractPropertiesContainer -from spark_rapids_pytools.common.sys_storage import FSUtil, FileVerifier +from spark_rapids_pytools.common.sys_storage import FSUtil from spark_rapids_pytools.common.utilities import ToolLogging, Utils, ToolsSpinner from spark_rapids_pytools.rapids.rapids_job import RapidsJobPropContainer from spark_rapids_pytools.rapids.tool_ctxt import ToolContext from spark_rapids_tools import CspEnv -from spark_rapids_tools.storagelib import LocalPath +from spark_rapids_tools.enums import HashAlgorithm +from spark_rapids_tools.storagelib import LocalPath, CspFs +from spark_rapids_tools.storagelib.tools.fs_utils import untar_file, FileHashAlgorithm from spark_rapids_tools.utils import Utilities +from spark_rapids_tools.utils.net_utils import DownloadTask @dataclass @@ -139,7 +141,7 @@ def _process_output_args(self): self.output_folder = Utils.get_rapids_tools_env('OUTPUT_DIRECTORY', os.getcwd()) try: output_folder_path = LocalPath(self.output_folder) - self.output_folder = output_folder_path.no_prefix + self.output_folder = output_folder_path.no_scheme except Exception as ex: # pylint: disable=broad-except self.logger.error('Failed in processing output arguments. Output_folder must be a local directory') raise ex @@ -411,6 +413,7 @@ class RapidsJarTool(RapidsTool): """ def _process_jar_arg(self): + # TODO: use the StorageLib to download the jar file jar_path = '' tools_jar_url = self.wrapper_options.get('toolsJar') try: @@ -533,43 +536,40 @@ def cache_single_dependency(dep: dict) -> str: """ Downloads the specified URL and saves it to disk """ - start_time = time.monotonic() self.logger.info('Checking dependency %s', dep['name']) dest_folder = self.ctxt.get_cache_folder() - resource_file_name = FSUtil.get_resource_name(dep['uri']) - resource_file = FSUtil.build_path(dest_folder, resource_file_name) - file_check_dict = {'size': dep['size']} - signature_file = FileVerifier.get_signature_file(dep['uri'], dest_folder) - if signature_file is not None: - file_check_dict['signatureFile'] = signature_file - algorithm = FileVerifier.get_integrity_algorithm(dep) - if algorithm is not None: - file_check_dict['hashlib'] = { - 'algorithm': algorithm, - 'hash': dep[algorithm] - } - is_created = FSUtil.cache_from_url(dep['uri'], resource_file, file_checks=file_check_dict) - if is_created: - self.logger.info('The dependency %s has been downloaded into %s', dep['uri'], - resource_file) - # check if we need to decompress files + verify_opts = {} + dep_verification = dep.get('verification') + if dep_verification is not None: + if 'size' in dep_verification: + verify_opts['size'] = dep_verification['size'] + hash_lib_alg = dep_verification.get('hashLib') + if hash_lib_alg: + verify_opts['file_hash'] = FileHashAlgorithm(HashAlgorithm(hash_lib_alg['type']), + hash_lib_alg['value']) + download_task = DownloadTask(src_url=dep['uri'], # pylint: disable=no-value-for-parameter) + dest_folder=dest_folder, + verification=verify_opts) + download_result = download_task.run_task() + self.logger.info('Completed downloading of dependency [%s] => %s', + dep['name'], + f'{download_result.pretty_print()}') + if not download_result.success: + msg = f'Failed to download dependency {dep["name"]}, reason: {download_result.download_error}' + raise RuntimeError(f'Could not download all dependencies. Aborting Executions.\n\t{msg}') + destination_path = self.ctxt.get_local_work_dir() + destination_cspath = LocalPath(destination_path) if dep['type'] == 'archive': - destination_path = self.ctxt.get_local_work_dir() - with tarfile.open(resource_file, mode='r:*') as tar: - tar.extractall(destination_path) - tar.close() - dep_item = FSUtil.remove_ext(resource_file_name) - if dep.get('relativePath') is not None: - dep_item = FSUtil.build_path(dep_item, dep.get('relativePath')) - dep_item = FSUtil.build_path(destination_path, dep_item) + uncompressed_cspath = untar_file(download_result.resource, destination_cspath) + dep_item = uncompressed_cspath.no_scheme + relative_path = dep.get('relativePath') + if relative_path is not None: + dep_item = f'{dep_item}/{relative_path}' else: # copy the jar into dependency folder - dep_item = self.ctxt.platform.storage.download_resource(resource_file, - self.ctxt.get_local_work_dir()) - end_time = time.monotonic() - self.logger.info('Completed downloading of dependency [%s] => %s seconds', - dep['name'], - f'{(end_time-start_time):,.3f}') + CspFs.copy_resources(download_result.resource, destination_cspath) + final_dep_csp = destination_cspath.create_sub_path(download_result.resource.base_name()) + dep_item = final_dep_csp.no_scheme return dep_item def cache_all_dependencies(dep_arr: List[dict]): @@ -593,7 +593,6 @@ def cache_all_dependencies(dep_arr: List[dict]): raise ex return results - # TODO: Verify the downloaded file by checking their MD5 deploy_mode = DeployMode.tostring(self.ctxt.get_deploy_mode()) depend_arr = self.get_rapids_tools_dependencies(deploy_mode, self.ctxt.platform.configs) if depend_arr: diff --git a/user_tools/src/spark_rapids_pytools/resources/databricks_aws-configs.json b/user_tools/src/spark_rapids_pytools/resources/databricks_aws-configs.json index 20afcaa93..7d64dfb44 100644 --- a/user_tools/src/spark_rapids_pytools/resources/databricks_aws-configs.json +++ b/user_tools/src/spark_rapids_pytools/resources/databricks_aws-configs.json @@ -7,52 +7,78 @@ { "name": "Apache Spark", "uri": "https://archive.apache.org/dist/spark/spark-3.5.0/spark-3.5.0-bin-hadoop3.tgz", + "verification": { + "hashLib": { + "type": "sha512", + "value": "8883c67e0a138069e597f3e7d4edbbd5c3a565d50b28644aad02856a1ec1da7cb92b8f80454ca427118f69459ea326eaa073cf7b1a860c3b796f4b07c2101319" + }, + "size": 400395283 + }, "type": "archive", - "relativePath": "jars/*", - "sha512": "8883c67e0a138069e597f3e7d4edbbd5c3a565d50b28644aad02856a1ec1da7cb92b8f80454ca427118f69459ea326eaa073cf7b1a860c3b796f4b07c2101319", - "size": 400395283 + "relativePath": "jars/*" }, { "name": "Hadoop AWS", "uri": "https://repo1.maven.org/maven2/org/apache/hadoop/hadoop-aws/3.3.4/hadoop-aws-3.3.4.jar", - "type": "jar", - "md5": "59907e790ce713441955015d79f670bc", - "sha1": "a65839fbf1869f81a1632e09f415e586922e4f80", - "size": 962685 + "verification": { + "hashLib": { + "type": "sha1", + "value": "a65839fbf1869f81a1632e09f415e586922e4f80" + }, + "size": 962685 + }, + "type": "jar" }, { "name": "AWS Java SDK Bundled", "uri": "https://repo1.maven.org/maven2/com/amazonaws/aws-java-sdk-bundle/1.12.262/aws-java-sdk-bundle-1.12.262.jar", - "type": "jar", - "md5": "8a22f2d30b7e8eee9ea44f04fb13b35a", - "sha1": "02deec3a0ad83d13d032b1812421b23d7a961eea", - "size": 280645251 + "verification": { + "hashLib": { + "type": "sha1", + "value": "02deec3a0ad83d13d032b1812421b23d7a961eea" + }, + "size": 280645251 + }, + "type": "jar" } ], "333": [ { "name": "Apache Spark", "uri": "https://archive.apache.org/dist/spark/spark-3.3.3/spark-3.3.3-bin-hadoop3.tgz", + "verification": { + "hashLib": { + "type": "sha512", + "value": "ebf79c7861f3120d5ed9465fdd8d5302a734ff30713a0454b714bbded7ab9f218b3108dc46a5de4cc2102c86e7be53908f84d2c7a19e59bc75880766eeefeef9" + }, + "size": 299426263 + }, "type": "archive", - "relativePath": "jars/*", - "sha512": "ebf79c7861f3120d5ed9465fdd8d5302a734ff30713a0454b714bbded7ab9f218b3108dc46a5de4cc2102c86e7be53908f84d2c7a19e59bc75880766eeefeef9", - "size": 299426263 + "relativePath": "jars/*" }, { "name": "Hadoop AWS", "uri": "https://repo1.maven.org/maven2/org/apache/hadoop/hadoop-aws/3.3.4/hadoop-aws-3.3.4.jar", - "type": "jar", - "md5": "59907e790ce713441955015d79f670bc", - "sha1": "a65839fbf1869f81a1632e09f415e586922e4f80", - "size": 962685 + "verification": { + "hashLib": { + "type": "sha1", + "value": "a65839fbf1869f81a1632e09f415e586922e4f80" + }, + "size": 962685 + }, + "type": "jar" }, { "name": "AWS Java SDK Bundled", "uri": "https://repo1.maven.org/maven2/com/amazonaws/aws-java-sdk-bundle/1.12.262/aws-java-sdk-bundle-1.12.262.jar", - "type": "jar", - "md5": "8a22f2d30b7e8eee9ea44f04fb13b35a", - "sha1": "02deec3a0ad83d13d032b1812421b23d7a961eea", - "size": 280645251 + "verification": { + "hashLib": { + "type": "sha1", + "value": "02deec3a0ad83d13d032b1812421b23d7a961eea" + }, + "size": 280645251 + }, + "type": "jar" } ] } @@ -205,34 +231,6 @@ } } }, - "pricing": { - "catalog": { - "onlineResources": [ - { - "resourceKey": "databricks-aws-catalog", - "onlineURL": "https://www.databricks.com/en-website-assets/data/pricing/AWS.json", - "//localFile": "the name of the file after downloading", - "localFile": "databricks-aws-catalog.json", - "backupArchive": { - "//description-1": "In case the file is stuck, we use this archive as a backup.", - "//description-2": "It is stored in the resources", - "archiveName": "databricks-aws-catalog.json" - } - }, - { - "resourceKey": "ec2-catalog", - "onlineURL": "https://pricing.us-east-1.amazonaws.com/offers/v1.0/aws/AmazonEC2/current/us-west-2/index.json", - "//localFile": "the name of the file after downloading", - "localFile": "aws_ec2_catalog_ec2_us-west-2.json", - "backupArchive": { - "//description-1": "In case the file is stuck, we use this archive as a backup.", - "//description-2": "It is stored in the resources", - "archiveName": "aws_ec2_catalog_ec2_us-west-2.json" - } - } - ] - } - }, "gpuConfigs": { "user-tools": { "supportedGpuInstances": { diff --git a/user_tools/src/spark_rapids_pytools/resources/databricks_azure-configs.json b/user_tools/src/spark_rapids_pytools/resources/databricks_azure-configs.json index 07040a0be..0e41fb1ae 100644 --- a/user_tools/src/spark_rapids_pytools/resources/databricks_azure-configs.json +++ b/user_tools/src/spark_rapids_pytools/resources/databricks_azure-configs.json @@ -7,36 +7,54 @@ { "name": "Apache Spark", "uri": "https://archive.apache.org/dist/spark/spark-3.5.0/spark-3.5.0-bin-hadoop3.tgz", + "verification": { + "hashLib": { + "type": "sha512", + "value": "8883c67e0a138069e597f3e7d4edbbd5c3a565d50b28644aad02856a1ec1da7cb92b8f80454ca427118f69459ea326eaa073cf7b1a860c3b796f4b07c2101319" + }, + "size": 400395283 + }, "type": "archive", - "relativePath": "jars/*", - "sha512": "8883c67e0a138069e597f3e7d4edbbd5c3a565d50b28644aad02856a1ec1da7cb92b8f80454ca427118f69459ea326eaa073cf7b1a860c3b796f4b07c2101319", - "size": 400395283 + "relativePath": "jars/*" }, { "name": "Hadoop Azure", "uri": "https://repo1.maven.org/maven2/org/apache/hadoop/hadoop-azure/3.3.4/hadoop-azure-3.3.4.jar", - "type": "jar", - "md5": "1ec4cbd59548412010fe1515070eef73", - "sha1": "a23f621bca9b2100554150f6b0b521f94b8b419e", - "size": 574116 + "verification": { + "hashLib": { + "type": "sha1", + "value": "a23f621bca9b2100554150f6b0b521f94b8b419e" + }, + "size": 574116 + }, + "type": "jar" } ], "333": [ { "name": "Apache Spark", "uri": "https://archive.apache.org/dist/spark/spark-3.3.3/spark-3.3.3-bin-hadoop3.tgz", + "verification": { + "hashLib": { + "type": "sha512", + "value": "ebf79c7861f3120d5ed9465fdd8d5302a734ff30713a0454b714bbded7ab9f218b3108dc46a5de4cc2102c86e7be53908f84d2c7a19e59bc75880766eeefeef9" + }, + "size": 299426263 + }, "type": "archive", - "relativePath": "jars/*", - "sha512": "ebf79c7861f3120d5ed9465fdd8d5302a734ff30713a0454b714bbded7ab9f218b3108dc46a5de4cc2102c86e7be53908f84d2c7a19e59bc75880766eeefeef9", - "size": 299426263 + "relativePath": "jars/*" }, { "name": "Hadoop Azure", "uri": "https://repo1.maven.org/maven2/org/apache/hadoop/hadoop-azure/3.3.4/hadoop-azure-3.3.4.jar", - "type": "jar", - "md5": "1ec4cbd59548412010fe1515070eef73", - "sha1": "a23f621bca9b2100554150f6b0b521f94b8b419e", - "size": 574116 + "verification": { + "hashLib": { + "type": "sha1", + "value": "a23f621bca9b2100554150f6b0b521f94b8b419e" + }, + "size": 574116 + }, + "type": "jar" } ] } @@ -154,22 +172,6 @@ } } }, - "pricing": { - "catalog": { - "onlineResources": [ - { - "resourceKey": "premium-databricks-azure-catalog", - "onlineURL": "https://azure.microsoft.com/en-us/pricing/details/databricks/", - "//localFile": "the name of the local file", - "localFile": "premium-databricks-azure-catalog.json", - "backupArchive": { - "//description": "We use this archive as a backup. It is stored in the resources", - "archiveName": "premium-databricks-azure-catalog.json" - } - } - ] - } - }, "gpuConfigs": { "user-tools": { "supportedGpuInstances": { diff --git a/user_tools/src/spark_rapids_pytools/resources/dataproc-configs.json b/user_tools/src/spark_rapids_pytools/resources/dataproc-configs.json index d25daa2db..1f99e7bac 100644 --- a/user_tools/src/spark_rapids_pytools/resources/dataproc-configs.json +++ b/user_tools/src/spark_rapids_pytools/resources/dataproc-configs.json @@ -7,36 +7,54 @@ { "name": "Apache Spark", "uri": "https://archive.apache.org/dist/spark/spark-3.5.0/spark-3.5.0-bin-hadoop3.tgz", + "verification": { + "hashLib": { + "type": "sha512", + "value": "8883c67e0a138069e597f3e7d4edbbd5c3a565d50b28644aad02856a1ec1da7cb92b8f80454ca427118f69459ea326eaa073cf7b1a860c3b796f4b07c2101319" + }, + "size": 400395283 + }, "type": "archive", - "relativePath": "jars/*", - "sha512": "8883c67e0a138069e597f3e7d4edbbd5c3a565d50b28644aad02856a1ec1da7cb92b8f80454ca427118f69459ea326eaa073cf7b1a860c3b796f4b07c2101319", - "size": 400395283 + "relativePath": "jars/*" }, { "name": "GCS Connector Hadoop3", "uri": "https://repo1.maven.org/maven2/com/google/cloud/bigdataoss/gcs-connector/hadoop3-2.2.19/gcs-connector-hadoop3-2.2.19-shaded.jar", - "type": "jar", - "md5": "2ee6ad7215304cf5da8e731afb36ad72", - "sha1": "3bea6d5e62663a2a5c03d8ca44dff4921aeb3170", - "size": 39359477 + "verification": { + "hashLib": { + "type": "sha1", + "value": "3bea6d5e62663a2a5c03d8ca44dff4921aeb3170" + }, + "size": 39359477 + }, + "type": "jar" } ], "333": [ { "name": "Apache Spark", "uri": "https://archive.apache.org/dist/spark/spark-3.3.3/spark-3.3.3-bin-hadoop3.tgz", + "verification": { + "hashLib": { + "type": "sha512", + "value": "ebf79c7861f3120d5ed9465fdd8d5302a734ff30713a0454b714bbded7ab9f218b3108dc46a5de4cc2102c86e7be53908f84d2c7a19e59bc75880766eeefeef9" + }, + "size": 299426263 + }, "type": "archive", - "relativePath": "jars/*", - "sha512": "ebf79c7861f3120d5ed9465fdd8d5302a734ff30713a0454b714bbded7ab9f218b3108dc46a5de4cc2102c86e7be53908f84d2c7a19e59bc75880766eeefeef9", - "size": 299426263 + "relativePath": "jars/*" }, { "name": "GCS Connector Hadoop3", "uri": "https://repo1.maven.org/maven2/com/google/cloud/bigdataoss/gcs-connector/hadoop3-2.2.17/gcs-connector-hadoop3-2.2.17-shaded.jar", - "type": "jar", - "md5": "41aea3add826dfbf3384a2c638148709", - "sha1": "06438f562692ff8fae5e8555eba2b9f95cb74f66", - "size": 38413466 + "verification": { + "hashLib": { + "type": "sha1", + "value": "06438f562692ff8fae5e8555eba2b9f95cb74f66" + }, + "size": 38413466 + }, + "type": "jar" } ] } @@ -141,28 +159,6 @@ } } }, - "pricing": { - "catalog": { - "onlineResources": [ - { - "resourceKey": "gcloud-catalog", - "onlineURL": "https://cloudpricingcalculator.appspot.com/static/data/pricelist.json", - "//localFile": "the name of the file after downloading", - "localFile": "gcloud-catalog.json", - "backupArchive": { - "//description-1": "In case the file is stuck, we use this archive as a backup.", - "//description-2": "It is stored in the resources", - "archiveName": "gcloud-catalog.tgz" - } - } - ], - "components": { - "ssd": { - "unitSizeFactor": 0.513698630136986 - } - } - } - }, "gpuConfigs": { "user-tools": { "gpuPerMachine": { diff --git a/user_tools/src/spark_rapids_pytools/resources/dataproc_gke-configs.json b/user_tools/src/spark_rapids_pytools/resources/dataproc_gke-configs.json index dec1d05db..861b127d7 100644 --- a/user_tools/src/spark_rapids_pytools/resources/dataproc_gke-configs.json +++ b/user_tools/src/spark_rapids_pytools/resources/dataproc_gke-configs.json @@ -7,36 +7,54 @@ { "name": "Apache Spark", "uri": "https://archive.apache.org/dist/spark/spark-3.5.0/spark-3.5.0-bin-hadoop3.tgz", + "verification": { + "hashLib": { + "type": "sha512", + "value": "8883c67e0a138069e597f3e7d4edbbd5c3a565d50b28644aad02856a1ec1da7cb92b8f80454ca427118f69459ea326eaa073cf7b1a860c3b796f4b07c2101319" + }, + "size": 400395283 + }, "type": "archive", - "relativePath": "jars/*", - "sha512": "8883c67e0a138069e597f3e7d4edbbd5c3a565d50b28644aad02856a1ec1da7cb92b8f80454ca427118f69459ea326eaa073cf7b1a860c3b796f4b07c2101319", - "size": 400395283 + "relativePath": "jars/*" }, { "name": "GCS Connector Hadoop3", "uri": "https://repo1.maven.org/maven2/com/google/cloud/bigdataoss/gcs-connector/hadoop3-2.2.19/gcs-connector-hadoop3-2.2.19-shaded.jar", - "type": "jar", - "md5": "2ee6ad7215304cf5da8e731afb36ad72", - "sha1": "3bea6d5e62663a2a5c03d8ca44dff4921aeb3170", - "size": 39359477 + "verification": { + "hashLib": { + "type": "sha1", + "value": "3bea6d5e62663a2a5c03d8ca44dff4921aeb3170" + }, + "size": 39359477 + }, + "type": "jar" } ], "333": [ { "name": "Apache Spark", "uri": "https://archive.apache.org/dist/spark/spark-3.3.3/spark-3.3.3-bin-hadoop3.tgz", + "verification": { + "hashLib": { + "type": "sha512", + "value": "ebf79c7861f3120d5ed9465fdd8d5302a734ff30713a0454b714bbded7ab9f218b3108dc46a5de4cc2102c86e7be53908f84d2c7a19e59bc75880766eeefeef9" + }, + "size": 299426263 + }, "type": "archive", - "relativePath": "jars/*", - "sha512": "ebf79c7861f3120d5ed9465fdd8d5302a734ff30713a0454b714bbded7ab9f218b3108dc46a5de4cc2102c86e7be53908f84d2c7a19e59bc75880766eeefeef9", - "size": 299426263 + "relativePath": "jars/*" }, { "name": "GCS Connector Hadoop3", "uri": "https://repo1.maven.org/maven2/com/google/cloud/bigdataoss/gcs-connector/hadoop3-2.2.17/gcs-connector-hadoop3-2.2.17-shaded.jar", - "type": "jar", - "md5": "41aea3add826dfbf3384a2c638148709", - "sha1": "06438f562692ff8fae5e8555eba2b9f95cb74f66", - "size": 38413466 + "verification": { + "hashLib": { + "type": "sha1", + "value": "06438f562692ff8fae5e8555eba2b9f95cb74f66" + }, + "size": 38413466 + }, + "type": "jar" } ] } @@ -136,28 +154,6 @@ } } }, - "pricing": { - "catalog": { - "onlineResources": [ - { - "resourceKey": "gcloud-catalog", - "onlineURL": "https://cloudpricingcalculator.appspot.com/static/data/pricelist.json", - "//localFile": "the name of the file after downloading", - "localFile": "gcloud-catalog.json", - "backupArchive": { - "//description-1": "In case the file is stuck, we use this archive as a backup.", - "//description-2": "It is stored in the resources", - "archiveName": "gcloud-catalog.tgz" - } - } - ], - "components": { - "ssd": { - "unitSizeFactor": 0.513698630136986 - } - } - } - }, "gpuConfigs": { "user-tools": { "gpuPerMachine": { diff --git a/user_tools/src/spark_rapids_pytools/resources/dev/prepackage_mgr.py b/user_tools/src/spark_rapids_pytools/resources/dev/prepackage_mgr.py index 4c9971689..1920c134b 100644 --- a/user_tools/src/spark_rapids_pytools/resources/dev/prepackage_mgr.py +++ b/user_tools/src/spark_rapids_pytools/resources/dev/prepackage_mgr.py @@ -20,7 +20,6 @@ import os import shutil import tarfile -from concurrent.futures import ThreadPoolExecutor from typing import Optional import fire @@ -30,6 +29,7 @@ from spark_rapids_pytools.rapids.rapids_tool import RapidsTool from spark_rapids_tools import CspEnv from spark_rapids_tools.utils import Utilities +from spark_rapids_tools.utils.net_utils import DownloadManager, DownloadTask # Defines the constants and static configurations prepackage_conf = { @@ -122,30 +122,23 @@ def _fetch_resources(self) -> dict: if uri: resource_uris[uri] = {'name': name, 'pbar_enabled': False} resource_uris[uri + '.asc'] = {'name': name + '.asc', 'pbar_enabled': False} - - # Add pricing files as resources - if platform_conf.get_value_silent('pricing'): - for pricing_entry in platform_conf.get_value('pricing', 'catalog', 'onlineResources'): - uri = pricing_entry.get('onlineURL') - name = pricing_entry.get('localFile') - if uri and name: - resource_uris[uri] = {'name': name, 'pbar_enabled': False} - return resource_uris def _download_resources(self, resource_uris: dict): - resource_uris_list = list(resource_uris.items()) - - def download_task(resource_uri, resource_info): - resource_name = resource_info['name'] - pbar_enabled = resource_info['pbar_enabled'] - resource_file_path = FSUtil.build_full_path(self.dest_dir, resource_name) - - print(f'Downloading {resource_name}') - FSUtil.fast_download_url(resource_uri, resource_file_path, pbar_enabled=pbar_enabled) - - with ThreadPoolExecutor() as executor: - executor.map(lambda x: download_task(x[0], x[1]), resource_uris_list) + download_tasks = [] + for res_uri, res_info in resource_uris.items(): + resource_name = res_info.get('name') + print(f'Creating download task: {resource_name}') + # All the downloadTasks enforces download + download_tasks.append(DownloadTask(src_url=res_uri, # pylint: disable=no-value-for-parameter) + dest_folder=self.dest_dir, + configs={'forceDownload': True})) + # Begin downloading the resources + download_results = DownloadManager(download_tasks, max_workers=12).submit() + print('----Download summary---') + for res in download_results: + print(res.pretty_print()) + print('-----------------------') def _compress_resources(self) -> Optional[str]: if not self.archive_enabled: diff --git a/user_tools/src/spark_rapids_pytools/resources/emr-configs.json b/user_tools/src/spark_rapids_pytools/resources/emr-configs.json index e273c224c..5b6a34bd2 100644 --- a/user_tools/src/spark_rapids_pytools/resources/emr-configs.json +++ b/user_tools/src/spark_rapids_pytools/resources/emr-configs.json @@ -7,52 +7,78 @@ { "name": "Apache Spark", "uri": "https://archive.apache.org/dist/spark/spark-3.5.0/spark-3.5.0-bin-hadoop3.tgz", + "verification": { + "hashLib": { + "type": "sha512", + "value": "8883c67e0a138069e597f3e7d4edbbd5c3a565d50b28644aad02856a1ec1da7cb92b8f80454ca427118f69459ea326eaa073cf7b1a860c3b796f4b07c2101319" + }, + "size": 400395283 + }, "type": "archive", - "relativePath": "jars/*", - "sha512": "8883c67e0a138069e597f3e7d4edbbd5c3a565d50b28644aad02856a1ec1da7cb92b8f80454ca427118f69459ea326eaa073cf7b1a860c3b796f4b07c2101319", - "size": 400395283 + "relativePath": "jars/*" }, { "name": "Hadoop AWS", "uri": "https://repo1.maven.org/maven2/org/apache/hadoop/hadoop-aws/3.3.4/hadoop-aws-3.3.4.jar", - "type": "jar", - "md5": "59907e790ce713441955015d79f670bc", - "sha1": "a65839fbf1869f81a1632e09f415e586922e4f80", - "size": 962685 + "verification": { + "hashLib": { + "type": "sha1", + "value": "a65839fbf1869f81a1632e09f415e586922e4f80" + }, + "size": 962685 + }, + "type": "jar" }, { "name": "AWS Java SDK Bundled", "uri": "https://repo1.maven.org/maven2/com/amazonaws/aws-java-sdk-bundle/1.12.262/aws-java-sdk-bundle-1.12.262.jar", - "type": "jar", - "md5": "8a22f2d30b7e8eee9ea44f04fb13b35a", - "sha1": "02deec3a0ad83d13d032b1812421b23d7a961eea", - "size": 280645251 + "verification": { + "hashLib": { + "type": "sha1", + "value": "02deec3a0ad83d13d032b1812421b23d7a961eea" + }, + "size": 280645251 + }, + "type": "jar" } ], "333": [ { "name": "Apache Spark", "uri": "https://archive.apache.org/dist/spark/spark-3.3.3/spark-3.3.3-bin-hadoop3.tgz", + "verification": { + "hashLib": { + "type": "sha512", + "value": "ebf79c7861f3120d5ed9465fdd8d5302a734ff30713a0454b714bbded7ab9f218b3108dc46a5de4cc2102c86e7be53908f84d2c7a19e59bc75880766eeefeef9" + }, + "size": 299426263 + }, "type": "archive", - "relativePath": "jars/*", - "sha512": "ebf79c7861f3120d5ed9465fdd8d5302a734ff30713a0454b714bbded7ab9f218b3108dc46a5de4cc2102c86e7be53908f84d2c7a19e59bc75880766eeefeef9", - "size": 299426263 + "relativePath": "jars/*" }, { "name": "Hadoop AWS", "uri": "https://repo1.maven.org/maven2/org/apache/hadoop/hadoop-aws/3.3.4/hadoop-aws-3.3.4.jar", - "type": "jar", - "md5": "59907e790ce713441955015d79f670bc", - "sha1": "a65839fbf1869f81a1632e09f415e586922e4f80", - "size": 962685 + "verification": { + "hashLib": { + "type": "sha1", + "value": "a65839fbf1869f81a1632e09f415e586922e4f80" + }, + "size": 962685 + }, + "type": "jar" }, { "name": "AWS Java SDK Bundled", "uri": "https://repo1.maven.org/maven2/com/amazonaws/aws-java-sdk-bundle/1.12.262/aws-java-sdk-bundle-1.12.262.jar", - "type": "jar", - "md5": "8a22f2d30b7e8eee9ea44f04fb13b35a", - "sha1": "02deec3a0ad83d13d032b1812421b23d7a961eea", - "size": 280645251 + "verification": { + "hashLib": { + "type": "sha1", + "value": "02deec3a0ad83d13d032b1812421b23d7a961eea" + }, + "size": 280645251 + }, + "type": "jar" } ] } @@ -170,34 +196,6 @@ } } }, - "pricing": { - "catalog": { - "onlineResources": [ - { - "resourceKey": "emr-catalog", - "onlineURL": "https://pricing.us-east-1.amazonaws.com/offers/v1.0/aws/ElasticMapReduce/current/us-west-2/index.json", - "//localFile": "the name of the file after downloading", - "localFile": "aws_ec2_catalog_emr_us-west-2.json", - "backupArchive": { - "//description-1": "In case the file is stuck, we use this archive as a backup.", - "//description-2": "It is stored in the resources", - "archiveName": "aws_ec2_catalog_emr_us-west-2.json" - } - }, - { - "resourceKey": "ec2-catalog", - "onlineURL": "https://pricing.us-east-1.amazonaws.com/offers/v1.0/aws/AmazonEC2/current/us-west-2/index.json", - "//localFile": "the name of the file after downloading", - "localFile": "aws_ec2_catalog_ec2_us-west-2.json", - "backupArchive": { - "//description-1": "In case the file is stuck, we use this archive as a backup.", - "//description-2": "It is stored in the resources", - "archiveName": "aws_ec2_catalog_ec2_us-west-2.json" - } - } - ] - } - }, "gpuConfigs": { "user-tools": { "supportedGpuInstances": { diff --git a/user_tools/src/spark_rapids_pytools/resources/onprem-configs.json b/user_tools/src/spark_rapids_pytools/resources/onprem-configs.json index 35eecae50..f7e9af218 100644 --- a/user_tools/src/spark_rapids_pytools/resources/onprem-configs.json +++ b/user_tools/src/spark_rapids_pytools/resources/onprem-configs.json @@ -7,57 +7,50 @@ { "name": "Apache Spark", "uri": "https://archive.apache.org/dist/spark/spark-3.5.0/spark-3.5.0-bin-hadoop3.tgz", + "verification": { + "hashLib": { + "type": "sha512", + "value": "8883c67e0a138069e597f3e7d4edbbd5c3a565d50b28644aad02856a1ec1da7cb92b8f80454ca427118f69459ea326eaa073cf7b1a860c3b796f4b07c2101319" + }, + "size": 400395283 + }, "type": "archive", - "relativePath": "jars/*", - "sha512": "8883c67e0a138069e597f3e7d4edbbd5c3a565d50b28644aad02856a1ec1da7cb92b8f80454ca427118f69459ea326eaa073cf7b1a860c3b796f4b07c2101319", - "size": 400395283 + "relativePath": "jars/*" } ], "342": [ { "name": "Apache Spark", "uri": "https://archive.apache.org/dist/spark/spark-3.4.2/spark-3.4.2-bin-hadoop3.tgz", + "verification": { + "hashLib": { + "type": "sha512", + "value": "c9470a557c96fe899dd1c9ea8d0dda3310eaf0155b2bb972f70a6d97fee8cdaf838b425c30df3d5856b2c31fc2be933537c111db72d0427eabb76c6abd92c1f1" + }, + "size": 388664780 + }, "type": "archive", - "relativePath": "jars/*", - "sha512": "c9470a557c96fe899dd1c9ea8d0dda3310eaf0155b2bb972f70a6d97fee8cdaf838b425c30df3d5856b2c31fc2be933537c111db72d0427eabb76c6abd92c1f1", - "size": 388664780 + "relativePath": "jars/*" } ], "333": [ { "name": "Apache Spark", "uri": "https://archive.apache.org/dist/spark/spark-3.3.3/spark-3.3.3-bin-hadoop3.tgz", + "verification": { + "hashLib": { + "type": "sha512", + "value": "ebf79c7861f3120d5ed9465fdd8d5302a734ff30713a0454b714bbded7ab9f218b3108dc46a5de4cc2102c86e7be53908f84d2c7a19e59bc75880766eeefeef9" + }, + "size": 299426263 + }, "type": "archive", - "relativePath": "jars/*", - "sha512": "ebf79c7861f3120d5ed9465fdd8d5302a734ff30713a0454b714bbded7ab9f218b3108dc46a5de4cc2102c86e7be53908f84d2c7a19e59bc75880766eeefeef9", - "size": 299426263 + "relativePath": "jars/*" } ] } } }, - "csp_pricing": { - "catalog": { - "onlineResources": [ - { - "resourceKey": "gcloud-catalog", - "onlineURL": "https://cloudpricingcalculator.appspot.com/static/data/pricelist.json", - "//localFile": "the name of the file after downloading", - "localFile": "gcloud-catalog.json", - "backupArchive": { - "//description-1": "In case the file is stuck, we use this archive as a backup.", - "//description-2": "It is stored in the resources", - "archiveName": "gcloud-catalog.tgz" - } - } - ], - "components": { - "ssd": { - "unitSizeFactor": 0.513698630136986 - } - } - } - }, "gpuConfigs": { "dataproc" : { "user-tools": { diff --git a/user_tools/src/spark_rapids_tools/enums.py b/user_tools/src/spark_rapids_tools/enums.py index 29bfcb2c9..61dca23a6 100644 --- a/user_tools/src/spark_rapids_tools/enums.py +++ b/user_tools/src/spark_rapids_tools/enums.py @@ -13,7 +13,7 @@ # limitations under the License. """Enumeration types commonly used through the AS python implementations.""" - +import hashlib from enum import Enum, auto from typing import Union, cast, Callable @@ -67,11 +67,45 @@ def pretty_print(cls, value): value = cast(Enum, value) return str(value._value_) # pylint: disable=protected-access +############### +# Utility Enums +############### + + +class HashAlgorithm(EnumeratedType): + """Represents the supported hashing algorithms""" + MD5 = 'md5' + SHA1 = 'sha1' + SHA256 = 'sha256' + SHA512 = 'sha512' + + @classmethod + def get_default(cls): + return cls.SHA256 + + @classmethod + def _missing_(cls, value): + value = value.lower() + for member in cls: + if member.lower() == value: + return member + return None + + def get_hash_func(self) -> Callable: + """Maps the hash function to the appropriate hashing algorithm""" + hash_functions = { + self.MD5: hashlib.md5, + self.SHA1: hashlib.sha1, + self.SHA256: hashlib.sha256, + self.SHA512: hashlib.sha512, + } + return hash_functions[self] ########### # CSP Enums ########### + class CspEnv(EnumeratedType): """Represents the supported types of runtime CSP""" DATABRICKS_AWS = 'databricks_aws' diff --git a/user_tools/src/spark_rapids_tools/storagelib/cspfs.py b/user_tools/src/spark_rapids_tools/storagelib/cspfs.py index 9adb58ea4..6c00bfaf3 100644 --- a/user_tools/src/spark_rapids_tools/storagelib/cspfs.py +++ b/user_tools/src/spark_rapids_tools/storagelib/cspfs.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -99,6 +99,20 @@ def __init__(self, *args: Any, **kwargs: Any): def create_as_path(self, entry_path: Union[str, BoundedCspPath]) -> BoundedCspPath: return self._path_meta.path_class(entry_path=entry_path, fs_obj=self) + @classmethod + def copy_file(cls, src: BoundedCspPath, dest: BoundedCspPath): + """ + Copy a single file between FileSystems. This function assumes that + :param src: + :param dest: + :return: + """ + arrow_fs.copy_files(src.no_scheme, dest.no_scheme, + source_filesystem=src.fs_obj.fs, + destination_filesystem=dest.fs_obj.fs, + # 64 MB chunk size + chunk_size=64 * 1024 * 1024) + @classmethod def copy_resources(cls, src: BoundedCspPath, dest: BoundedCspPath): """ @@ -132,6 +146,8 @@ def copy_resources(cls, src: BoundedCspPath, dest: BoundedCspPath): dest.create_dirs() dest = dest.fs_obj.create_as_path(entry_path=dest_path) - arrow_fs.copy_files(src.no_prefix, dest.no_prefix, + arrow_fs.copy_files(src.no_scheme, dest.no_scheme, source_filesystem=src.fs_obj.fs, - destination_filesystem=dest.fs_obj.fs) + destination_filesystem=dest.fs_obj.fs, + # 64 MB chunk size + chunk_size=64 * 1024 * 1024) diff --git a/user_tools/src/spark_rapids_tools/storagelib/csppath.py b/user_tools/src/spark_rapids_tools/storagelib/csppath.py index 50d1b8c2d..59ec0aed6 100644 --- a/user_tools/src/spark_rapids_tools/storagelib/csppath.py +++ b/user_tools/src/spark_rapids_tools/storagelib/csppath.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ """ Abstract representation of a file path that can access local/URI values. -Similar to cloudpathlib project, this implementation uses dict registry to +Like to cloudpathlib project, this implementation uses dict registry to register an implementation. However, the path representation is built on top of pyArrow FS API. As a result, there is no need to write a full storage client to access remote files. This comes with a tradeoff in providing limited set of file @@ -279,11 +279,15 @@ def is_protocol_prefix(cls, value: str) -> bool: return value.lower().startswith(cls.protocol_prefix.lower()) @cached_property - def no_prefix(self) -> str: + def no_scheme(self) -> str: + """ + Get the path without the scheme. i.e., file:///path/to/file returns /path/to/file + :return: the full url without scheme part. + """ return self._fpath[len(self.protocol_prefix):] def _pull_file_info(self) -> FileInfo: - return self.fs_obj.get_file_info(self.no_prefix) + return self.fs_obj.get_file_info(self.no_scheme) @cached_property def file_info(self) -> FileInfo: @@ -308,16 +312,52 @@ def create_dirs(self, exist_ok: bool = True): # check that the file does not exist if self.exists(): raise CspFileExistsError(f'Path already Exists: {self}') - self.fs_obj.create_dir(self.no_prefix) + self.fs_obj.create_dir(self.no_scheme) # force the file information object to be retrieved again by invalidating the cached property if 'file_info' in self.__dict__: del self.__dict__['file_info'] def open_input_stream(self): - return self.fs_obj.open_input_stream(self.no_prefix) + return self.fs_obj.open_input_stream(self.no_scheme) def open_output_stream(self): - return self.fs_obj.open_output_stream(self.no_prefix) + return self.fs_obj.open_output_stream(self.no_scheme) + + def create_sub_path(self, relative: str) -> 'CspPath': + """ + Given a relative path, it will return a new CspPath object with the relative path appended to + the current path. This is just for building a path, and it does not call mkdirs. + For example, + ```py + root_folder = CspPath('gs://bucket-name/folder_00/subfolder_01') + new_path = root_folder.create_sub_path('subfolder_02') + print(new_path) + >> gs://bucket-name/folder_00/subfolder_01/subfolder_02 + ``` + :param relative: A relative path to append to the current path. + :return: A new path without creating the directory/file. + """ + postfix = '/' + sub_path = relative + if relative.startswith('/'): + sub_path = relative[1:] + if self._fpath.endswith('/'): + postfix = '' + new_path = f'{self._fpath}{postfix}{sub_path}' + return CspPath(new_path) + + @property + def size(self) -> int: + return self.file_info.size + + @property + def extension(self) -> str: + # this is used for existing files + return self.file_info.extension + + def extension_from_path(self) -> str: + # if file does not exist then get extension cannot use pull_info + return self.no_scheme.split('.')[-1] @classmethod def download_files(cls, src_url: str, dest_url: str): diff --git a/user_tools/src/spark_rapids_tools/storagelib/tools/fs_utils.py b/user_tools/src/spark_rapids_tools/storagelib/tools/fs_utils.py new file mode 100644 index 000000000..58dc56ebc --- /dev/null +++ b/user_tools/src/spark_rapids_tools/storagelib/tools/fs_utils.py @@ -0,0 +1,186 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility functions for file system operations""" + +import dataclasses +import tarfile +from functools import cached_property +from typing import Optional, Union, List + +from pydantic import model_validator, FilePath, AnyHttpUrl, StringConstraints, ValidationError +from pydantic.dataclasses import dataclass +from pydantic_core import PydanticCustomError +from typing_extensions import Annotated + +from spark_rapids_tools import CspPath +from spark_rapids_tools.enums import HashAlgorithm +from spark_rapids_tools.storagelib import LocalPath + +CspPathString = Annotated[str, StringConstraints(pattern=r'^\w+://.*')] +""" +A type alias for a string that represents a path. The path must start with a protocol scheme. +""" + + +def strip_extension(file_name: str, count: int = 1) -> str: + """ + Utility method to strip the extension from a file name. By default, it only removes the last extension. + The caller can override the count of extensions. For examples: + :param file_name: The file name. + :param count: The number of extensions to remove. i.e., this for files with multi-extensions. + :return: The file name without the extension. + + Examples: + ```py + strip_extension('foo.tgz') # returns 'foo' + strip_extension('foo.tar.gz', count = 2) # returns 'foo' + strip_extension('foo.tar.gz') # returns 'foo.tar' + ``` + """ + return file_name.rsplit('.', count)[0] + + +def raise_invalid_file(file_path: CspPath, msg: str, error_type: str = 'invalid_file') -> None: + """ + Utility method to raise a custom pydantic error for invalid files. See the custom pydantic + errors for more details: https://docs.pydantic.dev/latest/errors/errors/#custom-errors + :param file_path: The object instantiated with the file path. + :param msg: + :param error_type: the error type top be displayed in the error message. + :return: + """ + raise PydanticCustomError(error_type, f'File {str(file_path)} {msg}') + + +@dataclass +class FileHashAlgorithm: + """ + Represents a file hash algorithm and its value. Used for verification against an + existing file. + ```py + try: + file_algo = FileHashAlgorithm(algorithm=HashAlgorithm.SHA256, value='...') + file_algo.verify_file(CspPath('file://path/to/file')) + except ValidationError as e: + print(e) + ``` + """ + algorithm: HashAlgorithm + value: str + + def verify_file(self, file_path: CspPath) -> bool: + cb = self.algorithm.get_hash_func() + with file_path.open_input_stream() as stream: + res = cb(stream.readall()) + if res.hexdigest() != self.value: + raise_invalid_file(file_path, f'incorrect file hash [{HashAlgorithm.tostring(self.algorithm)}]') + return True + + +@dataclasses.dataclass +class FileVerificationResult: + """ + A class that represents the result of a file verification. + :param res_path: The path to the resource subject of verification. + :param opts: Checking options to be passed to the FileChecker + :param raise_on_error: Flag to raise an exception if the file is invalid. + """ + res_path: CspPath + opts: dict + raise_on_error: bool = False + validation_error: Optional[ValidationError] = dataclasses.field(default=None, init=False) + + def __post_init__(self): + try: + CspFileChecker(**{'file_path': str(self.res_path), **self.opts}) + except ValidationError as err: + self.validation_error = err + if self.raise_on_error: + raise err + + @cached_property + def successful(self) -> bool: + return self.validation_error is None + + +@dataclass +class CspFileChecker: + """ + A class that represents a file checker. It is used as a pydantic model to validate the file. + :param file_path: The path to the file. It accepts any valid CspPath or AnyHttpUrl in case we want to + verify an url path extension even before we download it. + :param must_exist: When True, the file must exist. + :param is_file: When True, the file must be a file. Otherwise, it is a directory. + :param size: The expected size of the file. When 0, it is not checked. + :param extensions: A list of expected extensions. When None, it is not checked. + :param file_hash: A hash algorithm and its value to verify the file. + ```py + try: + TypeAdapter(CspFileChecker).validate_python({ + 'file_path': 'file:///var/tmp/spark_cache_folder_test/rapids-4-spark-tools_2.12-24.08.2.jar', + 'must_exist': True, + 'size': ....., + 'extensions': ['jar']}) + except ValidationError as e: + print(e) + ``` + """ + file_path: Union[CspPathString, FilePath, AnyHttpUrl] + must_exist: Optional[bool] = False + is_file: Optional[bool] = True + size: Optional[int] = 0 + extensions: Optional[List[str]] = None + file_hash: Optional[FileHashAlgorithm] = None + # TODO add verification for the modified time. + + @cached_property + def csp_path(self) -> CspPath: + return CspPath(str(self.file_path)) + + @model_validator(mode='after') + def verify_file(self) -> 'CspFileChecker': + if self.must_exist and not self.csp_path.exists(): + raise_invalid_file(self.file_path, 'does not exist') + if self.is_file and not self.csp_path.is_file(): + raise_invalid_file(self.file_path, 'expected a file but got a directory') + if self.size and self.size > 0: + if self.csp_path.size != self.size: + raise_invalid_file(self.file_path, f'size {self.csp_path.size} does not have the expected size') + if self.extensions: + file_ext = self.csp_path.extension_from_path() + if not any(file_ext == ext for ext in self.extensions): + raise_invalid_file(self.file_path, + f'[{file_ext}] does not match the expected extensions') + if self.file_hash: + self.file_hash.verify_file(self.csp_path) + return self + + +def untar_file(file_path: LocalPath, dest_folder: LocalPath) -> CspPath: + """ + Utility method to decompress a tgz file. + :param file_path: The path to the tar file. + :param dest_folder: The destination folder to untar the file. + """ + dest_folder.create_dirs() + # the compressed file must be local and must exist + FileVerificationResult(res_path=file_path, opts={'must_exist': True, 'is_file': True}, raise_on_error=True) + with tarfile.open(file_path.no_scheme, mode='r:*') as tar: + tar.extractall(dest_folder.no_scheme) + tar.close() + extracted_name = strip_extension(file_path.base_name()) + result = dest_folder.create_sub_path(extracted_name) + # we do not need to verify that it is a folder because the result might be a file. + return result diff --git a/user_tools/src/spark_rapids_tools/tools/qualx/qualx_main.py b/user_tools/src/spark_rapids_tools/tools/qualx/qualx_main.py index d14849cfc..bde6da008 100644 --- a/user_tools/src/spark_rapids_tools/tools/qualx/qualx_main.py +++ b/user_tools/src/spark_rapids_tools/tools/qualx/qualx_main.py @@ -75,7 +75,7 @@ def _get_model_path(platform: str, model: Optional[str]) -> Path: f'Custom model file [{model}] is invalid. Please specify a valid JSON file.') # TODO: If the path is remote, we need to copy it locally in order to successfully # load it with xgboost. - model_path = Path(CspPath(model).no_prefix) + model_path = Path(CspPath(model).no_scheme) if not model_path.exists(): raise FileNotFoundError(f'Model JSON file not found: {model_path}') else: diff --git a/user_tools/src/spark_rapids_tools/utils/net_utils.py b/user_tools/src/spark_rapids_tools/utils/net_utils.py new file mode 100644 index 000000000..85d8158a2 --- /dev/null +++ b/user_tools/src/spark_rapids_tools/utils/net_utils.py @@ -0,0 +1,366 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility functions for network functions""" + +import concurrent +import dataclasses +import datetime +import json +import logging +import os +import shutil +import ssl +import time +import urllib +from concurrent.futures import Future, ThreadPoolExecutor +from functools import cached_property +from logging import Logger +from typing import Optional, List + +import certifi +import requests +from fastcore.all import urlsave +from fastprogress.fastprogress import progress_bar +from pydantic import Field, AnyUrl +from pydantic.dataclasses import dataclass +from typing_extensions import Annotated + +from spark_rapids_pytools.common.utilities import ToolLogging +from spark_rapids_tools import CspPath +from spark_rapids_tools.storagelib import CspFs +from spark_rapids_tools.storagelib.tools.fs_utils import FileVerificationResult + + +def download_url_request(url: str, fpath: str, timeout: float = None, + chunk_size: int = 32 * 1024 * 1024) -> str: + """ + Downloads a file from url source using the requests library. + This implementation is more suitable for large files as the chunk size is set to 32 MB. + For smaller file sizes, it might represent an overhead on memory consumption. + :param url: The source of the file to download. + :param fpath: The file path where the resource is saved. + :param timeout: Time in seconds before the requests times out. + :param chunk_size: Default buffer size to download the file. + :return: Local path where the file is downloaded. + """ + # disable the urllib3 debug messages + logging.getLogger('urllib3').setLevel(logging.WARNING) + with requests.get(url, stream=True, timeout=timeout) as r: + r.raise_for_status() + with open(fpath, 'wb') as f: + # set chunk size to 16 MB to lower the count of iterations. + for chunk in r.iter_content(chunk_size=chunk_size): + f.write(chunk) + return fpath + + +def download_url_urllib(url: str, fpath: str) -> str: + """ + Download the given url to the file path. This function is a simple wrapper around the urllib.request.urlopen + :param url: URL to download. + :param fpath: the destination of the saved file. + :return: Local path to the saved file. + """ + # disable the urllib3 debug messages + logging.getLogger('urllib3').setLevel(logging.WARNING) + # We create a context here to fix and issue with urlib requests issue. + context = ssl.create_default_context(cafile=certifi.where()) + with urllib.request.urlopen(url, context=context) as resp: + with open(fpath, 'wb') as f: + shutil.copyfileobj(resp, f, 64 * 1024 * 1024) + return fpath + + +def download_url_fastcore(url: str, fpath: str, timeout=None, pbar_enabled=False) -> str: + """ + Download the given url and display a progress bar. This implementation uses the fastcore library. + We used this by default because it is faster than the download_from_url function. + """ + pbar = progress_bar([]) + + def progress_bar_cb(count=1, bsize=1, total_size=None): + pbar.total = total_size + pbar.update(count * bsize) + + return urlsave(url, fpath, reporthook=progress_bar_cb if pbar_enabled else None, timeout=timeout) + + +def default_download_options(custom_opts: dict = None) -> dict: + """ + Utility function to create the default options for the download. + :param custom_opts: allows user to override or extend the download options. + :return: A dictionary with the default options + custom options. + """ + custom_opts = custom_opts if custom_opts else {} + default_opts = { + # force the download even teh file exists + 'forceDownload': False, + # 3600 is 1 hour + 'timeOut': 3600, + # number of retries + 'retryCount': 3 + } + default_opts.update(custom_opts) + return default_opts + + +def default_verification_options(custom_opts: dict = None) -> dict: + """ + Utility function to create the default options for the verification. + :param custom_opts: Allows user to override and to extend the default verification options. + :return: A dictionary with the default options + custom options. + """ + custom_opts = custom_opts if custom_opts else {} + default_opts = { + # the file must exist + 'must_exist': True, + # it has to be a file + 'is_file': True, + # size of the file. 0 to ignore + 'size': 0, + # list of extensions to check + 'extensions': [] + } + default_opts.update(custom_opts) + return default_opts + + +def download_exception_handler(future: Future) -> None: + # Handle any exceptions raised by the task + exception = future.exception() + if exception: + print('Error while downloading dependency: %s', exception) + + +@dataclasses.dataclass +class DownloadResult: + """ + A class that represents the result of a download task. It contains the following information: + :param resource: The path where the downloaded resource is located. Notice that we use a CspPath + instead of a string or FilePath to represent different schemes and CspStorages. + :param origin_url: The original URL of the resource. + :param success: Whether the download is successful. + :param downloaded: Whether the resource is downloaded or loaded from an existing folder. + i.e., if the file already exists, and the download is not enforced, then the downloaded + value should be false. + :param download_time: The elapsed download time in seconds. + :param verified: Whether the resource is verified. + :param download_error: The error that occurred during the download if any. + """ + resource: CspPath + origin_url: str + success: bool + downloaded: bool + download_time: float = 0 + verified: bool = True + download_error: Optional[Exception] = None + + def pretty_print(self) -> str: + json_str = { + 'resource': str(self.resource), + 'origin_url': self.origin_url, + 'success': self.success, + 'downloaded': self.downloaded, + 'download_time(seconds)': f'{self.download_time:,.3f}', + } + if self.download_error: + json_str['download_error'] = str(self.download_error) + + return json.dumps(json_str, indent=4) + + +@dataclass +class DownloadTask: + """ + A class that represents a download task. It contains the following information: + :param src_url: The URL of the resource to download. + :param dest_folder: The destination folder where the resource is downloaded. + :param configs: A dictionary of download options. See default_download_options() for more + information about acceptable options. + :param verification: A dictionary of verification options. See default_verification_options() + for more information about acceptable options. + """ + src_url: AnyUrl + dest_folder: str + configs: Annotated[Optional[dict], Field(default_factory=lambda: default_download_options())] # pylint: disable=unnecessary-lambda + verification: Annotated[Optional[dict], Field(default_factory=lambda: default_verification_options())] # pylint: disable=unnecessary-lambda + + def __post_init__(self): + # Add the defaults when the caller does not set default values. + self.configs = default_download_options(self.configs) + self.verification = default_verification_options(self.verification) + + @cached_property + def resource_base_name(self) -> str: + return CspPath(self.src_url.path).base_name() + + @cached_property + def dest_dir(self) -> CspPath: + dest_root = CspPath(self.dest_folder) + dest_root.create_dirs(exist_ok=True) + return dest_root + + @cached_property + def dest_res(self) -> CspPath: + return self.dest_dir.create_sub_path(self.resource_base_name) + + @cached_property + def force_download(self) -> bool: + return self.configs.get('forceDownload', False) + + def _download_resource(self, opts: dict) -> DownloadResult: + """ + Downloads a single Url path to a local file system. + :param opts: Options passed to the internal download call. + :return: A DownloadResult object. + """ + def download_from_weburl() -> None: + # by default use the requests library to download the file as it performs better for + # larger files. + download_url_request(opts['srcUrl'], opts['destPath'], timeout=opts['timeOut']) + + def download_from_cs() -> None: + # download the file from the Cloud storage including local file system. + csp_src = CspPath(opts['srcUrl']) + CspFs.copy_file(csp_src, self.dest_res) + + start_time = time.monotonic() + curr_time_stamp = datetime.datetime.now().timestamp() + download_exception = None + downloaded = False + success = False + try: + if self.src_url.scheme == 'https': + download_from_weburl() + else: + download_from_cs() + FileVerificationResult(res_path=self.dest_res, opts=self.verification, raise_on_error=True) + # update modified time and access time + os.utime(opts['destPath'], times=(curr_time_stamp, curr_time_stamp)) + success = True + downloaded = True + except Exception as e: # pylint: disable=broad-except + download_exception = e + # we need to create a new CsPath in order to refresh the fileInfo cached with the instance. + return DownloadResult(resource=CspPath(self.dest_res), + origin_url=opts['srcUrl'], + success=success, + downloaded=downloaded, + download_time=time.monotonic() - start_time, + download_error=download_exception) + + def run_task(self) -> DownloadResult: + local_res = self.dest_res + if local_res.exists() and not self.force_download: + # verify that the file is correct using the verification options + if FileVerificationResult(res_path=self.dest_res, + opts=self.verification, raise_on_error=False).successful: + # the file already exists. Skip downloading it. + return DownloadResult(resource=local_res, origin_url=str(self.src_url), + success=True, downloaded=False) + # the file needs to be redownloaded. For example, it might be a corrupted attempt. + download_opts = { + 'srcUrl': str(self.src_url), + 'destPath': local_res.no_scheme, + # set default timeout oif a single task to 1 hour + 'timeOut': self.configs.get('timeOut', 3600), + } + download_res = self._download_resource(download_opts) + return download_res + + def async_submit(self, thread_executor: ThreadPoolExecutor, task_list: List[Future]) -> None: + futures = thread_executor.submit(self.run_task) + futures.add_done_callback(download_exception_handler) + task_list.append(futures) + + +@dataclass +class DownloadManager: + """ + A class that downloads a list of resources in parallel. It creates a threadPool and run a + DownloadTask for each. + https://stackoverflow.com/questions/6509261/how-to-use-concurrent-futures-with-timeouts + :param download_tasks: A list of DownloadTask objects. + :param max_workers: The maximum workers threads to run in parallel. + To disable parallelism, set the value to 1. + :param time_out: The maximum time to wait for the download to complete. Default is 30 minutes. + Note that there is a bug in the python module that ignores the timeout. The timeout exception + is not triggered for some reason. Nevertheless, we set the timeout hopefully this bug gets fixed. + The current behavior is to throws an exception after all the future tasks get completed. + + example usage: + ```py + DownloadManager( + [DownloadTask(src_url='https://urlpath/rapids-4-spark-tools_2.12-24.08.2.jar', + dest_folder='file:///var/tmp/spark_cache_folder_test/async', + configs={'forceDownload': True}, + verification={'size': ....}), + DownloadTask(src_url='https://urlpath/rapids-4-spark-tools_2.12-24.08.1.jar', + dest_folder='file:///var/tmp/spark_cache_folder_test/async', + configs={'forceDownload': True}, + verification={'file_hash': FileHashAlgorithm(HashAlgorithm('md5'), '.....')}), + # the following is file-to-file copy. + DownloadTask(src_url='file:///path/to/file.ext', + dest_folder='file:///var/tmp/spark_cache_folder_test/async'), + DownloadTask(src_url='https://urlpath/spark-3.5.0-bin-hadoop3.tgz', + dest_folder='file:///var/tmp/spark_cache_folder_test/async', + configs={'forceDownload': True}, + verification={'file_hash': FileHashAlgorithm(HashAlgorithm('sha512'), '....')}) + ]).submit() + ``` + """ + download_tasks: List[DownloadTask] + # set it to 1 to avoid parallelism + max_workers: Optional[int] = 4 + # set the timeout to 60 minutes. + time_out: Optional[int] = 3600 + raise_on_error: Optional[bool] = True + enable_logging: Optional[bool] = False + + @cached_property + def get_logger(self) -> Logger: + return ToolLogging.get_and_setup_logger('rapids.tools.download_manager', debug_mode=True) + + def loginfo(self, msg: str) -> None: + if self.enable_logging: + self.get_logger.info(msg) + + def submit(self) -> List[DownloadResult]: + futures_list = [] + results = [] + final_results = [] + failed_downloads = [] + with ThreadPoolExecutor(max_workers=self.max_workers) as executor: + try: + for task in self.download_tasks: + self.loginfo(f'Submitting download task: {task.src_url}') + task.async_submit(executor, futures_list) + # set the timeout to 30 minutes. + for future in concurrent.futures.as_completed(futures_list, timeout=self.time_out): + results.append(future.result()) + for res in results: + if res is not None: + final_results.append(res) + self.loginfo(f'download result: {res.pretty_print()}') + if not res.success: + failed_downloads.append(res) + if self.raise_on_error: + if failed_downloads: + raise ValueError(f'Failed to download the following resources: {failed_downloads}') + if len(final_results) != len(self.download_tasks): + raise ValueError('Not all tasks are completed') + except concurrent.futures.TimeoutError as e: + raise ValueError('Timed out while downloading all tasks') from e + return final_results