diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 4385f250856e7..398fdc5f0ae2b 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -184,6 +184,7 @@ steps: - python3 offline_inference_vision_language_multi_image.py - python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors - python3 offline_inference_encoder_decoder.py + - python3 offline_profile.py --model facebook/opt-125m - label: Prefix Caching Test # 9min #mirror_hardwares: [amd] diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.cu b/csrc/mamba/causal_conv1d/causal_conv1d.cu index 30831efdfa1a2..3a464c5f327ad 100644 --- a/csrc/mamba/causal_conv1d/causal_conv1d.cu +++ b/csrc/mamba/causal_conv1d/causal_conv1d.cu @@ -55,6 +55,7 @@ void set_conv_params_fwd(ConvParamsBase ¶ms, const at::Tensor out, const c10::optional& bias, bool silu_activation, + int64_t pad_slot_id, const c10::optional& query_start_loc = std::nullopt, const c10::optional& cache_indices = std::nullopt, const c10::optional& has_initial_state = std::nullopt) { @@ -66,6 +67,7 @@ void set_conv_params_fwd(ConvParamsBase ¶ms, params.dim = dim; params.seqlen = seqlen; params.width = width; + params.pad_slot_id = pad_slot_id; params.silu_activation = silu_activation; @@ -90,14 +92,16 @@ void set_conv_params_fwd(ConvParamsBase ¶ms, } -at::Tensor -causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight, +void causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight, const c10::optional &bias_, const c10::optional &conv_states, const c10::optional &query_start_loc, const c10::optional &cache_indices, const c10::optional &has_initial_state, - bool silu_activation) { + bool silu_activation, + // used to identify padding entries if cache_indices provided + // in case of padding, the kernel will return early + int64_t pad_slot_id) { auto input_type = x.scalar_type(); auto weight_type = weight.scalar_type(); TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); @@ -153,12 +157,13 @@ causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight, CHECK_SHAPE(cache_indices_, batch_size); } - at::Tensor out = torch::empty_like(x); + at::Tensor out = x; ConvParamsBase params; set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out, bias_, silu_activation, + pad_slot_id, query_start_loc, cache_indices, has_initial_state @@ -183,18 +188,19 @@ causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight, DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] { causal_conv1d_fwd_cuda(params, stream); }); - return out; } -at::Tensor -causal_conv1d_update(const at::Tensor &x, +void causal_conv1d_update(const at::Tensor &x, const at::Tensor &conv_state, const at::Tensor &weight, const c10::optional &bias_, bool silu_activation, const c10::optional &cache_seqlens_, - const c10::optional &conv_state_indices_) { + const c10::optional &conv_state_indices_, + // used to identify padding entries if cache_indices provided + // in case of padding, the kernel will return early + int64_t pad_slot_id) { auto input_type = x.scalar_type(); auto weight_type = weight.scalar_type(); TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); @@ -227,12 +233,13 @@ causal_conv1d_update(const at::Tensor &x, CHECK_SHAPE(bias, dim); } - at::Tensor out = torch::empty_like(x); + at::Tensor out = x; ConvParamsBase params; set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out, bias_, - silu_activation); + silu_activation, + pad_slot_id); params.conv_state_ptr = conv_state.data_ptr(); params.conv_state_len = conv_state_len; // All stride are in elements, not bytes. @@ -274,7 +281,6 @@ causal_conv1d_update(const at::Tensor &x, DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] { causal_conv1d_update_cuda(params, stream); }); - return out; } template @@ -340,7 +346,10 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) { int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr : reinterpret_cast(params.cache_indices_ptr); int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id]; - + // cache_index == params.pad_slot_id is defined as padding, so we exit early + if (cache_index == params.pad_slot_id){ + return; + } input_t *conv_states = params.conv_states_ptr == nullptr ? nullptr : reinterpret_cast(params.conv_states_ptr) + cache_index * params.conv_states_batch_stride + channel_id * params.conv_states_c_stride; @@ -528,6 +537,10 @@ void causal_conv1d_update_kernel(ConvParamsBase params) { const int conv_state_batch_coord = params.conv_state_indices_ptr == nullptr ? batch_id : params.conv_state_indices_ptr[batch_id]; + // conv_state_batch_coord == params.pad_slot_id is defined as padding so we exit early + if (conv_state_batch_coord == params.pad_slot_id){ + return; + } input_t *conv_state = reinterpret_cast(params.conv_state_ptr) + conv_state_batch_coord * params.conv_state_batch_stride + channel_id * params.conv_state_c_stride; diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.h b/csrc/mamba/causal_conv1d/causal_conv1d.h index 49e37ee4528be..e26684a2b98b8 100644 --- a/csrc/mamba/causal_conv1d/causal_conv1d.h +++ b/csrc/mamba/causal_conv1d/causal_conv1d.h @@ -13,6 +13,7 @@ struct ConvParamsBase { using index_t = uint32_t; int batch, dim, seqlen, width; + int64_t pad_slot_id; bool silu_activation; index_t x_batch_stride; diff --git a/csrc/mamba/mamba_ssm/selective_scan.h b/csrc/mamba/mamba_ssm/selective_scan.h index 580d0b2e17e74..563d2fe4ef65b 100644 --- a/csrc/mamba/mamba_ssm/selective_scan.h +++ b/csrc/mamba/mamba_ssm/selective_scan.h @@ -21,6 +21,7 @@ struct SSMParamsBase { int dim_ngroups_ratio; bool is_variable_B; bool is_variable_C; + int64_t pad_slot_id; bool delta_softplus; diff --git a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu index 6b225b41d295d..71624696338d0 100644 --- a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu +++ b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu @@ -115,6 +115,10 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { const int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr : reinterpret_cast(params.cache_indices_ptr); const int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id]; + // cache_index == params.pad_slot_id is defined as padding, so we exit early + if (cache_index == params.pad_slot_id){ + return; + } input_t *u = reinterpret_cast(params.u_ptr) + sequence_start_index * params.u_batch_stride + dim_id * kNRows * params.u_d_stride; input_t *delta = reinterpret_cast(params.delta_ptr) + sequence_start_index * params.delta_batch_stride @@ -387,7 +391,6 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms, const size_t seqlen, const size_t dstate, const size_t n_groups, - const size_t n_chunks, const bool is_variable_B, const bool is_variable_C, // device pointers @@ -407,7 +410,8 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms, const c10::optional& query_start_loc, const c10::optional& cache_indices, const c10::optional& has_initial_state, - bool varlen) { + bool varlen, + int64_t pad_slot_id) { // Reset the parameters memset(¶ms, 0, sizeof(params)); @@ -417,8 +421,8 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms, params.seqlen = seqlen; params.dstate = dstate; params.n_groups = n_groups; - params.n_chunks = n_chunks; params.dim_ngroups_ratio = dim / n_groups; + params.pad_slot_id = pad_slot_id; params.delta_softplus = delta_softplus; @@ -507,7 +511,10 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, const c10::optional &query_start_loc, const c10::optional &cache_indices, const c10::optional &has_initial_state, - const torch::Tensor &ssm_states) { + const torch::Tensor &ssm_states, + // used to identify padding entries if cache_indices provided + // in case of padding, the kernel will return early + int64_t pad_slot_id) { auto input_type = u.scalar_type(); auto weight_type = A.scalar_type(); TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); @@ -618,18 +625,14 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, out_z = z; - const int n_chunks = (seqlen + 2048 - 1) / 2048; - // const int n_chunks = (seqlen + 1024 - 1) / 1024; - // at::Tensor out = torch::empty_like(u); // Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout at::Tensor out = delta; TORCH_CHECK(ssm_states.scalar_type() == input_type); TORCH_CHECK(ssm_states.is_cuda()); TORCH_CHECK(ssm_states.stride(-1) == 1); - CHECK_SHAPE(ssm_states, batch_size, dim, dstate); SSMParamsBase params; - set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C, + set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, is_variable_B, is_variable_C, u, delta, A, B, C, out, z, out_z, D_, delta_bias_, @@ -639,7 +642,8 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, query_start_loc, cache_indices, has_initial_state, - varlen + varlen, + pad_slot_id ); diff --git a/csrc/ops.h b/csrc/ops.h index fce545f95a7cc..c10c34e085750 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -157,21 +157,23 @@ void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta, const c10::optional& query_start_loc, const c10::optional& cache_indices, const c10::optional& has_initial_state, - const torch::Tensor& ssm_states); - -at::Tensor causal_conv1d_update( - const at::Tensor& x, const at::Tensor& conv_state, const at::Tensor& weight, - const c10::optional& bias_, bool silu_activation, - const c10::optional& cache_seqlens_, - const c10::optional& conv_state_indices_); - -at::Tensor causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight, - const c10::optional& bias_, - const c10::optional& conv_states, - const c10::optional& query_start_loc, - const c10::optional& cache_indices, - const c10::optional& has_initial_state, - bool silu_activation); + const torch::Tensor& ssm_states, int64_t pad_slot_id); + +void causal_conv1d_update(const at::Tensor& x, const at::Tensor& conv_state, + const at::Tensor& weight, + const c10::optional& bias_, + bool silu_activation, + const c10::optional& cache_seqlens_, + const c10::optional& conv_state_indices_, + int64_t pad_slot_id); + +void causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight, + const c10::optional& bias_, + const c10::optional& conv_states, + const c10::optional& query_start_loc, + const c10::optional& cache_indices, + const c10::optional& has_initial_state, + bool silu_activation, int64_t pad_slot_id); #ifndef USE_ROCM using fptr_t = int64_t; diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index a0100b4a85edd..d69c4e5afb4a7 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -278,7 +278,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "Tensor? query_start_loc," "Tensor? cache_indices," "Tensor? has_initial_state," - "Tensor! ssm_states) -> ()"); + "Tensor! ssm_states," + "int pad_slot_id) -> ()"); ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd); ops.def( @@ -288,7 +289,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "Tensor? bias_," "bool silu_activation," "Tensor? cache_seqlens_," - "Tensor? conv_state_indices) -> Tensor"); + "Tensor? conv_state_indices," + "int pad_slot_id) -> ()"); ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update); ops.def( @@ -298,7 +300,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "Tensor? query_start_loc," "Tensor? cache_indices," "Tensor? has_initial_state," - "bool silu_activation) -> Tensor"); + "bool silu_activation," + "int pad_slot_id) -> ()"); ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd); #endif diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 7f1b2443824a2..b5fa83b437ac4 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -159,7 +159,7 @@ Text Generation - * - :code:`MiniCPMForCausalLM` - MiniCPM - - :code:`openbmb/MiniCPM-2B-sft-bf16`, :code:`openbmb/MiniCPM-2B-dpo-bf16`, etc. + - :code:`openbmb/MiniCPM-2B-sft-bf16`, :code:`openbmb/MiniCPM-2B-dpo-bf16`, :code:`openbmb/MiniCPM-S-1B-sft`, etc. - ✅︎ - ✅︎ * - :code:`MiniCPM3ForCausalLM` diff --git a/examples/offline_profile.py b/examples/offline_profile.py new file mode 100644 index 0000000000000..1d415b82cddb6 --- /dev/null +++ b/examples/offline_profile.py @@ -0,0 +1,282 @@ +import inspect +import json +import os +import sys +from argparse import RawTextHelpFormatter +from dataclasses import asdict, dataclass +from typing import Optional + +import torch + +from vllm import LLM, SamplingParams +from vllm.engine.arg_utils import EngineArgs +from vllm.profiler import layerwise_profile +from vllm.utils import FlexibleArgumentParser + +BATCH_SIZE_DEFAULT = 1 +PROMPT_LEN_DEFAULT = 256 +OUTPUT_LEN_DEFAULT = 2 + + +@dataclass +class ProfileContext: + engine_args: EngineArgs + prompt_len: int + output_len: int + batch_size: int + save_chrome_traces_folder: Optional[str] + + +def get_dtype(dtype: str): + if dtype == "torch.float": + return torch.float + else: + return dtype + + +def run_profile(context: ProfileContext, csv_output: Optional[str], + json_output: Optional[str]): + print("Run profile with:") + for key, value in asdict(context).items(): + print(f" {key} = {value}") + + # Create sampling params + sampling_params = SamplingParams(temperature=0.8, + top_p=0.95, + max_tokens=args.output_len, + ignore_eos=True) + + # Create LLM + llm = LLM(**asdict(context.engine_args)) + batch_size = context.batch_size + prompt_len = context.prompt_len + output_len = context.output_len + + scheduler_config = llm.llm_engine.scheduler_config + max_model_len = llm.llm_engine.model_config.max_model_len + max_num_batched_tokens = scheduler_config.max_num_batched_tokens + max_num_seqs = scheduler_config.max_num_seqs + + if batch_size * prompt_len > max_num_batched_tokens: + print(f"ERROR: chosen batch_size * prompt_len " + f"({batch_size} * {prompt_len} = {batch_size * prompt_len}) is " + f"larger than max_num_batched_tokens ({max_num_batched_tokens}) " + f"and therefore cannot be run in a single profile step, please " + f"choose a smaller batch size or prompt length, or increase " + f"--max-num-batched-tokens") + sys.exit(-1) + if batch_size >= max_num_seqs: + print( + f"ERROR: chosen batch_size ({batch_size}) is larger than " + f"max_num_seqs ({max_num_seqs}) and therefore cannot be run in a " + f"single profile step, please choose a smaller batch size") + sys.exit(-1) + print("llm.llm_engine.model_config.max_model_len: ", + llm.llm_engine.model_config.max_model_len) + if prompt_len + output_len > llm.llm_engine.model_config.max_model_len: + print( + f"ERROR: chosen prompt_len + output_len ({prompt_len} + " + f"{output_len} = {prompt_len + output_len}) is larger than the " + f"model's max_model_len ({max_model_len}), please choose a smaller " + f"prompt_len or output_len, or increase --max-model-len") + sys.exit(-1) + + def add_requests(): + for i in range(batch_size): + prompt_token_ids = torch.randint( + llm.llm_engine.model_config.get_vocab_size(), + size=(prompt_len, )).tolist() + + llm.llm_engine.add_request( + request_id=f"seq{i}", + prompt={'prompt_token_ids': prompt_token_ids}, + params=sampling_params) + + def abort_requests(): + for i in range(batch_size): + llm.llm_engine.abort_request(f"seq{i}") + + # Warm up run + print("Warm up run ...") + add_requests() + llm.llm_engine.step() # Prefill + llm.llm_engine.step() # Decode + abort_requests() + + print("Profile run ...") + add_requests() + + with layerwise_profile() as prefill_prof: + llm.llm_engine.step() # First step is prefill + + decode_profs = [] + for x in range(args.output_len - 1): + with layerwise_profile() as decode_prof: + llm.llm_engine.step() + decode_profs.append(decode_prof) + + decode_results_list = [prof.results for prof in decode_profs] + prefill_results = prefill_prof.results + has_decode = len(decode_results_list) > 0 + + LINE_WIDTH = 80 + print("=" * LINE_WIDTH) + print(f"= Prefill Model Table " + f"(prompt_len={prompt_len}, batch_size={batch_size})") + print("=" * LINE_WIDTH) + print() + prefill_results.print_model_table() + + if has_decode: + print() + print("=" * LINE_WIDTH) + print(f"= First Decode Step Model Table " + f"(prompt_len={prompt_len}, batch_size={batch_size})") + print("=" * LINE_WIDTH) + print() + decode_results_list[0].print_model_table() + + print() + print("=" * LINE_WIDTH) + print(f"= Prefill Summary Table " + f"(prompt_len={prompt_len}, batch_size={batch_size})") + print("=" * LINE_WIDTH) + print() + prefill_results.print_summary_table() + + if has_decode: + print() + print("=" * LINE_WIDTH) + print(f"= First Decode Step Summary Table " + f"(prompt_len={prompt_len}, batch_size={batch_size})") + print("=" * LINE_WIDTH) + print() + decode_results_list[0].print_summary_table() + + if csv_output: + csv_filename_base = csv_output.rstrip(".csv") + prefill_results.export_model_stats_table_csv( + csv_filename_base + "_prefill_model_table.csv") + prefill_results.export_summary_stats_table_csv( + csv_filename_base + "_prefill_summary_table.csv") + + if has_decode: + decode_results_list[0].export_model_stats_table_csv(\ + csv_filename_base + "_decode_model_table.csv") + decode_results_list[0].export_summary_stats_table_csv( + csv_filename_base + "_decode_summary_table.csv") + + if json_output: + cuda_devices = [ + torch.cuda.get_device_properties(dev_idx) + for dev_idx in range(torch.cuda.device_count()) + ] + + json_dict = { + "context": { + "python_version": f"{sys.version}", + "torch_version": f"{torch.__version__}", + "torch_cuda_version": f"{torch.version.cuda}", + "cuda_devices": f"{cuda_devices}", + **asdict(context) + }, + "prefill": prefill_results.convert_stats_to_dict(), + } + + if has_decode: + for idx, dr in enumerate(decode_results_list): + json_dict[f"decode_{idx + 1}"] = dr.convert_stats_to_dict() + + for idx, dr in enumerate(decode_results_list[1:]): + json_dict[f"decode_{idx + 1}"] = dr.convert_stats_to_dict() + + with open(json_output.rstrip(".json") + ".json", "w+") as f: + json.dump(json_dict, f, indent=2) + pass + + if context.save_chrome_traces_folder is not None: + os.makedirs(context.save_chrome_traces_folder, exist_ok=True) + prefill_prof.profiler.export_chrome_trace( + context.save_chrome_traces_folder + "/prefill.json") + for idx, decode_prof in enumerate(decode_profs): + decode_prof.profiler.export_chrome_trace( + context.save_chrome_traces_folder + f"/decode_{idx + 1}.json") + print("Traces saved as prefill.json and decode_1.json, etc." + f" in folder {context.save_chrome_traces_folder}") + + +if __name__ == "__main__": + parser = FlexibleArgumentParser(description=""" +Profile a model + + example: + ``` + python examples/offline_profile.py \\ + --model neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8 --batch-size 4 \\ + --prompt-len 512 --max-num-batched-tokens 8196 --json Llama31-8b-FP8 \\ + --enforce-eager + ``` + + then you can use various tools to analyze the json output + terminal ascii tables: + ``` + python tools/profiler/print_layerwise_table.py \\ + --json-trace Llama31-8b-FP8.json --phase prefill --table summary + ``` + or create matplotlib stacked bar charts: + ``` + python tools/profiler/visualize_layerwise_profile.py \\ + --json-trace Llama31-8b-FP8.json \\ + --output-directory profile_breakdown --plot-metric pct_cuda_time + ``` +""", + formatter_class=RawTextHelpFormatter) + parser.add_argument( + "--csv", + type=str, + default=None, + help="Export the results as multiple csv file. This should be the root " + "filename, will create _prefill_model_table.csv, " + "_prefill_summary_table.csv, " + "_decode_model_table.csv, and " + "_decode_summary_table.csv") + parser.add_argument( + "--json", + type=str, + default=None, + help="Export the results as a json file. This should be the filename") + parser.add_argument("--save-chrome-traces-folder", + type=str, + help="Save chrome traces for the prefill and decode " + "will save traces as prefill.json and decode_1.json, " + "etc. inside this folder") + parser.add_argument( + "--prompt-len", + type=int, + default=PROMPT_LEN_DEFAULT, + help=f"Length of the random prompt to use when profiling, all batched " + f"requests use the same prompt_len, default={PROMPT_LEN_DEFAULT}") + parser.add_argument("--batch-size", + type=int, + default=BATCH_SIZE_DEFAULT, + help=f"Number of requests to run as a single batch, " + f"default={BATCH_SIZE_DEFAULT}") + parser.add_argument( + "--output-len", + type=int, + default=OUTPUT_LEN_DEFAULT, + help="Number of llm steps to run (includes prefill and decode) " + "- default={OUTPUT_LEN_DEFAULT}") + + EngineArgs.add_cli_args(parser) + + args = parser.parse_args() + + context = ProfileContext( + engine_args=EngineArgs.from_cli_args(args), + **{ + k: v + for k, v in vars(args).items() + if k in inspect.signature(ProfileContext).parameters + }) + run_profile(context, csv_output=args.csv, json_output=args.json) diff --git a/tests/kernels/test_causal_conv1d.py b/tests/kernels/test_causal_conv1d.py index 069020a536d0e..277d7e4977d73 100644 --- a/tests/kernels/test_causal_conv1d.py +++ b/tests/kernels/test_causal_conv1d.py @@ -6,6 +6,7 @@ from tests.kernels.utils import opcheck from vllm import _custom_ops as ops # noqa: F401 +from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, causal_conv1d_update) from vllm.utils import seed_everything @@ -114,16 +115,15 @@ def causal_conv1d_update_ref(x, @pytest.mark.parametrize("itype", [torch.bfloat16, torch.float]) @pytest.mark.parametrize("silu_activation", [True]) @pytest.mark.parametrize("has_bias", [True]) -def causal_conv1d_opcheck_fn( - x: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor] = None, - cu_seq_len: Optional[torch.Tensor] = None, - cache_indices: Optional[torch.Tensor] = None, - has_initial_state: Optional[torch.Tensor] = None, - conv_states: Optional[torch.Tensor] = None, - activation: Optional[str] = "silu", -): +def causal_conv1d_opcheck_fn(x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + cu_seq_len: Optional[torch.Tensor] = None, + cache_indices: Optional[torch.Tensor] = None, + has_initial_state: Optional[torch.Tensor] = None, + conv_states: Optional[torch.Tensor] = None, + activation: Optional[str] = "silu", + pad_slot_id: int = PAD_SLOT_ID): """ x: (batch, dim, seqlen) weight: (dim, width) @@ -141,16 +141,9 @@ def causal_conv1d_opcheck_fn( x = x.contiguous() bias = bias.contiguous() if bias is not None else None - opcheck(torch.ops._C.causal_conv1d_fwd, ( - x, - weight, - bias, - conv_states, - cu_seq_len, - cache_indices, - has_initial_state, - activation in ["silu", "swish"], - )) + opcheck(torch.ops._C.causal_conv1d_fwd, + (x, weight, bias, conv_states, cu_seq_len, cache_indices, + has_initial_state, activation in ["silu", "swish"], pad_slot_id)) @pytest.mark.parametrize("itype", [torch.bfloat16, torch.float]) @@ -233,17 +226,11 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, seed_everything(0) batch = 2 x = torch.randn(batch, dim, seqlen, device=device, dtype=itype) + x_ref = x.clone() conv_state = torch.randn(batch, dim, width - 1, device=device, dtype=itype) - weight = torch.randn(dim, - width, - device=device, - dtype=itype, - requires_grad=True) - if has_bias: - bias = torch.randn(dim, device=device, dtype=itype, requires_grad=True) - else: - bias = None + weight = torch.randn(dim, width, device=device, dtype=itype) + bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None conv_state_ref = conv_state.detach().clone() activation = None if not silu_activation else "silu" out = causal_conv1d_update(x, @@ -251,7 +238,7 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, weight, bias, activation=activation) - out_ref = causal_conv1d_update_ref(x, + out_ref = causal_conv1d_update_ref(x_ref, conv_state_ref, weight, bias, @@ -260,15 +247,9 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, assert torch.equal(conv_state, conv_state_ref) assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) - opcheck(torch.ops._C.causal_conv1d_update, ( - x, - conv_state, - weight, - bias, - activation in ["silu", "swish"], - None, - None, - )) + opcheck(torch.ops._C.causal_conv1d_update, + (x, conv_state, weight, bias, activation + in ["silu", "swish"], None, None, PAD_SLOT_ID)) @pytest.mark.parametrize("itype", @@ -278,37 +259,48 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, @pytest.mark.parametrize("seqlen", [1, 4, 5]) @pytest.mark.parametrize("width", [2, 3, 4]) @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) -def test_causal_conv1d_update_with_batch_gather(dim, width, seqlen, has_bias, +# tests correctness in case subset of the sequences are padded +@pytest.mark.parametrize("with_padding", [True, False]) +def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width, + seqlen, has_bias, silu_activation, itype): device = "cuda" rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) if itype == torch.bfloat16: rtol, atol = 1e-2, 5e-2 - # set )seed + # set seed seed_everything(0) - batch = 64 - x = torch.randn(batch, dim, 1, device=device, dtype=itype) + batch_size = 3 + padding = 5 if with_padding else 0 + padded_batch_size = batch_size + padding + total_entries = 10 * batch_size - total_entries = 10 * batch + x = torch.randn(padded_batch_size, dim, 1, device=device, dtype=itype) + x_ref = x.clone() + + conv_state_indices = torch.randperm(total_entries)[:batch_size].to( + dtype=torch.int32, device=device) + unused_states_bool = torch.ones(total_entries, + dtype=torch.bool, + device=device) + unused_states_bool[conv_state_indices] = False + padded_state_indices = torch.concat([ + conv_state_indices, + torch.as_tensor( + [PAD_SLOT_ID] * padding, dtype=torch.int32, device=device) + ], + dim=0) conv_state = torch.randn(total_entries, dim, width - 1, device=device, dtype=itype) - conv_state_indices = torch.randperm(total_entries)[:batch].to( - dtype=torch.int32, device=device) + conv_state_for_padding_test = conv_state.clone() - weight = torch.randn(dim, - width, - device=device, - dtype=itype, - requires_grad=True) - if has_bias: - bias = torch.randn(dim, device=device, dtype=itype, requires_grad=True) - else: - bias = None + weight = torch.randn(dim, width, device=device, dtype=itype) + bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None conv_state_ref = conv_state[conv_state_indices, :].detach().clone() activation = None if not silu_activation else "silu" out = causal_conv1d_update(x, @@ -316,45 +308,50 @@ def test_causal_conv1d_update_with_batch_gather(dim, width, seqlen, has_bias, weight, bias, activation=activation, - conv_state_indices=conv_state_indices) - out_ref = causal_conv1d_update_ref(x, + conv_state_indices=padded_state_indices, + pad_slot_id=PAD_SLOT_ID) + out_ref = causal_conv1d_update_ref(x_ref[:batch_size], conv_state_ref, weight, bias, activation=activation) assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref) - assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol) + assert torch.equal(conv_state[unused_states_bool], + conv_state_for_padding_test[unused_states_bool]) - opcheck(torch.ops._C.causal_conv1d_update, ( - x, - conv_state, - weight, - bias, - activation in ["silu", "swish"], - None, - conv_state_indices, - )) + opcheck(torch.ops._C.causal_conv1d_update, + (x, conv_state, weight, bias, activation + in ["silu", "swish"], None, padded_state_indices, PAD_SLOT_ID)) @pytest.mark.parametrize("itype", [torch.bfloat16]) @pytest.mark.parametrize("silu_activation", [True]) @pytest.mark.parametrize("has_bias", [True]) @pytest.mark.parametrize("width", [4]) -@pytest.mark.parametrize('seqlen', - [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096]) +@pytest.mark.parametrize( + 'seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 2049, 4096]) @pytest.mark.parametrize('dim', [64, 4096]) -def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation, - itype): +# tests correctness in case subset of the sequences are padded +@pytest.mark.parametrize('with_padding', [True, False]) +def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias, + silu_activation, itype): device = "cuda" + torch.cuda.empty_cache() rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) if itype == torch.bfloat16: rtol, atol = 1e-2, 5e-2 # set seed seed_everything(0) - batch = 1 seqlens = [] - nsplits = 3 + batch_size = 4 + if seqlen < 10: + batch_size = 1 + padding = 3 if with_padding else 0 + padded_batch_size = batch_size + padding + nsplits = padded_batch_size - 1 + eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values seqlens.append( torch.diff( @@ -364,10 +361,11 @@ def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation, assert sum(seqlens[-1]) == seqlen assert all(s > 0 for s in seqlens[-1]) + total_entries = batch_size * 10 cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32) cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum], dim=0) - x = torch.randn(batch, 4096 + dim + 64, seqlen, device=device, + x = torch.randn(1, 4096 + dim + 64, seqlen, device=device, dtype=itype)[:, 4096:4096 + dim, :] weight = torch.randn(dim, width, device=device, dtype=itype) bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None @@ -375,7 +373,7 @@ def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation, weight_ref = weight.clone() bias_ref = bias.clone() if bias is not None else None activation = None if not silu_activation else "silu" - final_states = torch.randn(nsplits + 1, + final_states = torch.randn(total_entries, dim, width - 1, device=x.device, @@ -385,18 +383,27 @@ def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation, 2, (cumsum.shape[0] - 1, ), dtype=torch.bool, device=x.device) - cache_indices = torch.randperm(cumsum.shape[0] - 1, + state_indices = torch.randperm(total_entries, dtype=torch.int32, - device=x.device) + device=x.device)[:batch_size] + padded_state_indices = torch.concat([ + state_indices, + torch.as_tensor( + [PAD_SLOT_ID] * padding, dtype=torch.int32, device=device), + ], + dim=-1) + out = causal_conv1d_fn(x.squeeze(0), weight, bias, cumsum.cuda(), - cache_indices, has_initial_states, final_states, - activation) + padded_state_indices, has_initial_states, + final_states, activation, PAD_SLOT_ID) out_ref = [] out_ref_b = [] splits = [torch.split(var, seqlens[0], dim=-1) for var in (x_ref)] for i in range(len(seqlens[0])): x_s = [v[i].unsqueeze(0) for v in splits][0] + if padded_state_indices[i] == PAD_SLOT_ID: + continue out_ref_b.append( causal_conv1d_ref( x_s, @@ -404,21 +411,17 @@ def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation, bias_ref, activation=activation, return_final_states=True, - final_states_out=final_states_ref[cache_indices[i]].unsqueeze( - 0), - initial_states=final_states_ref[cache_indices[i]].unsqueeze(0) - if has_initial_states[i] else None)) + final_states_out=final_states_ref[ + padded_state_indices[i]].unsqueeze(0), + initial_states=final_states_ref[padded_state_indices[i]]. + unsqueeze(0) if has_initial_states[i] else None)) out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=2)) - out_ref = torch.cat(out_ref, dim=0) - - print(f"Output max diff: {(out - out_ref).abs().max().item()}") - print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") - print("Output state max diff" - f":{(final_states - final_states_ref).abs().max()}") - print("Output state mean diff" - f":{(final_states - final_states_ref).abs().mean()}") - assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + out_ref_tensor = torch.cat(out_ref, dim=0) + + unpadded_out = out[:, :out_ref_tensor.shape[-1]] + assert torch.allclose(unpadded_out, out_ref_tensor, rtol=rtol, atol=atol) assert torch.allclose(final_states, final_states_ref, rtol=rtol, atol=atol) + causal_conv1d_opcheck_fn(x.squeeze(0), weight, bias, cumsum.cuda(), - cache_indices, has_initial_states, final_states, - activation) + padded_state_indices, has_initial_states, + final_states, activation) diff --git a/tests/kernels/test_mamba_ssm.py b/tests/kernels/test_mamba_ssm.py index 8fa55e75f6c11..e92d401368a7b 100644 --- a/tests/kernels/test_mamba_ssm.py +++ b/tests/kernels/test_mamba_ssm.py @@ -5,6 +5,7 @@ from tests.kernels.utils import opcheck from vllm import _custom_ops as ops # noqa: F401 +from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( selective_scan_fn, selective_state_update) from vllm.utils import seed_everything @@ -174,7 +175,8 @@ def selective_scan_opcheck_fn(u, cu_seq_len=None, cache_indices=None, has_initial_state=None, - ssm_states=None): + ssm_states=None, + pad_slot_id=PAD_SLOT_ID): """if return_last_state is True, returns (out, last_state) last_state has shape (batch, dim, dstate). """ @@ -203,7 +205,7 @@ def selective_scan_opcheck_fn(u, # a bogus error. opcheck(torch.ops._C.selective_scan_fwd, (u, delta, A, B, C, D, z, delta_bias, delta_softplus, cu_seq_len, - cache_indices, has_initial_state, ssm_states), + cache_indices, has_initial_state, ssm_states, pad_slot_id), test_utils=["test_schema", "test_faketensor"]) @@ -404,9 +406,12 @@ def test_selective_state_update(dim, dstate, has_z, itype): @pytest.mark.parametrize("varBC_groups", [1, 2]) @pytest.mark.parametrize("is_variable_C", [True]) @pytest.mark.parametrize("is_variable_B", [True]) -def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups, - has_D, has_z, has_delta_bias, delta_softplus, - return_last_state, seqlen, itype, wtype): +# tests correctness in case subset of the sequences are padded +@pytest.mark.parametrize("with_padding", [False, True]) +def test_selective_scan_varlen(with_padding, is_variable_B, is_variable_C, + varBC_groups, has_D, has_z, has_delta_bias, + delta_softplus, return_last_state, seqlen, + itype, wtype): if varBC_groups > 1 and (not is_variable_B or not is_variable_C): pytest.skip() # This config is not applicable device = 'cuda' @@ -420,18 +425,27 @@ def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups, # set seed torch.random.manual_seed(0) seqlens = [] - nsplits = 3 + batch_size = 4 if seqlen < 10: - nsplits = 0 + batch_size = 1 + padding = 3 if with_padding else 0 + padded_batch_size = batch_size + padding + + if with_padding and seqlen < padded_batch_size: + pytest.skip() + + nsplits = padded_batch_size - 1 eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values seqlens.append( torch.diff( torch.cat( [torch.tensor([-1]), eos_pos, torch.tensor([seqlen - 1])])).tolist()) + assert sum(seqlens[-1]) == seqlen assert all(s > 0 for s in seqlens[-1]) + total_entries = batch_size * 10 cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32) cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum], dim=0).cuda() @@ -462,22 +476,33 @@ def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups, delta_ref = delta.clone() out = None out_ref = None - prev_state_shape = (cumsum.shape[0] - 1, u.shape[0], int(A.shape[1])) + + prev_state_shape = (total_entries, u.shape[0], int(A.shape[1])) prev_state = torch.randn(prev_state_shape, device=u.device, dtype=itype, requires_grad=False) prev_state_ref = prev_state.clone() - cache_indices = torch.randperm(cumsum.shape[0] - 1, + state_indices = torch.randperm(total_entries, dtype=torch.int32, - device=u.device) + device=u.device)[:batch_size] + unused_states_bool = torch.ones(total_entries, + dtype=torch.bool, + device=device) + unused_states_bool[state_indices] = False + padded_state_indices = torch.concat([ + state_indices, + torch.as_tensor( + [PAD_SLOT_ID] * padding, dtype=torch.int32, device=device), + ], + dim=-1) has_initial_state = torch.randint(0, 2, (cumsum.shape[0] - 1, ), dtype=torch.bool, device=u.device) out = selective_scan_fn(u, prev_state, delta, A, B, C, D, z, delta_bias, - delta_softplus, cumsum, cache_indices, + delta_softplus, cumsum, padded_state_indices, has_initial_state) outs_ref = [] splits = [ @@ -486,6 +511,8 @@ def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups, ] for i in range(len(seqlens[0])): u_s, delta_s, B_s, C_s, z_s = [v[i].unsqueeze(0) for v in splits] + if padded_state_indices[i] == PAD_SLOT_ID: + continue out_ref_s, _ = selective_scan_ref( u_s, delta_s, @@ -497,21 +524,22 @@ def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups, delta_bias=delta_bias, delta_softplus=delta_softplus, return_last_state=return_last_state, - prev_state=prev_state_ref[cache_indices[i]].unsqueeze(0) + prev_state=prev_state_ref[padded_state_indices[i]].unsqueeze(0) if has_initial_state[i] else None, - final_state_out=prev_state_ref[cache_indices[i]].unsqueeze(0)) + final_state_out=prev_state_ref[padded_state_indices[i]].unsqueeze( + 0)) outs_ref.append(out_ref_s) - out_ref = torch.cat(outs_ref, dim=-1) if len(outs_ref) > 1 else outs_ref[0] + out_ref = torch.cat(outs_ref, dim=-1)[0] - print("Output diff max", (out - out_ref[0]).max()) - print("Output diff mean", (out - out_ref[0]).mean()) + unpadded_out = out[:, :out_ref[0].shape[-1]] + print("Output diff max", (unpadded_out - out_ref).max()) + print("Output diff mean", (unpadded_out - out_ref).mean()) print("Output state diff max", (prev_state - prev_state_ref).max()) print("Output state diff mean", (prev_state - prev_state_ref).mean()) assert torch.allclose(prev_state, prev_state_ref, rtol=rtol, atol=atol) - assert torch.allclose(out, out_ref[0], rtol=rtol, atol=atol) - + assert torch.allclose(unpadded_out, out_ref, rtol=rtol, atol=atol) selective_scan_opcheck_fn(u, delta, A, B, C, D, z, delta_bias, - delta_softplus, cumsum, cache_indices, + delta_softplus, cumsum, padded_state_indices, has_initial_state, prev_state) @@ -520,7 +548,10 @@ def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups, @pytest.mark.parametrize("has_z", [True]) @pytest.mark.parametrize("dstate", [16, 32, 64]) @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) -def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype): +# tests correctness in case subset of the sequences are padded +@pytest.mark.parametrize("with_padding", [True, False]) +def test_selective_state_update_with_batch_indices(with_padding, dim, dstate, + has_z, itype): device = "cuda" rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2) if itype == torch.bfloat16: @@ -530,21 +561,32 @@ def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype): # set seed torch.random.manual_seed(0) batch_size = 3 - + padding = 5 if with_padding else 0 + padded_batch_size = batch_size + padding total_entries = 10 * batch_size state = torch.randn(total_entries, dim, dstate, dtype=itype, device=device) state_indices = torch.randperm(total_entries)[:batch_size].to( dtype=torch.int32, device=device) - - x = torch.randn(batch_size, dim, device=device, dtype=itype) - dt = torch.randn(batch_size, dim, device=device, dtype=itype) + unused_states_bool = torch.ones(total_entries, + dtype=torch.bool, + device=device) + unused_states_bool[state_indices] = False + padded_state_indices = torch.concat([ + state_indices, + torch.as_tensor( + [PAD_SLOT_ID] * padding, dtype=torch.int32, device=device) + ], + dim=0) + x = torch.randn(padded_batch_size, dim, device=device, dtype=itype) + dt = torch.randn(padded_batch_size, dim, device=device, dtype=itype) dt_bias = torch.rand(dim, device=device) - 4.0 A = -torch.rand(dim, dstate, device=device) - 1.0 - B = torch.randn(batch_size, dstate, device=device) - C = torch.randn(batch_size, dstate, device=device) + B = torch.randn(padded_batch_size, dstate, device=device) + C = torch.randn(padded_batch_size, dstate, device=device) D = torch.randn(dim, device=device) z = torch.randn_like(x) if has_z else None - state_ref = state[state_indices, :].detach().clone() + state_ref = state[state_indices, :].clone() + state_before = state.clone() out = selective_state_update(state, x, dt, @@ -555,15 +597,16 @@ def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype): z=z, dt_bias=dt_bias, dt_softplus=True, - state_batch_indices=state_indices) + state_batch_indices=padded_state_indices, + pad_slot_id=PAD_SLOT_ID) out_ref = selective_state_update_ref(state_ref, - x, - dt, + x[:batch_size], + dt[:batch_size], A, - B, - C, + B[:batch_size], + C[:batch_size], D=D, - z=z, + z=z[:batch_size], dt_bias=dt_bias, dt_softplus=True) @@ -572,11 +615,21 @@ def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype): print("Output state diff max", (state[state_indices, :] - state_ref).max()) print("Output state diff mean", (state[state_indices, :] - state_ref).mean()) + # test padded entries stay the same + if with_padding: + assert torch.equal(state_before[unused_states_bool], + state[unused_states_bool]) + assert torch.equal(x[batch_size + 1:], x[batch_size + 1:]) + assert torch.equal(dt[batch_size + 1:], dt[batch_size + 1:]) + assert torch.equal(B[batch_size + 1:], B[batch_size + 1:]) + assert torch.equal(C[batch_size + 1:], C[batch_size + 1:]) + + # test "real" entries assert torch.allclose(state[state_indices, :], state_ref, rtol=rtol, atol=atol) - assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol) @pytest.mark.parametrize("itype", @@ -645,7 +698,8 @@ def test_selective_state_update_with_heads_with_batch_indices( z=z, dt_bias=dt_bias, dt_softplus=True, - state_batch_indices=state_indices) + state_batch_indices=state_indices, + pad_slot_id=PAD_SLOT_ID) out_ref = selective_state_update_ref(state_ref, x, dt, diff --git a/tests/models/decoder_only/language/test_jamba.py b/tests/models/decoder_only/language/test_jamba.py index 408d12cd5ff5c..384ec77e5455a 100644 --- a/tests/models/decoder_only/language/test_jamba.py +++ b/tests/models/decoder_only/language/test_jamba.py @@ -1,5 +1,6 @@ import pytest +from tests.utils import multi_gpu_test from vllm.sampling_params import SamplingParams from vllm.worker.model_runner import _get_graph_batch_size @@ -270,6 +271,30 @@ def test_state_cleanup( "could be related to finished_requests_ids") +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [64]) +def test_jamba_distributed_produces_identical_generation( + vllm_runner, model: str, dtype: str, max_tokens: int, + example_prompts) -> None: + + with vllm_runner(model, dtype=dtype, tensor_parallel_size=2) as vllm_model: + vllm_outputs_tp_2 = vllm_model.generate_greedy(example_prompts, + max_tokens) + + with vllm_runner(model, dtype=dtype, tensor_parallel_size=1) as vllm_model: + vllm_outputs_tp_1 = vllm_model.generate_greedy(example_prompts, + max_tokens) + + check_outputs_equal( + outputs_0_lst=vllm_outputs_tp_1, + outputs_1_lst=vllm_outputs_tp_2, + name_0="vllm_tp_1", + name_1="vllm_tp_2", + ) + + @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float"]) def test_model_print( diff --git a/tools/profiler/print_layerwise_table.py b/tools/profiler/print_layerwise_table.py new file mode 100644 index 0000000000000..bbd24b085e3a7 --- /dev/null +++ b/tools/profiler/print_layerwise_table.py @@ -0,0 +1,77 @@ +import argparse +import json +from typing import Dict + +from vllm.profiler.layerwise_profile import ModelStatsEntry, SummaryStatsEntry +from vllm.profiler.utils import TablePrinter, indent_string + + +def flatten_entries(entry_cls, profile_dict: Dict): + entries_and_depth = [] + + def get_entries(node, curr_depth=0): + entries_and_depth.append((entry_cls(**node["entry"]), curr_depth)) + + for child in node["children"]: + get_entries( + child, + curr_depth=curr_depth + 1, + ) + + for root in profile_dict: + get_entries(root) + + return entries_and_depth + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--json-trace", + type=str, + required=True, + help="json trace file output by " + "examples/offline_profile.py") + parser.add_argument("--phase", + type=str, + choices=["prefill", "decode_1"], + required=True, + help="The phase to print the table for.") + parser.add_argument("--table", + type=str, + choices=["summary", "model"], + default="summary", + help="Which table to print, the summary table or the " + "layerwise model table") + + args = parser.parse_args() + + with open(args.json_trace, "r") as f: + profile_data = json.load(f) + + if args.table == "summary": + entries_and_depths = flatten_entries( + SummaryStatsEntry, profile_data[args.phase]["summary_stats"]) + column_widths = dict(name=80, + cuda_time_us=12, + pct_cuda_time=12, + invocations=15) + elif args.table == "model": + entries_and_depths = flatten_entries( + ModelStatsEntry, profile_data[args.phase]["model_stats"]) + column_widths = dict(name=60, + cpu_time_us=12, + cuda_time_us=12, + pct_cuda_time=12, + trace=60) + + # indent entry names based on the depth + entries = [] + for entry, depth in entries_and_depths: + entry.name = indent_string( + entry.name, + indent=depth, + indent_style=lambda indent: "|" + "-" * indent + " ") + entries.append(entry) + + TablePrinter(type(entries[0]), column_widths).print_table(entries) diff --git a/tools/profiler/visualize_layerwise_profile.py b/tools/profiler/visualize_layerwise_profile.py new file mode 100644 index 0000000000000..65ee3ae108ae1 --- /dev/null +++ b/tools/profiler/visualize_layerwise_profile.py @@ -0,0 +1,522 @@ +import argparse +import copy +import json +import math +import os +from pathlib import Path +from typing import Any, List, Optional, Tuple + +import matplotlib.pyplot as plt +import pandas as pd + +## JSON parsing utils #### + + +def largest_dist_from_leaf(node: dict, depth: int = 0): + if len(node["children"]) == 0: + return depth + return max([ + largest_dist_from_leaf(child, depth=depth + 1) + for child in node["children"] + ]) + + +def get_entries_at_depth(depth: int, + entries_and_traces: List[Tuple[Any, Any]], + node: dict, + curr_depth: int = 0, + trace=()): + # assert that the query is at kernel or module level + assert depth == -1 or depth == -2 + + if curr_depth == 0 and largest_dist_from_leaf(node) <= (abs(depth) - 1): + # The tree is not tall enough! + entries_and_traces.append((node["entry"], trace)) + return + + if largest_dist_from_leaf(node) == (abs(depth) - 1): + entries_and_traces.append((node["entry"], trace)) + + trace = (node["entry"]["name"], ) + trace + for child in node["children"]: + get_entries_at_depth(depth, + entries_and_traces, + child, + curr_depth=curr_depth + 1, + trace=trace) + + +def fold_nodes(root: dict, nodes_to_fold: List[str]): + + stack: List[dict] = [root] + while len(stack) != 0: + node = stack.pop() + if node['entry']['name'] in nodes_to_fold: + node["children"] = [] + continue + for child in node["children"]: + stack.append(child) + return root + + +## Operation name cleanup utils #### + + +def trim_string_back(string: str, width: int) -> str: + if len(string) > width: + offset = len(string) - width + 3 + string = string[:-offset] + if len(string) > 3: + string = string + "..." + return string + + +def shorten_plot_legend_strings(legend, max_char_len: int): + for t in legend.get_texts(): + t.set_text( + trim_string_back(abbreviate_known_names(t.get_text()), + max_char_len)) + + +def abbreviate_known_names(name: str) -> str: + abbreviations = { + "MergedColumnParallelLinear": "MCPLinear", + "QKVParallelLinear": "QKVPLinear", + "RowParallelLinear": "RPLinear", + "weight=": "w=", + "bfloat16": "bf16", + "float16": "f16", + } + for key, value in abbreviations.items(): + name = name.replace(key, value) + return name + + +def attempt_to_make_names_unique(entries_and_traces): + names, non_unique_names = (set(), set()) + + def all_the_same(items) -> bool: + return all(i == items[0] for i in items) + + for entry, _ in entries_and_traces: + if entry["name"] in names: + non_unique_names.add(entry["name"]) + else: + names.add(entry["name"]) + + for name in non_unique_names: + entries_and_traces_with_name = [(entry, trace) + for entry, trace in entries_and_traces + if entry["name"] == name] + + zipped_traces = list( + zip(*[trace for _, trace in entries_and_traces_with_name])) + first_trace_difference = next( + (i for i, trace_eles in enumerate(zipped_traces) + if not all_the_same(trace_eles)), None) + + if first_trace_difference is None: + # can't create a unique name, leave them names as the + # are they will get aggregated by the pivot_table call + continue + + for entry, trace in entries_and_traces_with_name: + entry["name"] = " <- ".join((entry["name"], ) + + trace[:first_trace_difference + 1]) + + +## Operation grouping utils #### +''' + Group operations in the given dataframe by some high-level ops like, + - gemms + - attention + - rms_norm + etc. +''' + + +def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame: + + def is_rms_norm(op_name: str): + if "rms_norm_kernel" in op_name: + return True + + def is_attention_block(op_name: str): + if "flash_fwd" in op_name or \ + "reshape_and_cache_flash_kernel" in op_name: + return True + + def is_quant(op_name: str): + if "scaled_fp8_quant" in op_name or \ + "scaled_int8_quant" in op_name: + return True + + def is_gemm_op(op_name: str): + if is_quant(op_name): + return False + if "xmma_gemm" in op_name or \ + "gemv2T_kernel" in op_name or \ + "splitKreduce" in op_name or \ + "void cutlass::Kernel" in op_name or \ + "void cutlass::device_kernel" in op_name or \ + "s16816gemm" in op_name: + return True + + def is_elementwise_op(op_name: str): + return "elementwise_kernel" in op_name + + def is_mem_op(op_name: str): + return "memcpy" in op_name.lower() or \ + "memset" in op_name.lower() + + def is_vocab_embedding_op(op_name: str): + return "vocabparallelembed" in op_name.lower() + + # nccl ops + def is_nccl_op(op_name: str): + return "nccl" in op_name.lower() + + def is_nccl_all_reduce(op_name: str): + return is_nccl_op(op_name) and \ + ("all_reduce" in op_name.lower() or \ + "allreduce" in op_name.lower()) + + def is_nccl_gather(op_name: str): + return is_nccl_op(op_name) and \ + "gather" in op_name.lower() + + def is_nccl_broadcast(op_name: str): + return is_nccl_op(op_name) and \ + "broadcast" in op_name.lower() + + # Reduce ops types + def is_cross_device_reduce_1stage(op_name: str): + return "cross_device_reduce_1stage" in op_name + + def is_cross_device_reduce_2stage(op_name: str): + return "cross_device_reduce_2stage" in op_name + + def is_custom_ar_all_reduce_unreg(op_name: str): + return "_C_custom_ar::all_reduce_unreg" in op_name + + def is_reduce_kernel(op_name: str): + return "reduce_kernel" in op_name + + headers = list(trace_df) + ops = copy.deepcopy(headers) + + attention_ops = list(filter(lambda x: is_attention_block(x), ops)) + ops = list(filter(lambda x: x not in attention_ops, ops)) + + quant_ops = list(filter(lambda x: is_quant(x), ops)) + ops = list(filter(lambda x: x not in quant_ops, ops)) + + gemm_ops = list(filter(lambda x: is_gemm_op(x), ops)) + ops = list(filter(lambda x: x not in gemm_ops, ops)) + + rms_norm_ops = list(filter(lambda x: is_rms_norm(x), ops)) + ops = list(filter(lambda x: x not in rms_norm_ops, ops)) + + vocab_embed_ops = list(filter(lambda x: is_vocab_embedding_op(x), ops)) + ops = list(filter(lambda x: x not in vocab_embed_ops, ops)) + + mem_ops = list(filter(lambda x: is_mem_op(x), ops)) + ops = list(filter(lambda x: x not in mem_ops, ops)) + + elementwise_ops = list(filter(lambda x: is_elementwise_op(x), ops)) + ops = list(filter(lambda x: x not in elementwise_ops, ops)) + + nccl_all_reduce_ops = list(filter(lambda x: is_nccl_all_reduce(x), ops)) + ops = list(filter(lambda x: x not in nccl_all_reduce_ops, ops)) + + nccl_gather_ops = list(filter(lambda x: is_nccl_gather(x), ops)) + ops = list(filter(lambda x: x not in nccl_gather_ops, ops)) + + nccl_broadcast_ops = list(filter(lambda x: is_nccl_broadcast(x), ops)) + ops = list(filter(lambda x: x not in nccl_broadcast_ops, ops)) + + nccl_other_ops = list(filter(lambda x: is_nccl_op(x), ops)) + ops = list(filter(lambda x: x not in nccl_other_ops, ops)) + + cross_device_reduce_1stage_ops = list( + filter(lambda x: is_cross_device_reduce_1stage(x), ops)) + ops = list(filter(lambda x: x not in cross_device_reduce_1stage_ops, ops)) + + cross_device_reduce_2stage_ops = list( + filter(lambda x: is_cross_device_reduce_2stage(x), ops)) + ops = list(filter(lambda x: x not in cross_device_reduce_2stage_ops, ops)) + + custom_ar_all_reduce_unreg_ops = list( + filter(lambda x: is_custom_ar_all_reduce_unreg(x), ops)) + ops = list(filter(lambda x: x not in custom_ar_all_reduce_unreg_ops, ops)) + + reduce_kernel_ops = list(filter(lambda x: is_reduce_kernel(x), ops)) + ops = list(filter(lambda x: x not in reduce_kernel_ops, ops)) + + if len(attention_ops): + trace_df['attention'] = trace_df[attention_ops].agg("sum", axis=1) + if len(quant_ops): + trace_df['quant_ops'] = trace_df[quant_ops].agg("sum", axis=1) + if len(gemm_ops): + trace_df['gemm_ops'] = trace_df[gemm_ops].agg("sum", axis=1) + if len(rms_norm_ops): + trace_df['rms_norm_ops'] = trace_df[rms_norm_ops].agg("sum", axis=1) + if len(vocab_embed_ops): + trace_df['vocab_embed_ops'] = trace_df[vocab_embed_ops].agg("sum", + axis=1) + if len(mem_ops): + trace_df['mem_ops'] = trace_df[mem_ops].agg("sum", axis=1) + if len(elementwise_ops): + trace_df['elementwise_ops'] = trace_df[elementwise_ops].agg("sum", + axis=1) + + if len(nccl_all_reduce_ops): + trace_df['nccl_all_reduce_ops'] = trace_df[nccl_all_reduce_ops].agg( + "sum", axis=1) + if len(nccl_gather_ops): + trace_df['nccl_gather_ops'] = trace_df[nccl_gather_ops].agg("sum", + axis=1) + if len(nccl_broadcast_ops): + trace_df['nccl_broadcast_ops'] = trace_df[nccl_broadcast_ops].agg( + "sum", axis=1) + if len(nccl_other_ops): + trace_df['nccl_other_ops'] = trace_df[nccl_other_ops].agg("sum", + axis=1) + + if len(cross_device_reduce_1stage_ops): + trace_df['cross_device_reduce_1stage_ops'] = trace_df[ + cross_device_reduce_1stage_ops].agg("sum", axis=1) + if len(cross_device_reduce_2stage_ops): + trace_df['cross_device_reduce_2stage_ops'] = trace_df[ + cross_device_reduce_2stage_ops].agg("sum", axis=1) + if len(custom_ar_all_reduce_unreg_ops): + trace_df['custom_ar_all_reduce_unreg_ops'] = trace_df[ + custom_ar_all_reduce_unreg_ops].agg("sum", axis=1) + if len(reduce_kernel_ops): + trace_df['reduce_kernel_ops'] = trace_df[reduce_kernel_ops].agg("sum", + axis=1) + + trace_df.drop( + attention_ops + quant_ops + gemm_ops + rms_norm_ops + vocab_embed_ops + + mem_ops + elementwise_ops + nccl_all_reduce_ops + nccl_gather_ops + + nccl_broadcast_ops + nccl_other_ops + cross_device_reduce_1stage_ops + + cross_device_reduce_2stage_ops + custom_ar_all_reduce_unreg_ops + + reduce_kernel_ops, + axis=1, + inplace=True) + return trace_df + + +## Data plotting utils #### + + +def plot_trace_df(traces_df: pd.DataFrame, + plot_metric: str, + plot_title: str, + output: Optional[Path] = None): + + phases = traces_df['phase'].unique() + traces_df = traces_df.pivot_table(index="phase", + columns="name", + values=plot_metric, + aggfunc="sum") + + traces_df = group_trace_by_operations(traces_df) + + # Make the figure + fig, ax = plt.subplots(1, figsize=(5, 8), sharex=True) + + # Draw the stacked bars + ops = list(traces_df) + bottom = [0] * len(phases) + for op in ops: + values = [traces_df[op][phase] for phase in phases] + values = list(map(lambda x: 0.0 if math.isnan(x) else x, values)) + ax.bar(phases, values, label=op, bottom=bottom) + bottom = [bottom[j] + values[j] for j in range(len(phases))] + + # Write the values as text on the bars + for bar in ax.patches: + if bar.get_height() != 0: + ax.text(bar.get_x() + bar.get_width() / 2, + bar.get_height() / 2 + bar.get_y(), + f"{round(bar.get_height(), 2)}", + ha='center', + color='w', + weight='bold', + size=5) + + # Setup legend + handles, labels = plt.gca().get_legend_handles_labels() + legend = fig.legend(handles, + labels, + loc='center left', + bbox_to_anchor=(1, 1)) + shorten_plot_legend_strings(legend, 50) + + # Setup labels and title + plt.setp(ax.get_xticklabels(), rotation=90) + ax.set_ylabel(plot_metric) + plt.suptitle(plot_title) + + plt.savefig(output, bbox_inches='tight') + print("Created: ", output) + + +def main( + json_trace: Path, + output_directory: Path, + depth: int, # Fetch/Plot operations at this depth of the Json tree + plot_metric: str, + make_names_unique: bool, + top_k: int, + json_nodes_to_fold: List[str]): + + def prepare_data(profile_json: dict, step_keys: List[str]) -> pd.DataFrame: + + def get_entries_and_traces(key: str): + entries_and_traces: List[Tuple[Any, Any]] = [] + for root in profile_json[key]["summary_stats"]: + # Fold nodes in the traces as per user request. i.e. simply + # make the requested nodes leaf-nodes. + root = fold_nodes(root, json_nodes_to_fold) + get_entries_at_depth(depth, entries_and_traces, root) + return entries_and_traces + + def keep_only_top_entries(df: pd.DataFrame, + metric: str, + top_k: int = 9) -> pd.DataFrame: + df.loc[df.nsmallest(len(df) - top_k + 1, metric).index, + ["name"]] = "others" + return df + + # Get data for each key + traces = list(map(lambda x: get_entries_and_traces(x), step_keys)) + + # Attempt some cleanup + if make_names_unique: + for trace in traces: + attempt_to_make_names_unique(trace) + + # To pandas dataframe + trace_dfs = list( + map(lambda t: pd.DataFrame([entry for entry, _ in t]).fillna(0), + traces)) + + # Respect top_k + if top_k: + trace_dfs = list( + map( + lambda trace_df: keep_only_top_entries( + trace_df, "cuda_time_us", top_k), trace_dfs)) + + # Fill in information about the step-keys + for trace_df, step_key in zip(trace_dfs, step_keys): + trace_df['phase'] = step_key + + # Combine all data frames so they can be put in a single plot + traces_df = pd.concat(trace_dfs) + + # Add a derived metric `cuda_time_ms` + traces_df["cuda_time_ms"] = traces_df["cuda_time_us"] / 1000 + traces_df = traces_df.fillna(0) + + return traces_df + + def make_plot_title_suffix(profile_json: dict) -> str: + context = profile_json["context"] + sparsity = context.get('sparsity', None) + return (f"{context['model']}\n" + f"Batch={context['batch_size']}, " + f"PromptLen={context['prompt_len']}, " + f"OutputLen={context['output_len']}," + f"NumGpus={context['tensor_parallel_size']}" + f"{', Sparsity ' + sparsity if sparsity else ''}") + + profile_json = None + with open(json_trace, "r") as f: + profile_json = json.load(f) + assert profile_json is not None + + # Get all `llm.generate.step()` profile + step_traces = list(profile_json.keys()) + assert (step_traces[0] == 'context') + step_traces = step_traces[1:] # have only prefill and decodes + prefills = list(filter(lambda x: "prefill" in x, step_traces)) + all_decodes = list(filter(lambda x: "decode" in x, step_traces)) + assert len(prefills) + len(all_decodes) == len(step_traces) + assert len(prefills) == 1 + + decodes = all_decodes[::args.step_plot_interval] + if decodes[-1] != all_decodes[-1]: + # Always have the last decode + decodes.append(all_decodes[-1]) + + prefill_traces = prepare_data(profile_json, prefills) + decode_traces = prepare_data(profile_json, decodes) + + plot_title_suffix = make_plot_title_suffix(profile_json) + + plot_trace_df(prefill_traces, plot_metric, "prefill " + plot_title_suffix, + output_directory / Path("prefill.png")) + plot_trace_df(decode_traces, plot_metric, "decodes " + plot_title_suffix, + output_directory / Path("decode_steps.png")) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--json-trace", + type=str, + required=True, + help="json trace file output by examples/offline_profile.py") + parser.add_argument("--output-directory", + type=str, + required=False, + help="Directory to output plots") + parser.add_argument("--level", + type=str, + default="module", + choices=["module", "kernel"]) + parser.add_argument("--top-k", + type=int, + default=12, + help="Only graph the top `top_k` entries by time.") + parser.add_argument("--fold-json-node", + nargs='+', + default=['Sampler', 'LogitsProcessor'], + help='Do not plot the children of these nodes. Let, \ + the node represent the aggregate of all its \ + children') + parser.add_argument("--plot-metric", + type=str, + default="cuda_time_ms", + help='Metric to plot. some options are cuda_time_ms, \ + pct_cuda_time') + parser.add_argument( + "--step-plot-interval", + type=int, + default=4, + help="For every `step_plot_interval` steps, plot 1 step") + + args = parser.parse_args() + + # Prepare/Extract relevant args + make_names_unique = False + if args.level == "module": + depth = -2 + make_names_unique = True + elif args.level == "kernel": + depth = -1 + else: + raise Exception(f"Unexpected level value ({args.level})") + + output_directory = args.output_directory if args.output_directory else Path( + args.json_trace).parent + + if not os.path.exists(output_directory): + os.makedirs(output_directory) + + main(Path(args.json_trace), output_directory, depth, args.plot_metric, + make_names_unique, args.top_k, args.fold_json_node) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 3a23692285efe..ec035f137c3a6 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -464,16 +464,18 @@ def causal_conv1d_fwd_fake(x: torch.Tensor, weight: torch.Tensor, cu_seq_len: Optional[torch.Tensor], cache_indices: Optional[torch.Tensor], has_initial_state: Optional[torch.Tensor], - silu_activation: bool) -> torch.Tensor: - return torch.empty_like(x) + silu_activation: bool, pad_slot_id: int): + return None @register_fake("_C::causal_conv1d_update") - def causal_conv1d_update_fake( - x: torch.Tensor, conv_state: torch.Tensor, weight: torch.Tensor, - bias_: Optional[torch.Tensor], silu_activation: bool, - cache_seqlens: Optional[torch.Tensor], - conv_state_indices: Optional[torch.Tensor]) -> torch.Tensor: - return torch.empty_like(x) + def causal_conv1d_update_fake(x: torch.Tensor, conv_state: torch.Tensor, + weight: torch.Tensor, + bias_: Optional[torch.Tensor], + silu_activation: bool, + cache_seqlens: Optional[torch.Tensor], + conv_state_indices: Optional[torch.Tensor], + pad_slot_id: int) -> None: + return None @register_fake("_C::selective_scan_fwd") def selective_scan_fwd_fake(u: torch.Tensor, delta: torch.Tensor, @@ -485,7 +487,8 @@ def selective_scan_fwd_fake(u: torch.Tensor, delta: torch.Tensor, cu_seq_len: Optional[torch.Tensor], cache_indices: Optional[torch.Tensor], has_initial_state: Optional[torch.Tensor], - ssm_states: Optional[torch.Tensor]) -> None: + ssm_states: Optional[torch.Tensor], + pad_slot_id: int) -> None: return None @@ -800,33 +803,37 @@ def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor, query_start_loc: Optional[torch.Tensor], cache_indices: Optional[torch.Tensor], has_initial_state: Optional[torch.Tensor], - silu_activation: bool) -> torch.Tensor: - return torch.ops._C.causal_conv1d_fwd(x, weight, bias_, conv_states, - query_start_loc, cache_indices, - has_initial_state, silu_activation) - - -def causal_conv1d_update( - x: torch.Tensor, conv_state: torch.Tensor, weight: torch.Tensor, - bias_: Optional[torch.Tensor], silu_activation: bool, - cache_seqlens: Optional[torch.Tensor], - conv_state_indices: Optional[torch.Tensor]) -> torch.Tensor: - return torch.ops._C.causal_conv1d_update(x, conv_state, weight, bias_, - silu_activation, cache_seqlens, - conv_state_indices) - - -def selective_scan_fwd( - u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, B: torch.Tensor, - C: torch.Tensor, D_: Optional[torch.Tensor], - z_: Optional[torch.Tensor], delta_bias_: Optional[torch.Tensor], - delta_softplus: bool, query_start_loc: Optional[torch.Tensor], - cache_indices: Optional[torch.Tensor], - has_initial_state: Optional[torch.Tensor], ssm_states: torch.Tensor): + silu_activation: bool, pad_slot_id: int): + torch.ops._C.causal_conv1d_fwd(x, weight, bias_, conv_states, + query_start_loc, cache_indices, + has_initial_state, silu_activation, + pad_slot_id) + + +def causal_conv1d_update(x: torch.Tensor, conv_state: torch.Tensor, + weight: torch.Tensor, bias_: Optional[torch.Tensor], + silu_activation: bool, + cache_seqlens: Optional[torch.Tensor], + conv_state_indices: Optional[torch.Tensor], + pad_slot_id: int): + torch.ops._C.causal_conv1d_update(x, conv_state, weight, bias_, + silu_activation, cache_seqlens, + conv_state_indices, pad_slot_id) + + +def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, + B: torch.Tensor, C: torch.Tensor, + D_: Optional[torch.Tensor], z_: Optional[torch.Tensor], + delta_bias_: Optional[torch.Tensor], + delta_softplus: bool, + query_start_loc: Optional[torch.Tensor], + cache_indices: Optional[torch.Tensor], + has_initial_state: Optional[torch.Tensor], + ssm_states: torch.Tensor, pad_slot_id: int): torch.ops._C.selective_scan_fwd(u, delta, A, B, C, D_, z_, delta_bias_, delta_softplus, query_start_loc, cache_indices, has_initial_state, - ssm_states) + ssm_states, pad_slot_id) # moe diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 43056786d35c9..f2ea53cad9f2a 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -13,6 +13,33 @@ from vllm.model_executor.utils import set_weight_attrs +class FatreluAndMul(CustomOp): + """An activation function for FATReLU. + + The function computes x -> FATReLU(x[:d]) * x[d:] where + d = x.shape[-1] // 2. + This is used in openbmb/MiniCPM-S-1B-sft. + + Shapes: + x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d) + return: (num_tokens, d) or (batch_size, seq_len, d) + """ + + def __init__(self, threshold: float = 0.): + super().__init__() + self.threshold = threshold + + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] // 2 + x1 = x[..., :d] + x2 = x[..., d:] + x1 = F.threshold(x1, self.threshold, 0.0) + return x1 * x2 + + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + return self.forward_native(x) + + class SiluAndMul(CustomOp): """An activation function for SwiGLU. diff --git a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py index ed7241af6cd14..be5639df985fa 100644 --- a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +++ b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py @@ -6,18 +6,18 @@ import torch from vllm import _custom_ops as ops +from vllm.attention.backends.utils import PAD_SLOT_ID -def causal_conv1d_fn( - x: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor] = None, - query_start_loc: Optional[torch.Tensor] = None, - cache_indices: Optional[torch.Tensor] = None, - has_initial_state: Optional[torch.Tensor] = None, - conv_states: Optional[torch.Tensor] = None, - activation: Optional[str] = "silu", -): +def causal_conv1d_fn(x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + query_start_loc: Optional[torch.Tensor] = None, + cache_indices: Optional[torch.Tensor] = None, + has_initial_state: Optional[torch.Tensor] = None, + conv_states: Optional[torch.Tensor] = None, + activation: Optional[str] = "silu", + pad_slot_id: int = PAD_SLOT_ID): """ x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen sequences are concatenated from left to right for varlen @@ -37,6 +37,13 @@ def causal_conv1d_fn( conv_states: (...,dim,width - 1) itype updated inplace if provided activation: either None or "silu" or "swish" + pad_slot_id: int + if cache_indices is passed, lets the kernel identify padded + entries that will not be processed, + for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] + in this case, the kernel will not process entries at + indices 0 and 3 + out: (batch, dim, seqlen) """ @@ -46,10 +53,10 @@ def causal_conv1d_fn( x = x.contiguous() bias = bias.contiguous() if bias is not None else None - out = ops.causal_conv1d_fwd(x, weight, bias, conv_states, query_start_loc, - cache_indices, has_initial_state, activation - in ["silu", "swish"]) - return out + ops.causal_conv1d_fwd(x, weight, bias, conv_states, query_start_loc, + cache_indices, has_initial_state, activation + in ["silu", "swish"], pad_slot_id) + return x def causal_conv1d_update(x: torch.Tensor, @@ -58,7 +65,8 @@ def causal_conv1d_update(x: torch.Tensor, bias: Optional[torch.Tensor] = None, activation: Optional[str] = None, cache_seqlens: Optional[torch.Tensor] = None, - conv_state_indices: Optional[torch.Tensor] = None): + conv_state_indices: Optional[torch.Tensor] = None, + pad_slot_id: int = PAD_SLOT_ID): """ x: (batch, dim) or (batch, dim, seqlen) conv_state: (batch, dim, state_len), where state_len >= width - 1 @@ -73,7 +81,12 @@ def causal_conv1d_update(x: torch.Tensor, If not None, the conv_state is a larger tensor along the batch dim, and we are selecting the batch coords specified by conv_state_indices. Useful for a continuous batching scenario. - + pad_slot_id: int + if cache_indices is passed, lets the kernel identify padded + entries that will not be processed, + for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] + in this case, the kernel will not process entries at + indices 0 and 3 out: (batch, dim) or (batch, dim, seqlen) """ if activation not in [None, "silu", "swish"]: @@ -82,8 +95,8 @@ def causal_conv1d_update(x: torch.Tensor, unsqueeze = x.dim() == 2 if unsqueeze: x = x.unsqueeze(-1) - out = ops.causal_conv1d_update(x, conv_state, weight, bias, activation_val, - cache_seqlens, conv_state_indices) + ops.causal_conv1d_update(x, conv_state, weight, bias, activation_val, + cache_seqlens, conv_state_indices, pad_slot_id) if unsqueeze: - out = out.squeeze(-1) - return out + x = x.squeeze(-1) + return x diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index 08b016c20c42d..1484b79815ab9 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -1,14 +1,13 @@ # Copyright (c) 2024, Tri Dao, Albert Gu. # Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/selective_state_update.py -from typing import Tuple - import torch import triton import triton.language as tl from packaging import version from vllm import _custom_ops as ops +from vllm.attention.backends.utils import PAD_SLOT_ID TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0") @@ -50,6 +49,7 @@ def _selective_scan_update_kernel( z_ptr, out_ptr, state_batch_indices_ptr, + pad_slot_id, # Matrix dimensions batch, nheads, @@ -143,10 +143,11 @@ def _selective_scan_update_kernel( if HAS_Z: z_ptrs = z_ptr + offs_m * stride_z_dim out_ptrs = out_ptr + offs_m * stride_out_dim + mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate) + if HAS_STATE_BATCH_INDICES: + mask &= (state_batch_idx != pad_slot_id) + state = tl.load(state_ptrs, mask=mask, other=0.0) - state = tl.load(state_ptrs, - mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), - other=0.0) x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) if not TIE_HDIM: dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) @@ -177,9 +178,11 @@ def _selective_scan_update_kernel( dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt state = state * dA + dB * x[:, None] - tl.store(state_ptrs, - state, - mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate)) + + mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate) + if HAS_STATE_BATCH_INDICES: + mask &= (state_batch_idx != pad_slot_id) + tl.store(state_ptrs, state, mask=mask) out = tl.sum(state * C[None, :], axis=1) if HAS_D: out += x * D @@ -198,7 +201,8 @@ def selective_state_update(state, z=None, dt_bias=None, dt_softplus=False, - state_batch_indices=None): + state_batch_indices=None, + pad_slot_id=PAD_SLOT_ID): """ Argument: state: (batch, dim, dstate) or (batch, nheads, dim, dstate) @@ -210,6 +214,12 @@ def selective_state_update(state, D: (dim,) or (nheads, dim) z: (batch, dim) or (batch, nheads, dim) dt_bias: (dim,) or (nheads, dim) + pad_slot_id: int + if cache_indices is passed, lets the kernel identify padded + entries that will not be processed, + for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] + in this case, the kernel will not process entries at + indices 0 and 3 Return: out: (batch, dim) or (batch, nheads, dim) """ @@ -276,6 +286,7 @@ def selective_state_update(state, z, out, state_batch_indices, + pad_slot_id, batch, nheads, dim, @@ -319,22 +330,25 @@ def selective_state_update(state, return out -def selective_scan_fn( - u, - ssm_states, - delta, - A, - B, - C, - D=None, - z=None, - delta_bias=None, - delta_softplus=False, - query_start_loc=None, - cache_indices=None, - has_initial_state=None) -> Tuple[torch.Tensor, torch.Tensor]: +def selective_scan_fn(u, + ssm_states, + delta, + A, + B, + C, + D=None, + z=None, + delta_bias=None, + delta_softplus=False, + query_start_loc=None, + cache_indices=None, + has_initial_state=None, + pad_slot_id=PAD_SLOT_ID) -> torch.Tensor: """ u: (dim, total_length) for varlen or (batch, dim, seqlen) + applies changes in place. + ssm_states: (batch, dim, dstate) or (batch, nheads, dim, dstate) + applies changes in place. delta: (dim, total_length) for varlen or (batch, dim, seqlen) A: (dim, dstate) B: (ngroups, dstate, total_length) for varlen or @@ -357,12 +371,14 @@ def selective_scan_fn( indicate if the ssm_state at the corresponding index should be used as initial state. Not providing argument assumes there's no initial state - + pad_slot_id: int + if cache_indices is passed, lets the kernel identify padding entries + that will not be processed, + for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] + in this case, the kernel will not process entries at indices 0 and 3 returns output: (dim, total_length) for varlen or (batch, dim, seqlen) supports inplace replacement - last_state has shape (batch, dim, dstate). - supports inplace replacement if ssm_state was provided """ if u.stride(-1) != 1: u = u.contiguous() @@ -387,7 +403,7 @@ def selective_scan_fn( ops.selective_scan_fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus, query_start_loc, cache_indices, has_initial_state, - ssm_states) + ssm_states, pad_slot_id) if z is None: return delta # output written inplace to delta diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index ac251b88e872c..fddd39fb8c85b 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -1,6 +1,5 @@ # coding=utf-8 """Inference-only Jamba model.""" -from dataclasses import dataclass from typing import Iterable, List, Optional, Tuple import torch @@ -29,7 +28,8 @@ DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( composed_weight_loader, default_weight_loader, sharded_weight_loader) -from vllm.model_executor.models.mamba_cache import MambaCacheManager +from vllm.model_executor.models.mamba_cache import (MambaCacheManager, + MambaCacheParams) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import IntermediateTensors @@ -41,13 +41,6 @@ KVCache = Tuple[torch.Tensor, torch.Tensor] -@dataclass -class MambaCacheParams: - is_prompt: bool = False - conv_state: torch.Tensor = torch.Tensor() - ssm_state: torch.Tensor = torch.Tensor() - - # Adapted from transformers.models.mamba.modeling_mamba.MambaMixer class JambaMambaMixer(nn.Module): """ @@ -60,10 +53,9 @@ class JambaMambaMixer(nn.Module): **selective** state spaces) """ - def __init__(self, config: JambaConfig, layer_idx): + def __init__(self, config: JambaConfig): super().__init__() self.config = config - self.layer_idx = layer_idx self.hidden_size = config.hidden_size self.ssm_state_size = config.mamba_d_state self.conv_kernel_size = config.mamba_d_conv @@ -129,8 +121,8 @@ def __init__(self, config: JambaConfig, layer_idx): eps=config.rms_norm_eps) def forward(self, hidden_states: torch.Tensor, - attn_metadata: AttentionMetadata, conv_state: torch.Tensor, - ssm_state: torch.Tensor): + attn_metadata: AttentionMetadata, + mamba_cache_params: MambaCacheParams): # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1) @@ -153,17 +145,18 @@ def forward(self, hidden_states: torch.Tensor, conv_weights, self.conv1d.bias, activation=self.activation, - conv_states=conv_state, + conv_states=mamba_cache_params.conv_state, has_initial_state=attn_metadata.context_lens_tensor > 0, + cache_indices=mamba_cache_params.state_indices_tensor, query_start_loc=attn_metadata.query_start_loc) else: hidden_states = causal_conv1d_update( hidden_states.transpose(0, 1), - conv_state, + mamba_cache_params.conv_state, conv_weights, self.conv1d.bias, self.activation, - ) + conv_state_indices=mamba_cache_params.state_indices_tensor) hidden_states = hidden_states.transpose(0, 1) # 3. State Space Model sequence transformation @@ -188,7 +181,7 @@ def forward(self, hidden_states: torch.Tensor, and attn_metadata.context_lens_tensor is not None: scan_outputs = selective_scan_fn( hidden_states, - ssm_state, + mamba_cache_params.ssm_state, discrete_time_step, self.A, B.transpose(-2, -1), @@ -197,11 +190,12 @@ def forward(self, hidden_states: torch.Tensor, gate, time_proj_bias, delta_softplus=True, + cache_indices=mamba_cache_params.state_indices_tensor, has_initial_state=attn_metadata.context_lens_tensor > 0, query_start_loc=attn_metadata.query_start_loc) else: scan_outputs = selective_state_update( - ssm_state, + mamba_cache_params.ssm_state, hidden_states.transpose(0, 1), discrete_time_step.transpose(0, 1), self.A, @@ -211,7 +205,7 @@ def forward(self, hidden_states: torch.Tensor, gate.transpose(0, 1), time_proj_bias, dt_softplus=True, - ) + state_batch_indices=mamba_cache_params.state_indices_tensor) scan_outputs = scan_outputs.transpose(0, 1) # 4. Final linear projection @@ -292,7 +286,7 @@ def __init__(self, super().__init__() self.layer_idx = layer_idx self.config = config - self.mamba = JambaMambaMixer(config, layer_idx) + self.mamba = JambaMambaMixer(config) num_experts = config.layers_num_experts[layer_idx] ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP @@ -307,8 +301,7 @@ def forward( hidden_states: torch.Tensor, attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], - conv_state: torch.Tensor, - ssm_state: torch.Tensor, + mamba_cache_params: MambaCacheParams, **kwargs, ): if residual is None: @@ -318,8 +311,8 @@ def forward( hidden_states, residual = self.input_layernorm( hidden_states, residual) - hidden_states = self.mamba(hidden_states, attn_metadata, conv_state, - ssm_state) + hidden_states = self.mamba(hidden_states, attn_metadata, + mamba_cache_params) # Fully Connected hidden_states, residual = self.pre_ff_layernorm( hidden_states, residual) @@ -476,17 +469,14 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - conv_state: torch.Tensor, - ssm_state: torch.Tensor, + mamba_cache_params: MambaCacheParams, ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) residual = None - for i in range(len(self.layers)): layer = self.layers[i] kv_cache = None - current_ssm_state = None - current_conv_state = None + layer_mamba_cache_params = None if isinstance(layer, JambaAttentionDecoderLayer): kv_cache = kv_caches[(i - self.config.attn_layer_offset) // self.config.attn_layer_period] @@ -494,8 +484,8 @@ def forward( current_state_layer = i - (1 + (i - self.config.attn_layer_offset) // self.config.attn_layer_period) - current_ssm_state = ssm_state[current_state_layer] - current_conv_state = conv_state[current_state_layer] + layer_mamba_cache_params = mamba_cache_params.at_layer_idx( + current_state_layer) hidden_states, residual = layer( positions=positions, @@ -503,9 +493,7 @@ def forward( kv_cache=kv_cache, attn_metadata=attn_metadata, residual=residual, - conv_state=current_conv_state, - ssm_state=current_ssm_state, - ) + mamba_cache_params=layer_mamba_cache_params) hidden_states, _ = self.final_layernorm(hidden_states, residual) return hidden_states @@ -588,13 +576,16 @@ def forward(self, self.mamba_cache = MambaCacheManager( self.lm_head.weight.dtype, num_mamba_layers, max_batch_size, *self._get_mamba_cache_shape()) - - mamba_cache_tensors = self.mamba_cache.current_run_tensors( - input_ids, attn_metadata, **kwargs) - + ( + mamba_cache_tensors, + state_indices_tensor, + ) = self.mamba_cache.current_run_tensors(input_ids, attn_metadata, + **kwargs) + mamba_cache_params = MambaCacheParams(mamba_cache_tensors[0], + mamba_cache_tensors[1], + state_indices_tensor) hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, mamba_cache_tensors[0], - mamba_cache_tensors[1]) + attn_metadata, mamba_cache_params) return hidden_states def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index b86b687a9c361..7f2efb9895f25 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -27,7 +27,8 @@ composed_weight_loader, default_weight_loader, sharded_weight_loader) from vllm.model_executor.models.interfaces import (HasInnerState, IsAttentionFree) -from vllm.model_executor.models.mamba_cache import MambaCacheManager +from vllm.model_executor.models.mamba_cache import (MambaCacheManager, + MambaCacheParams) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import IntermediateTensors @@ -110,8 +111,8 @@ def __init__(self, config: MambaConfig, layer_idx): self.activation = config.hidden_act def forward(self, hidden_states: torch.Tensor, - attn_metadata: AttentionMetadata, conv_state: torch.Tensor, - ssm_state: torch.Tensor): + attn_metadata: AttentionMetadata, + mamba_cache_params: MambaCacheParams): # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1) @@ -134,17 +135,18 @@ def forward(self, hidden_states: torch.Tensor, conv_weights, self.conv1d.bias, activation=self.activation, - conv_states=conv_state, + conv_states=mamba_cache_params.conv_state, has_initial_state=attn_metadata.context_lens_tensor > 0, + cache_indices=mamba_cache_params.state_indices_tensor, query_start_loc=attn_metadata.query_start_loc) else: hidden_states = causal_conv1d_update( hidden_states.transpose(0, 1), - conv_state, + mamba_cache_params.conv_state, conv_weights, self.conv1d.bias, self.activation, - ) + conv_state_indices=mamba_cache_params.state_indices_tensor) hidden_states = hidden_states.transpose(0, 1) # 3. State Space Model sequence transformation @@ -168,7 +170,7 @@ def forward(self, hidden_states: torch.Tensor, and attn_metadata.context_lens_tensor is not None: scan_outputs = selective_scan_fn( hidden_states, - ssm_state, + mamba_cache_params.ssm_state, discrete_time_step, self.A, B.transpose(-2, -1), @@ -177,11 +179,12 @@ def forward(self, hidden_states: torch.Tensor, gate, time_proj_bias, delta_softplus=True, + cache_indices=mamba_cache_params.state_indices_tensor, has_initial_state=attn_metadata.context_lens_tensor > 0, query_start_loc=attn_metadata.query_start_loc) else: scan_outputs = selective_state_update( - ssm_state, + mamba_cache_params.ssm_state, hidden_states.transpose(0, 1), discrete_time_step.transpose(0, 1), self.A, @@ -191,7 +194,7 @@ def forward(self, hidden_states: torch.Tensor, gate.transpose(0, 1), time_proj_bias, dt_softplus=True, - ) + state_batch_indices=mamba_cache_params.state_indices_tensor) scan_outputs = scan_outputs.transpose(0, 1) # 4. Final linear projection @@ -221,8 +224,7 @@ def forward( hidden_states: torch.Tensor, attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], - conv_state: torch.Tensor, - ssm_state: torch.Tensor, + mamba_cache_params: MambaCacheParams, **kwargs, ): if residual is None: @@ -231,8 +233,8 @@ def forward( else: hidden_states, residual = self.norm(hidden_states, residual) - hidden_states = self.mixer(hidden_states, attn_metadata, conv_state, - ssm_state) + hidden_states = self.mixer(hidden_states, attn_metadata, + mamba_cache_params) return hidden_states, residual @@ -275,25 +277,20 @@ def forward( input_ids: torch.Tensor, positions: torch.Tensor, attn_metadata: AttentionMetadata, - conv_state: torch.Tensor, - ssm_state: torch.Tensor, + mamba_cache_params: MambaCacheParams, ) -> torch.Tensor: + hidden_states = self.embeddings(input_ids) residual = None for i in range(len(self.layers)): layer = self.layers[i] - current_ssm_state = ssm_state[i] - current_conv_state = conv_state[i] - hidden_states, residual = layer( positions=positions, hidden_states=hidden_states, attn_metadata=attn_metadata, residual=residual, - conv_state=current_conv_state, - ssm_state=current_ssm_state, - ) + mamba_cache_params=mamba_cache_params.at_layer_idx(i)) hidden_states, _ = self.norm_f(hidden_states, residual) return hidden_states @@ -347,12 +344,18 @@ def forward(self, self.lm_head.weight.dtype, self.config.num_hidden_layers, max_batch_size, *self._get_mamba_cache_shape()) - mamba_cache_tensors = self.mamba_cache.current_run_tensors( - input_ids, attn_metadata, **kwargs) + ( + mamba_cache_tensors, + state_indices_tensor, + ) = self.mamba_cache.current_run_tensors(input_ids, attn_metadata, + **kwargs) + + mamba_cache_params = MambaCacheParams(mamba_cache_tensors[0], + mamba_cache_tensors[1], + state_indices_tensor) hidden_states = self.backbone(input_ids, positions, attn_metadata, - mamba_cache_tensors[0], - mamba_cache_tensors[1]) + mamba_cache_params) return hidden_states diff --git a/vllm/model_executor/models/mamba_cache.py b/vllm/model_executor/models/mamba_cache.py index 8d1ba3737d4a5..79393421f3ae9 100644 --- a/vllm/model_executor/models/mamba_cache.py +++ b/vllm/model_executor/models/mamba_cache.py @@ -1,8 +1,22 @@ -from typing import Dict, List, Optional +from dataclasses import dataclass +from typing import Dict, List import torch from vllm.attention.backends.abstract import AttentionMetadata +from vllm.attention.backends.utils import PAD_SLOT_ID + + +@dataclass +class MambaCacheParams: + conv_state: torch.Tensor = torch.Tensor() + ssm_state: torch.Tensor = torch.Tensor() + state_indices_tensor: torch.Tensor = torch.Tensor() + + def at_layer_idx(self, layer_idx): + return MambaCacheParams(self.conv_state[layer_idx], + self.ssm_state[layer_idx], + self.state_indices_tensor) class MambaCacheManager: @@ -24,6 +38,7 @@ def __init__(self, dtype, num_mamba_layers, max_batch_size, # Maps between the request id and a dict that maps between the seq_id # and its index inside the self.mamba_cache self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {} + self.free_cache_indices = list(range(max_batch_size)) def current_run_tensors(self, input_ids: torch.Tensor, attn_metadata: AttentionMetadata, **kwargs): @@ -36,30 +51,43 @@ def current_run_tensors(self, input_ids: torch.Tensor, finished_requests_ids = kwargs["finished_requests_ids"] self._release_finished_requests(finished_requests_ids) - mamba_cache_tensors = self._prepare_current_run_mamba_cache( + state_indices = self._prepare_current_run_mamba_cache( request_ids_to_seq_ids, finished_requests_ids) + state_indices_tensor = torch.as_tensor(state_indices, + dtype=torch.int32, + device="cuda") + mamba_cache_tensors = self.mamba_cache + else: # CUDA graph capturing runs - mamba_cache_tensors = kwargs["seqlen_agnostic_capture_inputs"] + (mamba_cache_tensors, + state_indices_tensor) = kwargs["seqlen_agnostic_capture_inputs"] - return mamba_cache_tensors + return (mamba_cache_tensors, state_indices_tensor) def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): """ - Copy the relevant Mamba cache into the CUDA graph input buffer - that was provided during the capture runs - (JambaForCausalLM.mamba_gc_cache_buffer). + Copy the relevant state_indices into the CUDA graph input buffer """ assert all( key in kwargs for key in ["request_ids_to_seq_ids", "finished_requests_ids"]) finished_requests_ids = kwargs["finished_requests_ids"] request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] + assert "seqlen_agnostic_capture_inputs" in input_buffers + _, input_state_indices_buffer = input_buffers[ + "seqlen_agnostic_capture_inputs"] self._release_finished_requests(finished_requests_ids) - self._prepare_current_run_mamba_cache(request_ids_to_seq_ids, - finished_requests_ids) + state_indices = self._prepare_current_run_mamba_cache( + request_ids_to_seq_ids, finished_requests_ids) + cuda_graph_pad_len = input_state_indices_buffer.shape[0] - len( + state_indices) + state_indices.extend([PAD_SLOT_ID] * cuda_graph_pad_len) + + input_state_indices_buffer.copy_( + torch.as_tensor(state_indices, dtype=torch.int32, device="cuda")) def get_seqlen_agnostic_capture_inputs(self, batch_size: int): """ @@ -67,13 +95,10 @@ def get_seqlen_agnostic_capture_inputs(self, batch_size: int): The buffer is used to maintain the Mamba Cache during the CUDA graph replay runs. """ - return tuple(buffer[:, :batch_size] for buffer in self.mamba_cache) - - def _swap_mamba_cache(self, from_index: int, to_index: int): - assert len(self.mamba_cache) > 0 - for cache_t in self.mamba_cache: - cache_t[:, [to_index,from_index]] = \ - cache_t[:, [from_index,to_index]] + state_indices_tensor = torch.as_tensor([PAD_SLOT_ID] * batch_size, + dtype=torch.int32, + device="cuda") + return (self.mamba_cache, state_indices_tensor) def _copy_mamba_cache(self, from_index: int, to_index: int): assert len(self.mamba_cache) > 0 @@ -81,142 +106,53 @@ def _copy_mamba_cache(self, from_index: int, to_index: int): cache_t[:, to_index].copy_(cache_t[:, from_index], non_blocking=True) - def _move_out_if_already_occupied(self, index: int, - all_occupied_indices: List[int]): - if index in all_occupied_indices: - first_free_index = self._first_free_index_in_mamba_cache() - # In case occupied, move the occupied to a new empty block - self._move_cache_index_and_mappings(from_index=index, - to_index=first_free_index) - - def _assign_seq_id_to_mamba_cache_in_specific_dest(self, cur_rid: str, - seq_id: int, - destination_index: int): + def _assign_seq_id_to_cache_index(self, cur_rid: str, seq_id: int, + finished_requests_ids) -> int: """ Assign (req_id,seq_id) pair to a `destination_index` index, if already occupied, move the occupying index to a free index. """ - all_occupied_indices = self._get_all_occupied_indices() - if cur_rid not in self.mamba_cache_indices_mapping: - self._move_out_if_already_occupied( - index=destination_index, - all_occupied_indices=all_occupied_indices) + if cur_rid in finished_requests_ids: + # set as pad, do not allocate destination index + return PAD_SLOT_ID + elif cur_rid not in self.mamba_cache_indices_mapping: + destination_index = self.free_cache_indices.pop() self.mamba_cache_indices_mapping[cur_rid] = { seq_id: destination_index } + return destination_index elif seq_id not in (seq_ids2indices := self.mamba_cache_indices_mapping[cur_rid]): # parallel sampling , where n > 1, assume prefill have - # already happened now we only need to copy the already + # already happened, so we copy the # existing cache into the siblings seq_ids caches - self._move_out_if_already_occupied( - index=destination_index, - all_occupied_indices=all_occupied_indices) - index_exists = list(seq_ids2indices.values())[0] + index_exists = next(iter(seq_ids2indices.values())) # case of decoding n>1, copy prefill cache to decoding indices + destination_index = self.free_cache_indices.pop() self._copy_mamba_cache(from_index=index_exists, to_index=destination_index) self.mamba_cache_indices_mapping[cur_rid][ seq_id] = destination_index + return destination_index else: # already exists - cache_index_already_exists = self.mamba_cache_indices_mapping[ - cur_rid][seq_id] - if cache_index_already_exists != destination_index: - # In case the seq id already exists but not in - # the right destination, swap it with what's occupying it - self._swap_pair_indices_and_mappings( - from_index=cache_index_already_exists, - to_index=destination_index) + return self.mamba_cache_indices_mapping[cur_rid][seq_id] def _prepare_current_run_mamba_cache( self, request_ids_to_seq_ids: Dict[str, list[int]], - finished_requests_ids: List[str]): - running_indices = [] - request_ids_to_seq_ids_flatten = [ - (req_id, seq_id) + finished_requests_ids: List[str]) -> List[int]: + return [ + self._assign_seq_id_to_cache_index(req_id, seq_id, + finished_requests_ids) for req_id, seq_ids in request_ids_to_seq_ids.items() for seq_id in seq_ids ] - batch_size = len(request_ids_to_seq_ids_flatten) - for dest_index, (request_id, - seq_id) in enumerate(request_ids_to_seq_ids_flatten): - if request_id in finished_requests_ids: - # Do not allocate cache index for requests that run - # and finish right after - continue - self._assign_seq_id_to_mamba_cache_in_specific_dest( - request_id, seq_id, dest_index) - running_indices.append(dest_index) - - self._clean_up_first_bs_blocks(batch_size, running_indices) - conv_state = self.mamba_cache[0][:, :batch_size] - temporal_state = self.mamba_cache[1][:, :batch_size] - - return (conv_state, temporal_state) - - def _get_all_occupied_indices(self): - return [ - cache_idx - for seq_ids2indices in self.mamba_cache_indices_mapping.values() - for cache_idx in seq_ids2indices.values() - ] - - def _clean_up_first_bs_blocks(self, batch_size: int, - indices_for_current_run: List[int]): - # move out all of the occupied but currently not running blocks - # outside of the first n blocks - destination_indices = range(batch_size) - max_possible_batch_size = self.mamba_cache[0].shape[1] - for destination_index in destination_indices: - if destination_index in self._get_all_occupied_indices() and \ - destination_index not in indices_for_current_run: - # move not running indices outside of the batch - all_other_indices = list( - range(batch_size, max_possible_batch_size)) - first_avail_index = self._first_free_index_in_mamba_cache( - all_other_indices) - self._swap_indices(from_index=destination_index, - to_index=first_avail_index) - - def _move_cache_index_and_mappings(self, from_index: int, to_index: int): - self._copy_mamba_cache(from_index=from_index, to_index=to_index) - self._update_mapping_index(from_index=from_index, to_index=to_index) - - def _swap_pair_indices_and_mappings(self, from_index: int, to_index: int): - self._swap_mamba_cache(from_index=from_index, to_index=to_index) - self._swap_mapping_index(from_index=from_index, to_index=to_index) - - def _swap_mapping_index(self, from_index: int, to_index: int): - for seq_ids2index in self.mamba_cache_indices_mapping.values(): - for seq_id, index in seq_ids2index.items(): - if from_index == index: - seq_ids2index.update({seq_id: to_index}) - elif to_index == index: - seq_ids2index.update({seq_id: from_index}) - - def _update_mapping_index(self, from_index: int, to_index: int): - for seq_ids2index in self.mamba_cache_indices_mapping.values(): - for seq_id, index in seq_ids2index.items(): - if from_index == index: - seq_ids2index.update({seq_id: to_index}) - return def _release_finished_requests(self, finished_seq_groups_req_ids: List[str]): for req_id in finished_seq_groups_req_ids: if req_id in self.mamba_cache_indices_mapping: + for seq_id in self.mamba_cache_indices_mapping[req_id]: + self.free_cache_indices.append( + self.mamba_cache_indices_mapping[req_id][seq_id]) self.mamba_cache_indices_mapping.pop(req_id) - - def _first_free_index_in_mamba_cache( - self, indices_range: Optional[List[int]] = None) -> int: - assert self.mamba_cache is not None - if indices_range is None: - max_possible_batch_size = self.mamba_cache[0].shape[1] - indices_range = list(range(max_possible_batch_size)) - all_occupied_indices = self._get_all_occupied_indices() - for i in indices_range: - if i not in all_occupied_indices: - return i - raise Exception("Couldn't find a free spot in the mamba cache! This" - "should never happen") diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index 41c2877194bb2..decd90b682a1e 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -33,7 +33,7 @@ from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) -from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.activation import FatreluAndMul, SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -152,6 +152,7 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, + hidden_act_param: float, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -163,10 +164,13 @@ def __init__( hidden_size, bias=False, quant_config=quant_config) - if hidden_act != "silu": + if hidden_act == "silu": + self.act_fn = SiluAndMul() + elif hidden_act == "fatrelu": + self.act_fn = FatreluAndMul(threshold=hidden_act_param) + else: raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") - self.act_fn = SiluAndMul() + "Only silu and fatrelu are supported for now.") def forward(self, x): gate_up, _ = self.gate_up_proj(x) @@ -304,6 +308,7 @@ def _init_ffn_block(self): hidden_size=self.hidden_size, intermediate_size=self.config.intermediate_size, hidden_act=self.config.hidden_act, + hidden_act_param=getattr(self.config, "hidden_act_param", 0.), quant_config=self.quant_config, ) else: diff --git a/vllm/profiler/__init__.py b/vllm/profiler/__init__.py new file mode 100644 index 0000000000000..3e25f5cc283f2 --- /dev/null +++ b/vllm/profiler/__init__.py @@ -0,0 +1,5 @@ +from .layerwise_profile import layerwise_profile + +__all__ = [ + "layerwise_profile", +] diff --git a/vllm/profiler/layerwise_profile.py b/vllm/profiler/layerwise_profile.py new file mode 100644 index 0000000000000..9d9f427e807f6 --- /dev/null +++ b/vllm/profiler/layerwise_profile.py @@ -0,0 +1,354 @@ +import copy +from collections import defaultdict +from dataclasses import asdict, dataclass, field +from typing import Callable, Dict, List, Optional, Tuple, TypeAlias, Union + +import pandas as pd +from torch._C._autograd import DeviceType, _KinetoEvent, _ProfilerResult +from torch._C._profiler import _EventType, _ExperimentalConfig, _ProfilerEvent +from torch.autograd.profiler import FunctionEvent +from torch.profiler import ProfilerActivity, profile + +from vllm.profiler.utils import (TablePrinter, event_has_module, + event_is_torch_op, event_module_repr, + event_torch_op_stack_trace, indent_string) + + +@dataclass +class _ModuleTreeNode: + event: _ProfilerEvent + parent: Optional['_ModuleTreeNode'] = None + children: List['_ModuleTreeNode'] = field(default_factory=list) + trace: str = "" + + @property + def is_leaf(self): + return (self.event.children is None or len(self.event.children) == 0) + + @property + def is_torch_op(self): + return event_is_torch_op(self.event) + + @property + def is_cuda(self): + return (self.event.tag == _EventType.Kineto + and self.event.typed[1].device_type == DeviceType.CUDA) + + +@dataclass +class SummaryStatsEntry: + name: str + cuda_time_us: float + pct_cuda_time: float + invocations: int + + +@dataclass +class ModelStatsEntry: + name: str + cpu_time_us: float + cuda_time_us: float + pct_cuda_time: float + trace: str + + +StatsEntry: TypeAlias = Union[ModelStatsEntry, SummaryStatsEntry] + + +@dataclass +class _StatsTreeNode: + entry: StatsEntry + children: List[StatsEntry] + parent: Optional[StatsEntry] + + +@dataclass +class LayerwiseProfileResults(profile): + _kineto_results: _ProfilerResult + _kineto_event_correlation_map: Dict[int, + List[_KinetoEvent]] = field(init=False) + _event_correlation_map: Dict[int, List[FunctionEvent]] = field(init=False) + _module_tree: List[_ModuleTreeNode] = field(init=False) + _model_stats_tree: List[_StatsTreeNode] = field(init=False) + _summary_stats_tree: List[_StatsTreeNode] = field(init=False) + + def __post_init__(self): + self._build_correlation_map() + self._build_module_tree() + self._build_stats_trees() + + def print_model_table(self, column_widths: Dict[str, int] = None): + _column_widths = dict(name=60, + cpu_time_us=12, + cuda_time_us=12, + pct_cuda_time=12, + trace=60) + if column_widths: + _column_widths.update(**column_widths) + filtered_model_table = [ + (depth, row) + for depth, row in self._flatten_stats_tree(self._model_stats_tree) + if row.cuda_time_us > 0 or row.cpu_time_us > 0 + ] + TablePrinter(ModelStatsEntry, _column_widths).print_table( + self._indent_row_names_based_on_depth( + filtered_model_table, + indent_style=lambda indent: "|" + "-" * indent + " ")) + + def print_summary_table(self, column_widths: Dict[str, int] = None): + _column_widths = dict(name=80, + cuda_time_us=12, + pct_cuda_time=12, + invocations=15) + if column_widths: + _column_widths.update(**column_widths) + filtered_summary_table = [(depth, row) + for depth, row in self._flatten_stats_tree( + self._summary_stats_tree) + if row.cuda_time_us > 0] + TablePrinter(SummaryStatsEntry, _column_widths).print_table( + self._indent_row_names_based_on_depth( + filtered_summary_table, + indent_style=lambda indent: "|" + "-" * indent + " ")) + + def export_model_stats_table_csv(self, filename: str): + df = pd.DataFrame([ + asdict(row) + for _, row in self._flatten_stats_tree(self._model_stats_tree) + ]) + df.to_csv(filename) + + def export_summary_stats_table_csv(self, filename: str): + df = pd.DataFrame([ + asdict(row) + for _, row in self._flatten_stats_tree(self._summary_stats_tree) + ]) + df.to_csv(filename) + + def convert_stats_to_dict(self) -> str: + return { + "summary_stats": + self._convert_stats_tree_to_dict(self._summary_stats_tree), + "model_stats": + self._convert_stats_tree_to_dict(self._model_stats_tree) + } + + @staticmethod + def _indent_row_names_based_on_depth(depths_rows: List[Tuple[int, + StatsEntry]], + indent_style: Union[Callable[[int], + str], + str] = " "): + indented_rows = [] + for depth, row in depths_rows: + if row.cuda_time_us == 0: + continue + indented_row = copy.deepcopy(row) + indented_row.name = indent_string(indented_row.name, depth, + indent_style) + indented_rows.append(indented_row) + return indented_rows + + def _build_correlation_map(self): + self._kineto_event_correlation_map = defaultdict(list) + for event in self._kineto_results.events(): + self._kineto_event_correlation_map[event.correlation_id()].append( + event) + + def _build_module_tree(self): + self._module_tree = [] + event_tree = self._kineto_results.experimental_event_tree() + + def _df_traversal(event: _ProfilerEvent, + curr_node: Optional[_ModuleTreeNode] = None): + + # For the tensor parallel case for now only look at task 1 + if event.start_tid != 1: + return + + if event_has_module(event): + node = _ModuleTreeNode(event=event, parent=curr_node) + if curr_node: + curr_node.children.append(node) + else: + self._module_tree.append(node) + curr_node = node + + is_leaf = (event.children is None or len(event.children) == 0) + if is_leaf and curr_node: + node = _ModuleTreeNode( + event=event, + parent=curr_node, + trace=event_torch_op_stack_trace( + event, until=lambda x: event_has_module(x))) + curr_node.children.append(node) + curr_node = node + + for child in event.children: + _df_traversal(child, curr_node) + + for root in event_tree: + _df_traversal(root) + + def _get_kineto_gpu_event(self, node: _ModuleTreeNode): + if node.event.tag != _EventType.Kineto: + return None + correlated_kineto_events = self._kineto_event_correlation_map.get( + node.event.correlation_id, []) + iterator = (x for x in correlated_kineto_events + if x.device_type() == DeviceType.CUDA + and x.name() == node.event.name) + return next(iterator, None) + + def _cumulative_cuda_time(self, node: _ModuleTreeNode): + 'Return cuda time in microseconds' + + def _cumulative_cuda_time_recursive(node: _ModuleTreeNode): + if node.is_leaf and (gpu_kineto_event := + self._get_kineto_gpu_event(node)): + return gpu_kineto_event.duration_ns() / 1000.0 + else: + cumulative_cuda_time = 0 + for child in node.children: + cumulative_cuda_time += _cumulative_cuda_time_recursive( + child) + return cumulative_cuda_time + + return _cumulative_cuda_time_recursive(node) + + def _total_cuda_time(self): + return sum( + [self._cumulative_cuda_time(root) for root in self._module_tree]) + + def _build_stats_trees(self): + summary_dict: Dict[str, self.StatsTreeNode] = {} + total_cuda_time = self._total_cuda_time() + + def pct_cuda_time(cuda_time_us): + return (cuda_time_us / total_cuda_time) * 100 + + def build_summary_stats_tree_df( + node: _ModuleTreeNode, + parent: Optional[_StatsTreeNode] = None, + summary_trace: Tuple[str] = ()): + + if event_has_module(node.event): + name = event_module_repr(node.event) + cuda_time_us = self._cumulative_cuda_time(node) + elif (gpu_kineto_event := self._get_kineto_gpu_event(node)): + name = gpu_kineto_event.name() + cuda_time_us = gpu_kineto_event.duration_ns() / 1000.0 + else: + return None + + summary_trace = summary_trace + (name, ) + if summary_trace in summary_dict: + entry = summary_dict[summary_trace].entry + entry.cuda_time_us += cuda_time_us + entry.invocations += 1 + entry.pct_cuda_time = pct_cuda_time(entry.cuda_time_us) + else: + new_node = _StatsTreeNode(entry=SummaryStatsEntry( + name=name, + cuda_time_us=cuda_time_us, + pct_cuda_time=pct_cuda_time(cuda_time_us), + invocations=1), + children=[], + parent=parent) + if parent: + parent.children.append(new_node) + summary_dict[summary_trace] = new_node + + for child in node.children: + build_summary_stats_tree_df(child, summary_dict[summary_trace], + summary_trace) + + return summary_dict[summary_trace] + + self._summary_stats_tree = [] + for root in self._module_tree: + self._summary_stats_tree.append(build_summary_stats_tree_df(root)) + + def build_model_stats_tree_df(node: _ModuleTreeNode, + parent: Optional[_StatsTreeNode] = None): + if event_has_module(node.event, ): + name = event_module_repr(node.event) + cuda_time_us = self._cumulative_cuda_time(node) + cpu_time_us = node.event.duration_time_ns / 1000 + trace = "" + elif (gpu_kineto_event := self._get_kineto_gpu_event(node)): + name = gpu_kineto_event.name() + cuda_time_us = gpu_kineto_event.duration_ns() / 1000.0 + cpu_time_us = 0 + trace = node.trace + else: + return None + + new_node = _StatsTreeNode(entry=ModelStatsEntry( + name=name, + cpu_time_us=cpu_time_us, + cuda_time_us=cuda_time_us, + pct_cuda_time=pct_cuda_time(cuda_time_us), + trace=trace), + parent=parent, + children=[]) + if parent: + parent.children.append(new_node) + + for child in node.children: + build_model_stats_tree_df(child, new_node) + + return new_node + + self._model_stats_tree = [] + for root in self._module_tree: + self._model_stats_tree.append(build_model_stats_tree_df(root)) + + def _flatten_stats_tree( + self, tree: List[_StatsTreeNode]) -> List[Tuple[int, StatsEntry]]: + entries: List[Tuple[int, StatsEntry]] = [] + + def df_traversal(node: _StatsTreeNode, depth=0): + entries.append((depth, node.entry)) + for child in node.children: + df_traversal(child, depth=depth + 1) + + for root in tree: + df_traversal(root) + + return entries + + def _convert_stats_tree_to_dict(self, + tree: List[_StatsTreeNode]) -> List[Dict]: + root_dicts: List[Dict] = [] + + def df_traversal(node: _StatsTreeNode, curr_json_list: List[Dict]): + curr_json_list.append({ + "entry": asdict(node.entry), + "children": [] + }) + for child in node.children: + df_traversal(child, curr_json_list[-1]["children"]) + + for root in tree: + df_traversal(root, root_dicts) + + return root_dicts + + +class layerwise_profile(profile): + + def __init__(self): + super().__init__( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + record_shapes=True, + with_stack=True, + with_modules=True, + experimental_config=_ExperimentalConfig(verbose=True)) + + def __enter__(self): + return super().__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb): + super().__exit__(exc_type, exc_val, exc_tb) + self.results = LayerwiseProfileResults(self.profiler.kineto_results) diff --git a/vllm/profiler/utils.py b/vllm/profiler/utils.py new file mode 100644 index 0000000000000..033035e434325 --- /dev/null +++ b/vllm/profiler/utils.py @@ -0,0 +1,145 @@ +import dataclasses +from typing import Callable, Dict, List, Type, Union + +from torch._C._profiler import _EventType, _ProfilerEvent, _TensorMetadata + +# +# String / Print Manipulation +# + + +def trim_string_front(string, width): + if len(string) > width: + offset = len(string) - width + 3 + string = string[offset:] + if len(string) > 3: + string = "..." + string[3:] + return string + + +def trim_string_back(string, width): + if len(string) > width: + offset = len(string) - width + 3 + string = string[:-offset] + if len(string) > 3: + string = string + "..." + return string + + +class TablePrinter: + + def __init__(self, row_cls: Type[dataclasses.dataclass], + column_widths: Dict[str, int]): + self.row_cls = row_cls + self.fieldnames = [x.name for x in dataclasses.fields(row_cls)] + self.column_widths = column_widths + assert set(self.column_widths.keys()) == set(self.fieldnames) + + def print_table(self, rows: List[dataclasses.dataclass]): + self._print_header() + self._print_line() + for row in rows: + self._print_row(row) + + def _print_header(self): + for i, f in enumerate(self.fieldnames): + last = (i == len(self.fieldnames) - 1) + col_width = self.column_widths[f] + print(trim_string_back(f, col_width).ljust(col_width), + end=" | " if not last else "\n") + + def _print_row(self, row): + assert isinstance(row, self.row_cls) + + for i, f in enumerate(self.fieldnames): + last = (i == len(self.fieldnames) - 1) + col_width = self.column_widths[f] + val = getattr(row, f) + + val_str = "" + if isinstance(val, str): + val_str = trim_string_back(val, col_width).ljust(col_width) + elif type(val) in [float, int]: + val_str = f"{float(val):>.2f}".rjust(col_width) + else: + val_str = f"{val}".rjust(col_width) + print(val_str, end=" | " if not last else "\n") + + def _print_line(self): + total_col_width = 0 + for column_width in self.column_widths.values(): + total_col_width += column_width + print("=" * (total_col_width + 3 * (len(self.column_widths) - 1))) + + +def indent_string(string: str, + indent: int, + indent_style: Union[Callable[[int], str], str] = " ") -> str: + if indent: + if isinstance(indent_style, str): + return indent_style * indent + string + else: + return indent_style(indent) + string + else: + return string + + +# +# _ProfilerEvent utils +# + + +def event_has_module(event: _ProfilerEvent) -> bool: + event_type, typed_event = event.typed + if event_type == _EventType.PyCall: + return typed_event.module is not None + return False + + +def event_is_torch_op(event: _ProfilerEvent) -> bool: + return event.tag == _EventType.TorchOp + + +def event_arg_repr(arg) -> str: + if arg is None or type(arg) in [float, int, bool, str]: + return f"{arg}" + elif isinstance(arg, list): + return f"[{', '.join([event_arg_repr(x) for x in arg])}]" + elif isinstance(arg, tuple): + return f"({', '.join([event_arg_repr(x) for x in arg])})" + else: + assert isinstance(arg, + _TensorMetadata), f"Unsupported type: {type(arg)}" + sizes_str = ', '.join([str(x) for x in arg.sizes]) + return f"{str(arg.dtype).replace('torch.', '')}[{sizes_str}]" + + +def event_torch_op_repr(event: _ProfilerEvent) -> str: + assert event.tag == _EventType.TorchOp + args_str = ', '.join([event_arg_repr(x) for x in event.typed[1].inputs]) + return f"{event.name}({args_str})".replace("aten::", "") + + +def event_module_repr(event: _ProfilerEvent) -> str: + assert event_has_module(event) + module = event.typed[1].module + if module.parameters and len(module.parameters) > 0: + args_str = ', '.join( + [f'{x[0]}={event_arg_repr(x[1])}' for x in module.parameters]) + return f"{module.cls_name}({args_str})" + else: + return module.cls_name + + +def event_torch_op_stack_trace(curr_event: _ProfilerEvent, + until: Callable[[_ProfilerEvent], bool]) -> str: + trace = "" + curr_event = curr_event.parent + while curr_event and not until(curr_event): + if event_is_torch_op(curr_event): + if len(trace) > 0: + trace += " <- " + trace += event_torch_op_repr(curr_event) + curr_event = curr_event.parent + + return trace diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 0f3c379cee8f0..36753b8580f6f 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1742,10 +1742,13 @@ def execute_model( return [output] -class CUDAGraphRunner: +# NOTE: this is nn.Module so the profiler can properly capture/group +# kernels calls made within the graph +class CUDAGraphRunner(nn.Module): def __init__(self, model: nn.Module, backend_name: str, attn_state: AttentionState, is_encoder_decoder_model: bool): + super().__init__() self.model = model self.backend_name = backend_name self.attn_state = attn_state @@ -1892,9 +1895,6 @@ def forward( return self.output_buffers - def __call__(self, *args, **kwargs): - return self.forward(*args, **kwargs) - def _get_graph_batch_size(batch_size: int) -> int: """Returns the padded batch size given actual batch size.