Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PoC]: Support encode only models by Workflow Defined Engine #8452

Draft
wants to merge 43 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
a6a003b
wde core, xlm_roberta, bge-m3
noooop Sep 13, 2024
fdfa0ff
tests xlm_roberta, bge-m3
noooop Sep 13, 2024
62bf07d
demos
noooop Sep 13, 2024
15348da
update tests
noooop Sep 14, 2024
046b4e9
support bge-reranker-v2-m3
noooop Sep 14, 2024
6c9ca52
yapf
noooop Sep 18, 2024
6d7fef0
yapf
noooop Sep 18, 2024
dd976e7
yapf
noooop Sep 18, 2024
93503b7
support torch spda backend
noooop Sep 18, 2024
81c6936
support bert
noooop Sep 18, 2024
41214ab
yapf & ruff
noooop Sep 19, 2024
b06fc40
yapf & ruff
noooop Sep 19, 2024
1059e65
yapf & ruff
noooop Sep 19, 2024
6000cdf
isort
noooop Sep 19, 2024
e838179
yapf & ruff
noooop Sep 19, 2024
f2f4a09
isort
noooop Sep 19, 2024
a4d01fb
yapf & ruff
noooop Sep 19, 2024
d421ccd
isort
noooop Sep 19, 2024
7f6675d
support xformers
noooop Sep 19, 2024
f953e86
support bge v1.5 family
noooop Sep 19, 2024
b30e9f2
ruff & isort & yapf
noooop Sep 20, 2024
d3ceeb0
support xformers
noooop Sep 20, 2024
0ce2482
support Flashinfer backend
noooop Sep 20, 2024
7cede53
support Torch naive backend
noooop Sep 20, 2024
6736e0d
support Snowflake Arctic Embed (Family)
noooop Sep 20, 2024
a992ba0
ready to support enable_bidirectional
noooop Sep 23, 2024
ed5f433
ruff & isort & yapf
noooop Sep 24, 2024
b14d864
fix decoder only attention basic_correctness
noooop Sep 24, 2024
aa3ef30
support output_last_hidden_states
noooop Sep 24, 2024
5032def
support e5-mistral-7b
noooop Sep 24, 2024
4b12e30
misc
noooop Sep 24, 2024
2649945
misc
noooop Sep 24, 2024
3a5b30c
mv e5-mistral-7 to retriever
noooop Sep 25, 2024
5532932
support gte-Qwen2
noooop Sep 25, 2024
bf477a9
ruff
noooop Sep 25, 2024
4820daa
support gte-Qwen2-7B-instruct
noooop Sep 26, 2024
aa302ba
misc
noooop Sep 26, 2024
60931e8
support data parallelism
noooop Sep 26, 2024
efef0b3
dirty fix distributed environment
noooop Sep 27, 2024
a45e867
dirty fix destroy model parallel
noooop Sep 27, 2024
236980d
dirty fix pp environment
noooop Sep 27, 2024
653da13
Merge branch 'vllm-project:main' into wde_encode_only
noooop Sep 29, 2024
653794e
catch up vllm main
noooop Sep 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added demo_temporary/__init__.py
Empty file.
Empty file.
102 changes: 102 additions & 0 deletions demo_temporary/benchmarks/benchmark_attention_impl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import os
import random
import time


def benchmark_vllm(args):
random.seed(args.seed)
os.environ["VLLM_ATTENTION_BACKEND"] = args.attention_impl

import gc

import torch

from vllm.wde.encode_only.arg_utils import ( # noqa: E501
EncodeOnlyEngineArgs as EngineArgs)
from vllm.wde.entrypoints.llm import LLMEngine

prompt = "if" * args.input_len
requests = [prompt for _ in range(args.num_prompts)]

engine_args = EngineArgs(model=args.model,
tokenizer=args.tokenizer,
seed=args.seed,
trust_remote_code=args.trust_remote_code,
dtype=args.dtype,
max_model_len=args.max_model_len,
device=args.device,
max_num_seqs=32,
scheduling=args.scheduling)

engine = LLMEngine.from_engine_args(engine_args)

for batchsize in args.batchsize:
engine.engine_config.scheduler_config.set_args(max_num_seqs=batchsize)

start = time.perf_counter()
for request_id, prompt in enumerate(requests):
engine.add_request(str(request_id), prompt)

n_step = 0
while engine.has_unfinished_requests():
engine.step()
n_step += 1
end = time.perf_counter()

elapsed_time = end - start
delay = elapsed_time / n_step

print(f"Batchsize {batchsize}, Throughput: "
f"{len(requests) / elapsed_time:.4f} requests/s, "
f"Delay {delay * 1000:0.2f} ms, n_step {n_step}")

engine.executor.shutdown_execute_loop()
gc.collect()
torch.cuda.empty_cache()


if __name__ == '__main__':
from easydict import EasyDict as edict
args = edict()

args.input_len = 256
args.num_prompts = 10000

args.model = 'BAAI/bge-m3'

args.trust_remote_code = False
args.tokenizer = args.model
args.seed = 0
args.max_model_len = None
args.device = "cuda"
args.batchsize = [1, 2, 4, 8, 16, 32, 64]
args.scheduling = "double_buffer"

from concurrent.futures import ProcessPoolExecutor

def run_vllm(args):
with ProcessPoolExecutor(1) as executor:
f = executor.submit(benchmark_vllm, args)
f.result()

AttentionImpls_fp32 = ["TORCH_SDPA", "XFORMERS", "TORCH_NAIVE"]
AttentionImpls_fp16 = [
"FLASH_ATTN", "TORCH_SDPA", "XFORMERS", "FLASHINFER", "TORCH_NAIVE"
]
AttentionImpls_bf16 = [
"FLASH_ATTN", "TORCH_SDPA", "XFORMERS", "FLASHINFER", "TORCH_NAIVE"
]

AttentionImpls = {
"float": AttentionImpls_fp32,
"half": AttentionImpls_fp16,
"bfloat16": AttentionImpls_bf16,
}

for dtype, attention_impls in AttentionImpls.items():
print("dtype:", dtype)
for attention_impl in attention_impls:
print("attention_impl:", attention_impl)
args.attention_impl = attention_impl
args.dtype = dtype
run_vllm(args)
119 changes: 119 additions & 0 deletions demo_temporary/benchmarks/benchmark_bge-m3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import random
import time


def benchmark_hf(args):
random.seed(args.seed)

import torch
from FlagEmbedding import BGEM3FlagModel

model = BGEM3FlagModel(args.model, use_fp16=True)

prompt = "if" * args.input_len
requests = [prompt for _ in range(args.num_prompts)]

with torch.no_grad():
for batchsize in args.batchsize:
start = time.perf_counter()
n_step = 0
for i in range(0, len(requests), batchsize):
batch = requests[i:i + batchsize]
model.encode(batch, batch_size=batchsize)
n_step += 1
end = time.perf_counter()

elapsed_time = end - start
delay = elapsed_time / n_step

print(f"Batchsize {batchsize}, Throughput: "
f"{len(requests) / elapsed_time:.4f} requests/s, "
f"Delay {delay * 1000:0.2f} ms, n_step {n_step}")


def benchmark_vllm(args):
random.seed(args.seed)

import gc

import torch

from vllm.wde.encode_only.arg_utils import ( # noqa: E501
EncodeOnlyEngineArgs as EngineArgs)
from vllm.wde.entrypoints.llm import LLMEngine

prompt = "if" * args.input_len
requests = [prompt for _ in range(args.num_prompts)]

engine_args = EngineArgs(model=args.model,
tokenizer=args.tokenizer,
seed=args.seed,
trust_remote_code=args.trust_remote_code,
dtype=args.dtype,
max_model_len=args.max_model_len,
device=args.device,
max_num_seqs=32,
scheduling=args.scheduling)

engine = LLMEngine.from_engine_args(engine_args)

for batchsize in args.batchsize:
engine.engine_config.scheduler_config.set_args(max_num_seqs=batchsize)

start = time.perf_counter()
for request_id, prompt in enumerate(requests):
engine.add_request(str(request_id), prompt)

n_step = 0
while engine.has_unfinished_requests():
engine.step()
n_step += 1
end = time.perf_counter()

elapsed_time = end - start
delay = elapsed_time / n_step

print(f"Batchsize {batchsize}, Throughput: "
f"{len(requests) / elapsed_time:.4f} requests/s, "
f"Delay {delay * 1000:0.2f} ms, n_step {n_step}")

engine.executor.shutdown_execute_loop()
gc.collect()
torch.cuda.empty_cache()


if __name__ == '__main__':
from easydict import EasyDict as edict
args = edict()

args.input_len = 256
args.num_prompts = 10000

args.model = 'BAAI/bge-m3'

args.trust_remote_code = False
args.tokenizer = args.model
args.seed = 0
args.max_model_len = None
args.dtype = "half"
args.device = "cuda"
args.batchsize = [1, 2, 4, 8, 16, 32, 64]

from concurrent.futures import ProcessPoolExecutor

def run_hf(args):
with ProcessPoolExecutor(1) as executor:
f = executor.submit(benchmark_hf, args)
f.result()

run_hf(args)

def run_vllm(args):
with ProcessPoolExecutor(1) as executor:
f = executor.submit(benchmark_vllm, args)
f.result()

for scheduling in ["sync", "async", "double_buffer"]:
print(scheduling)
args.scheduling = scheduling
run_vllm(args)
89 changes: 89 additions & 0 deletions demo_temporary/benchmarks/benchmark_data_parallelism.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import random
import time


def benchmark(args):
random.seed(args.seed)

import gc

import torch

from vllm.wde.encode_only.arg_utils import ( # noqa: E501
EncodeOnlyEngineArgs as EngineArgs)
from vllm.wde.entrypoints.llm import LLMEngine

prompt = "if" * args.input_len
requests = [prompt for _ in range(args.num_prompts)]

engine_args = EngineArgs(model=args.model,
tokenizer=args.tokenizer,
seed=args.seed,
trust_remote_code=args.trust_remote_code,
dtype=args.dtype,
max_model_len=args.max_model_len,
device=args.device,
max_num_seqs=32,
scheduling=args.scheduling,
data_parallel_size=args.data_parallel_size)

engine = LLMEngine.from_engine_args(engine_args)

for batchsize in args.batchsize:
engine.engine_config.scheduler_config.set_args(max_num_seqs=batchsize)
engine.executor.ensure_start_execute_loop()

# Because each thread has to load the model separately,
# the loading may not be completed here.
# If it is run only once, the measured data parallel speed will be low.

for i in range(3):
start = time.perf_counter()
for request_id, prompt in enumerate(requests):
engine.add_request(str(request_id), prompt)

while engine.has_unfinished_requests():
engine.step()
end = time.perf_counter()
elapsed_time = end - start

print(f"Batchsize {batchsize}, Throughput: "
f"{len(requests) / elapsed_time:.4f} requests/s")

engine.executor.shutdown_execute_loop()
gc.collect()
torch.cuda.empty_cache()


if __name__ == '__main__':
from easydict import EasyDict as edict
args = edict()

args.input_len = 256
args.num_prompts = 10000

args.model = 'BAAI/bge-m3'

args.trust_remote_code = False
args.tokenizer = args.model
args.seed = 0
args.max_model_len = None
args.dtype = "half"
args.device = "cuda"
args.batchsize = [1, 2, 4, 8, 16]
args.max_data_parallel_size = 1

from concurrent.futures import ProcessPoolExecutor

def run_vllm(args):
with ProcessPoolExecutor(1) as executor:
f = executor.submit(benchmark, args)
f.result()

for scheduling in ["async", "double_buffer"]:
for data_parallel_size in range(args.max_data_parallel_size + 1):
print("scheduling:", scheduling, "data_parallel_size",
data_parallel_size)
args.data_parallel_size = data_parallel_size
args.scheduling = scheduling
run_vllm(args)
Loading
Loading