🍐 nguyen

Banhxeo

June 1, 2025 → Present
index

I’ve always been fascinated by how libraries like Hugging Face transformers work under the hood. So, I decided to build one myself from scratch. I picked JAX/Flax for the coolness. This project, ‘Banhxeo’, is the result—a deep dive into the guts of an NLP pipeline, especially the tokenizer.

1. The Tokenizer Pipeline: A (Hugging) Face-lift

The first goal was to replicate the Hugging Face tokenizers library. It’s a 4-stage process that takes a raw string and turns it into model-ready inputs 1.

My main Tokenizer class implements this exact flow in its __call__ method:

src/banhxeo/core/tokenizer/__init__.py
def __call__(
self,
texts: Union[str, List[str]],
...
):
pre_tokenized_strs = []
for text in texts:
# Step 1: Normalized string (e.g., lowercase, NFC)
normalized_string = NormalizedString.from_str(text)
normalized_string = self.normalizer.normalize(normalized_string)
# Step 2: Pre-Tokenize (e.g., split on whitespace, bytes)
pre_tokenized_str = self.pre_tokenizer.pre_tokenize(
PreTokenizedString(splits=[Split(normalized=normalized_string)])
)
# Step 3: Model (Turn pre-tokenized splits into token IDs)
self.model.tokenize(pre_tokenized_str)
pre_tokenized_strs.append(pre_tokenized_str)
# Step 4: Post-process (Add special tokens [CLS], [SEP], padding, etc.)
post_process_config = ProcessConfig(...)
post_process_result = self.post_processor.process_batch(
pre_tokenized_strs, config=post_process_config
)
# Convert to JAX arrays
return {
key: jnp.array(value, dtype=jnp.int32)
for key, value in post_process_result.items()
}

I implemented several options for each stage, like ByteLevelPreTokenizer (for GPT-2 style) and BertPostProcessor/GPTPostProcessor to handle the different special tokens.

2. From-Scratch BPE: Heaps of Fun

The most complex part was building the BPEModel (Byte-Pair Encoding) trainer from scratch. BPE works by counting all character pairs and iteratively merging the most frequent one.

To do this efficiently, you need to:

  1. Count all pairs in the corpus.
  2. Put them in a priority queue (a min-heap in Python) so you can always get the most frequent pair in O(1)O(1) time.
  3. When you merge a pair (e.g., ('h', 'u') -> 'hu'), you have to update the counts for all new pairs that are created (e.g., if you had 'p hu g', the counts for ('p', 'h') and ('u', 'g') are removed, and the count for ('p', 'hu') and ('hu', 'g') are added).

This logic is all in bpe.py:

src/banhxeo/core/tokenizer/model/bpe.py
def get_pair_stats(word_freqs: BPEWord) -> Tuple[PairStats, PairHeap]:
pair_stats = defaultdict(int)
for _, (frequency, split) in word_freqs.items():
# add each pair to pair stats
for pair in set(itertools.pairwise(split)):
pair_stats[pair] += frequency
# Then create a heap based on pair_stats
# We use negative frequency because heapq is a min-heap
pair_heap = [(-freq, pair) for pair, freq in pair_stats.items()]
heapq.heapify(pair_heap)
return pair_stats, pair_heap
# Inside BPEModel.train()
def train(
self,
corpus: Iterable[PreTokenizedString],
**kwargs,
):
...
# 1. Get initial stats and heap
pair_stats, pair_heap = get_pair_stats(word_freqs)
# 2. Loop until vocab size is reached
for _ in progress_bar(
range(0, (vocab_size - initial_vocab_size)),
desc="Training BPE",
):
...
# 3. Pop most frequent pair from the heap
most_freq_pair = None
while pair_heap:
nfreq, most_freq_pair = heapq.heappop(pair_heap)
# This check is crucial to handle stale entries in the heap
if (freq := pair_stats.get(most_freq_pair)) is None or freq != -nfreq:
continue
else:
break # Found a valid, most-frequent pair
...
# 4. Add to merge rules
self.merges[most_freq_pair] = rank
# 5. Merge this pair everywhere and update stats
merge_pair(
most_freq_pair, word_freqs, inverted_word_freqs, pair_stats, pair_heap
)
...

3. The Full Stack: DataLoaders and Flax Models

A tokenizer isn’t much good without data and models. I expanded the library to include a data pipeline and some simple Flax models.

a. A PyTorch-powered JAX DataLoader

The JAX ecosystem doesn’t have a multi-process DataLoader as mature as PyTorch’s. Instead of reinventing the wheel, I made a wrapper. My DataLoader class:

  • Uses torch.utils.data.DataLoader under the hood if USE_TORCH=True and num_workers > 0.
  • It feeds the Torch loader a “dummy” dataset and uses a custom collate_fn that calls my actual JAX-based dataset’s __getitems__ method.
  • If PyTorch isn’t available or num_workers=0, it falls back to a simple, single-process NaiveDataLoader.
src/banhxeo/data/loader.py
class DataLoader:
def __init__(
self,
dataset,
batch_size: int,
...
num_workers: int = 0,
**kwargs,
):
self.dataset = dataset
if USE_TORCH and num_workers > 0:
# 1. Create a dummy dataset with just the length
adapter = TorchDummyDataset(len(dataset))
# 2. Create a collator that calls our *real* dataset
collate_fn = TorchCollator(dataset)
...
# 3. Use the powerful Torch DataLoader for multiprocessing
self._loader = TorchDataLoader(
dataset=adapter,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
collate_fn=collate_fn,
**kwargs,
)
else:
# 4. Fallback to a simple single-process loader
self._loader = NaiveDataLoader(
dataset=dataset,
batch_size=batch_size,
...
)

b. Flax Models

Finally, I implemented some standard models in Flax, like a simple MLP for text classification. It handles embedding aggregation (mean, sum, max) and builds a simple feed-forward network.

src/banhxeo/model/mlp.py
class MLP(nn.Module):
vocab_size: int
output_size: int
embedding_dim: int
hidden_sizes: List[int]
aggregate_strategy: str = "mean"
@nn.compact
def __call__(
self,
input_ids: Integer[jax.Array, "batch seq"],
attention_mask: Integer[jax.Array, "batch seq"],
dropout: bool = True,
):
embeddings = nn.Embed(
num_embeddings=self.vocab_size,
features=self.embedding_dim,
)(input_ids)
# attention_mask: (batch, seq) -> (batch, seq, 1)
mask_expanded = attention_mask[:, None].astype(jnp.float32)
# Aggregate embeddings based on strategy
match self.aggregate_strategy:
case "mean":
summed = einops.reduce(
embeddings * mask_expanded, "batch seq dim -> batch dim", "sum"
)
count = einops.reduce(
mask_expanded, "batch seq 1 -> batch 1", "sum"
).clip(min=1e-9)
x = summed / count
case "sum":
...
case "max":
...
# Pass through hidden layers
for hidden_dim in self.hidden_sizes:
x = nn.Dense(hidden_dim)(x)
x = getattr(jnn, self.activation_fn.lower())(x)
if self.dropout_rate > 0:
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not dropout)
logits = nn.Dense(self.output_size)(x)
return logits

This project was a fantastic learning experience in building a modular, testable NLP library. Digging into the tokenizer logic was surprisingly complex but incredibly rewarding. Now, to actually finish that GPT2 model… 😅

Footnotes

  1. https://huggingface.co/docs/tokenizers/pipeline