Skip to content

Commit

Permalink
added configs
Browse files Browse the repository at this point in the history
  • Loading branch information
robertgshaw2-neuralmagic committed Dec 28, 2024
1 parent b55c584 commit 06c1f26
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 20 deletions.
30 changes: 13 additions & 17 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,15 +339,15 @@ def invoke_fused_moe_kernel(A: torch.Tensor,


# Adapted from: https://github.com/sgl-project/sglang/pull/2628
def get_config_file_name(
E: int, N: int, dtype: Optional[str], block_shape: Optional[int] = None
) -> str:
def get_config_file_name(E: int,
N: int,
dtype: Optional[str],
block_shape: List[Optional[int]] = None) -> str:

Check failure on line 345 in vllm/model_executor/layers/fused_moe/fused_moe.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Incompatible default for argument "block_shape" (default has type "None", argument has type "list[Optional[int]]") [assignment]

Check failure on line 345 in vllm/model_executor/layers/fused_moe/fused_moe.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Incompatible default for argument "block_shape" (default has type "None", argument has type "list[int | None]") [assignment]

Check failure on line 345 in vllm/model_executor/layers/fused_moe/fused_moe.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Incompatible default for argument "block_shape" (default has type "None", argument has type "list[int | None]") [assignment]

Check failure on line 345 in vllm/model_executor/layers/fused_moe/fused_moe.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Incompatible default for argument "block_shape" (default has type "None", argument has type "list[int | None]") [assignment]
device_name = current_platform.get_device_name().replace(" ", "_")
dtype_selector = "" if not dtype else f",dtype={dtype}"
block_shape_selector = (
"" if not block_shape or not all(block_shape) else f",block_shape={block_shape}"
)
return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}.json"
block_shape_selector = ("" if not block_shape or not all(block_shape) else
f",block_shape={block_shape}")
return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}.json" # noqa: E501


# Adapted from: https://github.com/sgl-project/sglang/pull/2628
Expand Down Expand Up @@ -419,8 +419,8 @@ def get_default_config(
"num_stages": 4,
}
else:
# Block-wise quant: BLOCK_SIZE_N must be divisable by block_shape[0]
# BLOCK_SIZE_K must be divisable by block_shape[1]
# Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0]
# BLOCK_SIZE_K must be divisible by block_shape[1]
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": block_shape[0],
Expand Down Expand Up @@ -463,7 +463,9 @@ def try_get_optimal_moe_config(
else:
# First try to load optimal config from the file
E, _, N = w2_shape
configs = get_moe_configs(E, N, dtype)
block_n = block_shape[0] if block_shape else 0
block_k = block_shape[1] if block_shape else 0
configs = get_moe_configs(E, N, dtype, block_n, block_k)

if configs:
# If an optimal configuration map has been found, look up the
Expand All @@ -472,13 +474,7 @@ def try_get_optimal_moe_config(
else:
# Else use the default config
config = get_default_config(M, E, N, w1_shape[2], top_k, dtype,
is_marlin)
# NOTE: For block-wise quant,
# BLOCK_K must be divisible by block_shape[1]
# BLOCK_N and BLOCK_M has no requirements
if block_shape is not None:
config["BLOCK_SIZE_N"] = block_shape[0]
config["BLOCK_SIZE_K"] = block_shape[1]
is_marlin, block_shape)
return config


Expand Down
6 changes: 3 additions & 3 deletions vllm/model_executor/layers/quantization/utils/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def get_w8a8_block_fp8_configs(N: int, K: int, block_n: int,
# First look up if an optimized configuration is available in the configs
# directory
device_name = current_platform.get_device_name().replace(" ", "_")
json_file_name = f"N={N},K={K},device_name={device_name},dtype=fp8_w8a8,block_shape=[{block_n}, {block_k}].json"
json_file_name = f"N={N},K={K},device_name={device_name},dtype=fp8_w8a8,block_shape=[{block_n}, {block_k}].json" # noqa: E501

config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
Expand All @@ -319,8 +319,8 @@ def get_w8a8_block_fp8_configs(N: int, K: int, block_n: int,
# If no optimized configuration is available, we will use the default
# configuration
logger.warning(
("Using default W8A8 Block FP8 kernel config. Performance might be sub-optimal! "
"Config file not found at %s"),
"Using default W8A8 Block FP8 kernel config. Performance might "
"be sub-optimal! Config file not found at %s",
config_file_path,
)
return None
Expand Down

0 comments on commit 06c1f26

Please sign in to comment.