diff --git a/examples/offline_profile.py b/examples/offline_profile.py index ee16779776710..04e727103f683 100644 --- a/examples/offline_profile.py +++ b/examples/offline_profile.py @@ -1,11 +1,12 @@ import argparse -import torch -import sys -import json import inspect - -from dataclasses import dataclass, asdict +import json +import sys +from dataclasses import asdict, dataclass from typing import Optional + +import torch + from vllm import LLM, SamplingParams from vllm.profiler import nm_profile @@ -68,6 +69,7 @@ def run_profile(context: ProfileContext, csv_output: Optional[str], output_len = context.output_len scheduler_config = llm.llm_engine.scheduler_config + max_model_len = llm.llm_engine.model_config.max_model_len max_num_batched_tokens = scheduler_config.max_num_batched_tokens max_num_seqs = scheduler_config.max_num_seqs @@ -89,9 +91,9 @@ def run_profile(context: ProfileContext, csv_output: Optional[str], llm.llm_engine.model_config.max_model_len) if prompt_len + output_len > llm.llm_engine.model_config.max_model_len: print( - f"ERROR: chosen prompt_len + output_len ({prompt_len} + {output_len} = " - f"{prompt_len + output_len}) is larger than the model's max_model_len " - f"({llm.llm_engine.model_config.max_model_len}), please choose a smaller " + f"ERROR: chosen prompt_len + output_len ({prompt_len} + " + f"{output_len} = {prompt_len + output_len}) is larger than the " + f"model's max_model_len ({max_model_len}), please choose a smaller " f"prompt_len or output_len, or increase --max-model-len") sys.exit(-1) @@ -222,9 +224,8 @@ def run_profile(context: ProfileContext, csv_output: Optional[str], type=str, choices=['awq', 'gptq', 'squeezellm', 'marlin', 'smoothquant', None], default=None, - help="The method used to quantize the model weights, " - "options are \"marlin\", \"awq\", \"gptq\", \"squeezellm\", \"smoothquant\"" - ) + help="The method used to quantize the model weights, options are " + "\"marlin\", \"awq\", \"gptq\", \"squeezellm\", \"smoothquant\"") parser.add_argument("--dtype", type=str, default='auto', @@ -233,7 +234,7 @@ def run_profile(context: ProfileContext, csv_output: Optional[str], "--max-model-len", type=int, default=None, - help=f"Maximum length of a sequence (including prompt and output)") + help="Maximum length of a sequence (including prompt and output)") parser.add_argument( "--max-num-batched-tokens", type=int, diff --git a/neuralmagic/tools/profiler/print_table.py b/neuralmagic/tools/profiler/print_table.py index ce4b6a7d7bbb9..9081583a9f95d 100644 --- a/neuralmagic/tools/profiler/print_table.py +++ b/neuralmagic/tools/profiler/print_table.py @@ -1,10 +1,10 @@ import argparse import json - -from vllm.profiler.nm_profile import SummaryStatsEntry, ModelStatsEntry -from vllm.profiler.utils import indent_string, TablePrinter from typing import Dict +from vllm.profiler.nm_profile import ModelStatsEntry, SummaryStatsEntry +from vllm.profiler.utils import TablePrinter, indent_string + def flatten_entries(entry_cls, profile_dict: Dict): entries_and_depth = [] diff --git a/neuralmagic/tools/profiler/visualize_trace.py b/neuralmagic/tools/profiler/visualize_trace.py index 09b1d01041d74..fd5659161b046 100644 --- a/neuralmagic/tools/profiler/visualize_trace.py +++ b/neuralmagic/tools/profiler/visualize_trace.py @@ -1,7 +1,8 @@ import argparse import json -import pandas as pd + import matplotlib.pyplot as plt +import pandas as pd def trim_string_back(string: str, width: int): @@ -198,12 +199,11 @@ def plot_metric(metric: str, ax, add_totals=False): shorten_plot_legend_strings(legend, 50) context = profile_data["context"] - plt.suptitle( - f"{context['model']}\n" - f"Batch={context['batch_size']}, " - f"PromptLen={context['prompt_len']}, " - f"NumGpus={context['tensor_parallel_size']}" - f"{', Sparsity ' + context['sparsity'] if context.get('sparsity', None) else ''}" - ) + sparsity = context.get('sparsity', None) + plt.suptitle(f"{context['model']}\n" + f"Batch={context['batch_size']}, " + f"PromptLen={context['prompt_len']}, " + f"NumGpus={context['tensor_parallel_size']}" + f"{', Sparsity ' + sparsity if sparsity else ''}") plt.savefig(output, bbox_inches='tight') print("Created: ", output)