Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve new CLI testing ensuring complete coverage of arguments cases #652

Merged
merged 11 commits into from
Nov 17, 2023
43 changes: 30 additions & 13 deletions user_tools/src/spark_rapids_tools/cmdli/argprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,11 +168,21 @@ def detect_platform_from_eventlogs_prefix(self):
self.p_args['toolArgs']['platform'] = map_storage_to_platform[storage_type]

def validate_onprem_with_cluster_name(self):
if self.platform == CspEnv.ONPREM:
# this field has already been populated during initialization
selected_platform = self.p_args['toolArgs']['platform']
if selected_platform == CspEnv.ONPREM:
raise PydanticCustomError(
'invalid_argument',
f'Cannot run cluster by name with platform [{CspEnv.ONPREM}]\n Error:')

def validate_onprem_with_cluster_props_without_eventlogs(self):
# this field has already been populated during initialization
selected_platform = self.p_args['toolArgs']['platform']
if selected_platform == CspEnv.ONPREM:
raise PydanticCustomError(
'invalid_argument',
f'Cannot run cluster by properties with platform [{CspEnv.ONPREM}] without event logs\n Error:')

def init_extra_arg_cases(self) -> list:
return []

Expand Down Expand Up @@ -202,39 +212,39 @@ def define_extra_arg_cases(self):
def build_tools_args(self) -> dict:
pass

def apply_arg_cases(self):
for curr_cases in [self.rejected, self.detected, self.extra]:
def apply_arg_cases(self, cases_list: list):
for curr_cases in cases_list:
for case_key, case_value in curr_cases.items():
if any(ArgValueCase.array_equal(self.argv_cases, case_i) for case_i in case_value['cases']):
# debug the case key
self.logger.info('...applying argument case: %s', case_key)
case_value['callable']()

def apply_all_arg_cases(self):
self.apply_arg_cases([self.rejected, self.detected, self.extra])

def validate_arguments(self):
self.init_tool_args()
self.init_arg_cases()
self.define_invalid_arg_cases()
self.define_detection_cases()
self.define_extra_arg_cases()
self.apply_arg_cases()
self.apply_all_arg_cases()

def get_or_set_platform(self) -> CspEnv:
if self.p_args['toolArgs']['platform'] is None:
# set the platform to default onPrem
runtime_platform = CspEnv.get_default()
else:
runtime_platform = self.p_args['toolArgs']['platform']
self.post_platform_assignment_validation(runtime_platform)
self.post_platform_assignment_validation()
return runtime_platform

def post_platform_assignment_validation(self, assigned_platform):
# do some validation after we decide the cluster type
if self.argv_cases[1] == ArgValueCase.VALUE_A:
if assigned_platform == CspEnv.ONPREM:
# it is not allowed to run cluster_by_name on an OnPrem platform
raise PydanticCustomError(
'invalid_argument',
f'Cannot run cluster by name with platform [{CspEnv.ONPREM}]\n Error:')
def post_platform_assignment_validation(self):
# Update argv_cases to reflect the platform
self.argv_cases[0] = ArgValueCase.VALUE_A
# Any validation post platform assignment should be done here
self.apply_arg_cases([self.rejected, self.extra])


@dataclass
Expand Down Expand Up @@ -278,6 +288,13 @@ def define_invalid_arg_cases(self):
[ArgValueCase.VALUE_A, ArgValueCase.VALUE_A, ArgValueCase.IGNORE]
]
}
self.rejected['Cluster By Properties Cannot go with OnPrem'] = {
'valid': False,
'callable': partial(self.validate_onprem_with_cluster_props_without_eventlogs),
'cases': [
[ArgValueCase.VALUE_A, ArgValueCase.VALUE_B, ArgValueCase.UNDEFINED]
]
}

def define_detection_cases(self):
self.detected['Define Platform from Cluster Properties file'] = {
Expand Down
5 changes: 3 additions & 2 deletions user_tools/tests/spark_rapids_tools_ut/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import sys

import pytest # pylint: disable=import-error
import pytest # pylint: disable=import-error


def get_test_resources_path():
Expand Down Expand Up @@ -46,9 +46,10 @@ def gen_cpu_cluster_props():
# all csps except onprem
csps = ['dataproc', 'dataproc_gke', 'emr', 'databricks_aws', 'databricks_azure']
all_csps = csps + ['onprem']
autotuner_prop_path = 'worker_info.yaml'


class SparkRapidsToolsUT: # pylint: disable=too-few-public-methods
class SparkRapidsToolsUT: # pylint: disable=too-few-public-methods

@pytest.fixture(autouse=True)
def get_ut_data_dir(self):
Expand Down
19 changes: 19 additions & 0 deletions user_tools/tests/spark_rapids_tools_ut/resources/worker_info.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
system:
numCores: 32
memory: 212992MiB
numWorkers: 5
gpu:
memory: 15109MiB
count: 4
name: T4
softwareProperties:
spark.driver.maxResultSize: 7680m
spark.driver.memory: 15360m
spark.executor.cores: '8'
spark.executor.instances: '2'
spark.executor.memory: 47222m
spark.executorEnv.OPENBLAS_NUM_THREADS: '1'
spark.scheduler.mode: FAIR
spark.sql.cbo.enabled: 'true'
spark.ui.port: '0'
spark.yarn.am.memory: 640m
Loading
Loading