All Problems Description Template Solution

Grouped Query Attention

GQA (LLaMA 2), KV sharing across heads

Hard Attention

Problem Description

Implement Grouped Query Attention — used in LLaMA 2, Mistral, etc. to reduce KV cache size.

Like MHA, but with fewer KV heads than Q heads. Each group of Q heads shares the same K/V head.

Signature

class GroupQueryAttention: def __init__(self, d_model: int, num_heads: int, num_kv_heads: int): ... def forward(self, x) -> torch.Tensor: # self-attention

Requirements

self.W_q: nn.Linear(d_model, d_model) — full Q projection

self.W_k: nn.Linear(d_model, num_kv_heads * d_k) — reduced K projection

self.W_v: nn.Linear(d_model, num_kv_heads * d_k) — reduced V projection

self.W_o: nn.Linear(d_model, d_model) — output projection

d_k = d_model // num_heads

• Expand KV heads with repeat_interleave to match Q heads

• When num_kv_heads == num_heads, should behave like standard MHA

Template

Implement the function below. Use only basic PyTorch operations.

# ✏️ YOUR IMPLEMENTATION HERE class GroupQueryAttention: def __init__(self, d_model, num_heads, num_kv_heads): pass # Initialize projections def forward(self, x): pass # Self-attention with grouped KV

Test Your Implementation

Use this code to debug before submitting.

# 🧪 Debug torch.manual_seed(0) gqa = GroupQueryAttention(d_model=32, num_heads=8, num_kv_heads=2) print("W_q shape:", gqa.W_q.weight.shape) # (32, 32) print("W_k shape:", gqa.W_k.weight.shape) # (8, 32) — only 2 KV heads * d_k=4 x = torch.randn(2, 6, 32) out = gqa.forward(x) print("Output shape:", out.shape) # (2, 6, 32)

Reference Solution

Try solving it yourself first! Click below to reveal the solution.

# ✅ SOLUTION class GroupQueryAttention: def __init__(self, d_model, num_heads, num_kv_heads): self.num_heads = num_heads self.num_kv_heads = num_kv_heads self.d_k = d_model // num_heads self.W_q = nn.Linear(d_model, d_model) self.W_k = nn.Linear(d_model, num_kv_heads * self.d_k) self.W_v = nn.Linear(d_model, num_kv_heads * self.d_k) self.W_o = nn.Linear(d_model, d_model) def forward(self, x): B, S, _ = x.shape q = self.W_q(x).view(B, S, self.num_heads, self.d_k).transpose(1, 2) k = self.W_k(x).view(B, S, self.num_kv_heads, self.d_k).transpose(1, 2) v = self.W_v(x).view(B, S, self.num_kv_heads, self.d_k).transpose(1, 2) repeats = self.num_heads // self.num_kv_heads k = k.repeat_interleave(repeats, dim=1) v = v.repeat_interleave(repeats, dim=1) scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) weights = torch.softmax(scores, dim=-1) attn = torch.matmul(weights, v) out = attn.transpose(1, 2).contiguous().view(B, S, -1) return self.W_o(out)

Tips

Run Locally

For interactive practice with auto-grading, run TorchCode locally:
pip install torch-judge then use check("gqa")

Key Concepts

GQA (LLaMA 2), KV sharing across heads

Grouped Query Attention

Description Template Test Solution Tips