diff --git a/user_tools/pyproject.toml b/user_tools/pyproject.toml index 8decf7a82..c92bff303 100644 --- a/user_tools/pyproject.toml +++ b/user_tools/pyproject.toml @@ -27,7 +27,9 @@ dependencies = [ "fastcore==1.7.10", "fire>=0.5.0", "pandas==1.4.3", - "pyYAML>=6.0", + "pyYAML>=6.0,<=7.0", + # This is used to resolve env-variable in yaml files. It requires netween 5.0 and 6.0 + "pyaml-env==1.2.1", "tabulate==0.8.10", "importlib-resources==5.10.2", "requests==2.31.0", diff --git a/user_tools/src/spark_rapids_pytools/common/prop_manager.py b/user_tools/src/spark_rapids_pytools/common/prop_manager.py index 637b139b0..2b2f8588d 100644 --- a/user_tools/src/spark_rapids_pytools/common/prop_manager.py +++ b/user_tools/src/spark_rapids_pytools/common/prop_manager.py @@ -21,6 +21,7 @@ from typing import Any, Callable, Union import yaml +from pyaml_env import parse_config from spark_rapids_tools import get_elem_from_dict, get_elem_non_safe @@ -104,14 +105,13 @@ def __open_json_file(self): def __open_yaml_file(self): try: - with open(self.prop_arg, 'r', encoding='utf-8') as yaml_file: - try: - self.props = yaml.safe_load(yaml_file) - except yaml.YAMLError as e: - raise RuntimeError('Incorrect format of Yaml File') from e + # parse_config sets the default encoding to utf-8 + self.props = parse_config(path=self.prop_arg) + except yaml.YAMLError as e: + raise RuntimeError('Incorrect format of Yaml File') from e except OSError as err: - raise RuntimeError('Please ensure the properties file exists ' - 'and you have the required access privileges.') from err + raise RuntimeError('Please ensure the properties file exists and you have the required ' + 'access privileges.') from err def _load_as_yaml(self): if self.file_load: diff --git a/user_tools/src/spark_rapids_pytools/resources/qualification-conf.yaml b/user_tools/src/spark_rapids_pytools/resources/qualification-conf.yaml index 74c138a11..c6edb6431 100644 --- a/user_tools/src/spark_rapids_pytools/resources/qualification-conf.yaml +++ b/user_tools/src/spark_rapids_pytools/resources/qualification-conf.yaml @@ -249,7 +249,7 @@ local: columnWidth: 14 totalCoreSecCol: 'Total Core Seconds' # This is total core seconds of an 8-core machine running for 24 hours - totalCoreSecThreshold: 691200 + totalCoreSecThreshold: !ENV ${RAPIDS_USER_TOOLS_CORE_SECONDS_THRESHOLD:691200} speedupCategories: speedupColumnName: 'Estimated GPU Speedup' categoryColumnName: 'Estimated GPU Speedup Category' @@ -293,7 +293,7 @@ local: columns: - 'stageId' - 'SQL Nodes(IDs)' - spillThresholdBytes: 10737418240 + spillThresholdBytes: !ENV ${RAPIDS_USER_TOOLS_SPILL_BYTES_THRESHOLD:10737418240} allowedExecs: - 'Aggregate' - 'Join' diff --git a/user_tools/src/spark_rapids_tools/tools/additional_heuristics.py b/user_tools/src/spark_rapids_tools/tools/additional_heuristics.py index 0396957e2..48f01fac4 100644 --- a/user_tools/src/spark_rapids_tools/tools/additional_heuristics.py +++ b/user_tools/src/spark_rapids_tools/tools/additional_heuristics.py @@ -109,7 +109,8 @@ def heuristics_based_on_spills(self, app_id_path: str) -> (bool, str): 'sqlToStageInfo', 'columns')] # Identify stages with significant spills - spill_threshold_bytes = self.props.get_value('spillBased', 'spillThresholdBytes') + # Convert the string to int because the parse_config method returns a string + spill_threshold_bytes = int(self.props.get_value('spillBased', 'spillThresholdBytes')) spill_condition = stage_agg_metrics['memoryBytesSpilled_sum'] > spill_threshold_bytes stages_with_spills = stage_agg_metrics[spill_condition] diff --git a/user_tools/src/spark_rapids_tools/tools/top_candidates.py b/user_tools/src/spark_rapids_tools/tools/top_candidates.py index 1572aa538..fb207406b 100644 --- a/user_tools/src/spark_rapids_tools/tools/top_candidates.py +++ b/user_tools/src/spark_rapids_tools/tools/top_candidates.py @@ -74,7 +74,8 @@ def _filter_apps(self) -> None: # Filter based on total core seconds threshold total_core_sec_col = self.props.get('totalCoreSecCol') - total_core_sec_threshold = self.props.get('totalCoreSecThreshold') + # Convert the string to int because the parse_config method returns a string + total_core_sec_threshold = int(self.props.get('totalCoreSecThreshold')) total_core_sec_condition = self.tools_processed_apps[total_core_sec_col] > total_core_sec_threshold filter_condition = filter_condition & total_core_sec_condition diff --git a/user_tools/src/spark_rapids_tools/utils/propmanager.py b/user_tools/src/spark_rapids_tools/utils/propmanager.py index c7dd6d822..584bbf7ac 100644 --- a/user_tools/src/spark_rapids_tools/utils/propmanager.py +++ b/user_tools/src/spark_rapids_tools/utils/propmanager.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,6 +21,7 @@ from typing import Union, Any, TypeVar, ClassVar, Type, Tuple, Optional import yaml +from pyaml_env import parse_config from pydantic import BaseModel, ConfigDict, model_validator, ValidationError from spark_rapids_tools.exceptions import JsonLoadException, YamlLoadException, InvalidPropertiesSchema @@ -45,7 +46,7 @@ def load_yaml(file_path: Union[str, CspPathT]) -> Any: file_path = CspPath(file_path) with file_path.open_input_stream() as fis: try: - return yaml.safe_load(fis) + return parse_config(data=fis.readall()) except yaml.YAMLError as e: raise YamlLoadException('Incorrect format of Yaml File') from e