From 688606fe104e6f336c76deb4aa216622e236e016 Mon Sep 17 00:00:00 2001 From: Seungduk Kim Date: Mon, 27 May 2024 23:46:01 +0900 Subject: [PATCH] Add cache_config for DeepseekV2 --- vllm/model_executor/models/deepseek_v2.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 0b96396417787..918ea8d3b13d5 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -28,6 +28,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) @@ -197,6 +198,7 @@ def __init__( rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, layer_idx=None, ) -> None: @@ -276,7 +278,8 @@ def __init__( self.attn = Attention(self.num_local_heads, 256, self.scaling, - num_kv_heads=self.num_local_heads) + num_kv_heads=self.num_local_heads, + cache_config=cache_config) def forward( self, @@ -333,6 +336,7 @@ def __init__( self, config: PretrainedConfig, layer_idx: int, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -354,6 +358,7 @@ def __init__( rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, + cache_config=cache_config, quant_config=quant_config, layer_idx=layer_idx, ) @@ -409,6 +414,7 @@ class DeepseekV2Model(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -422,6 +428,7 @@ def __init__( self.layers = nn.ModuleList([ DeepseekV2DecoderLayer(config, layer_idx, + cache_config=cache_config, quant_config=quant_config) for layer_idx in range(config.num_hidden_layers) ]) @@ -450,12 +457,13 @@ class DeepseekV2ForCausalLM(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config self.quant_config = quant_config - self.model = DeepseekV2Model(config, quant_config) + self.model = DeepseekV2Model(config, cache_config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler()