Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core]: (2/N) Support prefill only models by Workflow Defined Engine - Prefill only attention #9124

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

noooop
Copy link
Contributor

@noooop noooop commented Oct 7, 2024

Previously

What is this pr going to do?

This PR focuses on the following new features

attention

  • Prefill only models requires simpler attention implementations, no need to consider kvcache, no decoding phase
  • We need to support enable_bidirectional flag manually or read hf config automatically, enable bidirectional.

Proposed Change.

  1. Prefill only attention backend implementations
  2. Attention adds optional parameter attn_backend
  3. Show how to support enable bidirectional

Prefill only attention backend implementations

  • Flash Attention Backend
  • Torch SDPA Backend
  • XFormers Backend
  • FlashInfer Backend (Because prefill only models do not involve kv cache, When using Flashinfer backend in prefill only models, you are actually using FLASH ATTN backend
  • Torch naive backend

Adding the atten parameter

class Attention(nn.Module):
    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: Optional[int] = None,
        alibi_slopes: Optional[List[float]] = None,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        blocksparse_params: Optional[Dict[str, Any]] = None,
        logits_soft_cap: Optional[float] = None,
        prefix: str = "",
        attn_backend: Optional[AttentionBackend] = None,    # <-Newly added parameters
    ) -> None:

Why do you need this?

  • We need to support enable_bidirectional flag manually or read hf config automatically, enable bidirectional.
  • We need all attn_backends to come from the same instance, and then we only need to install a enable bidirectional switch on this instance.
  • When we operate this switch, we can ensure that all attn_backends are switched
  • The current get_attn_backend is too magical to fine-grained control.

How to support enable bidirectional

The following is not the focus of this pr, but we will need to discuss it sooner or later. It's best to put it here first.

In the proof-of-concept branch #8452

1.https://github.com/noooop/vllm/blob/653794e37db5af6a8951d927f2a231a67531bea0/vllm/wde/retriever/modelzoo/gte_qwen/arg_utils.py#L23-L25

        elif "gte-Qwen2-7B-instruct" in self.model:
            self.output_last_hidden_states = True
            self.enable_bidirectional = True

2.https://github.com/noooop/vllm/blob/653794e37db5af6a8951d927f2a231a67531bea0/vllm/wde/decode_only/workflow.py#L15C1-L17C47

            if engine.engine_config.model_config.enable_bidirectional:
                workflow.attn_type = "ENCODER"
            else:
                workflow.attn_type = "DECODER"

3.https://github.com/noooop/vllm/blob/653794e37db5af6a8951d927f2a231a67531bea0/vllm/wde/prefill_only/layers/attention/selector.py#L51-L59

        backend = cls.which_attn_to_use(num_heads, head_size, num_kv_heads,
                                        sliding_window, dtype)

        backend_cls = cls.get_backend_cls(backend)

        attn_type = AttentionType.attn_type_name_to_enum(
            engine.workflow.attn_type)

        return backend_cls(attn_type)

4.https://github.com/noooop/vllm/blob/653794e37db5af6a8951d927f2a231a67531bea0/vllm/wde/core/llm_engine.py#L83C1-L84C57

        self.attn_backend = lazy_import(
            self.workflow.AttnBackend).from_engine(self)

5.https://github.com/noooop/vllm/blob/653794e37db5af6a8951d927f2a231a67531bea0/vllm/wde/prefill_only/runner/model_runner.py#L35C1-L50C1

    def load_model(self) -> None:
        from vllm.wde.core.loader.loader import (get_model_loader,
                                                 initialize_model)

        logger.info("Starting to load model %s...", self.model_config.model)
        with DeviceMemoryProfiler() as m:
            loader = get_model_loader(self.load_config)
            self.model = initialize_model(model_config=self.model_config,
                                          load_config=self.load_config,
                                          device_config=self.device_config,
                                          attn_backend=self.attn_backend)

            loader.load_model(self.model,
                              model_config=self.model_config,
                              device_config=self.device_config)

6.https://github.com/noooop/vllm/blob/653794e37db5af6a8951d927f2a231a67531bea0/vllm/wde/decode_only/modelzoo/qwen2.py#L313C1-L319C15

    def __init__(
        self,
        config: Qwen2Config,
        attn_backend: AttentionBackend,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
    ) -> None:

7.https://github.com/noooop/vllm/blob/653794e37db5af6a8951d927f2a231a67531bea0/vllm/wde/decode_only/modelzoo/qwen2.py#L141C1-L147C57

        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
                              num_kv_heads=self.num_kv_heads,
                              cache_config=cache_config,
                              quant_config=quant_config,
                              attn_backend=attn_backend)

Briefly introduce

What new models need to be supported

These models are all from issues and are also very famous:

  • xlm_roberta
  • bge-m3
  • bge-reranker-v2-m3
  • bert
  • bge v1.5 family
  • Snowflake Arctic Embed (Family)
  • gte-Qwen2
  • This list is still growing

These models is roughly divided into three categories:

  • Encode only models. (Bidirectional Transformers, causal=False), Often fine-tuned as retriever and reranker etc.
  • Decode only models. (masked multi-head attention, causal=True). There are two interesting uses:
    • Output last hidden states as a feature extractor
    • Decode only retriever (I don't know of a better name),E.g. e5-mistral-7b (The only Embed model currently supported by vllm)
    • Whether it has been fine-tuned or not, there is almost no difference in the code.
  • Enable bidirectional. LLM2Vec propose a simple unsupervised approach that can transform any decoder-only LLM into a strong text encoder.

What new features these new models have

What the above three categories have in common is that there is only the prefill stage. In order to make the terminology more precise, prefill only is used below.

You can think of prefill only as encode only fancy writing.

New features:

  1. attention
    • Prefill only models requires simpler attention implementations, no need to consider kvcache, no decoding phase
    • We need to support enable_bidirectional flag manually or read hf config automatically, enable bidirectional.
  2. scheduler
    • Prefill only models requires simpler scheduler, no need to consider kvcache and preemption
    • Prefill only models, there is no correlation between tasks, so it is easy to implement async scheduling
  3. executer
    • In order to support async scheduling, model_input_builder needs to be separated from the runner.
    • The main thread executes scheduling and all CPU processing, and the gpu thread only executes h2d, execution model, d2h
    • If async scheduling and async execution are implemented, data parallelism is also easy to implement. Data parallelism is more efficient for small models

How engine Architecture needs to support these features flexibly and efficiently.

If we directly add new functions to existing modules, these modules are becoming increasingly complex, and sometimes new features must be compromised for compatibility. ultimately leading to suboptimal results

The most flexible and efficient way to support the prefill only models is to implement different modules for models of different architectures and load the required modules on demand.

I call this architecture Workflow Defined Engine, or WDE for short.

I divided the Engine into the following modules.

  • InputProcessor: The llm models inputs strings, the reranker inputs pairs, and the multimodal model input is more complex...
  • OutputProcessor: The retriever(embedding) models output embeddings, reranker models and classification models output Scores...
  • ModelInputBuilder: Building model inputs and attention metadata
  • AttnBackend: Support different AttnBackend and enable bidirectional
  • Tokenizer: There may be different tokenizers
  • Executor: Sync\Async\TP\PP\DP\maybe more
  • Worker & runner: Support different devices\maybe more
  • EngineArgs: Different models, different config may accept different parameters
  • maybe more

With wde, there is no need for one module to be compatible with all functions. You can use the dynamic loading feature of python to load different modules at the highest level, for different models and different needs.

  • Modules can be configured through Workflow, plug and play
  • Flexibly support plug-ins, and developers can load their own modules.
  • Workflow is really the best place to hide dirty codes.
    Some models cannot use the common Workflow. When you don’t know where to put the dirty code, you can always create a new workflow and link the model architecture to the new workflow to avoid leaving dirty code everywhere for the sake of compatibility.

Copy link

github-actions bot commented Oct 7, 2024

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@noooop noooop marked this pull request as ready for review October 7, 2024 12:14
@DarkLight1337 DarkLight1337 marked this pull request as draft October 9, 2024 08:44
@noooop noooop changed the title [Core]: (1/N) Support prefill only models by Workflow Defined Engine - Prefill only attention [Core]: (2/N) Support prefill only models by Workflow Defined Engine - Prefill only attention Oct 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant