All Problems Description Template Solution

SwiGLU MLP

Gated FFN, SiLU(gate) * up, LLaMA/Mistral-style

Medium Fundamentals

Problem Description

Implement the SwiGLU MLP (feed-forward network) used in modern LLMs like LLaMA.

$$\text{SwiGLU}(x) = \text{down\_proj}\big(\text{SiLU}(\text{gate\_proj}(x)) \odot \text{up\_proj}(x)\big)$$

where \text{SiLU}(x) = x \cdot \sigma(x)

Signature

class SwiGLUMLP(nn.Module): def __init__(self, d_model: int, d_ff: int): ... def forward(self, x: torch.Tensor) -> torch.Tensor: ...

Requirements

• Inherit from nn.Module

self.gate_proj: nn.Linear(d_model, d_ff)

self.up_proj: nn.Linear(d_model, d_ff)

self.down_proj: nn.Linear(d_ff, d_model)

• Activation: SiLU (a.k.a. Swish) — F.silu or implement as x * torch.sigmoid(x)

Why SwiGLU?

Unlike the classic Linear → ReLU/GELU → Linear FFN, SwiGLU uses a gating mechanism:

the gate projection controls information flow, while the up projection provides the content.

This consistently outperforms standard FFNs in practice (PaLM, LLaMA, Mistral all use it).

Template

Implement the function below. Use only basic PyTorch operations.

# ✏️ YOUR IMPLEMENTATION HERE class SwiGLUMLP(nn.Module): def __init__(self, d_model, d_ff): super().__init__() pass # Initialize gate_proj, up_proj, down_proj def forward(self, x): pass # down_proj(silu(gate_proj(x)) * up_proj(x))

Test Your Implementation

Use this code to debug before submitting.

# 🧪 Debug mlp = SwiGLUMLP(d_model=64, d_ff=128) x = torch.randn(2, 8, 64) out = mlp(x) print("Output shape:", out.shape) # (2, 8, 64) print("Params:", sum(p.numel() for p in mlp.parameters()))

Reference Solution

Try solving it yourself first! Click below to reveal the solution.

# ✅ SOLUTION class SwiGLUMLP(nn.Module): def __init__(self, d_model, d_ff): super().__init__() self.gate_proj = nn.Linear(d_model, d_ff) self.up_proj = nn.Linear(d_model, d_ff) self.down_proj = nn.Linear(d_ff, d_model) def forward(self, x): return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))

Tips

Run Locally

For interactive practice with auto-grading, run TorchCode locally:
pip install torch-judge then use check("mlp")

Key Concepts

Gated FFN, SiLU(gate) * up, LLaMA/Mistral-style

SwiGLU MLP

Description Template Test Solution Tips