All Problems Description Template Solution

BPE Tokenizer

Byte-pair encoding, merge rules, subword splits

Hard Advanced

Problem Description

Implement a simple BPE tokenizer โ€” the foundation of GPT/LLaMA tokenization.

Signature

class SimpleBPE: def __init__(self): ... def train(self, corpus: list[str], num_merges: int): ... def encode(self, text: str) -> list[str]: ...

Algorithm (training)

1. Split each word into characters + </w> end marker

2. Count all adjacent pairs across the corpus

3. Merge the most frequent pair into a single token

4. Repeat for num_merges iterations

Template

Implement the function below. Use only basic PyTorch operations.

# โœ๏ธ YOUR IMPLEMENTATION HERE class SimpleBPE: def __init__(self): self.merges = [] def train(self, corpus, num_merges): pass # iteratively find & merge most frequent pairs def encode(self, text): pass # apply learned merges to split text

Test Your Implementation

Use this code to debug before submitting.

# ๐Ÿงช Debug bpe = SimpleBPE() bpe.train(['low', 'low', 'low', 'lower', 'newest', 'widest'], num_merges=10) print('Merges:', bpe.merges[:5]) print('Encode:', bpe.encode('low lower'))

Reference Solution

Try solving it yourself first! Click below to reveal the solution.

# โœ… SOLUTION class SimpleBPE: def __init__(self): self.merges = [] def train(self, corpus, num_merges): vocab = {} for word in corpus: symbols = tuple(word) + ('</w>',) vocab[symbols] = vocab.get(symbols, 0) + 1 self.merges = [] for _ in range(num_merges): pairs = {} for word, freq in vocab.items(): for i in range(len(word) - 1): pair = (word[i], word[i + 1]) pairs[pair] = pairs.get(pair, 0) + freq if not pairs: break best = max(pairs, key=pairs.get) self.merges.append(best) new_vocab = {} for word, freq in vocab.items(): new_word = [] i = 0 while i < len(word): if i < len(word) - 1 and (word[i], word[i + 1]) == best: new_word.append(word[i] + word[i + 1]) i += 2 else: new_word.append(word[i]) i += 1 new_vocab[tuple(new_word)] = freq vocab = new_vocab def encode(self, text): all_tokens = [] for word in text.split(): symbols = list(word) + ['</w>'] for a, b in self.merges: i = 0 while i < len(symbols) - 1: if symbols[i] == a and symbols[i + 1] == b: symbols = symbols[:i] + [a + b] + symbols[i + 2:] else: i += 1 all_tokens.extend(symbols) return all_tokens

Tips

Run Locally

For interactive practice with auto-grading, run TorchCode locally:
pip install torch-judge then use check("bpe")

Key Concepts

Byte-pair encoding, merge rules, subword splits

BPE Tokenizer

Description Template Test Solution Tips