Norm-based clipping, direction preservation
Easy FundamentalsImplement gradient norm clipping โ a training stability technique.
1. Compute total norm: sqrt(sum(p.grad.norm()^2 for p in parameters))
2. If total > max_norm: scale all grads by max_norm / total
3. Return original total norm
Implement the function below. Use only basic PyTorch operations.
Use this code to debug before submitting.
Try solving it yourself first! Click below to reveal the solution.
For interactive practice with auto-grading, run TorchCode locally:pip install torch-judge then use check("gradient_clipping")
Norm-based clipping, direction preservation