From 35e0892d66394c3ff47b906b1c65ebaeb5ab479e Mon Sep 17 00:00:00 2001 From: "mulin.lyh" Date: Mon, 3 Jun 2024 18:42:19 +0800 Subject: [PATCH] Fixed broken of download models from modelscope --- .../entrypoints/test_model_from_modelscope.py | 24 +++++++++++++++++++ vllm/engine/llm_engine.py | 2 +- vllm/transformers_utils/config.py | 22 ++++++++++++----- 3 files changed, 41 insertions(+), 7 deletions(-) create mode 100644 tests/entrypoints/test_model_from_modelscope.py diff --git a/tests/entrypoints/test_model_from_modelscope.py b/tests/entrypoints/test_model_from_modelscope.py new file mode 100644 index 0000000000000..7d38cd973ccb6 --- /dev/null +++ b/tests/entrypoints/test_model_from_modelscope.py @@ -0,0 +1,24 @@ +import asyncio +import pytest + +from vllm import LLM, SamplingParams + +# model: https://modelscope.cn/models/qwen/Qwen1.5-0.5B-Chat/summary +MODEL_NAME = "qwen/Qwen1.5-0.5B-Chat" + + +def test_offline_inference(monkeypatch): + monkeypatch.setenv("VLLM_USE_MODELSCOPE", "True") + llm = LLM(model=MODEL_NAME) + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + + outputs = llm.generate(prompts, sampling_params) + assert len(outputs) == 4 + diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index cb5893e707c8b..df9c47911ced5 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -402,7 +402,7 @@ def _init_tokenizer(self, **tokenizer_init_kwargs) -> BaseTokenizerGroup: max_input_length=None, tokenizer_mode=self.model_config.tokenizer_mode, trust_remote_code=self.model_config.trust_remote_code, - revision=self.model_config.tokenizer_revision) + revision=self.model_config.tokenizer_revision if self.model_config.tokenizer_revision is not None else self.model_config.revision) init_kwargs.update(tokenizer_init_kwargs) return get_tokenizer_group(self.parallel_config.tokenizer_pool_config, diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 044eec6410a54..6b23da667abb2 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -1,10 +1,11 @@ from typing import Dict, Optional -from transformers import AutoConfig, PretrainedConfig +from transformers import PretrainedConfig from vllm.logger import init_logger from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig, JAISConfig, MPTConfig, RWConfig) +from vllm.envs import VLLM_USE_MODELSCOPE logger = init_logger(__name__) @@ -24,11 +25,20 @@ def get_config(model: str, code_revision: Optional[str] = None, rope_scaling: Optional[dict] = None) -> PretrainedConfig: try: - config = AutoConfig.from_pretrained( - model, - trust_remote_code=trust_remote_code, - revision=revision, - code_revision=code_revision) + if VLLM_USE_MODELSCOPE: + from modelscope import AutoConfig + config = AutoConfig.from_pretrained( + model, + trust_remote_code=trust_remote_code, + revision=revision, + code_revision=code_revision) + else: + from transformers import AutoConfig + config = AutoConfig.from_pretrained( + model, + trust_remote_code=trust_remote_code, + revision=revision, + code_revision=code_revision) except ValueError as e: if (not trust_remote_code and "requires you to execute the configuration file" in str(e)):