I’ve always been fascinated by how libraries like Hugging Face
transformerswork under the hood. So, I decided to build one myself from scratch. I picked JAX/Flax for the coolness. This project, ‘Banhxeo’, is the result—a deep dive into the guts of an NLP pipeline, especially the tokenizer.
1. The Tokenizer Pipeline: A (Hugging) Face-lift
The first goal was to replicate the Hugging Face tokenizers library. It’s a 4-stage process that takes a raw string and turns it into model-ready inputs 1.
My main Tokenizer class implements this exact flow in its __call__ method:
def __call__( self, texts: Union[str, List[str]], ...): pre_tokenized_strs = [] for text in texts: # Step 1: Normalized string (e.g., lowercase, NFC) normalized_string = NormalizedString.from_str(text) normalized_string = self.normalizer.normalize(normalized_string)
# Step 2: Pre-Tokenize (e.g., split on whitespace, bytes) pre_tokenized_str = self.pre_tokenizer.pre_tokenize( PreTokenizedString(splits=[Split(normalized=normalized_string)]) )
# Step 3: Model (Turn pre-tokenized splits into token IDs) self.model.tokenize(pre_tokenized_str) pre_tokenized_strs.append(pre_tokenized_str)
# Step 4: Post-process (Add special tokens [CLS], [SEP], padding, etc.) post_process_config = ProcessConfig(...) post_process_result = self.post_processor.process_batch( pre_tokenized_strs, config=post_process_config )
# Convert to JAX arrays return { key: jnp.array(value, dtype=jnp.int32) for key, value in post_process_result.items() }I implemented several options for each stage, like ByteLevelPreTokenizer (for GPT-2 style) and BertPostProcessor/GPTPostProcessor to handle the different special tokens.
2. From-Scratch BPE: Heaps of Fun
The most complex part was building the BPEModel (Byte-Pair Encoding) trainer from scratch. BPE works by counting all character pairs and iteratively merging the most frequent one.
To do this efficiently, you need to:
- Count all pairs in the corpus.
- Put them in a priority queue (a min-heap in Python) so you can always get the most frequent pair in time.
- When you merge a pair (e.g.,
('h', 'u')->'hu'), you have to update the counts for all new pairs that are created (e.g., if you had'p hu g', the counts for('p', 'h')and('u', 'g')are removed, and the count for('p', 'hu')and('hu', 'g')are added).
This logic is all in bpe.py:
def get_pair_stats(word_freqs: BPEWord) -> Tuple[PairStats, PairHeap]: pair_stats = defaultdict(int) for _, (frequency, split) in word_freqs.items(): # add each pair to pair stats for pair in set(itertools.pairwise(split)): pair_stats[pair] += frequency
# Then create a heap based on pair_stats # We use negative frequency because heapq is a min-heap pair_heap = [(-freq, pair) for pair, freq in pair_stats.items()] heapq.heapify(pair_heap) return pair_stats, pair_heap
# Inside BPEModel.train()def train( self, corpus: Iterable[PreTokenizedString], **kwargs,): ... # 1. Get initial stats and heap pair_stats, pair_heap = get_pair_stats(word_freqs)
# 2. Loop until vocab size is reached for _ in progress_bar( range(0, (vocab_size - initial_vocab_size)), desc="Training BPE", ): ... # 3. Pop most frequent pair from the heap most_freq_pair = None while pair_heap: nfreq, most_freq_pair = heapq.heappop(pair_heap) # This check is crucial to handle stale entries in the heap if (freq := pair_stats.get(most_freq_pair)) is None or freq != -nfreq: continue else: break # Found a valid, most-frequent pair ...
# 4. Add to merge rules self.merges[most_freq_pair] = rank
# 5. Merge this pair everywhere and update stats merge_pair( most_freq_pair, word_freqs, inverted_word_freqs, pair_stats, pair_heap ) ...3. The Full Stack: DataLoaders and Flax Models
A tokenizer isn’t much good without data and models. I expanded the library to include a data pipeline and some simple Flax models.
a. A PyTorch-powered JAX DataLoader
The JAX ecosystem doesn’t have a multi-process DataLoader as mature as PyTorch’s. Instead of reinventing the wheel, I made a wrapper. My DataLoader class:
- Uses
torch.utils.data.DataLoaderunder the hood ifUSE_TORCH=Trueandnum_workers > 0. - It feeds the Torch loader a “dummy” dataset and uses a custom
collate_fnthat calls my actual JAX-based dataset’s__getitems__method. - If PyTorch isn’t available or
num_workers=0, it falls back to a simple, single-processNaiveDataLoader.
class DataLoader: def __init__( self, dataset, batch_size: int, ... num_workers: int = 0, **kwargs, ): self.dataset = dataset if USE_TORCH and num_workers > 0: # 1. Create a dummy dataset with just the length adapter = TorchDummyDataset(len(dataset))
# 2. Create a collator that calls our *real* dataset collate_fn = TorchCollator(dataset) ...
# 3. Use the powerful Torch DataLoader for multiprocessing self._loader = TorchDataLoader( dataset=adapter, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, collate_fn=collate_fn, **kwargs, ) else: # 4. Fallback to a simple single-process loader self._loader = NaiveDataLoader( dataset=dataset, batch_size=batch_size, ... )b. Flax Models
Finally, I implemented some standard models in Flax, like a simple MLP for text classification. It handles embedding aggregation (mean, sum, max) and builds a simple feed-forward network.
class MLP(nn.Module): vocab_size: int output_size: int embedding_dim: int hidden_sizes: List[int] aggregate_strategy: str = "mean"
@nn.compact def __call__( self, input_ids: Integer[jax.Array, "batch seq"], attention_mask: Integer[jax.Array, "batch seq"], dropout: bool = True, ): embeddings = nn.Embed( num_embeddings=self.vocab_size, features=self.embedding_dim, )(input_ids)
# attention_mask: (batch, seq) -> (batch, seq, 1) mask_expanded = attention_mask[:, None].astype(jnp.float32)
# Aggregate embeddings based on strategy match self.aggregate_strategy: case "mean": summed = einops.reduce( embeddings * mask_expanded, "batch seq dim -> batch dim", "sum" ) count = einops.reduce( mask_expanded, "batch seq 1 -> batch 1", "sum" ).clip(min=1e-9) x = summed / count case "sum": ... case "max": ...
# Pass through hidden layers for hidden_dim in self.hidden_sizes: x = nn.Dense(hidden_dim)(x) x = getattr(jnn, self.activation_fn.lower())(x) if self.dropout_rate > 0: x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not dropout)
logits = nn.Dense(self.output_size)(x) return logitsThis project was a fantastic learning experience in building a modular, testable NLP library. Digging into the tokenizer logic was surprisingly complex but incredibly rewarding. Now, to actually finish that
GPT2model… 😅