diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index d5c8d6a376961..54558fc2d7e53 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -10,8 +10,7 @@ from vllm.attention.backends.abstract import AttentionBackend from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import (STR_BACKEND_ENV_VAR, is_cpu, is_hip, is_openvino, - is_tpu, is_xpu) +from vllm.utils import STR_BACKEND_ENV_VAR, is_cpu, is_hip, is_openvino, is_xpu logger = init_logger(__name__) @@ -194,7 +193,7 @@ def which_attn_to_use( logger.info("Cannot use %s backend on XPU.", selected_backend) return _Backend.IPEX - if is_tpu(): + if current_platform.is_tpu(): if selected_backend != _Backend.PALLAS: logger.info("Cannot use %s backend on TPU.", selected_backend) return _Backend.PALLAS diff --git a/vllm/config.py b/vllm/config.py index 809d6370763dc..a39f5307931e5 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -10,11 +10,12 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.models import ModelRegistry +from vllm.platforms import current_platform from vllm.tracing import is_otel_installed from vllm.transformers_utils.config import get_config, get_hf_text_config from vllm.utils import (STR_NOT_IMPL_ENC_DEC_CUDAGRAPH, GiB_bytes, cuda_device_count_stateless, get_cpu_memory, is_cpu, - is_hip, is_neuron, is_openvino, is_tpu, is_xpu, + is_hip, is_neuron, is_openvino, is_xpu, print_warning_once) if TYPE_CHECKING: @@ -282,7 +283,7 @@ def _verify_quantization(self) -> None: raise ValueError( f"{self.quantization} quantization is currently not " f"supported in ROCm.") - if is_tpu( + if current_platform.is_tpu( ) and self.quantization not in tpu_supported_quantization: raise ValueError( f"{self.quantization} quantization is currently not " @@ -910,7 +911,7 @@ def __init__(self, device: str = "auto") -> None: self.device_type = "neuron" elif is_openvino(): self.device_type = "openvino" - elif is_tpu(): + elif current_platform.is_tpu(): self.device_type = "tpu" elif is_cpu(): self.device_type = "cpu" diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index ac948331e81e0..ab283467d4783 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -2,8 +2,9 @@ from vllm.config import ParallelConfig from vllm.logger import init_logger +from vllm.platforms import current_platform from vllm.sequence import ExecuteModelRequest, IntermediateTensors -from vllm.utils import get_ip, is_hip, is_tpu, is_xpu +from vllm.utils import get_ip, is_hip, is_xpu from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) @@ -111,7 +112,7 @@ def initialize_ray_cluster( # Placement group is already set. return - device_str = "GPU" if not is_tpu() else "TPU" + device_str = "GPU" if not current_platform.is_tpu() else "TPU" # Create placement group for worker processes current_placement_group = ray.util.get_current_placement_group() if current_placement_group: diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index 0db72d8d95f24..51f3ef5dbb325 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -1,6 +1,7 @@ import torch.nn as nn -from vllm.utils import is_cpu, is_hip, is_tpu, is_xpu +from vllm.platforms import current_platform +from vllm.utils import is_cpu, is_hip, is_xpu class CustomOp(nn.Module): @@ -54,7 +55,7 @@ def dispatch_forward(self): return self.forward_hip elif is_cpu(): return self.forward_cpu - elif is_tpu(): + elif current_platform.is_tpu(): return self.forward_tpu elif is_xpu(): return self.forward_xpu diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index aecba0ae74911..95888e7976ad3 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -28,7 +28,7 @@ import torch.nn as nn from vllm.model_executor.custom_op import CustomOp -from vllm.utils import is_tpu +from vllm.platforms import current_platform def _rotate_neox(x: torch.Tensor) -> torch.Tensor: @@ -78,7 +78,7 @@ def __init__( self.dtype = dtype cache = self._compute_cos_sin_cache() - self.use_native2 = is_tpu() and is_neox_style + self.use_native2 = current_platform.is_tpu() and is_neox_style if not self.use_native2: cache = cache.to(dtype) self.register_buffer("cos_sin_cache", cache, persistent=False) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index ba9c8af88f864..493d3dd29b376 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -41,7 +41,7 @@ supports_vision) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform -from vllm.utils import is_pin_memory_available, is_tpu +from vllm.utils import is_pin_memory_available @contextmanager @@ -94,7 +94,7 @@ def _get_quantization_config( """Get the quantization config.""" if model_config.quantization is not None: quant_config = get_quant_config(model_config, load_config) - if not is_tpu(): + if not current_platform.is_tpu(): capability = current_platform.get_device_capability() capability = capability[0] * 10 + capability[1] if capability < quant_config.get_min_capability(): @@ -320,7 +320,7 @@ def _get_weights_iterator( else: weights_iterator = pt_weights_iterator(hf_weights_files) - if is_tpu(): + if current_platform.is_tpu(): # In PyTorch XLA, we should call `xm.mark_step` frequently so that # not too many ops are accumulated in the XLA program. import torch_xla.core.xla_model as xm diff --git a/vllm/platforms/__init__.py b/vllm/platforms/__init__.py index eac917786bd6b..99ba940e5d2ab 100644 --- a/vllm/platforms/__init__.py +++ b/vllm/platforms/__init__.py @@ -1,22 +1,25 @@ -from typing import Optional - import torch -from vllm.utils import is_tpu - from .interface import Platform, PlatformEnum, UnspecifiedPlatform -current_platform: Optional[Platform] +current_platform: Platform -if torch.version.cuda is not None: +try: + import libtpu +except ImportError: + libtpu = None + +if libtpu is not None: + # people might install pytorch built with cuda but run on tpu + # so we need to check tpu first + from .tpu import TpuPlatform + current_platform = TpuPlatform() +elif torch.version.cuda is not None: from .cuda import CudaPlatform current_platform = CudaPlatform() elif torch.version.hip is not None: from .rocm import RocmPlatform current_platform = RocmPlatform() -elif is_tpu(): - from .tpu import TpuPlatform - current_platform = TpuPlatform() else: current_platform = UnspecifiedPlatform() diff --git a/vllm/utils.py b/vllm/utils.py index 30bb81722aa04..a758c78dc9c25 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -333,15 +333,6 @@ def is_neuron() -> bool: return transformers_neuronx is not None -@lru_cache(maxsize=None) -def is_tpu() -> bool: - try: - import libtpu - except ImportError: - libtpu = None - return libtpu is not None - - @lru_cache(maxsize=None) def is_xpu() -> bool: from importlib.metadata import PackageNotFoundError, version