All Problems Description Template Solution

GPT-2 Block

Pre-norm, causal MHA + MLP (4x, GELU), residual

Hard Architecture

Problem Description

Implement a full GPT-2 style Transformer block โ€” combining everything you've learned.

Architecture (Pre-Norm)

x = x + causal_self_attention(ln1(x)) x = x + mlp(ln2(x))

Signature

class GPT2Block(nn.Module): def __init__(self, d_model: int, num_heads: int): ... def forward(self, x: torch.Tensor) -> torch.Tensor: ...

Requirements

• Inherit from nn.Module

self.ln1, self.ln2: nn.LayerNorm(d_model)

self.W_q, self.W_k, self.W_v, self.W_o: nn.Linear for attention

self.mlp: nn.Sequential(Linear(d, 4d), GELU(), Linear(4d, d))

• Attention must be causal (mask future positions)

• Pre-norm architecture (LayerNorm *before* attention and MLP)

• Residual connections around both attention and MLP

Template

Implement the function below. Use only basic PyTorch operations.

# โœ๏ธ YOUR IMPLEMENTATION HERE class GPT2Block(nn.Module): def __init__(self, d_model, num_heads): super().__init__() pass # Initialize layers def forward(self, x): pass # Pre-norm + causal attention + MLP with residuals

Test Your Implementation

Use this code to debug before submitting.

# ๐Ÿงช Debug torch.manual_seed(0) block = GPT2Block(d_model=64, num_heads=4) x = torch.randn(2, 8, 64) out = block(x) print("Output shape:", out.shape) # (2, 8, 64) print("Is nn.Module?", isinstance(block, nn.Module)) print("Params:", sum(p.numel() for p in block.parameters()))

Reference Solution

Try solving it yourself first! Click below to reveal the solution.

# โœ… SOLUTION class GPT2Block(nn.Module): def __init__(self, d_model, num_heads): super().__init__() self.num_heads = num_heads self.d_k = d_model // num_heads self.ln1 = nn.LayerNorm(d_model) self.ln2 = nn.LayerNorm(d_model) self.W_q = nn.Linear(d_model, d_model) self.W_k = nn.Linear(d_model, d_model) self.W_v = nn.Linear(d_model, d_model) self.W_o = nn.Linear(d_model, d_model) self.mlp = nn.Sequential( nn.Linear(d_model, 4 * d_model), nn.GELU(), nn.Linear(4 * d_model, d_model), ) def _attn(self, x): B, S, _ = x.shape q = self.W_q(x).view(B, S, self.num_heads, self.d_k).transpose(1, 2) k = self.W_k(x).view(B, S, self.num_heads, self.d_k).transpose(1, 2) v = self.W_v(x).view(B, S, self.num_heads, self.d_k).transpose(1, 2) scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) mask = torch.triu(torch.ones(S, S, device=x.device, dtype=torch.bool), diagonal=1) scores = scores.masked_fill(mask, float('-inf')) weights = torch.softmax(scores, dim=-1) attn = torch.matmul(weights, v) return self.W_o(attn.transpose(1, 2).contiguous().view(B, S, -1)) def forward(self, x): x = x + self._attn(self.ln1(x)) x = x + self.mlp(self.ln2(x)) return x

Tips

Run Locally

For interactive practice with auto-grading, run TorchCode locally:
pip install torch-judge then use check("gpt2_block")

Key Concepts

Pre-norm, causal MHA + MLP (4x, GELU), residual

GPT-2 Block

Description Template Test Solution Tips