From 84297dfafa79aed786744859cac5a030aff0710d Mon Sep 17 00:00:00 2001 From: masahi Date: Tue, 7 Nov 2023 07:29:13 +0900 Subject: [PATCH] Enable `StagingEngine` in `test_engine_paged_cache_model.py` (#50) * enable StagingEngine in test_engine_paged_cache_model.py * simplify param_bytes calc * set random seed in test --- serve/mlc_serve/model/paged_cache_model.py | 5 +- serve/tests/test_engine_paged_cache_model.py | 81 ++++++++++++++------ 2 files changed, 58 insertions(+), 28 deletions(-) diff --git a/serve/mlc_serve/model/paged_cache_model.py b/serve/mlc_serve/model/paged_cache_model.py index 138732d577..4807429bf0 100644 --- a/serve/mlc_serve/model/paged_cache_model.py +++ b/serve/mlc_serve/model/paged_cache_model.py @@ -470,10 +470,7 @@ def get_used_memory(self): ) peak_memory = get_used_memory_func(self.dev) - param_bytes = 0 - - for param in params: - param_bytes += param.numpy().nbytes + param_bytes = sum(math.prod(param.shape) * np.dtype(param.dtype).itemsize for param in params) return peak_memory + param_bytes diff --git a/serve/tests/test_engine_paged_cache_model.py b/serve/tests/test_engine_paged_cache_model.py index 8a4cb3997d..53c949b854 100644 --- a/serve/tests/test_engine_paged_cache_model.py +++ b/serve/tests/test_engine_paged_cache_model.py @@ -1,3 +1,5 @@ +import torch + import argparse import json import random @@ -10,8 +12,9 @@ SamplingParams, StoppingCriteria, ) +from mlc_serve.engine.staging_engine import StagingInferenceEngine from mlc_serve.engine.sync_engine import SynchronousInferenceEngine -from mlc_serve.model.paged_cache_model import PagedCacheModelModule +from mlc_serve.model.paged_cache_model import HfTokenizerModule, PagedCacheModelModule def test(args: argparse.Namespace): @@ -26,29 +29,52 @@ def test(args: argparse.Namespace): # Disco: # python serve/tests/test_engine_paged_cache_model.py --local-id vicuna-v1-7b-q0f16 --num-shards 2 - model_module = PagedCacheModelModule( - args.model, - args.artifact_path, - args.quantization.name, - args.num_shards, - max_num_batched_tokens=args.max_num_batched_tokens, - max_input_len=args.max_input_len, - ) - - engine = SynchronousInferenceEngine( - model_module, - max_batched_tokens=args.max_num_batched_tokens, - ) + if args.use_staging_engine: + tokenizer_module = HfTokenizerModule(args.model, args.artifact_path) + engine = StagingInferenceEngine( + tokenizer_module=tokenizer_module, + model_module_loader=PagedCacheModelModule, + model_module_loader_kwargs={ + "model_name": args.model, + "artifact_path": args.artifact_path, + "quantization": args.quantization.name, + "num_shards": args.num_shards, + "max_num_batched_tokens": args.max_num_batched_tokens, + "max_input_len": args.max_input_len, + }, + max_batched_tokens=args.max_num_batched_tokens, + min_decode_steps=args.min_decode_steps, + max_decode_steps=args.max_decode_steps, + ) + engine.start() + else: + model_module = PagedCacheModelModule( + args.model, + args.artifact_path, + args.quantization.name, + args.num_shards, + max_num_batched_tokens=args.max_num_batched_tokens, + max_input_len=args.max_input_len, + ) - sampling_params_random = SamplingParams( - temperature=0.9, - top_p=1.0, - ) + engine = SynchronousInferenceEngine( + model_module, + max_batched_tokens=args.max_num_batched_tokens, + ) sampling_params_greedy = SamplingParams( temperature=0.0, ) + if args.use_random_sampling: + sampling_params_random = SamplingParams( + temperature=0.9, + top_p=1.0, + ) + sampling_params_choices = [sampling_params_random, sampling_params_greedy] + else: + sampling_params_choices = [sampling_params_greedy] + if args.long_prompt: with open("serve/tests/data/long_prompts.json", "r") as f: prompts = json.load(f)["prompts"] @@ -62,16 +88,12 @@ def test(args: argparse.Namespace): ] for i, prompt in enumerate(prompts): - sampling_params = random.choice( - [sampling_params_random, sampling_params_greedy] - ) - engine.add( [ Request( request_id=str(i), messages=[ChatMessage(role="user", content=prompt)], - sampling_params=sampling_params, + sampling_params=random.choice(sampling_params_choices), stopping_criteria=StoppingCriteria(max_tokens=args.max_output_len), debug_options=DebugOptions(prompt=prompt), ) @@ -80,7 +102,7 @@ def test(args: argparse.Namespace): generated = ["" for _ in range(len(prompts))] - while engine._has_request_to_process(): + while engine.has_pending_requests(): results = engine.step() for res in results.outputs: seq = res.sequences[0] @@ -94,6 +116,9 @@ def test(args: argparse.Namespace): for p, g in zip(prompts, generated): print(f"Prompt = '{p}', generated text = '{g}'") + if args.use_staging_engine: + engine.stop() + if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -104,9 +129,17 @@ def test(args: argparse.Namespace): parser.add_argument("--max-input-len", type=int, default=-1) parser.add_argument("--max-output-len", type=int, default=20) parser.add_argument("--long-prompt", action="store_true") + parser.add_argument("--use-random-sampling", action="store_true") + parser.add_argument("--use-staging-engine", action="store_true") + parser.add_argument("--min-decode-steps", type=int, default=12) + parser.add_argument("--max-decode-steps", type=int, default=16) + parser.add_argument("--seed", type=int, default=0) args = parser.parse_args() args.model, args.quantization = args.local_id.rsplit("-", 1) utils.argparse_postproc_common(args) + torch.manual_seed(args.seed) + torch.cuda.manual_seed(args.seed) + test(args)