Skip to content

Commit

Permalink
[hardware] unify usage of is_tpu to current_platform.is_tpu() (vllm-p…
Browse files Browse the repository at this point in the history
  • Loading branch information
youkaichao authored Aug 13, 2024
1 parent 7025b11 commit 4d2dc50
Show file tree
Hide file tree
Showing 8 changed files with 29 additions and 33 deletions.
5 changes: 2 additions & 3 deletions vllm/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
from vllm.attention.backends.abstract import AttentionBackend
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import (STR_BACKEND_ENV_VAR, is_cpu, is_hip, is_openvino,
is_tpu, is_xpu)
from vllm.utils import STR_BACKEND_ENV_VAR, is_cpu, is_hip, is_openvino, is_xpu

logger = init_logger(__name__)

Expand Down Expand Up @@ -194,7 +193,7 @@ def which_attn_to_use(
logger.info("Cannot use %s backend on XPU.", selected_backend)
return _Backend.IPEX

if is_tpu():
if current_platform.is_tpu():
if selected_backend != _Backend.PALLAS:
logger.info("Cannot use %s backend on TPU.", selected_backend)
return _Backend.PALLAS
Expand Down
7 changes: 4 additions & 3 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.model_executor.models import ModelRegistry
from vllm.platforms import current_platform
from vllm.tracing import is_otel_installed
from vllm.transformers_utils.config import get_config, get_hf_text_config
from vllm.utils import (STR_NOT_IMPL_ENC_DEC_CUDAGRAPH, GiB_bytes,
cuda_device_count_stateless, get_cpu_memory, is_cpu,
is_hip, is_neuron, is_openvino, is_tpu, is_xpu,
is_hip, is_neuron, is_openvino, is_xpu,
print_warning_once)

if TYPE_CHECKING:
Expand Down Expand Up @@ -282,7 +283,7 @@ def _verify_quantization(self) -> None:
raise ValueError(
f"{self.quantization} quantization is currently not "
f"supported in ROCm.")
if is_tpu(
if current_platform.is_tpu(
) and self.quantization not in tpu_supported_quantization:
raise ValueError(
f"{self.quantization} quantization is currently not "
Expand Down Expand Up @@ -910,7 +911,7 @@ def __init__(self, device: str = "auto") -> None:
self.device_type = "neuron"
elif is_openvino():
self.device_type = "openvino"
elif is_tpu():
elif current_platform.is_tpu():
self.device_type = "tpu"
elif is_cpu():
self.device_type = "cpu"
Expand Down
5 changes: 3 additions & 2 deletions vllm/executor/ray_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

from vllm.config import ParallelConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
from vllm.utils import get_ip, is_hip, is_tpu, is_xpu
from vllm.utils import get_ip, is_hip, is_xpu
from vllm.worker.worker_base import WorkerWrapperBase

logger = init_logger(__name__)
Expand Down Expand Up @@ -111,7 +112,7 @@ def initialize_ray_cluster(
# Placement group is already set.
return

device_str = "GPU" if not is_tpu() else "TPU"
device_str = "GPU" if not current_platform.is_tpu() else "TPU"
# Create placement group for worker processes
current_placement_group = ray.util.get_current_placement_group()
if current_placement_group:
Expand Down
5 changes: 3 additions & 2 deletions vllm/model_executor/custom_op.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch.nn as nn

from vllm.utils import is_cpu, is_hip, is_tpu, is_xpu
from vllm.platforms import current_platform
from vllm.utils import is_cpu, is_hip, is_xpu


class CustomOp(nn.Module):
Expand Down Expand Up @@ -54,7 +55,7 @@ def dispatch_forward(self):
return self.forward_hip
elif is_cpu():
return self.forward_cpu
elif is_tpu():
elif current_platform.is_tpu():
return self.forward_tpu
elif is_xpu():
return self.forward_xpu
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import torch.nn as nn

from vllm.model_executor.custom_op import CustomOp
from vllm.utils import is_tpu
from vllm.platforms import current_platform


def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -78,7 +78,7 @@ def __init__(
self.dtype = dtype

cache = self._compute_cos_sin_cache()
self.use_native2 = is_tpu() and is_neox_style
self.use_native2 = current_platform.is_tpu() and is_neox_style
if not self.use_native2:
cache = cache.to(dtype)
self.register_buffer("cos_sin_cache", cache, persistent=False)
Expand Down
6 changes: 3 additions & 3 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
supports_vision)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import is_pin_memory_available, is_tpu
from vllm.utils import is_pin_memory_available


@contextmanager
Expand Down Expand Up @@ -94,7 +94,7 @@ def _get_quantization_config(
"""Get the quantization config."""
if model_config.quantization is not None:
quant_config = get_quant_config(model_config, load_config)
if not is_tpu():
if not current_platform.is_tpu():
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
if capability < quant_config.get_min_capability():
Expand Down Expand Up @@ -320,7 +320,7 @@ def _get_weights_iterator(
else:
weights_iterator = pt_weights_iterator(hf_weights_files)

if is_tpu():
if current_platform.is_tpu():
# In PyTorch XLA, we should call `xm.mark_step` frequently so that
# not too many ops are accumulated in the XLA program.
import torch_xla.core.xla_model as xm
Expand Down
21 changes: 12 additions & 9 deletions vllm/platforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
from typing import Optional

import torch

from vllm.utils import is_tpu

from .interface import Platform, PlatformEnum, UnspecifiedPlatform

current_platform: Optional[Platform]
current_platform: Platform

if torch.version.cuda is not None:
try:
import libtpu
except ImportError:
libtpu = None

if libtpu is not None:
# people might install pytorch built with cuda but run on tpu
# so we need to check tpu first
from .tpu import TpuPlatform
current_platform = TpuPlatform()
elif torch.version.cuda is not None:
from .cuda import CudaPlatform
current_platform = CudaPlatform()
elif torch.version.hip is not None:
from .rocm import RocmPlatform
current_platform = RocmPlatform()
elif is_tpu():
from .tpu import TpuPlatform
current_platform = TpuPlatform()
else:
current_platform = UnspecifiedPlatform()

Expand Down
9 changes: 0 additions & 9 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,15 +333,6 @@ def is_neuron() -> bool:
return transformers_neuronx is not None


@lru_cache(maxsize=None)
def is_tpu() -> bool:
try:
import libtpu
except ImportError:
libtpu = None
return libtpu is not None


@lru_cache(maxsize=None)
def is_xpu() -> bool:
from importlib.metadata import PackageNotFoundError, version
Expand Down

0 comments on commit 4d2dc50

Please sign in to comment.