All Problems Description Template Solution

Adam Optimizer

Momentum + RMSProp, bias correction

Medium Training

Problem Description

Implement the Adam optimizer from scratch.

Signature

class MyAdam: def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8): ... def step(self): ... def zero_grad(self): ...

Algorithm (per parameter)

m = β1 * m + (1-β1) * grad v = β2 * v + (1-β2) * grad² m̂ = m / (1 - β1ᵗ) # bias correction v̂ = v / (1 - β2ᵗ) p -= lr * m̂ / (√v̂ + ε)

Template

Implement the function below. Use only basic PyTorch operations.

# ✏️ YOUR IMPLEMENTATION HERE class MyAdam: def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8): pass # store params, init m and v to zeros def step(self): pass # update params using Adam rule def zero_grad(self): pass # zero all gradients

Test Your Implementation

Use this code to debug before submitting.

# 🧪 Debug torch.manual_seed(0) w = torch.randn(4, 3, requires_grad=True) opt = MyAdam([w], lr=0.01) for i in range(5): loss = (w ** 2).sum() loss.backward() opt.step() opt.zero_grad() print(f'Step {i}: loss={loss.item():.4f}')

Reference Solution

Try solving it yourself first! Click below to reveal the solution.

# ✅ SOLUTION class MyAdam: def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8): self.params = list(params) self.lr = lr self.beta1, self.beta2 = betas self.eps = eps self.t = 0 self.m = [torch.zeros_like(p) for p in self.params] self.v = [torch.zeros_like(p) for p in self.params] def step(self): self.t += 1 with torch.no_grad(): for i, p in enumerate(self.params): if p.grad is None: continue self.m[i] = self.beta1 * self.m[i] + (1 - self.beta1) * p.grad self.v[i] = self.beta2 * self.v[i] + (1 - self.beta2) * p.grad ** 2 m_hat = self.m[i] / (1 - self.beta1 ** self.t) v_hat = self.v[i] / (1 - self.beta2 ** self.t) p -= self.lr * m_hat / (torch.sqrt(v_hat) + self.eps) def zero_grad(self): for p in self.params: if p.grad is not None: p.grad.zero_()

Tips

Run Locally

For interactive practice with auto-grading, run TorchCode locally:
pip install torch-judge then use check("adam")

Key Concepts

Momentum + RMSProp, bias correction

Adam Optimizer

Description Template Test Solution Tips