Micro-batching, loss scaling
Easy FundamentalsImplement a training step with gradient accumulation โ simulating large batches with limited memory.
1. optimizer.zero_grad()
2. For each (x, y) in micro_batches: loss = loss_fn(model(x), y) / len(micro_batches), then loss.backward()
3. optimizer.step()
4. Return total accumulated loss
The key insight: dividing each loss by n before backward makes accumulated gradients equal to a single large-batch gradient.
Implement the function below. Use only basic PyTorch operations.
Use this code to debug before submitting.
Try solving it yourself first! Click below to reveal the solution.
For interactive practice with auto-grading, run TorchCode locally:pip install torch-judge then use check("gradient_accumulation")
Micro-batching, loss scaling