Skip to content

Commit

Permalink
[Inference] Finish dynamic batching offline test (#4948)
Browse files Browse the repository at this point in the history
* test

* fix test
  • Loading branch information
CjhHa1 authored Oct 19, 2023
1 parent 3f6af12 commit 4867561
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 12 deletions.
2 changes: 1 addition & 1 deletion colossalai/inference/dynamic_batching/ray_dist_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def step(self):
outputs = results[0] # get any one of the copies
return outputs

def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompt: str):
def add_req(self, request_id: str, prompt_ids: List[int], sampling_params: SamplingParams, prompt: str):
ray.get([w.add_req.remote(prompt_ids, sampling_params, request_id, prompt) for w in self.workers])

def is_running(self):
Expand Down
8 changes: 4 additions & 4 deletions colossalai/inference/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(
self.mem_usage_interval = log_stats_interval * 2
self.tokenizer = get_tokenizer(tokenizer_name=self.model)

def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompts: str = ""):
def add_req(self, request_id: str, prompt_ids: List[int], sampling_params: SamplingParams, prompts: str = ""):
"""
Add new request to req queue, during initialization all requests are held in waiting list.
"""
Expand All @@ -75,7 +75,7 @@ def add_input(self, request_id, prompts, sampling_params):
if prompt_len > self.engine.max_input_len:
raise ValueError(f"the input prompt token len {prompt_len} is too long > {self.engine.max_input_len}")
sampling_params.stop_sentences_to_token_ids(self.tokenizer)
self.add_req(prompt_ids, sampling_params, request_id, prompts)
self.add_req(request_id, prompt_ids, sampling_params, prompts)
return

def abort(self, request_id):
Expand Down Expand Up @@ -258,11 +258,11 @@ def clean_up(self):
# this logic should be implemented in the future.
pass

def generate(self, prompts, sampling_params, request_id):
def generate(self, request_id, prompts, sampling_params):
"""
Generate the output of a request.
"""
self.add_input(request_id, sampling_params, prompts)
self.add_input(request_id, prompts, sampling_params)
return self.loop_for_fwd()

def is_running(self):
Expand Down
2 changes: 1 addition & 1 deletion colossalai/inference/tensor_parallel/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,8 +400,8 @@ def forward(self, batch_id, is_prefill):
model = self.model.model
elif isinstance(model, BloomForCausalLM):
model = self.model.transformer

setattr(model, "infer_state", infer_state)

output = self.model.forward(input_ids=input_)
logits = output.logits
# bsz, seq_len, vocab_size
Expand Down
3 changes: 1 addition & 2 deletions colossalai/inference/tensor_parallel/modeling/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,13 @@ def llama_model_forward(
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")

# NOT READY FOR PRIME TIME
# dummy but work, revise it
if infer_state.is_context_stage:
past_key_values_length = 0
else:
past_key_values_length = infer_state.max_len_in_batch - 1

# NOTE: differentiate with prefill stage
# block_loc require different value-assigning method for two different stage
if use_cache and seq_length != 1:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from dataclasses import dataclass

import pytest
import torch
from packaging import version
from transformers import LlamaForCausalLM
from transformers.models.llama.configuration_llama import LlamaConfig

import colossalai
from dataclasses import dataclass
from colossalai.inference.dynamic_batching.io_struct import Req
from colossalai.inference.dynamic_batching.sampling_params import SamplingParams
from colossalai.inference.manager import start_dynamic_batching
Expand All @@ -19,17 +20,26 @@
MAX_OUTPUT_LEN = 16
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")


@dataclass
class args:
max_total_token_num: int
batch_max_tokens: int
model: str
eos_id: int
disable_log_stats: bool
log_stats_interval: int


def run():
arg = args(max_total_token_num=42, batch_max_tokens=42, eos_id=0, disable_log_stats=False, log_stats_interval=10)
arg = args(
max_total_token_num=42,
model="llama",
batch_max_tokens=42,
eos_id=0,
disable_log_stats=False,
log_stats_interval=10,
)
sampling_params = SamplingParams()

req1 = Req(0, [0, 0, 10, 6, 8], sampling_params)
Expand All @@ -43,14 +53,18 @@ def run():
waiting_list.append(req3)
waiting_list.append(req4)

llama_config = LlamaConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=1200, hidden_size=1024)
llama_config = LlamaConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=30000, hidden_size=1024)
model = LlamaForCausalLM(llama_config)
model = model.half()

shard_config = ShardConfig(enable_tensor_parallelism=True if TP_SIZE > 1 else False, inference_only=True)

infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
start_dynamic_batching(arg, tp_engine=infer_engine, waiting_req_list=waiting_list)
batch_manager = start_dynamic_batching(arg, tp_engine=infer_engine, waiting_req_list=waiting_list)

ans_gen = batch_manager.generate(request_id=5, prompts="hello", sampling_params=sampling_params)
for result in ans_gen:
assert result is not None


def check_dynamic_forward(rank, world_size, port):
Expand Down

0 comments on commit 4867561

Please sign in to comment.