From 2298e69b5f1dc77f00aee687a3843a4dae12cb91 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 18 Nov 2024 15:29:37 -0800 Subject: [PATCH] [ci][bugfix] fix kernel tests (#10431) Signed-off-by: youkaichao --- vllm/plugins/__init__.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/vllm/plugins/__init__.py b/vllm/plugins/__init__.py index c5182139db50b..fdc848cedf054 100644 --- a/vllm/plugins/__init__.py +++ b/vllm/plugins/__init__.py @@ -6,9 +6,6 @@ if TYPE_CHECKING: from vllm.config import CompilationConfig, VllmConfig -else: - CompilationConfig = None - VllmConfig = None logger = logging.getLogger(__name__) @@ -50,23 +47,23 @@ def load_general_plugins(): logger.exception("Failed to load plugin %s", plugin.name) -_compilation_config: Optional[CompilationConfig] = None +_compilation_config: Optional["CompilationConfig"] = None -def set_compilation_config(config: Optional[CompilationConfig]): +def set_compilation_config(config: Optional["CompilationConfig"]): global _compilation_config _compilation_config = config -def get_compilation_config() -> Optional[CompilationConfig]: +def get_compilation_config() -> Optional["CompilationConfig"]: return _compilation_config -_current_vllm_config: Optional[VllmConfig] = None +_current_vllm_config: Optional["VllmConfig"] = None @contextmanager -def set_current_vllm_config(vllm_config: VllmConfig): +def set_current_vllm_config(vllm_config: "VllmConfig"): """ Temporarily set the current VLLM config. Used during model initialization. @@ -87,6 +84,12 @@ def set_current_vllm_config(vllm_config: VllmConfig): _current_vllm_config = old_vllm_config -def get_current_vllm_config() -> VllmConfig: - assert _current_vllm_config is not None, "Current VLLM config is not set." +def get_current_vllm_config() -> "VllmConfig": + if _current_vllm_config is None: + # in ci, usually when we test custom ops/modules directly, + # we don't set the vllm config. In that case, we set a default + # config. + logger.warning("Current VLLM config is not set.") + from vllm.config import VllmConfig + return VllmConfig() return _current_vllm_config