GQA (LLaMA 2), KV sharing across heads
Hard AttentionImplement Grouped Query Attention — used in LLaMA 2, Mistral, etc. to reduce KV cache size.
Like MHA, but with fewer KV heads than Q heads. Each group of Q heads shares the same K/V head.
• self.W_q: nn.Linear(d_model, d_model) — full Q projection
• self.W_k: nn.Linear(d_model, num_kv_heads * d_k) — reduced K projection
• self.W_v: nn.Linear(d_model, num_kv_heads * d_k) — reduced V projection
• self.W_o: nn.Linear(d_model, d_model) — output projection
• d_k = d_model // num_heads
• Expand KV heads with repeat_interleave to match Q heads
• When num_kv_heads == num_heads, should behave like standard MHA
Implement the function below. Use only basic PyTorch operations.
Use this code to debug before submitting.
Try solving it yourself first! Click below to reveal the solution.
For interactive practice with auto-grading, run TorchCode locally:pip install torch-judge then use check("gqa")
GQA (LLaMA 2), KV sharing across heads