diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 5bcae37961195..140b61fe6d56a 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -94,18 +94,34 @@ def _initialize_model(vllm_config: VllmConfig, prefix: str = "") -> nn.Module: model_config = vllm_config.model_config model_class, _ = get_model_architecture(model_config) signatures = inspect.signature(model_class.__init__) - # collect all kw-only parameters - kw_only_params = [ - param.name for param in signatures.parameters.values() - if param.kind == inspect.Parameter.KEYWORD_ONLY - ] - assert "vllm_config" in kw_only_params and "prefix" in kw_only_params, \ - ("vLLM model class must accept `vllm_config` and `prefix` as kw-only " - "arguments. Possibly you have an old-style model class registered from " - "out of tree and it is used for new vLLM version. " - "Please check https://docs.vllm.ai/en/latest/design/class_hierarchy.html " - "for the design and update the model class accordingly.") - return model_class(vllm_config=vllm_config, prefix=prefix) + all_params = [param.name for param in signatures.parameters.values()] + if "vllm_config" in all_params and "prefix" in all_params: + # new-style model class + return model_class(vllm_config=vllm_config, prefix=prefix) + msg = ("vLLM model class should accept `vllm_config` and `prefix` as " + "input arguments. Possibly you have an old-style model class" + " registered from out of tree and it is used for new vLLM version. " + "Check https://docs.vllm.ai/en/latest/design/class_hierarchy.html " + "for the design and update the model class accordingly.") + logger.warning(msg) + logger.warning( + "Trying to guess the arguments for old-style model class %s", + model_class) + # try to be compatible with old-style model class + kwargs = {} + if "prefix" in all_params: + kwargs["prefix"] = prefix + if "config" in all_params: + kwargs["config"] = model_config.hf_config + if "cache_config" in all_params: + kwargs["cache_config"] = vllm_config.cache_config + if "quant_config" in all_params: + kwargs["quant_config"] = vllm_config.quant_config + if "lora_config" in all_params: + kwargs["lora_config"] = vllm_config.lora_config + if "scheduler_config" in all_params: + kwargs["scheduler_config"] = vllm_config.scheduler_config + return model_class(**kwargs) class BaseModelLoader(ABC):