Mixtral-style, top-k routing, expert MLPs
Hard ArchitectureImplement a Mixture of Experts layer (Mixtral / Switch Transformer style).
• self.router: nn.Linear(d_model, num_experts) — gating network
• self.experts: nn.ModuleList of MLPs (Linear→ReLU→Linear)
• For each token: select top-k experts, compute weighted sum of their outputs
Implement the function below. Use only basic PyTorch operations.
Use this code to debug before submitting.
Try solving it yourself first! Click below to reveal the solution.
For interactive practice with auto-grading, run TorchCode locally:pip install torch-judge then use check("moe")
Mixtral-style, top-k routing, expert MLPs