Skip to content

Commit

Permalink
Fix num kv heads per device calculation for num devices > num kv heads
Browse files Browse the repository at this point in the history
Signed-off-by: Salar <[email protected]>
  • Loading branch information
skhorasganiTT committed Jan 3, 2025
1 parent 8240715 commit 7246419
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions vllm/worker/tt_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,12 @@ def get_num_kv_heads(
device_config: DeviceConfig,
) -> int:
'''
Returns the number of KV heads per attention layer. Makes the assumption
that we are tensor parallel by the number of devices.
Returns the number of KV heads per attention layer (per device). Makes the assumption
that we are tensor parallel by min(number of devices, number of KV heads).
'''
num_devices = len(device_config.device.get_devices())
num_kv_heads = model_config.get_num_kv_heads(parallel_config)
num_kv_heads //= num_devices # TP = num_devices
num_kv_heads //= min(num_devices, num_kv_heads) # TP = num_devices if num_devices < num_kv_heads
return num_kv_heads

@staticmethod
Expand Down

0 comments on commit 7246419

Please sign in to comment.