Image to patches to linear projection
Medium ArchitectureImplement the patch embedding layer from Vision Transformer (ViT).
1. Reshape image into non-overlapping patches: (B, C, H, W) โ (B, N, C*P*P)
2. Project each patch: nn.Linear(C*P*P, embed_dim)
3. num_patches = (img_size // patch_size) ** 2
Implement the function below. Use only basic PyTorch operations.
Use this code to debug before submitting.
Try solving it yourself first! Click below to reveal the solution.
For interactive practice with auto-grading, run TorchCode locally:pip install torch-judge then use check("vit_patch")
Image to patches to linear projection