All Problems Description Template Solution

Cross-Attention

Encoder-decoder, Q from decoder, K/V from encoder

Medium Attention

Problem Description

Implement multi-head cross-attention (encoder-decoder attention).

Signature

class MultiHeadCrossAttention(nn.Module): def __init__(self, d_model: int, num_heads: int): ... def forward(self, x_q: Tensor, x_kv: Tensor) -> Tensor: # x_q: (B, S_q, D) โ€” decoder queries # x_kv: (B, S_kv, D) โ€” encoder keys/values

Key Differences from Self-Attention

• Q comes from the decoder, K and V come from the encoder

• No causal mask (all encoder positions visible)

Template

Implement the function below. Use only basic PyTorch operations.

# โœ๏ธ YOUR IMPLEMENTATION HERE class MultiHeadCrossAttention(nn.Module): def __init__(self, d_model, num_heads): super().__init__() pass # W_q, W_k, W_v, W_o def forward(self, x_q, x_kv): pass # Q from x_q, K/V from x_kv, no causal mask

Test Your Implementation

Use this code to debug before submitting.

# ๐Ÿงช Debug attn = MultiHeadCrossAttention(64, 4) x_q = torch.randn(2, 6, 64) x_kv = torch.randn(2, 10, 64) print('Output:', attn(x_q, x_kv).shape)

Reference Solution

Try solving it yourself first! Click below to reveal the solution.

# โœ… SOLUTION class MultiHeadCrossAttention(nn.Module): def __init__(self, d_model, num_heads): super().__init__() self.num_heads = num_heads self.d_k = d_model // num_heads self.W_q = nn.Linear(d_model, d_model) self.W_k = nn.Linear(d_model, d_model) self.W_v = nn.Linear(d_model, d_model) self.W_o = nn.Linear(d_model, d_model) def forward(self, x_q, x_kv): B, S_q, _ = x_q.shape S_kv = x_kv.shape[1] q = self.W_q(x_q).view(B, S_q, self.num_heads, self.d_k).transpose(1, 2) k = self.W_k(x_kv).view(B, S_kv, self.num_heads, self.d_k).transpose(1, 2) v = self.W_v(x_kv).view(B, S_kv, self.num_heads, self.d_k).transpose(1, 2) scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) weights = torch.softmax(scores, dim=-1) attn = torch.matmul(weights, v) return self.W_o(attn.transpose(1, 2).contiguous().view(B, S_q, -1))

Tips

Run Locally

For interactive practice with auto-grading, run TorchCode locally:
pip install torch-judge then use check("cross_attention")

Key Concepts

Encoder-decoder, Q from decoder, K/V from encoder

Cross-Attention

Description Template Test Solution Tips