From 70885d707d97f224caeb30f79172ef8f0b7f3e1c Mon Sep 17 00:00:00 2001 From: Zhongkai Zhao Date: Fri, 10 Nov 2023 10:49:50 +0800 Subject: [PATCH] [hotfix] Suport extra_kwargs in ShardConfig (#5031) * [refactor]: replace inference args with extra_kwargs in ShardConfig * modify shardconfig * polish code * fix policy bug in llama * fix bug in auto policy * remove setattr in ShardConfig --- .../dynamic_batching/ray_dist_init.py | 4 +- .../inference/hybridengine/polices/llama.py | 3 +- .../inference/tensor_parallel/engine.py | 16 +++-- .../tensor_parallel/policies/bloom.py | 61 ++++++++++--------- .../tensor_parallel/policies/llama.py | 4 +- colossalai/shardformer/README.md | 3 - .../shardformer/policies/auto_policy.py | 5 +- colossalai/shardformer/shard/shard_config.py | 9 +-- examples/inference/bench_bloom.py | 4 +- examples/inference/bench_chatglm2.py | 4 +- examples/inference/bench_llama.py | 4 +- examples/inference/gptq_bloom.py | 7 ++- examples/inference/gptq_llama.py | 3 +- .../ray_serve/Colossal_Inference_rayserve.py | 4 +- .../torch_serve/Colossal_Inference_Handler.py | 4 +- tests/test_infer/_utils.py | 2 +- tests/test_infer/test_bloom_infer.py | 10 +-- tests/test_infer/test_chatglm2_infer.py | 2 +- .../test_dynamic_batching_manager.py | 2 +- .../test_offline_dynamic_batching.py | 4 +- tests/test_infer/test_infer_engine.py | 2 +- tests/test_infer/test_llama2_infer.py | 10 +-- tests/test_infer/test_llama_infer.py | 8 ++- 23 files changed, 98 insertions(+), 77 deletions(-) diff --git a/colossalai/inference/dynamic_batching/ray_dist_init.py b/colossalai/inference/dynamic_batching/ray_dist_init.py index 70ef489d3b70..3e40bb0eeb9d 100644 --- a/colossalai/inference/dynamic_batching/ray_dist_init.py +++ b/colossalai/inference/dynamic_batching/ray_dist_init.py @@ -67,7 +67,9 @@ def setup(self, world_size, rank, port): self.model = AutoModelForCausalLM.from_pretrained( self.model_path, pad_token_id=self.tokenizer.pad_token_id, torch_dtype=torch.float16 ) - shard_config = ShardConfig(enable_tensor_parallelism=True if world_size > 1 else False, inference_only=True) + shard_config = ShardConfig( + enable_tensor_parallelism=True if world_size > 1 else False, extra_kwargs={"inference_only": True} + ) self.infer_engine = TPInferEngine( self.model, shard_config, self.max_batch_size, self.max_input_len, self.max_output_len ) diff --git a/colossalai/inference/hybridengine/polices/llama.py b/colossalai/inference/hybridengine/polices/llama.py index 992299714bd1..5caaaa978336 100644 --- a/colossalai/inference/hybridengine/polices/llama.py +++ b/colossalai/inference/hybridengine/polices/llama.py @@ -45,8 +45,7 @@ def __init__(self) -> None: def module_policy(self): policy = super().module_policy() - - if self.shard_config.inference_gptq: + if self.shard_config.extra_kwargs.get("inference_gptq", False): from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear decoder_attribute_replacement = { diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index 2478b574d307..a8fd3ca9e5ba 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -44,7 +44,7 @@ class TPInferEngine: >>> # define model and shard config for your inference >>> model = ... >>> generate_kwargs = ... - >>> shard_config = ShardConfig(enable_tensor_parallelism=True, inference_only=True) + >>> shard_config = ShardConfig(enable_tensor_parallelism=True, extra_kwargs={"inference_only": True}) >>> infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) >>> outputs = infer_engine.generate(input_ids, **generate_kwargs) """ @@ -181,7 +181,7 @@ def _optimize_model(self, model: nn.Module) -> None: In further generation, use the sharded model instead of original model. """ # NOTE we will change to use an inference config later with additional attrs we want - assert self.shard_config.inference_only is True + assert self.shard_config.extra_kwargs["inference_only"] is True shardformer = ShardFormer(shard_config=self.shard_config) self._prepare_with_shard_config(shard_config=self.shard_config) self._shard_model_by(shardformer, model) @@ -203,10 +203,10 @@ def _prepare_with_shard_config(self, shard_config: Optional[ShardConfig] = None) enable_all_optimization=False, enable_flash_attention=False, enable_jit_fused=False, - inference_only=True, + extra_kwargs={"inference_only": True}, ) else: - shard_config.inference_only = True + shard_config.extra_kwargs = {"inference_only": True} shard_config.pipeline_stage_manager = None if shard_config.enable_tensor_parallelism: self.tp_size = shard_config.tensor_parallel_size @@ -221,13 +221,11 @@ def _shard_model_by(self, shardformer: ShardFormer, model: nn.Module) -> None: ), "Discrepancy between the tp size of TPInferEngine and the tp size of shard config" model_name = model.__class__.__name__ assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference." - - model = model.model if self.shard_config.inference_gptq else model + if self.shard_config.extra_kwargs.get("inference_gptq", False): + model = model.model policy = get_autopolicy(model, shard_config=self.shard_config) - self.model, _ = shardformer.optimize(model, policy) - - if self.shard_config.inference_gptq: + if self.shard_config.extra_kwargs.get("inference_gptq", False): self._post_init_gptq_buffer(self.model) self.model = self.model.cuda() diff --git a/colossalai/inference/tensor_parallel/policies/bloom.py b/colossalai/inference/tensor_parallel/policies/bloom.py index 3d6df2097000..f980bdb53add 100644 --- a/colossalai/inference/tensor_parallel/policies/bloom.py +++ b/colossalai/inference/tensor_parallel/policies/bloom.py @@ -4,7 +4,6 @@ from torch.nn import LayerNorm import colossalai.shardformer.layer as col_nn -from colossalai.shardformer.modeling.bloom import build_bloom_alibi_tensor_fn from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription from colossalai.shardformer.policies.bloom import BloomForCausalLMPolicy @@ -38,35 +37,39 @@ def module_policy(self): from transformers.models.bloom.modeling_bloom import BloomAttention, BloomBlock, BloomForCausalLM, BloomModel policy = super().module_policy() - if self.shard_config.inference_gptq: + + if self.shard_config.extra_kwargs.get("inference_gptq", False): from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear - policy[BloomBlock] = ModulePolicyDescription(attribute_replacement={ - "self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "self_attention.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "self_attention.num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size, - }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="self_attention.query_key_value", - target_module=ColCaiQuantLinear, - kwargs={'split_num': 3}), - SubModuleReplacementDescription( - suffix="self_attention.dense", - target_module=RowCaiQuantLinear, - kwargs={'split_num': 1}), - SubModuleReplacementDescription( - suffix="self_attention.attention_dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="mlp.dense_h_to_4h", - target_module=ColCaiQuantLinear, - kwargs={'split_num': 1}), - SubModuleReplacementDescription( - suffix="mlp.dense_4h_to_h", - target_module=RowCaiQuantLinear, - kwargs={'split_num': 1}), - ]) + + policy[BloomBlock] = ModulePolicyDescription( + attribute_replacement={ + "self_attention.hidden_size": self.model.config.hidden_size + // self.shard_config.tensor_parallel_size, + "self_attention.split_size": self.model.config.hidden_size + // self.shard_config.tensor_parallel_size, + "self_attention.num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attention.query_key_value", + target_module=ColCaiQuantLinear, + kwargs={"split_num": 3}, + ), + SubModuleReplacementDescription( + suffix="self_attention.dense", target_module=RowCaiQuantLinear, kwargs={"split_num": 1} + ), + SubModuleReplacementDescription( + suffix="self_attention.attention_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="mlp.dense_h_to_4h", target_module=ColCaiQuantLinear, kwargs={"split_num": 1} + ), + SubModuleReplacementDescription( + suffix="mlp.dense_4h_to_h", target_module=RowCaiQuantLinear, kwargs={"split_num": 1} + ), + ], + ) # NOTE set inference mode to shard config self.shard_config._infer() diff --git a/colossalai/inference/tensor_parallel/policies/llama.py b/colossalai/inference/tensor_parallel/policies/llama.py index d6c072c747b7..896d55712254 100644 --- a/colossalai/inference/tensor_parallel/policies/llama.py +++ b/colossalai/inference/tensor_parallel/policies/llama.py @@ -13,6 +13,7 @@ try: from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward as lightllm_rmsnorm_forward + HAS_TRITON_RMSNORM = True except: print("you should install triton from https://github.com/openai/triton") @@ -21,6 +22,7 @@ def get_triton_rmsnorm_forward(): if HAS_TRITON_RMSNORM: + def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor): return lightllm_rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon) @@ -36,7 +38,7 @@ def __init__(self) -> None: def module_policy(self): policy = super().module_policy() - if self.shard_config.inference_gptq: + if self.shard_config.extra_kwargs.get("inference_gptq", False): from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear decoder_attribute_replacement = { diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index cabd10bba49e..e3f3d84bc9c2 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -81,8 +81,6 @@ Following are the description `ShardConfig`'s arguments: - `enable_all_optimization`: Whether to turn on all optimization tools including `fused normalizaion`, `flash attention`, `JIT fused operators`, `sequence parallelism` and `sequence overlap`. Defaults to False. -- `inference_only`: Whether only doing forward passing. Defaults to False. - ### Write your own policy If you have a custom model, you can also use Shardformer to parallelize it by writing your own sharding policy. More information about the sharding policy can be found in [API Design](#-api-design). @@ -185,7 +183,6 @@ class ShardConfig: # Some possible future config fields tensor_parallel_mode: Choice['1d', '2d', '2.5d', '3d'] # support different tensor parallel mode - inference_only: bool # only inject inference-suitable sharding policy use_flash_attention: bool # whether to use flash attention to speed up attention ``` diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index 3014f1cf3663..d3dfeff22a66 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -209,7 +209,8 @@ def get_autopolicy(model: nn.Module, shard_config: ShardConfig = None) -> Policy :class:`Policy`: The auto policy for the model """ full_name = _fullname(model) - if shard_config.inference_only: + inference_only = shard_config.extra_kwargs.get("inference_only", False) + if inference_only: policy_location = _INFER_POLICY_LIST.get(full_name, None) else: policy_location = _POLICY_LIST.get(full_name, None) @@ -219,5 +220,5 @@ def get_autopolicy(model: nn.Module, shard_config: ShardConfig = None) -> Policy f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())} and {list(_INFER_POLICY_LIST.keys())}" ) else: - policy = import_policy(policy_location, shard_config.inference_only) + policy = import_policy(policy_location, inference_only) return policy() diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 2aa6139836a5..f654c2e83539 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -1,5 +1,5 @@ -from dataclasses import dataclass -from typing import Optional +from dataclasses import dataclass, field +from typing import Dict, Optional import torch.distributed as dist from torch.distributed import ProcessGroup @@ -24,7 +24,6 @@ class ShardConfig: enable_sequence_parallelism (bool): Whether to turn on sequence parallelism, which partitions non-tensor-parallel regions along the sequence dimension. Defaults to False. enable_sequence_overlap (bool): Whether to turn on sequence overlap, wheich overlap the computation and communication in sequence parallelism. It can only be used when enable_sequence_parallelism is True. Defaults to False. enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalizaion', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False. - inference_only (bool): Whether only doing forward passing. Defaults to False. """ tensor_parallel_process_group: Optional[ProcessGroup] = None pipeline_stage_manager: Optional[PipelineStageManager] = None @@ -33,10 +32,9 @@ class ShardConfig: enable_flash_attention: bool = False enable_jit_fused: bool = False enable_all_optimization: bool = False - inference_only: bool = False - inference_gptq: bool = False enable_sequence_parallelism: bool = False enable_sequence_overlap: bool = False + extra_kwargs: Dict[str, bool] = field(default_factory=dict) # pipeline_parallel_size: int # data_parallel_size: int # tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d'] @@ -77,4 +75,3 @@ def _infer(self): Set default params for inference. """ # assert self.pipeline_stage_manager is None, "pipeline parallelism is not supported in inference for now" - pass diff --git a/examples/inference/bench_bloom.py b/examples/inference/bench_bloom.py index 054641f6eebf..5c7af6ed5aef 100644 --- a/examples/inference/bench_bloom.py +++ b/examples/inference/bench_bloom.py @@ -28,7 +28,9 @@ def bench_bloom(args): # init TPInferEngine and shard the original model # To benchmark torch original, comment out the line of optimizing model - shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True) + shard_config = ShardConfig( + enable_tensor_parallelism=True if args.tp_size > 1 else False, extra_kwargs={"inference_only": True} + ) infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len) # prepare data for generation diff --git a/examples/inference/bench_chatglm2.py b/examples/inference/bench_chatglm2.py index f3678d29ff93..3892d98ba743 100644 --- a/examples/inference/bench_chatglm2.py +++ b/examples/inference/bench_chatglm2.py @@ -30,7 +30,9 @@ def run_chatglm2_test(args): model = model.half() model.config - shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True) + shard_config = ShardConfig( + enable_tensor_parallelism=True if args.tp_size > 1 else False, extra_kwargs={"inference_only": True} + ) infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len) generate_kwargs = dict(max_new_tokens=1, do_sample=False) diff --git a/examples/inference/bench_llama.py b/examples/inference/bench_llama.py index 56bf062e2e68..4db32c71af30 100644 --- a/examples/inference/bench_llama.py +++ b/examples/inference/bench_llama.py @@ -30,7 +30,9 @@ def run_llama_test(args): model = model.half() model.config - shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True) + shard_config = ShardConfig( + enable_tensor_parallelism=True if args.tp_size > 1 else False, extra_kwargs={"inference_only": True} + ) infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len) generate_kwargs = dict(max_new_tokens=1, do_sample=False) diff --git a/examples/inference/gptq_bloom.py b/examples/inference/gptq_bloom.py index cfa3171374dd..61e3829b940b 100644 --- a/examples/inference/gptq_bloom.py +++ b/examples/inference/gptq_bloom.py @@ -34,7 +34,9 @@ def bench_bloom(args): model = model.half() model_config = model.config - shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True) + shard_config = ShardConfig( + enable_tensor_parallelism=True if args.tp_size > 1 else False, extra_kwargs={"inference_only": True} + ) infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len) generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False) @@ -46,7 +48,8 @@ def bench_bloom(args): # init TPInferEngine and shard the original model # To benchmark torch original, comment out the line of optimizing model shard_config = ShardConfig( - enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True, inference_gptq=True + enable_tensor_parallelism=True if args.tp_size > 1 else False, + extra_kwargs={"inference_only": True, "inference_gptq": True}, ) infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len) diff --git a/examples/inference/gptq_llama.py b/examples/inference/gptq_llama.py index 35a6049ad409..79b2a45b9c86 100644 --- a/examples/inference/gptq_llama.py +++ b/examples/inference/gptq_llama.py @@ -33,7 +33,8 @@ def run_llama_test(args): model_config = model.config shard_config = ShardConfig( - enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True, inference_gptq=True + enable_tensor_parallelism=True if args.tp_size > 1 else False, + extra_kwargs={"inference_only": True, "inference_gptq": True}, ) infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len) diff --git a/examples/inference/serving/ray_serve/Colossal_Inference_rayserve.py b/examples/inference/serving/ray_serve/Colossal_Inference_rayserve.py index 51d520ebbcf6..d758b467c730 100644 --- a/examples/inference/serving/ray_serve/Colossal_Inference_rayserve.py +++ b/examples/inference/serving/ray_serve/Colossal_Inference_rayserve.py @@ -68,7 +68,9 @@ def setup(self, world_size, rank, port): self.model_path, pad_token_id=self.tokenizer.pad_token_id, torch_dtype=torch.float16 ) - shard_config = ShardConfig(enable_tensor_parallelism=True if world_size > 1 else False, inference_only=True) + shard_config = ShardConfig( + enable_tensor_parallelism=True if world_size > 1 else False, extra_kwargs={"inference_only": True} + ) self.infer_engine = TPInferEngine( self.model, shard_config, self.max_batch_size, self.max_input_len, self.max_output_len ) diff --git a/examples/inference/serving/torch_serve/Colossal_Inference_Handler.py b/examples/inference/serving/torch_serve/Colossal_Inference_Handler.py index c0d30501efea..e07494b8a1a9 100644 --- a/examples/inference/serving/torch_serve/Colossal_Inference_Handler.py +++ b/examples/inference/serving/torch_serve/Colossal_Inference_Handler.py @@ -100,7 +100,9 @@ def initialize(self, ctx): colossalai.launch(config={}, rank=rank, world_size=world_size, host=host, port=port, backend="nccl") logger.info("Initializing TPInferEngine ...") - shard_config = ShardConfig(enable_tensor_parallelism=True if self.tp_size > 1 else False, inference_only=True) + shard_config = ShardConfig( + enable_tensor_parallelism=True if self.tp_size > 1 else False, extra_kwargs={"inference_only": True} + ) self.infer_engine = TPInferEngine( self.model, shard_config, self.max_batch_size, self.max_input_len, self.max_output_len ) diff --git a/tests/test_infer/_utils.py b/tests/test_infer/_utils.py index 2ddc8b6e68e4..0bd791cc878d 100644 --- a/tests/test_infer/_utils.py +++ b/tests/test_infer/_utils.py @@ -19,7 +19,7 @@ def build_model( enable_tensor_parallelism=enable_tensor_parallelism, enable_flash_attention=enable_flash_attention, enable_jit_fused=enable_jit_fused, - inference_only=True, + extra_kwargs={"inference_only": True}, ) model_copy = copy.deepcopy(org_model) shard_former = ShardFormer(shard_config=shard_config) diff --git a/tests/test_infer/test_bloom_infer.py b/tests/test_infer/test_bloom_infer.py index d4366758d6a3..8f1e5cbbff44 100644 --- a/tests/test_infer/test_bloom_infer.py +++ b/tests/test_infer/test_bloom_infer.py @@ -11,11 +11,10 @@ from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn try: - import lightllm HAS_LIGHTLLM_KERNEL = True except: HAS_LIGHTLLM_KERNEL = False - + TP_SIZE = 2 MAX_BATCH_SIZE = 4 MAX_INPUT_LEN = 16 @@ -38,7 +37,7 @@ def run(test_config): model = model.half() shard_config = ShardConfig( - enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True + enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, extra_kwargs={"inference_only": True} ) infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False) @@ -58,7 +57,10 @@ def check_bloom(rank, world_size, port): run() -@pytest.mark.skipif(not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL, reason="kv-cache manager engine requires cuda version to be higher than 11.5") +@pytest.mark.skipif( + not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL, + reason="kv-cache manager engine requires cuda version to be higher than 11.5", +) @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() diff --git a/tests/test_infer/test_chatglm2_infer.py b/tests/test_infer/test_chatglm2_infer.py index a2ec35dcdb8a..6f8ed038d1ed 100644 --- a/tests/test_infer/test_chatglm2_infer.py +++ b/tests/test_infer/test_chatglm2_infer.py @@ -49,7 +49,7 @@ def run_chatglm2_test(test_config): model = model.half() shard_config = ShardConfig( - enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True + enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, extra_kwargs={"inference_only": True} ) infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False) diff --git a/tests/test_infer/test_dynamic_batching/test_dynamic_batching_manager.py b/tests/test_infer/test_dynamic_batching/test_dynamic_batching_manager.py index 78df0d304096..9f741e6cac4b 100644 --- a/tests/test_infer/test_dynamic_batching/test_dynamic_batching_manager.py +++ b/tests/test_infer/test_dynamic_batching/test_dynamic_batching_manager.py @@ -34,7 +34,7 @@ def run(): model = LlamaForCausalLM(llama_config) model = model.half() - shard_config = ShardConfig(enable_tensor_parallelism=False, inference_only=True) + shard_config = ShardConfig(enable_tensor_parallelism=False, extra_kwargs={"inference_only": True}) infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) dynamic_batch_manager = DynamicBatchManager( diff --git a/tests/test_infer/test_dynamic_batching/test_offline_dynamic_batching.py b/tests/test_infer/test_dynamic_batching/test_offline_dynamic_batching.py index 9925a80b6e77..176199f9ae99 100644 --- a/tests/test_infer/test_dynamic_batching/test_offline_dynamic_batching.py +++ b/tests/test_infer/test_dynamic_batching/test_offline_dynamic_batching.py @@ -57,7 +57,9 @@ def run(): model = LlamaForCausalLM(llama_config) model = model.half() - shard_config = ShardConfig(enable_tensor_parallelism=True if TP_SIZE > 1 else False, inference_only=True) + shard_config = ShardConfig( + enable_tensor_parallelism=True if TP_SIZE > 1 else False, extra_kwargs={"inference_only": True} + ) infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) batch_manager = start_dynamic_batching(arg, tp_engine=infer_engine, waiting_req_list=waiting_list) diff --git a/tests/test_infer/test_infer_engine.py b/tests/test_infer/test_infer_engine.py index f24160820e71..cc5b4c263700 100644 --- a/tests/test_infer/test_infer_engine.py +++ b/tests/test_infer/test_infer_engine.py @@ -36,7 +36,7 @@ def run(test_config): # 1. check TPInferEngine init and model optimization shard_config = ShardConfig( - enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True + enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, extra_kwargs={"inference_only": True} ) infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) diff --git a/tests/test_infer/test_llama2_infer.py b/tests/test_infer/test_llama2_infer.py index 13e7a61826ab..8fd0cebe9ff7 100644 --- a/tests/test_infer/test_llama2_infer.py +++ b/tests/test_infer/test_llama2_infer.py @@ -13,11 +13,10 @@ from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn try: - import lightllm HAS_LIGHTLLM_KERNEL = True except: HAS_LIGHTLLM_KERNEL = False - + os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" TPSIZE = 2 BATCH_SIZE = 8 @@ -43,7 +42,7 @@ def run_llama_test(test_config): model = model.half() shard_config = ShardConfig( - enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True + enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, extra_kwargs={"inference_only": True} ) infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False) @@ -63,7 +62,10 @@ def check_llama(rank, world_size, port): run_llama_test() -@pytest.mark.skipif(not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL, reason="kv-cache manager engine requires cuda version to be higher than 11.5") +@pytest.mark.skipif( + not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL, + reason="kv-cache manager engine requires cuda version to be higher than 11.5", +) @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() diff --git a/tests/test_infer/test_llama_infer.py b/tests/test_infer/test_llama_infer.py index a4f54d197065..2b0eb6c6b98a 100644 --- a/tests/test_infer/test_llama_infer.py +++ b/tests/test_infer/test_llama_infer.py @@ -13,7 +13,6 @@ from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn try: - import lightllm HAS_LIGHTLLM_KERNEL = True except: HAS_LIGHTLLM_KERNEL = False @@ -41,7 +40,7 @@ def run_llama_test(test_config): model = model.half() shard_config = ShardConfig( - enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True + enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, extra_kwargs={"inference_only": True} ) infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False) @@ -61,7 +60,10 @@ def check_llama(rank, world_size, port): run_llama_test() -@pytest.mark.skipif(not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL, reason="kv-cache manager engine requires cuda version to be higher than 11.5") +@pytest.mark.skipif( + not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL, + reason="kv-cache manager engine requires cuda version to be higher than 11.5", +) @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run()