Skip to content

Commit

Permalink
Support extra arguments in new user tools CLI (#646)
Browse files Browse the repository at this point in the history
* Support extra arguments in new user tools CLI

Signed-off-by: Partho Sarthi <[email protected]>

* Update tests

Signed-off-by: Partho Sarthi <[email protected]>

---------

Signed-off-by: Partho Sarthi <[email protected]>
  • Loading branch information
parthosa authored Nov 2, 2023
1 parent defd16c commit ebca530
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 7 deletions.
22 changes: 18 additions & 4 deletions user_tools/src/spark_rapids_tools/cmdli/tools_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def qualification(self,
global_discount: int = None,
gpu_cluster_recommendation: str = QualGpuClusterReshapeType.tostring(
QualGpuClusterReshapeType.get_default()),
verbose: bool = False):
verbose: bool = False,
**rapids_options):
"""The Qualification cmd provides estimated running costs and speedups by migrating Apache
Spark applications to GPU accelerated clusters.
Expand Down Expand Up @@ -98,6 +99,11 @@ def qualification(self,
"CLUSTER": recommend optimal GPU cluster by cost for entire cluster;
"JOB": recommend optimal GPU cluster by cost per job
:param verbose: True or False to enable verbosity of the script.
:param rapids_options: A list of valid Qualification tool options.
Note that the wrapper ignores ["output-directory", "platform"] flags, and it does not support
multiple "spark-property" arguments.
For more details on Qualification tool options, please visit
https://nvidia.github.io/spark-rapids/docs/spark-qualification-tool.html#qualification-tool-options
"""
if verbose:
ToolLogging.enable_debug_mode()
Expand All @@ -118,15 +124,17 @@ def qualification(self,
if qual_args:
tool_obj = QualificationAsLocal(platform_type=qual_args['runtimePlatform'],
output_folder=qual_args['outputFolder'],
wrapper_options=qual_args)
wrapper_options=qual_args,
rapids_options=rapids_options)
tool_obj.launch()

def profiling(self,
eventlogs: str = None,
cluster: str = None,
platform: str = None,
output_folder: str = None,
verbose: bool = False):
verbose: bool = False,
**rapids_options):
"""The Profiling cmd provides information which can be used for debugging and profiling
Apache Spark applications running on accelerated GPU cluster.
Expand All @@ -145,6 +153,11 @@ def profiling(self,
and "databricks-azure".
:param output_folder: path to store the output.
:param verbose: True or False to enable verbosity of the script.
:param rapids_options: A list of valid Profiling tool options.
Note that the wrapper ignores ["output-directory", "worker-info"] flags, and it does not support
multiple "spark-property" arguments.
For more details on Profiling tool options, please visit
https://nvidia.github.io/spark-rapids/docs/spark-profiling-tool.html#profiling-tool-options
"""
if verbose:
ToolLogging.enable_debug_mode()
Expand All @@ -157,7 +170,8 @@ def profiling(self,
if prof_args:
tool_obj = ProfilingAsLocal(platform_type=prof_args['runtimePlatform'],
output_folder=prof_args['outputFolder'],
wrapper_options=prof_args)
wrapper_options=prof_args,
rapids_options=rapids_options)
tool_obj.launch()

def bootstrap(self,
Expand Down
2 changes: 1 addition & 1 deletion user_tools/src/spark_rapids_tools/utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def dump_tool_usage(tool_name: Optional[str], raise_sys_exit: Optional[bool] = T
imported_module = __import__('spark_rapids_tools.cmdli', globals(), locals(), ['ToolsCLI'])
wrapper_clzz = getattr(imported_module, 'ToolsCLI')
help_name = 'ascli'
usage_cmd = f'{tool_name} --help'
usage_cmd = f'{tool_name} -- --help'
try:
fire.Fire(wrapper_clzz(), name=help_name, command=usage_cmd)
except fire.core.FireExit:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,8 @@ def test_cluster_props_no_eventlogs_on_prem(self, capsys, tool_name):
platform='onprem')
assert pytest_wrapped_e.type == SystemExit
captured = capsys.readouterr()
# Verify there is no URL in error message
assert 'https://' not in captured.err
# Verify there is no URL in error message except for the one from the documentation
assert 'https://' not in captured.err or 'nvidia.github.io' in captured.err

@pytest.mark.skip(reason='Unit tests are not completed yet')
def test_arg_cases_coverage(self):
Expand Down

0 comments on commit ebca530

Please sign in to comment.