Skip to content

Commit

Permalink
[4/N][torch.compile] clean up set_torch_compile_backend (vllm-project…
Browse files Browse the repository at this point in the history
…#10401)

Signed-off-by: youkaichao <[email protected]>
  • Loading branch information
youkaichao authored Nov 18, 2024
1 parent 47826ca commit 51bb12d
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 42 deletions.
16 changes: 2 additions & 14 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@
import dataclasses
import operator
from contextlib import ExitStack
from typing import (Any, Callable, Dict, List, Optional, Sequence, Set, Tuple,
Union)
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple
from unittest.mock import patch

import torch
import torch.fx as fx

import vllm.envs as envs
from vllm.config import CompilationConfig, CompilationLevel
from vllm.config import CompilationConfig
from vllm.logger import init_logger
from vllm.utils import combine_fx_passes, weak_ref_tensors

Expand Down Expand Up @@ -684,14 +683,3 @@ def __call__(self, *args) -> Any:

entry.cudagraph.replay()
return entry.output


def select_default_backend(level: int) -> Union[str, Callable]:
if level in [CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE]:
backend_str = "eager"
return backend_str
assert level == CompilationLevel.PIECEWISE

from vllm.plugins import get_current_vllm_config
compilation_config = get_current_vllm_config().compilation_config
return VllmBackend(compilation_config)
11 changes: 3 additions & 8 deletions vllm/compilation/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,9 @@ def __init__(self,
# default compilation settings
# compiling the forward method

# choose the compile backend

# if the user has set the backend, use it
from vllm.plugins import get_torch_compile_backend
backend = get_torch_compile_backend()
if backend is None:
from vllm.compilation.backends import select_default_backend
backend = select_default_backend(compilation_level)
from vllm.plugins import get_current_vllm_config
backend = get_current_vllm_config(
).compilation_config.init_backend()

compiled_callable = torch.compile(
self.forward,
Expand Down
31 changes: 30 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
get_hf_text_config, get_pooling_config,
get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope)
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
identity, print_warning_once)
identity, print_warning_once, resolve_obj_by_qualname)

if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
Expand Down Expand Up @@ -2072,6 +2072,13 @@ class CompilationConfig(BaseModel):
- 1: dynamo as is.
- 2: dynamo once.
- 3: piecewise compilation.
- backend: the backend for compilation. It needs to be a string.
- "" (empty string): use the default backend.
- "eager"/"openxla"/...: use the specified backend registered in PyTorch.
- "full.module.name": a qualified name which can be used to import the backend function.
We use string to avoid serialization issues when using compilation in a distributed setting.
When the compilation level is 1 or 2, the backend is used for the compilation directly (it sees the whole graph).
When the compilation level is 3, the backend is used for the piecewise compilation (it sees a part of the graph).
- custom_ops: fine-grained control over which custom ops to enable/disable.
Use 'all' to enable all, 'none' to disable all.
Also specify a list of custom op names to enable (prefixed with a '+'),
Expand Down Expand Up @@ -2139,6 +2146,7 @@ class CompilationConfig(BaseModel):
certain small batchsizes, where inductor is good at optimizing.
""" # noqa
level: int = 0
backend: str = ""
custom_ops: List[str] = Field(default_factory=list)

use_inductor: bool = True
Expand Down Expand Up @@ -2182,6 +2190,27 @@ def model_post_init(self, __context: Any) -> None:
func = __import__(module).__dict__[func_name]
self.inductor_compile_config[k] = func

def init_backend(self) -> Union[str, Callable]:
if self.level == CompilationLevel.NO_COMPILATION:
raise ValueError("No compilation level is set.")

from torch._dynamo.backends.registry import list_backends
torch_backends = list_backends(exclude_tags=tuple())
if self.level in [
CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE
]:
if self.backend == "":
return "eager"
if self.backend in torch_backends:
return self.backend
return resolve_obj_by_qualname(self.backend)

# TODO: pass user-specified backend to piecewise compilation
# merge with the config use_inductor
assert self.level == CompilationLevel.PIECEWISE
from vllm.compilation.backends import VllmBackend
return VllmBackend(self)

def init_during_runtime(self):
"""To complete the initialization of config,
we need to know the compile context, which is only available
Expand Down
7 changes: 3 additions & 4 deletions vllm/platforms/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,13 @@

import torch

from vllm.plugins import set_torch_compile_backend

from .interface import Platform, PlatformEnum

if TYPE_CHECKING:
from vllm.config import VllmConfig
else:
VllmConfig = None

set_torch_compile_backend("openxla")


class TpuPlatform(Platform):
_enum = PlatformEnum.TPU
Expand All @@ -38,3 +34,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
compilation_config.level = CompilationLevel.DYNAMO_ONCE
assert compilation_config.level < CompilationLevel.PIECEWISE,\
"TPU does not support Inductor."

if compilation_config.backend == "":
compilation_config.backend = "openxla"
14 changes: 1 addition & 13 deletions vllm/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from contextlib import contextmanager
from typing import TYPE_CHECKING, Callable, Optional, Union
from typing import TYPE_CHECKING, Optional

import vllm.envs as envs

Expand Down Expand Up @@ -50,18 +50,6 @@ def load_general_plugins():
logger.exception("Failed to load plugin %s", plugin.name)


_torch_compile_backend: Optional[Union[Callable, str]] = None


def set_torch_compile_backend(backend: Union[Callable, str]):
global _torch_compile_backend
_torch_compile_backend = backend


def get_torch_compile_backend() -> Optional[Union[Callable, str]]:
return _torch_compile_backend


_compilation_config: Optional[CompilationConfig] = None


Expand Down
9 changes: 9 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1600,3 +1600,12 @@ def direct_register_custom_op(
my_lib.impl(op_name, op_func, "CUDA")
if fake_impl is not None:
my_lib._register_fake(op_name, fake_impl)


def resolve_obj_by_qualname(qualname: str) -> Any:
"""
Resolve an object by its fully qualified name.
"""
module_name, obj_name = qualname.rsplit(".", 1)
module = importlib.import_module(module_name)
return getattr(module, obj_name)
3 changes: 1 addition & 2 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,8 +1143,7 @@ def load_model(self) -> None:

if self.vllm_config.compilation_config.level ==\
CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
from vllm.plugins import get_torch_compile_backend
backend = get_torch_compile_backend() or "eager"
backend = self.vllm_config.compilation_config.init_backend()
self.model = torch.compile(
self.model,
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
Expand Down

0 comments on commit 51bb12d

Please sign in to comment.