Skip to content

Commit

Permalink
Fixed broken of download models from modelscope
Browse files Browse the repository at this point in the history
  • Loading branch information
mulin.lyh committed Jun 3, 2024
1 parent 7a64d24 commit 35e0892
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 7 deletions.
24 changes: 24 additions & 0 deletions tests/entrypoints/test_model_from_modelscope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import asyncio
import pytest

from vllm import LLM, SamplingParams

# model: https://modelscope.cn/models/qwen/Qwen1.5-0.5B-Chat/summary
MODEL_NAME = "qwen/Qwen1.5-0.5B-Chat"


def test_offline_inference(monkeypatch):
monkeypatch.setenv("VLLM_USE_MODELSCOPE", "True")
llm = LLM(model=MODEL_NAME)

prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

outputs = llm.generate(prompts, sampling_params)
assert len(outputs) == 4

2 changes: 1 addition & 1 deletion vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def _init_tokenizer(self, **tokenizer_init_kwargs) -> BaseTokenizerGroup:
max_input_length=None,
tokenizer_mode=self.model_config.tokenizer_mode,
trust_remote_code=self.model_config.trust_remote_code,
revision=self.model_config.tokenizer_revision)
revision=self.model_config.tokenizer_revision if self.model_config.tokenizer_revision is not None else self.model_config.revision)
init_kwargs.update(tokenizer_init_kwargs)

return get_tokenizer_group(self.parallel_config.tokenizer_pool_config,
Expand Down
22 changes: 16 additions & 6 deletions vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import Dict, Optional

from transformers import AutoConfig, PretrainedConfig
from transformers import PretrainedConfig

from vllm.logger import init_logger
from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
JAISConfig, MPTConfig, RWConfig)
from vllm.envs import VLLM_USE_MODELSCOPE

logger = init_logger(__name__)

Expand All @@ -24,11 +25,20 @@ def get_config(model: str,
code_revision: Optional[str] = None,
rope_scaling: Optional[dict] = None) -> PretrainedConfig:
try:
config = AutoConfig.from_pretrained(
model,
trust_remote_code=trust_remote_code,
revision=revision,
code_revision=code_revision)
if VLLM_USE_MODELSCOPE:
from modelscope import AutoConfig
config = AutoConfig.from_pretrained(
model,
trust_remote_code=trust_remote_code,
revision=revision,
code_revision=code_revision)
else:
from transformers import AutoConfig
config = AutoConfig.from_pretrained(
model,
trust_remote_code=trust_remote_code,
revision=revision,
code_revision=code_revision)
except ValueError as e:
if (not trust_remote_code and
"requires you to execute the configuration file" in str(e)):
Expand Down

0 comments on commit 35e0892

Please sign in to comment.