diff --git a/awq/modules/fused/attn.py b/awq/modules/fused/attn.py index 010f039b..04bb6fa3 100644 --- a/awq/modules/fused/attn.py +++ b/awq/modules/fused/attn.py @@ -146,7 +146,8 @@ def _get_attention_shapes(self, attention_shapes, max_seq_len): def forward( self, - hidden_states:torch.Tensor, past_key_value=None, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False + hidden_states:torch.Tensor, past_key_value=None, attention_mask=None, position_ids=None, + output_attentions=False, use_cache=False, *args, **kwargs ): bsz, seqlen, _ = hidden_states.shape if bsz != self.cache_batch_size: