🍐 nguyen

Banhxeo

February 8, 2026 → Present
index

I’ve always wanted to understand how deep learning frameworks really work under the hood. Not just the API, but the graph execution, the kernel generation, the autograd. So I built one from scratch in ~2000 lines of Python. It’s called ‘banhxeo’—a lazy tensor library with automatic differentiation and Triton codegen.

1. The Philosophy: Lazy Evaluation

Modern frameworks are bloated. Reading PyTorch source code is like trying to understand a compiler by staring at assembly. banhxeo strips away the magic with one core idea: lazy evaluation.

When you write x + y, nothing is computed. Instead, a computation graph is built:

from banhxeo import Tensor
# 1. Define tensors (Lazy - no memory allocated for data)
x = Tensor.eye(3, requires_grad=True)
y = Tensor([[2.0, 0, -2.0]], requires_grad=True)
# 2. Build graph (nothing is computed yet!)
z = y.matmul(x).sum()
# 3. Backprop (Implicitly realizes the forward pass first)
z.backward()
# 4. Check gradients
print(x.grad.numpy())
print(y.grad.numpy())

The graph is only realized (executed) when you call .backward() or .realize(). This allows us to fuse operations and generate optimal GPU kernels.

2. The Core Abstractions: 4 Classes That Run Everything

The entire engine fits in your head. There are just four key classes:

ComponentPurpose
TensorThe user-facing API. Handles operator overloading and autograd state.
LazyBufferThe node in the computation graph. Tracks the operation (ADD, MUL) and its parents.
ViewHandles shapes and strides. Enables zero-copy reshapes, permutes, and slices.
TritonCodegenWalks the LazyBuffer graph and emits a Triton kernel string.

Here’s the core Tensor class that orchestrates everything:

src/banhxeo/tensor.py
class Tensor:
__slots__ = "lazydata", "requires_grad", "grad", "_ctx"
def __init__(
self,
data: Optional[Union[LazyBuffer, List, np.ndarray, ...]] = None,
...
):
if isinstance(data, LazyBuffer):
self.lazydata = data
elif isinstance(data, np.ndarray):
self.lazydata = LazyBuffer(
LoadOp.FROM_NUMPY,
view=View.create(shape=data.shape),
args=[data.flatten()],
device=device,
)
# ... other constructors
self.grad: Optional[Tensor] = None
self.requires_grad = requires_grad
self._ctx: Optional[Function] = None # for autograd
def add(self, other) -> "Tensor":
from banhxeo.core.function import Add
return Add.apply(self, other)
def realize(self) -> "Tensor":
Device.get_backend(self.lazydata.device)().exec(self.lazydata)
return self

3. Zero-Copy Movement with View

One of the coolest parts is the View system. Operations like reshape, permute, and slice don’t copy any data. They just change how we interpret the underlying memory through strides.

src/banhxeo/core/view.py
@dataclass
class View:
shape: Tuple[int, ...]
strides: Tuple[int, ...]
offset: int = 0
def permute(self, new_axis: Tuple[int, ...]) -> "View":
# Just shuffle the strides - no data movement!
target_shape = [self.shape[ax] for ax in new_axis]
target_stride = [self.strides[ax] for ax in new_axis]
return View(tuple(target_shape), tuple(target_stride))
def broadcast_to(self, target_shape: Tuple[int, ...]) -> "View":
# Expanding a dimension? Just set its stride to 0!
new_strides = []
for dim_s, dim_t, stride_s in zip(padded_shape, target_shape, padded_strides):
if dim_s == dim_t:
new_strides.append(stride_s)
elif dim_s == 1:
new_strides.append(0) # Zero stride = repeat this element
return View(target_shape, tuple(new_strides))

When we finally need to read data, the codegen uses these strides to calculate the correct memory offset. This is how NumPy and PyTorch do it too.

4. Triton Codegen: From Graph to GPU Kernel

This is where the magic happens. The TritonCodegen class walks the LazyBuffer DAG and emits a single, fused Triton kernel string.

src/banhxeo/backend/triton.py
class TritonCodegen:
def visit_BinaryOp(self, buf: LazyBuffer, name: str):
src0 = self.get_var_name(buf.src[0])
src1 = self.get_var_name(buf.src[1])
op_map = {
BinaryOp.ADD: "+",
BinaryOp.MUL: "*",
BinaryOp.DIV: "/",
}
self.code.append(f" {name} = {src0} {op_map[buf.op]} {src1}")
def generate(self):
for buf in self.schedule:
self.visit(buf, self.get_var_name(buf))
# Emit the final kernel string
return "\n".join([
"@triton.jit",
f"def generated_kernel({', '.join(args_sig)}, out_ptr, N, ...):",
" pid = tl.program_id(0)",
" offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)",
*self.code,
f" tl.store(out_ptr + offsets, {output_name}, mask=...)"
])

Want to see the generated kernel? Set DEBUG=1:

Terminal window
DEBUG=1 python script.py
# --- GENERATED TRITON KERNEL ---
@triton.jit
def generated_kernel(in_0_ptr, in_1_const, out_ptr, N, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(0)
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
in_0 = tl.load(in_0_ptr + offsets, mask=offsets < N)
temp_0 = in_0 * in_1_const
temp_1 = tl.exp(temp_0)
tl.store(out_ptr + offsets, temp_1, mask=offsets < N)

5. Autograd: Backprop from Scratch

The autograd engine is surprisingly simple. Each operation is a Function subclass with forward and backward methods:

src/banhxeo/core/function.py
class Mul(Function):
def forward(self, x: LazyBuffer, y: LazyBuffer):
self.x, self.y = x, y
return x.compute_ops(BinaryOp.MUL, y)
def backward(self, grad_output: LazyBuffer):
# d(x*y)/dx = y, d(x*y)/dy = x
return (
self.y.compute_ops(BinaryOp.MUL, grad_output),
self.x.compute_ops(BinaryOp.MUL, grad_output)
)
class Matmul(Function):
def forward(self, x: LazyBuffer, y: LazyBuffer):
self.x, self.y = x, y
return x.matmul(y)
def backward(self, grad_output: LazyBuffer):
# d(X@Y)/dX = grad @ Y.T, d(X@Y)/dY = X.T @ grad
return (
grad_output.matmul(self.y.t()),
self.x.t().matmul(grad_output)
)

The backward() method on Tensor does a topological sort of the graph and calls each function’s backward in reverse order:

def backward(self, retain_graph: bool = False):
# Topological Sort
topo, visited = [], set()
def build_topo(t):
if t not in visited:
visited.add(t)
if t._ctx:
for parent in t._ctx.parents:
build_topo(parent)
topo.append(t)
build_topo(self)
# Backward pass (reverse order)
for t in reversed(topo):
grads = t._ctx.backward(t.grad.lazydata)
for parent, g in zip(t._ctx.parents, grads):
if g is not None and parent.requires_grad:
parent.grad = Tensor(g) if parent.grad is None else parent.grad + Tensor(g)

This project taught me more about deep learning internals than any course ever could. Building the View system was a “eureka” moment—suddenly, broadcasting and striding made complete sense. The Triton codegen is still rough, but seeing a fused kernel pop out of a simple Python expression is pure magic.