Skip to content

Commit

Permalink
Enable StagingEngine in test_engine_paged_cache_model.py (#50)
Browse files Browse the repository at this point in the history
* enable StagingEngine in test_engine_paged_cache_model.py

* simplify param_bytes calc

* set random seed in test
  • Loading branch information
masahi authored Nov 6, 2023
1 parent 1d3d12a commit 84297df
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 28 deletions.
5 changes: 1 addition & 4 deletions serve/mlc_serve/model/paged_cache_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,10 +470,7 @@ def get_used_memory(self):
)
peak_memory = get_used_memory_func(self.dev)

param_bytes = 0

for param in params:
param_bytes += param.numpy().nbytes
param_bytes = sum(math.prod(param.shape) * np.dtype(param.dtype).itemsize for param in params)

return peak_memory + param_bytes

Expand Down
81 changes: 57 additions & 24 deletions serve/tests/test_engine_paged_cache_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import torch

import argparse
import json
import random
Expand All @@ -10,8 +12,9 @@
SamplingParams,
StoppingCriteria,
)
from mlc_serve.engine.staging_engine import StagingInferenceEngine
from mlc_serve.engine.sync_engine import SynchronousInferenceEngine
from mlc_serve.model.paged_cache_model import PagedCacheModelModule
from mlc_serve.model.paged_cache_model import HfTokenizerModule, PagedCacheModelModule


def test(args: argparse.Namespace):
Expand All @@ -26,29 +29,52 @@ def test(args: argparse.Namespace):
# Disco:
# python serve/tests/test_engine_paged_cache_model.py --local-id vicuna-v1-7b-q0f16 --num-shards 2

model_module = PagedCacheModelModule(
args.model,
args.artifact_path,
args.quantization.name,
args.num_shards,
max_num_batched_tokens=args.max_num_batched_tokens,
max_input_len=args.max_input_len,
)

engine = SynchronousInferenceEngine(
model_module,
max_batched_tokens=args.max_num_batched_tokens,
)
if args.use_staging_engine:
tokenizer_module = HfTokenizerModule(args.model, args.artifact_path)
engine = StagingInferenceEngine(
tokenizer_module=tokenizer_module,
model_module_loader=PagedCacheModelModule,
model_module_loader_kwargs={
"model_name": args.model,
"artifact_path": args.artifact_path,
"quantization": args.quantization.name,
"num_shards": args.num_shards,
"max_num_batched_tokens": args.max_num_batched_tokens,
"max_input_len": args.max_input_len,
},
max_batched_tokens=args.max_num_batched_tokens,
min_decode_steps=args.min_decode_steps,
max_decode_steps=args.max_decode_steps,
)
engine.start()
else:
model_module = PagedCacheModelModule(
args.model,
args.artifact_path,
args.quantization.name,
args.num_shards,
max_num_batched_tokens=args.max_num_batched_tokens,
max_input_len=args.max_input_len,
)

sampling_params_random = SamplingParams(
temperature=0.9,
top_p=1.0,
)
engine = SynchronousInferenceEngine(
model_module,
max_batched_tokens=args.max_num_batched_tokens,
)

sampling_params_greedy = SamplingParams(
temperature=0.0,
)

if args.use_random_sampling:
sampling_params_random = SamplingParams(
temperature=0.9,
top_p=1.0,
)
sampling_params_choices = [sampling_params_random, sampling_params_greedy]
else:
sampling_params_choices = [sampling_params_greedy]

if args.long_prompt:
with open("serve/tests/data/long_prompts.json", "r") as f:
prompts = json.load(f)["prompts"]
Expand All @@ -62,16 +88,12 @@ def test(args: argparse.Namespace):
]

for i, prompt in enumerate(prompts):
sampling_params = random.choice(
[sampling_params_random, sampling_params_greedy]
)

engine.add(
[
Request(
request_id=str(i),
messages=[ChatMessage(role="user", content=prompt)],
sampling_params=sampling_params,
sampling_params=random.choice(sampling_params_choices),
stopping_criteria=StoppingCriteria(max_tokens=args.max_output_len),
debug_options=DebugOptions(prompt=prompt),
)
Expand All @@ -80,7 +102,7 @@ def test(args: argparse.Namespace):

generated = ["" for _ in range(len(prompts))]

while engine._has_request_to_process():
while engine.has_pending_requests():
results = engine.step()
for res in results.outputs:
seq = res.sequences[0]
Expand All @@ -94,6 +116,9 @@ def test(args: argparse.Namespace):
for p, g in zip(prompts, generated):
print(f"Prompt = '{p}', generated text = '{g}'")

if args.use_staging_engine:
engine.stop()


if __name__ == "__main__":
parser = argparse.ArgumentParser()
Expand All @@ -104,9 +129,17 @@ def test(args: argparse.Namespace):
parser.add_argument("--max-input-len", type=int, default=-1)
parser.add_argument("--max-output-len", type=int, default=20)
parser.add_argument("--long-prompt", action="store_true")
parser.add_argument("--use-random-sampling", action="store_true")
parser.add_argument("--use-staging-engine", action="store_true")
parser.add_argument("--min-decode-steps", type=int, default=12)
parser.add_argument("--max-decode-steps", type=int, default=16)
parser.add_argument("--seed", type=int, default=0)
args = parser.parse_args()

args.model, args.quantization = args.local_id.rsplit("-", 1)
utils.argparse_postproc_common(args)

torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)

test(args)

0 comments on commit 84297df

Please sign in to comment.