diff --git a/examples/offline_profile.py b/examples/offline_profile.py new file mode 100644 index 0000000000000..c25df445eb4f9 --- /dev/null +++ b/examples/offline_profile.py @@ -0,0 +1,286 @@ +import argparse +import torch +import sys +import json +import inspect + +from dataclasses import dataclass, asdict +from typing import Optional +from vllm import LLM, SamplingParams +from vllm.profiler import nm_profile + +BATCH_SIZE_DEFAULT = 1 +PROMPT_LEN_DEFAULT = 256 +OUTPUT_LEN_DEFAULT = 2 + + +@dataclass +class ProfileContext: + model: str + tokenizer: str + model_revision: str + quantization: str + max_model_len: int + max_num_batched_tokens: int + prompt_len: int + output_len: int + batch_size: int + dtype: str + tensor_parallel_size: int + allow_cuda_graphs: bool + + +def get_dtype(dtype: str): + if dtype == "torch.float": + return torch.float + else: + return dtype + + +def run_profile(context: ProfileContext, csv_output: Optional[str], + json_output: Optional[str]): + print("Run profile with:") + for key, value in asdict(context).items(): + print(f" {key} = {value}") + + # Create sampling params + sampling_params = SamplingParams(temperature=0.8, + top_p=0.95, + max_tokens=args.output_len, + ignore_eos=True) + + # Sparsity is in the future + # Create LLM + llm = LLM(model=context.model, + tokenizer=context.tokenizer + if context.tokenizer is not None else context.model, + revision=context.model_revision, + enforce_eager=not context.allow_cuda_graphs, + tensor_parallel_size=context.tensor_parallel_size, + gpu_memory_utilization=0.9, + max_model_len=context.max_model_len, + quantization=context.quantization, + dtype=get_dtype(context.dtype), + max_num_batched_tokens=context.max_num_batched_tokens) + + batch_size = context.batch_size + prompt_len = context.prompt_len + output_len = context.output_len + + scheduler_config = llm.llm_engine.scheduler_config + max_model_len = llm.llm_engine.model_config.max_model_len + max_num_batched_tokens = scheduler_config.max_num_batched_tokens + max_num_seqs = scheduler_config.max_num_seqs + + if batch_size * prompt_len > max_num_batched_tokens: + print(f"ERROR: chosen batch_size * prompt_len " + f"({batch_size} * {prompt_len} = {batch_size * prompt_len}) is " + f"larger than max_num_batched_tokens ({max_num_batched_tokens}) " + f"and therefore cannot be run in a single profile step, please " + f"choose a smaller batch size or prompt length, or increase " + f"--max_num_batched_tokens") + sys.exit(-1) + if batch_size >= max_num_seqs: + print( + f"ERROR: chosen batch_size ({batch_size}) is larger than " + f"max_num_seqs ({max_num_seqs}) and therefore cannot be run in a " + f"single profile step, please choose a smaller batch size") + sys.exit(-1) + print("llm.llm_engine.model_config.max_model_len: ", + llm.llm_engine.model_config.max_model_len) + if prompt_len + output_len > llm.llm_engine.model_config.max_model_len: + print( + f"ERROR: chosen prompt_len + output_len ({prompt_len} + " + f"{output_len} = {prompt_len + output_len}) is larger than the " + f"model's max_model_len ({max_model_len}), please choose a smaller " + f"prompt_len or output_len, or increase --max-model-len") + sys.exit(-1) + + for i in range(batch_size): + prompt_token_ids = torch.randint( + llm.llm_engine.model_config.get_vocab_size(), + size=(prompt_len, )).tolist() + + llm.llm_engine.add_request( + request_id=f"seq{i}", + inputs={'prompt_token_ids': prompt_token_ids}, + params=sampling_params) + + with nm_profile() as prefill_prof: + llm.llm_engine.step() # First step is prefill + + decode_results_list = [] + for x in range(args.output_len - 1): + with nm_profile() as decode_prof: + llm.llm_engine.step() + decode_results_list.append(decode_prof.results) + + prefill_results = prefill_prof.results + has_decode = len(decode_results_list) > 0 + + print("=" * 80) + print(f"= Prefill Model Table " + f"(prompt_len={prompt_len}, batch_size={batch_size})") + print("=" * 80) + print() + prefill_results.print_model_table() + + if has_decode: + print() + print("=" * 80) + print(f"= First Decode Step Model Table " + f"(prompt_len={prompt_len}, batch_size={batch_size})") + print("=" * 80) + print() + decode_results_list[0].print_model_table() + + print() + print("=" * 80) + print(f"= Prefill Summary Table " + f"(prompt_len={prompt_len}, batch_size={batch_size})") + print("=" * 80) + print() + prefill_results.print_summary_table() + + if has_decode: + print() + print("=" * 80) + print(f"= First Decode Step Summary Table " + f"(prompt_len={prompt_len}, batch_size={batch_size})") + print("=" * 80) + print() + decode_results_list[0].print_summary_table() + + if csv_output: + csv_filename_base = csv_output.rstrip(".csv") + prefill_results.export_model_stats_table_csv( + csv_filename_base + "_prefill_model_table.csv") + prefill_results.export_summary_stats_table_csv( + csv_filename_base + "_prefill_summary_table.csv") + + if has_decode: + decode_results_list[0].export_model_stats_table_csv(\ + csv_filename_base + "_decode_model_table.csv") + decode_results_list[0].export_summary_stats_table_csv( + csv_filename_base + "_decode_summary_table.csv") + + if json_output: + cuda_devices = [ + torch.cuda.get_device_properties(dev_idx) + for dev_idx in range(torch.cuda.device_count()) + ] + + json_dict = { + "context": { + "python_version": f"{sys.version}", + "torch_version": f"{torch.__version__}", + "torch_cuda_version": f"{torch.version.cuda}", + "cuda_devices": f"{cuda_devices}", + **asdict(context) + }, + "prefill": prefill_results.convert_stats_to_dict(), + } + + if has_decode: + for idx, dr in enumerate(decode_results_list): + json_dict[f"decode_{idx + 1}"] = dr.convert_stats_to_dict() + + for idx, dr in enumerate(decode_results_list[1:]): + json_dict[f"decode_{idx + 1}"] = dr.convert_stats_to_dict() + + + with open(json_output.rstrip(".json") + ".json", "w+") as f: + json.dump(json_dict, f, indent=2) + pass + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--model", + type=str, + required=True, + help='The name or path of a HuggingFace Transformers model.') + parser.add_argument("--tokenizer", + type=str, + default=None, + help="path to the tokenizer") + + parser.add_argument("--model-revision", type=str, default=None) + parser.add_argument( + "--csv", + type=str, + default=None, + help="Export the results as multiple csv file. This should be the root " + "filename, will create _prefill_model_table.csv, " + "_prefill_summary_table.csv, " + "_decode_model_table.csv, and " + "_decode_summary_table.csv") + parser.add_argument( + "--json", + type=str, + default=None, + help="Export the results as a json file. This should be the filename") + parser.add_argument( + "--quantization", + "-q", + type=str, + choices=['awq', 'gptq', 'squeezellm', 'marlin', 'smoothquant', None], + default=None, + help="The method used to quantize the model weights, " + "options are \"marlin\", \"awq\", \"gptq\", \"squeezellm\", \"smoothquant\"" + ) + parser.add_argument("--dtype", + type=str, + default='auto', + help="model dtype") + parser.add_argument( + "--max-model-len", + type=int, + default=None, + help="Maximum length of a sequence (including prompt and output)") + parser.add_argument( + "--max-num-batched-tokens", + type=int, + default=None, + help="Maximum number of tokens to be processed in a single iteration. " + " Should be greater than batch-size * prompt-len so the prefill can " + " run in a single iteration.") + parser.add_argument( + "--prompt-len", + type=int, + default=PROMPT_LEN_DEFAULT, + help=f"Length of the random prompt to use when profiling, all batched " + f"requests use the same prompt_len, default={PROMPT_LEN_DEFAULT}") + parser.add_argument("--batch-size", + type=int, + default=BATCH_SIZE_DEFAULT, + help=f"Number of requests to run as a single batch, " + f"default={BATCH_SIZE_DEFAULT}") + parser.add_argument("--tensor-parallel-size", + "-tp", + type=int, + default=1, + help="Number of GPUs to use i.e. tensor parallelism, " + "default=1") + parser.add_argument( + "--allow-cuda-graphs", + action='store_true', + help="Enables cuda graphs to be used, well remove a lot of the module " + "level info in the profiler results since almost everything runs in " + "the graph where we do not have access to an informative stack trace") + parser.add_argument("--output-len", + type=int, + default=OUTPUT_LEN_DEFAULT, + help="Number of llm steps to run (includes prefill and decode) - default={OUTPUT_LEN_DEFAULT}") + + args = parser.parse_args() + + context = ProfileContext( + **{ + k: v + for k, v in vars(args).items() + if k in inspect.signature(ProfileContext).parameters + }) + run_profile(context, csv_output=args.csv, json_output=args.json) diff --git a/neuralmagic/tools/profiler/print_table.py b/neuralmagic/tools/profiler/print_table.py new file mode 100644 index 0000000000000..ce4b6a7d7bbb9 --- /dev/null +++ b/neuralmagic/tools/profiler/print_table.py @@ -0,0 +1,77 @@ +import argparse +import json + +from vllm.profiler.nm_profile import SummaryStatsEntry, ModelStatsEntry +from vllm.profiler.utils import indent_string, TablePrinter +from typing import Dict + + +def flatten_entries(entry_cls, profile_dict: Dict): + entries_and_depth = [] + + def get_entries(node, curr_depth=0): + entries_and_depth.append((entry_cls(**node["entry"]), curr_depth)) + + for child in node["children"]: + get_entries( + child, + curr_depth=curr_depth + 1, + ) + + for root in profile_dict: + get_entries(root) + + return entries_and_depth + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--json-trace", + type=str, + required=True, + help="json trace file output by " + "examples/offline_profile.py") + parser.add_argument("--phase", + type=str, + choices=["prefill", "decode_1"], + required=True, + help="The phase to print the table for.") + parser.add_argument("--table", + type=str, + choices=["summary", "model"], + default="summary", + help="Which table to print, the summary table or the " + "layerwise model table") + + args = parser.parse_args() + + with open(args.json_trace, "r") as f: + profile_data = json.load(f) + + if args.table == "summary": + entries_and_depths = flatten_entries( + SummaryStatsEntry, profile_data[args.phase]["summary_stats"]) + column_widths = dict(name=80, + cuda_time_us=12, + pct_cuda_time=12, + invocations=15) + elif args.table == "model": + entries_and_depths = flatten_entries( + ModelStatsEntry, profile_data[args.phase]["model_stats"]) + column_widths = dict(name=60, + cpu_time_us=12, + cuda_time_us=12, + pct_cuda_time=12, + trace=60) + + # ident entry names based on the depth + entries = [] + for entry, depth in entries_and_depths: + entry.name = indent_string( + entry.name, + indent=depth, + indent_style=lambda indent: "|" + "-" * indent + " ") + entries.append(entry) + + TablePrinter(type(entries[0]), column_widths).print_table(entries) diff --git a/neuralmagic/tools/profiler/visualize_trace.py b/neuralmagic/tools/profiler/visualize_trace.py new file mode 100644 index 0000000000000..f4b1449281f69 --- /dev/null +++ b/neuralmagic/tools/profiler/visualize_trace.py @@ -0,0 +1,431 @@ +import argparse +import copy +import json +import math +from pathlib import Path +from typing import Any, List, Optional, Tuple + +import matplotlib.pyplot as plt +import pandas as pd + +## JSON parsing utils #### + + +def largest_dist_from_leaf(node: dict, depth: int = 0): + if len(node["children"]) == 0: + return depth + return max([ + largest_dist_from_leaf(child, depth=depth + 1) + for child in node["children"] + ]) + + +def get_entries_at_depth(depth: int, + entries_and_traces: List[Tuple[Any, Any]], + node: dict, + curr_depth: int = 0, + trace=()): + # assert that the query is at kernel or module level + assert depth == -1 or depth == -2 + + if curr_depth == 0 and largest_dist_from_leaf(node) <= (abs(depth) - 1): + # The tree is not tall enough! + entries_and_traces.append((node["entry"], trace)) + return + + if largest_dist_from_leaf(node) == (abs(depth) - 1): + entries_and_traces.append((node["entry"], trace)) + + trace = (node["entry"]["name"], ) + trace + for child in node["children"]: + get_entries_at_depth(depth, + entries_and_traces, + child, + curr_depth=curr_depth + 1, + trace=trace) + + +def fold_nodes(root: dict, nodes_to_fold: List[str]): + + stack: List[dict] = [root] + while len(stack) != 0: + node = stack.pop() + if node['entry']['name'] in nodes_to_fold: + node["children"] = [] + continue + for child in node["children"]: + stack.append(child) + return root + + +## Operation name cleanup utils #### + + +def trim_string_back(string: str, width: int) -> str: + if len(string) > width: + offset = len(string) - width + 3 + string = string[:-offset] + if len(string) > 3: + string = string + "..." + return string + + +def shorten_plot_legend_strings(legend, max_char_len: int): + for t in legend.get_texts(): + t.set_text( + trim_string_back(abbreviate_known_names(t.get_text()), + max_char_len)) + + +def abbreviate_known_names(name: str) -> str: + abbreviations = { + "MergedColumnParallelLinear": "MCPLinear", + "QKVParallelLinear": "QKVPLinear", + "RowParallelLinear": "RPLinear", + "weight=": "w=", + "bfloat16": "bf16", + "float16": "f16", + } + for key, value in abbreviations.items(): + name = name.replace(key, value) + return name + + +def attempt_to_make_names_unique(entries_and_traces): + names, non_unique_names = (set(), set()) + + def all_the_same(items) -> bool: + return all(i == items[0] for i in items) + + for entry, _ in entries_and_traces: + if entry["name"] in names: + non_unique_names.add(entry["name"]) + else: + names.add(entry["name"]) + + for name in non_unique_names: + entries_and_traces_with_name = [(entry, trace) + for entry, trace in entries_and_traces + if entry["name"] == name] + + zipped_traces = list( + zip(*[trace for _, trace in entries_and_traces_with_name])) + first_trace_difference = next( + (i for i, trace_eles in enumerate(zipped_traces) + if not all_the_same(trace_eles)), None) + + if first_trace_difference is None: + # can't create a unique name, leave them names as the + # are they will get aggregated by the pivot_table call + continue + + for entry, trace in entries_and_traces_with_name: + entry["name"] = " <- ".join((entry["name"], ) + + trace[:first_trace_difference + 1]) + + +## Operation grouping utils #### +''' + Group operations in the given dataframe by some high-level ops like, + - gemms + - attention + - rms_norm + etc. +''' + + +def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame: + + def is_rms_norm(op_name: str): + if "rms_norm_kernel" in op_name: + return True + + def is_attention_block(op_name: str): + if "flash_fwd" in op_name or \ + "reshape_and_cache_flash_kernel" in op_name: + return True + + def is_quant(op_name: str): + if "scaled_fp8_quant" in op_name or \ + "scaled_int8_quant" in op_name: + return True + + def is_gemm_op(op_name: str): + if is_quant(op_name): + return False + if "xmma_gemm" in op_name or \ + "gemv2T_kernel" in op_name or \ + "splitKreduce" in op_name or \ + "void cutlass::Kernel" in op_name or \ + "void cutlass::device_kernel" in op_name or \ + "s16816gemm" in op_name: + return True + + def is_elementwise_op(op_name: str): + return "elementwise_kernel" in op_name + + def is_mem_op(op_name: str): + return "memcpy" in op_name.lower() or \ + "memset" in op_name.lower() + + def is_vocab_embedding_op(op_name: str): + return "vocabparallelembed" in op_name.lower() + + headers = list(trace_df) + ops = copy.deepcopy(headers) + + attention_ops = list(filter(lambda x: is_attention_block(x), ops)) + ops = list(filter(lambda x: x not in attention_ops, ops)) + + quant_ops = list(filter(lambda x: is_quant(x), ops)) + ops = list(filter(lambda x: x not in quant_ops, ops)) + + gemm_ops = list(filter(lambda x: is_gemm_op(x), ops)) + ops = list(filter(lambda x: x not in gemm_ops, ops)) + + rms_norm_ops = list(filter(lambda x: is_rms_norm(x), ops)) + ops = list(filter(lambda x: x not in rms_norm_ops, ops)) + + vocab_embed_ops = list(filter(lambda x: is_vocab_embedding_op(x), ops)) + ops = list(filter(lambda x: x not in vocab_embed_ops, ops)) + + mem_ops = list(filter(lambda x: is_mem_op(x), ops)) + ops = list(filter(lambda x: x not in mem_ops, ops)) + + elementwise_ops = list(filter(lambda x: is_elementwise_op(x), ops)) + ops = list(filter(lambda x: x not in elementwise_ops, ops)) + + if len(attention_ops): + trace_df['attention'] = trace_df[attention_ops].agg("sum", axis=1) + if len(quant_ops): + trace_df['quant_ops'] = trace_df[quant_ops].agg("sum", axis=1) + if len(gemm_ops): + trace_df['gemm_ops'] = trace_df[gemm_ops].agg("sum", axis=1) + if len(rms_norm_ops): + trace_df['rms_norm_ops'] = trace_df[rms_norm_ops].agg("sum", axis=1) + if len(vocab_embed_ops): + trace_df['vocab_embed_ops'] = trace_df[vocab_embed_ops].agg("sum", + axis=1) + if len(mem_ops): + trace_df['mem_ops'] = trace_df[mem_ops].agg("sum", axis=1) + if len(elementwise_ops): + trace_df['elementwise_ops'] = trace_df[elementwise_ops].agg("sum", + axis=1) + + trace_df.drop(attention_ops + quant_ops + gemm_ops + rms_norm_ops + + vocab_embed_ops + mem_ops + elementwise_ops, + axis=1, + inplace=True) + return trace_df + + +## Data plotting utils #### + + +def plot_trace_df(traces_df: pd.DataFrame, + plot_metric: str, + plot_title: str, + output: Optional[Path] = None): + + phases = traces_df['phase'].unique() + traces_df = traces_df.pivot_table(index="phase", + columns="name", + values=plot_metric, + aggfunc="sum") + + traces_df = group_trace_by_operations(traces_df) + + # Make the figure + fig, ax = plt.subplots(1, figsize=(5, 8), sharex=True) + + # Draw the stacked bars + ops = list(traces_df) + bottom = [0] * len(phases) + for op in ops: + values = [traces_df[op][phase] for phase in phases] + values = list(map(lambda x: 0.0 if math.isnan(x) else x, values)) + ax.bar(phases, values, label=op, bottom=bottom) + bottom = [bottom[j] + values[j] for j in range(len(phases))] + + # Write the values as text on the bars + for bar in ax.patches: + if bar.get_height() != 0: + ax.text(bar.get_x() + bar.get_width() / 2, + bar.get_height() / 2 + bar.get_y(), + f"{round(bar.get_height(), 2)}", + ha='center', + color='w', + weight='bold', + size=5) + + # Setup legend + handles, labels = plt.gca().get_legend_handles_labels() + legend = fig.legend(handles, + labels, + loc='center left', + bbox_to_anchor=(1, 1)) + shorten_plot_legend_strings(legend, 50) + + # Setup labels and title + plt.setp(ax.get_xticklabels(), rotation=90) + ax.set_ylabel(plot_metric) + plt.suptitle(plot_title) + + plt.savefig(output, bbox_inches='tight') + print("Created: ", output) + + +def main( + json_trace: Path, + output_directory: Path, + depth: int, # Fetch/Plot operations at this depth of the Json tree + plot_metric: str, + make_names_unique: bool, + top_k: int, + json_nodes_to_fold: List[str]): + + def prepare_data(profile_json: dict, step_keys: List[str]) -> pd.DataFrame: + + def get_entries_and_traces(key: str): + entries_and_traces: List[Tuple[Any, Any]] = [] + for root in profile_json[key]["summary_stats"]: + # Fold nodes in the traces as per user request. i.e. simply + # make the requested nodes leaf-nodes. + root = fold_nodes(root, json_nodes_to_fold) + get_entries_at_depth(depth, entries_and_traces, root) + return entries_and_traces + + def keep_only_top_entries(df: pd.DataFrame, + metric: str, + top_k: int = 9) -> pd.DataFrame: + df.loc[df.nsmallest(len(df) - top_k + 1, metric).index, + ["name"]] = "others" + return df + + # Get data for each key + traces = list(map(lambda x: get_entries_and_traces(x), step_keys)) + + # Attempt some cleanup + if make_names_unique: + for trace in traces: + attempt_to_make_names_unique(trace) + + # To pandas dataframe + trace_dfs = list( + map(lambda t: pd.DataFrame([entry for entry, _ in t]).fillna(0), + traces)) + + # Respect top_k + if top_k: + trace_dfs = list( + map( + lambda trace_df: keep_only_top_entries( + trace_df, "cuda_time_us", top_k), trace_dfs)) + + # Fill in information about the step-keys + for trace_df, step_key in zip(trace_dfs, step_keys): + trace_df['phase'] = step_key + + # Combine all data frames so they can be put in a single plot + traces_df = pd.concat(trace_dfs) + + # Add a derived metric `cuda_time_ms` + traces_df["cuda_time_ms"] = traces_df["cuda_time_us"] / 1000 + traces_df = traces_df.fillna(0) + + return traces_df + + def make_plot_title_suffix(profile_json: dict) -> str: + context = profile_json["context"] + sparsity = context.get('sparsity', None) + return (f"{context['model']}\n" + f"Batch={context['batch_size']}, " + f"PromptLen={context['prompt_len']}, " + f"OutputLen={context['output_len']}," + f"NumGpus={context['tensor_parallel_size']}" + f"{', Sparsity ' + sparsity if sparsity else ''}") + + profile_json = None + with open(json_trace, "r") as f: + profile_json = json.load(f) + assert profile_json is not None + + # Get all `llm.generate.step()` profile + step_traces = list(profile_json.keys()) + assert (step_traces[0] == 'context') + step_traces = step_traces[1:] # have only prefill and decodes + prefills = list(filter(lambda x: "prefill" in x, step_traces)) + all_decodes = list(filter(lambda x: "decode" in x, step_traces)) + assert len(prefills) + len(all_decodes) == len(step_traces) + assert len(prefills) == 1 + + decodes = all_decodes[::args.step_plot_interval] + if decodes[-1] != all_decodes[-1]: + # Always have the last decode + decodes.append(all_decodes[-1]) + + prefill_traces = prepare_data(profile_json, prefills) + decode_traces = prepare_data(profile_json, decodes) + + plot_title_suffix = make_plot_title_suffix(profile_json) + + plot_trace_df(prefill_traces, plot_metric, "prefill " + plot_title_suffix, + output_directory / Path("prefill.png")) + plot_trace_df(decode_traces, plot_metric, "decodes " + plot_title_suffix, + output_directory / Path("decode_steps.png")) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--json-trace", + type=str, + required=True, + help="json trace file output by examples/offline_profile.py") + parser.add_argument("--output-directory", + type=str, + required=False, + help="Directory to output plots") + parser.add_argument("--level", + type=str, + default="module", + choices=["module", "kernel"]) + parser.add_argument("--top-k", + type=int, + default=12, + help="Only graph the top `top_k` entries by time.") + parser.add_argument("--fold-json-node", + nargs='+', + default=['Sampler', 'LogitsProcessor'], + help='Do not plot the children of these nodes. Let, \ + the node represent the aggregate of all its \ + children') + parser.add_argument("--plot-metric", + type=str, + default="cuda_time_ms", + help='Metric to plot. some options are cuda_time_ms, \ + pct_cuda_time') + parser.add_argument( + "--step-plot-interval", + type=int, + default=4, + help="For every `step_plot_interval` steps, plot 1 step") + + args = parser.parse_args() + + # Prepare/Extract relevant args + make_names_unique = False + if args.level == "module": + depth = -2 + make_names_unique = True + elif args.level == "kernel": + depth = -1 + else: + raise Exception(f"Unexpected level value ({args.level})") + + output_directory = args.output_directory if args.output_directory else Path( + args.json_trace).parent + + main(Path(args.json_trace), output_directory, depth, + args.plot_metric, make_names_unique, args.top_k, args.fold_json_node) diff --git a/vllm/profiler/__init__.py b/vllm/profiler/__init__.py new file mode 100644 index 0000000000000..93ec4a800e600 --- /dev/null +++ b/vllm/profiler/__init__.py @@ -0,0 +1,5 @@ +from .nm_profile import nm_profile + +__all__ = [ + "nm_profile", +] diff --git a/vllm/profiler/nm_profile.py b/vllm/profiler/nm_profile.py new file mode 100644 index 0000000000000..30be0e5ba0c50 --- /dev/null +++ b/vllm/profiler/nm_profile.py @@ -0,0 +1,348 @@ +import pandas as pd +import copy + +from collections import defaultdict +from dataclasses import dataclass, field, asdict +from vllm.profiler.utils import (indent_string, TablePrinter, event_has_module, + event_is_torch_op, event_module_repr, + event_torch_op_stack_trace) +from typing import Dict, List, Union, Optional, Tuple, Callable, TypeAlias +from torch.profiler import profile, ProfilerActivity +from torch.autograd.profiler import FunctionEvent +from torch._C._autograd import _ProfilerResult, _KinetoEvent, DeviceType +from torch._C._profiler import _EventType, _ProfilerEvent, _ExperimentalConfig + + +@dataclass +class _ModuleTreeNode: + event: _ProfilerEvent + parent: Optional['_ModuleTreeNode'] = None + children: List['_ModuleTreeNode'] = field(default_factory=list) + trace: str = "" + + @property + def is_leaf(self): + return (self.event.children is None or len(self.event.children) == 0) + + @property + def is_torch_op(self): + return event_is_torch_op(self.event) + + @property + def is_cuda(self): + return (self.event.tag == _EventType.Kineto + and self.event.typed[1].device_type == DeviceType.CUDA) + + +@dataclass +class SummaryStatsEntry: + name: str + cuda_time_us: float + pct_cuda_time: float + invocations: int + + +@dataclass +class ModelStatsEntry: + name: str + cpu_time_us: float + cuda_time_us: float + pct_cuda_time: float + trace: str + + +StatsEntry: TypeAlias = Union[ModelStatsEntry, SummaryStatsEntry] + + +@dataclass +class _StatsTreeNode: + entry: StatsEntry + children: List[StatsEntry] + parent: Optional[StatsEntry] + + +@dataclass +class NMProfileResults(profile): + _kineto_results: _ProfilerResult + _kineto_event_correlation_map: Dict[int, + List[_KinetoEvent]] = field(init=False) + _event_correlation_map: Dict[int, List[FunctionEvent]] = field(init=False) + _module_tree: List[_ModuleTreeNode] = field(init=False) + _model_stats_tree: List[_StatsTreeNode] = field(init=False) + _summary_stats_tree: List[_StatsTreeNode] = field(init=False) + + def __post_init__(self): + self._build_correlation_map() + self._build_module_tree() + self._build_stats_trees() + + def print_model_table(self, column_widths: Dict[str, int] = None): + _column_widths = dict(name=60, + cpu_time_us=12, + cuda_time_us=12, + pct_cuda_time=12, + trace=60) + if column_widths: + _column_widths.update(**column_widths) + filtered_model_table = [ + (depth, row) + for depth, row in self._flatten_stats_tree(self._model_stats_tree) + if row.cuda_time_us > 0 or row.cpu_time_us > 0 + ] + TablePrinter(ModelStatsEntry, _column_widths).print_table( + self._indent_row_names_based_on_depth( + filtered_model_table, + indent_style=lambda indent: "|" + "-" * indent + " ")) + + def print_summary_table(self, column_widths: Dict[str, int] = None): + _column_widths = dict(name=80, + cuda_time_us=12, + pct_cuda_time=12, + invocations=15) + if column_widths: + _column_widths.update(**column_widths) + filtered_summary_table = [(depth, row) + for depth, row in self._flatten_stats_tree( + self._summary_stats_tree) + if row.cuda_time_us > 0] + TablePrinter(SummaryStatsEntry, _column_widths).print_table( + self._indent_row_names_based_on_depth( + filtered_summary_table, + indent_style=lambda indent: "|" + "-" * indent + " ")) + + def export_model_stats_table_csv(self, filename: str): + df = pd.DataFrame([ + asdict(row) + for _, row in self._flatten_stats_tree(self._model_stats_tree) + ]) + df.to_csv(filename) + + def export_summary_stats_table_csv(self, filename: str): + df = pd.DataFrame([ + asdict(row) + for _, row in self._flatten_stats_tree(self._summary_stats_tree) + ]) + df.to_csv(filename) + + def convert_stats_to_dict(self) -> str: + return { + "summary_stats": + self._convert_stats_tree_to_dict(self._summary_stats_tree), + "model_stats": + self._convert_stats_tree_to_dict(self._model_stats_tree) + } + + @staticmethod + def _indent_row_names_based_on_depth(depths_rows: List[Tuple[int, + StatsEntry]], + indent_style: Union[Callable[[int], + str], + str] = " "): + indented_rows = [] + for depth, row in depths_rows: + if row.cuda_time_us == 0: + continue + indented_row = copy.deepcopy(row) + indented_row.name = indent_string(indented_row.name, depth, + indent_style) + indented_rows.append(indented_row) + return indented_rows + + def _build_correlation_map(self): + self._kineto_event_correlation_map = defaultdict(list) + for event in self._kineto_results.events(): + self._kineto_event_correlation_map[event.correlation_id()].append( + event) + + def _build_module_tree(self): + self._module_tree = [] + event_tree = self._kineto_results.experimental_event_tree() + + def _df_traversal(event: _ProfilerEvent, + curr_node: Optional[_ModuleTreeNode] = None): + if event_has_module(event): + node = _ModuleTreeNode(event=event, parent=curr_node) + if curr_node: + curr_node.children.append(node) + else: + self._module_tree.append(node) + curr_node = node + + is_leaf = (event.children is None or len(event.children) == 0) + if is_leaf and curr_node: + node = _ModuleTreeNode( + event=event, + parent=curr_node, + trace=event_torch_op_stack_trace( + event, until=lambda x: event_has_module(x))) + curr_node.children.append(node) + curr_node = node + + for child in event.children: + _df_traversal(child, curr_node) + + for root in event_tree: + _df_traversal(root) + + def _get_kineto_gpu_event(self, node: _ModuleTreeNode): + if node.event.tag != _EventType.Kineto: + return None + correlated_kineto_events = self._kineto_event_correlation_map.get( + node.event.correlation_id, []) + iterator = (x for x in correlated_kineto_events + if x.device_type() == DeviceType.CUDA + and x.name() == node.event.name) + return next(iterator, None) + + def _cumulative_cuda_time(self, node: _ModuleTreeNode): + 'Return cuda time in microseconds' + + def _cumulative_cuda_time_recursive(node: _ModuleTreeNode): + if node.is_leaf and (gpu_kineto_event := + self._get_kineto_gpu_event(node)): + return gpu_kineto_event.duration_ns() / 1000.0 + else: + cumulative_cuda_time = 0 + for child in node.children: + cumulative_cuda_time += _cumulative_cuda_time_recursive( + child) + return cumulative_cuda_time + + return _cumulative_cuda_time_recursive(node) + + def _total_cuda_time(self): + return sum( + [self._cumulative_cuda_time(root) for root in self._module_tree]) + + def _build_stats_trees(self): + summary_dict: Dict[str, self.StatsTreeNode] = {} + total_cuda_time = self._total_cuda_time() + + def pct_cuda_time(cuda_time_us): + return (cuda_time_us / total_cuda_time) * 100 + + def build_summary_stats_tree_df( + node: _ModuleTreeNode, + parent: Optional[_StatsTreeNode] = None, + summary_trace: Tuple[str] = ()): + + if event_has_module(node.event): + name = event_module_repr(node.event) + cuda_time_us = self._cumulative_cuda_time(node) + elif (gpu_kineto_event := self._get_kineto_gpu_event(node)): + name = gpu_kineto_event.name() + cuda_time_us = gpu_kineto_event.duration_ns() / 1000.0 + else: + return None + + summary_trace = summary_trace + (name, ) + if summary_trace in summary_dict: + entry = summary_dict[summary_trace].entry + entry.cuda_time_us += cuda_time_us + entry.invocations += 1 + entry.pct_cuda_time = pct_cuda_time(entry.cuda_time_us) + else: + new_node = _StatsTreeNode(entry=SummaryStatsEntry( + name=name, + cuda_time_us=cuda_time_us, + pct_cuda_time=pct_cuda_time(cuda_time_us), + invocations=1), + children=[], + parent=parent) + if parent: + parent.children.append(new_node) + summary_dict[summary_trace] = new_node + + for child in node.children: + build_summary_stats_tree_df(child, summary_dict[summary_trace], + summary_trace) + + return summary_dict[summary_trace] + + self._summary_stats_tree = [] + for root in self._module_tree: + self._summary_stats_tree.append(build_summary_stats_tree_df(root)) + + def build_model_stats_tree_df(node: _ModuleTreeNode, + parent: Optional[_StatsTreeNode] = None): + if event_has_module(node.event, ): + name = event_module_repr(node.event) + cuda_time_us = self._cumulative_cuda_time(node) + cpu_time_us = node.event.duration_time_ns / 1000 + trace = "" + elif (gpu_kineto_event := self._get_kineto_gpu_event(node)): + name = gpu_kineto_event.name() + cuda_time_us = gpu_kineto_event.duration_ns() / 1000.0 + cpu_time_us = 0 + trace = node.trace + else: + return None + + new_node = _StatsTreeNode(entry=ModelStatsEntry( + name=name, + cpu_time_us=cpu_time_us, + cuda_time_us=cuda_time_us, + pct_cuda_time=pct_cuda_time(cuda_time_us), + trace=trace), + parent=parent, + children=[]) + if parent: + parent.children.append(new_node) + + for child in node.children: + build_model_stats_tree_df(child, new_node) + + return new_node + + self._model_stats_tree = [] + for root in self._module_tree: + self._model_stats_tree.append(build_model_stats_tree_df(root)) + + def _flatten_stats_tree( + self, tree: List[_StatsTreeNode]) -> List[Tuple[int, StatsEntry]]: + entries: List[Tuple[int, StatsEntry]] = [] + + def df_traversal(node: _StatsTreeNode, depth=0): + entries.append((depth, node.entry)) + for child in node.children: + df_traversal(child, depth=depth + 1) + + for root in tree: + df_traversal(root) + + return entries + + def _convert_stats_tree_to_dict(self, + tree: List[_StatsTreeNode]) -> List[Dict]: + root_dicts: List[Dict] = [] + + def df_traversal(node: _StatsTreeNode, curr_json_list: List[Dict]): + curr_json_list.append({ + "entry": asdict(node.entry), + "children": [] + }) + for child in node.children: + df_traversal(child, curr_json_list[-1]["children"]) + + for root in tree: + df_traversal(root, root_dicts) + + return root_dicts + + +class nm_profile(profile): + + def __init__(self): + super().__init__( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + record_shapes=True, + with_stack=True, + with_modules=True, + experimental_config=_ExperimentalConfig(verbose=True)) + + def __enter__(self): + return super().__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb): + super().__exit__(exc_type, exc_val, exc_tb) + self.results = NMProfileResults(self.profiler.kineto_results) diff --git a/vllm/profiler/utils.py b/vllm/profiler/utils.py new file mode 100644 index 0000000000000..f8ead593d178b --- /dev/null +++ b/vllm/profiler/utils.py @@ -0,0 +1,146 @@ +import dataclasses + +from typing import Callable, Dict, Type, List, Union + +from torch._C._profiler import _EventType, _ProfilerEvent, _TensorMetadata + +# +# String / Print Manipulation +# + + +def trim_string_front(string, width): + if len(string) > width: + offset = len(string) - width + 3 + string = string[offset:] + if len(string) > 3: + string = "..." + string[3:] + return string + + +def trim_string_back(string, width): + if len(string) > width: + offset = len(string) - width + 3 + string = string[:-offset] + if len(string) > 3: + string = string + "..." + return string + + +class TablePrinter: + + def __init__(self, row_cls: Type[dataclasses.dataclass], + column_widths: Dict[str, int]): + self.row_cls = row_cls + self.fieldnames = [x.name for x in dataclasses.fields(row_cls)] + self.column_widths = column_widths + assert set(self.column_widths.keys()) == set(self.fieldnames) + + def print_table(self, rows: List[dataclasses.dataclass]): + self._print_header() + self._print_line() + for row in rows: + self._print_row(row) + + def _print_header(self): + for i, f in enumerate(self.fieldnames): + last = (i == len(self.fieldnames) - 1) + col_width = self.column_widths[f] + print(trim_string_back(f, col_width).ljust(col_width), + end=" | " if not last else "\n") + + def _print_row(self, row): + assert isinstance(row, self.row_cls) + + for i, f in enumerate(self.fieldnames): + last = (i == len(self.fieldnames) - 1) + col_width = self.column_widths[f] + val = getattr(row, f) + + val_str = "" + if isinstance(val, str): + val_str = trim_string_back(val, col_width).ljust(col_width) + elif type(val) in [float, int]: + val_str = f"{float(val):>.2f}".rjust(col_width) + else: + val_str = f"{val}".rjust(col_width) + print(val_str, end=" | " if not last else "\n") + + def _print_line(self): + total_col_width = 0 + for column_width in self.column_widths.values(): + total_col_width += column_width + print("=" * (total_col_width + 3 * (len(self.column_widths) - 1))) + + +def indent_string(string: str, + indent: int, + indent_style: Union[Callable[[int], str], str] = " ") -> str: + if indent: + if isinstance(indent_style, str): + return indent_style * indent + string + else: + return indent_style(indent) + string + else: + return string + + +# +# _ProfilerEvent utils +# + + +def event_has_module(event: _ProfilerEvent) -> bool: + event_type, typed_event = event.typed + if event_type == _EventType.PyCall: + return typed_event.module is not None + return False + + +def event_is_torch_op(event: _ProfilerEvent) -> bool: + return event.tag == _EventType.TorchOp + + +def event_arg_repr(arg) -> str: + if arg is None or type(arg) in [float, int, bool, str]: + return f"{arg}" + elif isinstance(arg, list): + return f"[{', '.join([event_arg_repr(x) for x in arg])}]" + elif isinstance(arg, tuple): + return f"({', '.join([event_arg_repr(x) for x in arg])})" + else: + assert isinstance(arg, + _TensorMetadata), f"Unsupported type: {type(arg)}" + sizes_str = ', '.join([str(x) for x in arg.sizes]) + return f"{str(arg.dtype).replace('torch.', '')}[{sizes_str}]" + + +def event_torch_op_repr(event: _ProfilerEvent) -> str: + assert event.tag == _EventType.TorchOp + args_str = ', '.join([event_arg_repr(x) for x in event.typed[1].inputs]) + return f"{event.name}({args_str})".replace("aten::", "") + + +def event_module_repr(event: _ProfilerEvent) -> str: + assert event_has_module(event) + module = event.typed[1].module + if module.parameters and len(module.parameters) > 0: + args_str = ', '.join( + [f'{x[0]}={event_arg_repr(x[1])}' for x in module.parameters]) + return f"{module.cls_name}({args_str})" + else: + return module.cls_name + + +def event_torch_op_stack_trace(curr_event: _ProfilerEvent, + until: Callable[[_ProfilerEvent], bool]) -> str: + trace = "" + curr_event = curr_event.parent + while curr_event and not until(curr_event): + if event_is_torch_op(curr_event): + if len(trace) > 0: + trace += " <- " + trace += event_torch_op_repr(curr_event) + curr_event = curr_event.parent + + return trace diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index f9c26e0c318b1..d2a5fbe396b70 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1,3 +1,5 @@ +# This file has been modified by Neural Magic + import dataclasses import gc import time @@ -1402,10 +1404,13 @@ def execute_model( return [output] +# NOTE: this is nn.Module so the profiler can properly capture/group +# kernels calls made within the graph +class CUDAGraphRunner(nn.Module): -class CUDAGraphRunner: + def __init__(self, model: nn.Module, backend_name:str): + super().__init__() - def __init__(self, model: nn.Module, backend_name: str): self.model = model self.backend_name = backend_name @@ -1555,9 +1560,6 @@ def forward( return self.output_buffers - def __call__(self, *args, **kwargs): - return self.forward(*args, **kwargs) - def _get_graph_batch_size(batch_size: int) -> int: """Returns the padded batch size given actual batch size.