Incremental decoding, cache K/V, prefill vs decode
Hard AttentionImplement multi-head attention with KV caching for efficient autoregressive generation.
During LLM inference, recomputing all key/value projections at every step is wasteful.
A KV cache stores previously computed K and V tensors so only the new token(s) need projection.
• Inherit from nn.Module
• self.W_q, self.W_k, self.W_v, self.W_o: nn.Linear projections
• When cache=None (prefill): apply causal mask, return all K/V as cache
• When cache provided (decode): concat new K/V with cached, no causal mask needed for single-token decode
• Incremental decode must produce identical results to full forward pass
Implement the function below. Use only basic PyTorch operations.
Use this code to debug before submitting.
Try solving it yourself first! Click below to reveal the solution.
For interactive practice with auto-grading, run TorchCode locally:pip install torch-judge then use check("kv_cache")
Incremental decoding, cache K/V, prefill vs decode