Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Yet another Prefill-Decode separation in vllm #9079

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

chenqianfzh
Copy link
Contributor

@chenqianfzh chenqianfzh commented Oct 4, 2024

This PoC demonstrates an implementation of Prefill-Decode separation within vLLM, leveraging a memory pool to store KV caches accessed via RDMA connections. The primary goal is to enhance data transfer efficiency by optimizing how KV caches are managed and transmitted.

@KuntaiDu @youkaichao @thesues wonder if you could take a moment to review our implementation and share your valuable feedback. Your insights and opinions would be greatly appreciated.

Key Changes to vllm

Comparing to #8498 (Thanks to your seminal work @KuntaiDu, which inspired us a lot ), this implementation optimizes data transfer efficiency. Instead of utilizing a single large tensor for KV caches per layer, the caches are divided into smaller blocks. This granular approach enables parallel transmission of multiple layers as well as multiple blocks within one layer.

To coordinate different layers effectively, changes are made at the model level rather than the entire model_runner. This implementation showcases modifications for OPT and LLaMA models, establishing a pattern that can be easily extended to other models.


Introducing InfiniteStore

Main contributor: @thesues
The core functionality for sending and receiving KV caches via RDMA is encapsulated in our separate package, InfiniteStore. We are more than happy to donate this project to the community.

Infinitestore Features:

  • High Throughput and Low Latency:
    • Utilizes RDMA for memory copying between GPU and CPU across servers.
    • Maximizes data transfer throughput while minimizing latency.
  • Dynamic Cluster Management:
    • Detects the hottest memory pages and dynamically distributes them across nodes to balance the load.
  • Asynchronous and Batched APIs:
    • Designed to handle asynchronous operations and support batching, tailored specifically for vLLM’s requirements.
  • Python Integration:
    • Integrates with Python's uvloop library, facilitating easy extension of management functionalities in Python.
    • Implements data communication in C++ for performance-critical operations.

API design
The load/store API is meticulously crafted to align with vLLM’s KV cache characteristics, aiming to parallelize computation with KV cache saving and loading operations. The APIs are asynchronous and support batch processing.The definitions are as follows:

def write_cache(self, cache : torch.Tensor, blocks: List[Tuple[str, int]], page_size: int)
def read_cache(self, cache : torch.Tensor, blocks: List[Tuple[str, int]], page_size: int)

Once the load/store requests are all sent, the sync function from the block API is called to ensure that the processing will be done before moving on.

def sync(self)

As shown in the diagram below, during the prefill computation process, only the final call to the sync API ensures that the kvcache has been fully written to the remote storage.
infinitystore

Copy link

github-actions bot commented Oct 4, 2024

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@KuntaiDu
Copy link
Collaborator

KuntaiDu commented Oct 4, 2024

@zeroorhero implemented a similar kv store-based solution, and me myself also believe that kv store is the way to go for disaggregated prefilling. So yes!

@chenqianfzh
Copy link
Contributor Author

@zeroorhero implemented a similar kv store-based solution, and me myself also believe that kv store is the way to go for disaggregated prefilling. So yes!

Thanks for your reply.

Would appreciate any comments to improve?

So far only a kv cache transporter of RDMA is created. Do you think adding a transporter using CPU memory will make this PR more likely to be accepted?

@WangErXiao
Copy link

I have a question. How can we guarantee the model updates correctly when using Prefill-Decode separation in production? @chenqianfzh . And how to install InfiniteStore using pip?

@kuangdao
Copy link

kuangdao commented Oct 9, 2024

m

@thesues
Copy link
Contributor

thesues commented Oct 9, 2024

I have a question. How can we guarantee the model updates correctly when using Prefill-Decode separation in production? @chenqianfzh . And how to install InfiniteStore using pip?

I am working on 'pip install InfiniteStore', it could be available in this week.

@liweiqing1997
Copy link

Hello, could you provide the way to start the program?

@chenqianfzh
Copy link
Contributor Author

Hello, could you provide the way to start the program?

the script examples/infinitestore_pd_separate.sh in this PR demonstrates how to start the two vllm instances for prefill and decode respectively in a host with multiple hosts.

Before that, make sure the infinitestore is installed and started by running the start.sh.

Please let me know if you run into any problems.

@chenqianfzh
Copy link
Contributor Author

I have a question. How can we guarantee the model updates correctly when using Prefill-Decode separation in production? @chenqianfzh . And how to install InfiniteStore using pip?

Sorry, I missed the first question from you.

I guess u are asking how I verified the kv_cache are updated correctly across the different vllms of prefill and decode. I verifed it by comparing the generated result with the vanilla vllm. The results are exactly the same.

@thesues
Copy link
Contributor

thesues commented Oct 13, 2024

Hello, could you provide the way to start the program?

you could use pip install infinistore and run infinistore or python -m infinistore.server to start mempool daemon. and this app requires kernel nv_peer_mem

@lixiaolx
Copy link

you could use pip install infinistore and run infinistore or python -m infinistore.server to start mempool daemon. and this app requires kernel nv_peer_mem

@chenqianfzh @thesues I have three questions:

first:After I installed according to the above command, a circular reference occurred when using the import command. Can you help me find out the reason?
image

second: In this PR, do the places where infinity is used need to be replaced with infinistore?

third:Is multi-machine deployment currently supported? Or a single machine with multiple cards, for example (prefill-2gpus vs decode-2gpus)?

@thesues
Copy link
Contributor

thesues commented Oct 14, 2024

you could use pip install infinistore and run infinistore or python -m infinistore.server to start mempool daemon. and this app requires kernel nv_peer_mem

@chenqianfzh @thesues I have three questions:

first:After I installed according to the above command, a circular reference occurred when using the import command. Can you help me find out the reason? image

second: In this PR, do the places where infinity is used need to be replaced with infinistore?

third:Is multi-machine deployment currently supported? Or a single machine with multiple cards, for example (prefill-2gpus vs decode-2gpus)?

  1. this is what I did:
conda create -n test python=3.11
conda activate test
pip install infinistore==0.1.73
python -m infinistore.server

maybe you could provide more information of this error? or check if package has _infinistore.cpython-311-x86_64-linux-gnu.so?

2 . I updated infinistore's API recently, so this PR should be updated as well, I will ping you as soon as it is ready.
3. this PR not support this yet.

@lixiaolx
Copy link

  1. this is what I did:
conda create -n test python=3.11
conda activate test
pip install infinistore==0.1.73
python -m infinistore.server

maybe you could provide more information of this error? or check if package has _infinistore.cpython-311-x86_64-linux-gnu.so?

First, I checked my installation directory and found what you said .so
image

image

Secondly, I am using python3.10. Can I build this package myself in the 3.10 environment?

@lixiaolx
Copy link

  1. this PR not support this yet.

@thesues Do you have any plans to do this? Or is it already being done?

@thesues
Copy link
Contributor

thesues commented Oct 15, 2024

  1. this is what I did:
conda create -n test python=3.11
conda activate test
pip install infinistore==0.1.73
python -m infinistore.server

maybe you could provide more information of this error? or check if package has _infinistore.cpython-311-x86_64-linux-gnu.so?

First, I checked my installation directory and found what you said .so image

image

Secondly, I am using python3.10. Can I build this package myself in the 3.10 environment?

image
now infinistore supports python 3.10, 3.11 and 3.12, you could pip install infinistore == 0.1.74 to try this on python3.10

As for cluster level deployment, it is still in design phase. we could use python asynio io to all manage API or do heartbeats. Do you have any ideas or suggestions to share?

@Luis-xu
Copy link

Luis-xu commented Oct 17, 2024

Hi, I encountered the following problem when running the infinistore service:
Server config: ServerConfig(service_port=22345, manage_port=18080, log_level='warning') Traceback (most recent call last): File "/usr/local/bin/infinistore", line 8, in <module> sys.exit(main()) File "/usr/local/lib/python3.10/dist-packages/infinistore/server.py", line 82, in main check_supported() File "/usr/local/lib/python3.10/dist-packages/infinistore/lib.py", line 123, in check_supported raise Exception("nv_peer_mem module is not loaded") Exception: nv_peer_mem module is not loaded
And, I checked the relevant files and found that the Nvidia driver name in the /proc/modules is “nvidia_peermem”. How should I solve this problem? Should I install a new version of “nv_peer_mem“ on the physical machine?
2024-10-17_11-18

@Luis-xu
Copy link

Luis-xu commented Oct 17, 2024

Hi, I encountered the following problem when running the infinistore service: Server config: ServerConfig(service_port=22345, manage_port=18080, log_level='warning') Traceback (most recent call last): File "/usr/local/bin/infinistore", line 8, in <module> sys.exit(main()) File "/usr/local/lib/python3.10/dist-packages/infinistore/server.py", line 82, in main check_supported() File "/usr/local/lib/python3.10/dist-packages/infinistore/lib.py", line 123, in check_supported raise Exception("nv_peer_mem module is not loaded") Exception: nv_peer_mem module is not loaded And, I checked the relevant files and found that the Nvidia driver name in the /proc/modules is “nvidia_peermem”. How should I solve this problem? Should I install a new version of “nv_peer_mem“ on the physical machine? 2024-10-17_11-18

I have solved this problem by modifying the check_supported function in lib.py (adding nvidia_peermem keyword to the judgment condition)

def check_supported():
    if "nv_peer_mem" not in _kernel_modules() and "nvidia_peermem" not in _kernel_modules():
        raise Exception("nv_peer_mem module is not loaded")
    _check_rdma_devices_ibv()

@liweiqing1997
Copy link

我的机器没有Mellanox ConnectX-3 VPI 或 Connect-IB InfiniBand 适配器。那么我可以在单机4卡A100上使用这个PR嘛?

@chenkaiyue
Copy link

chenkaiyue commented Oct 23, 2024

Have you encountered this problem before? This occured when request pressure is high. THX
image

@thesues
Copy link
Contributor

thesues commented Oct 23, 2024

Have you encountered this problem before? This occured when request pressure is high.

what's your infinistore version? I think newer version 0.1.82 solved this. before this version, infinistore has a limited RMDA queue length. the new version will has a thread which drain CQ asynchronously.

@thesues
Copy link
Contributor

thesues commented Oct 23, 2024

    if "nv_peer_mem" not in _kernel_modules() and "nvidia_peermem" not in _kernel_modules():

thank you, I did this update in infinistore ;-)

@Luis-xu
Copy link

Luis-xu commented Oct 31, 2024

Hi, @chenqianfzh @thesues, can you give me some test results of this PR?

@WilliamEricCheung
Copy link

Hi, do you plan to support other transporter like #8498 (torch distributed pipe)? your layer-wise idea seems very efficient. Could you give me your contact method? (Wechat / slack/ email, all be fine)

@thesues
Copy link
Contributor

thesues commented Nov 5, 2024

Hi, do you plan to support other transporter like #8498 (torch distributed pipe)? your layer-wise idea seems very efficient. Could you give me your contact method? (Wechat / slack/ email, all be fine)

I could discuss this on https://vllm-dev.slack.com/archives/C07VCUQLE1F or https://github.com/bd-iaas-us/infiniStore

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants