diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 3da99bcbee9ae..4529cf27ef565 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -1,5 +1,6 @@ import enum import json +import os from pathlib import Path from typing import Any, Dict, Optional, Type, Union @@ -41,6 +42,7 @@ from transformers import AutoConfig MISTRAL_CONFIG_NAME = "params.json" +HF_TOKEN = os.getenv('HF_TOKEN', None) logger = init_logger(__name__) @@ -77,8 +79,8 @@ class ConfigFormat(str, enum.Enum): MISTRAL = "mistral" -def file_or_path_exists(model: Union[str, Path], config_name, revision, - token) -> bool: +def file_or_path_exists(model: Union[str, Path], config_name: str, + revision: Optional[str]) -> bool: if Path(model).exists(): return (Path(model) / config_name).is_file() @@ -93,7 +95,10 @@ def file_or_path_exists(model: Union[str, Path], config_name, revision, # NB: file_exists will only check for the existence of the config file on # hf_hub. This will fail in offline mode. try: - return file_exists(model, config_name, revision=revision, token=token) + return file_exists(model, + config_name, + revision=revision, + token=HF_TOKEN) except huggingface_hub.errors.OfflineModeIsEnabled: # Don't raise in offline mode, all we know is that we don't have this # file cached. @@ -161,7 +166,6 @@ def get_config( revision: Optional[str] = None, code_revision: Optional[str] = None, config_format: ConfigFormat = ConfigFormat.AUTO, - token: Optional[str] = None, **kwargs, ) -> PretrainedConfig: # Separate model folder from file path for GGUF models @@ -173,19 +177,20 @@ def get_config( if config_format == ConfigFormat.AUTO: if is_gguf or file_or_path_exists( - model, HF_CONFIG_NAME, revision=revision, token=token): + model, HF_CONFIG_NAME, revision=revision): config_format = ConfigFormat.HF - elif file_or_path_exists(model, - MISTRAL_CONFIG_NAME, - revision=revision, - token=token): + elif file_or_path_exists(model, MISTRAL_CONFIG_NAME, + revision=revision): config_format = ConfigFormat.MISTRAL else: # If we're in offline mode and found no valid config format, then # raise an offline mode error to indicate to the user that they # don't have files cached and may need to go online. # This is conveniently triggered by calling file_exists(). - file_exists(model, HF_CONFIG_NAME, revision=revision, token=token) + file_exists(model, + HF_CONFIG_NAME, + revision=revision, + token=HF_TOKEN) raise ValueError(f"No supported config format found in {model}") @@ -194,7 +199,7 @@ def get_config( model, revision=revision, code_revision=code_revision, - token=token, + token=HF_TOKEN, **kwargs, ) @@ -206,7 +211,7 @@ def get_config( model, revision=revision, code_revision=code_revision, - token=token, + token=HF_TOKEN, **kwargs, ) else: @@ -216,7 +221,7 @@ def get_config( trust_remote_code=trust_remote_code, revision=revision, code_revision=code_revision, - token=token, + token=HF_TOKEN, **kwargs, ) except ValueError as e: @@ -234,7 +239,7 @@ def get_config( raise e elif config_format == ConfigFormat.MISTRAL: - config = load_params_config(model, revision, token=token, **kwargs) + config = load_params_config(model, revision, token=HF_TOKEN, **kwargs) else: raise ValueError(f"Unsupported config format: {config_format}") @@ -256,8 +261,7 @@ def get_config( def get_hf_file_to_dict(file_name: str, model: Union[str, Path], - revision: Optional[str] = 'main', - token: Optional[str] = None): + revision: Optional[str] = 'main'): """ Downloads a file from the Hugging Face Hub and returns its contents as a dictionary. @@ -266,7 +270,6 @@ def get_hf_file_to_dict(file_name: str, - file_name (str): The name of the file to download. - model (str): The name of the model on the Hugging Face Hub. - revision (str): The specific version of the model. - - token (str): The Hugging Face authentication token. Returns: - config_dict (dict): A dictionary containing @@ -276,8 +279,7 @@ def get_hf_file_to_dict(file_name: str, if file_or_path_exists(model=model, config_name=file_name, - revision=revision, - token=token): + revision=revision): if not file_path.is_file(): try: @@ -296,9 +298,7 @@ def get_hf_file_to_dict(file_name: str, return None -def get_pooling_config(model: str, - revision: Optional[str] = 'main', - token: Optional[str] = None): +def get_pooling_config(model: str, revision: Optional[str] = 'main'): """ This function gets the pooling and normalize config from the model - only applies to @@ -315,8 +315,7 @@ def get_pooling_config(model: str, """ modules_file_name = "modules.json" - modules_dict = get_hf_file_to_dict(modules_file_name, model, revision, - token) + modules_dict = get_hf_file_to_dict(modules_file_name, model, revision) if modules_dict is None: return None @@ -332,8 +331,7 @@ def get_pooling_config(model: str, if pooling: pooling_file_name = "{}/config.json".format(pooling["path"]) - pooling_dict = get_hf_file_to_dict(pooling_file_name, model, revision, - token) + pooling_dict = get_hf_file_to_dict(pooling_file_name, model, revision) pooling_type_name = next( (item for item, val in pooling_dict.items() if val is True), None) @@ -368,8 +366,8 @@ def get_pooling_config_name(pooling_name: str) -> Union[str, None]: def get_sentence_transformer_tokenizer_config(model: str, - revision: Optional[str] = 'main', - token: Optional[str] = None): + revision: Optional[str] = 'main' + ): """ Returns the tokenization configuration dictionary for a given Sentence Transformer BERT model. @@ -379,7 +377,6 @@ def get_sentence_transformer_tokenizer_config(model: str, BERT model. - revision (str, optional): The revision of the m odel to use. Defaults to 'main'. - - token (str): A Hugging Face access token. Returns: - dict: A dictionary containing the configuration parameters @@ -394,7 +391,7 @@ def get_sentence_transformer_tokenizer_config(model: str, "sentence_xlm-roberta_config.json", "sentence_xlnet_config.json", ]: - encoder_dict = get_hf_file_to_dict(config_name, model, revision, token) + encoder_dict = get_hf_file_to_dict(config_name, model, revision) if encoder_dict: break @@ -474,16 +471,14 @@ def _reduce_config(config: VllmConfig): exc_info=e) -def load_params_config(model: Union[str, Path], - revision: Optional[str], - token: Optional[str] = None, +def load_params_config(model: Union[str, Path], revision: Optional[str], **kwargs) -> PretrainedConfig: # This function loads a params.json config which # should be used when loading models in mistral format config_file_name = "params.json" - config_dict = get_hf_file_to_dict(config_file_name, model, revision, token) + config_dict = get_hf_file_to_dict(config_file_name, model, revision) assert isinstance(config_dict, dict) config_mapping = {