JIT Compilation in Machine Learning

Just-in-time compilation in machine learning takes high-level array code written in Python and compiles it, at runtime, into fast machine code for the target accelerator. Plain vectorized code is already fast per operation, but it still launches one kernel per operation and shuttles intermediate results in and out of memory between them. A JIT compiler can look at a whole sequence of operations at once and generate a single optimized program, eliminating that per-operation overhead and the redundant memory traffic.

JAX makes this the centerpiece of its design. Wrapping a function in jit traces it into an intermediate representation called a jaxpr, and, as the documentation describes, “the jaxpr is then compiled using XLA into very efficient code optimized for your GPU or TPU.” The trace captures the computation as a graph; XLA then optimizes that graph as a unit before any numbers are crunched. JAX’s own example shows roughly an 8x speedup from JIT on a small function, precisely because the compiler optimizes the whole computation rather than running each operation in isolation.

The compiler doing the heavy lifting in the JAX and TensorFlow worlds is XLA. The XLA project describes itself as taking “models from popular frameworks such as PyTorch, TensorFlow, and JAX” and optimizing them “for high-performance execution across different hardware platforms including GPUs, CPUs, and ML accelerators.” Its signature optimization is kernel fusion - combining many small operations into one larger kernel - so that, for instance, a chain of elementwise math becomes a single pass over memory instead of several. That is exactly the gap between naive array code and hand-tuned C that the JIT is built to close.

The pattern recurs throughout the ecosystem. PyTorch added torch.compile to capture and compile eager-mode graphs the same way, gaining graph-level optimization without giving up its define-by-run programming model. Numba JIT-compiles annotated Python numerical functions down to machine code via LLVM. In every case the bargain is the same: keep the convenience of writing in Python, but defer the heavy computation to a compiler that sees the whole graph.

The deeper point is that these systems turn a dynamic, interpreted language into a front end for an optimizing compiler. Tracing converts ordinary Python into a computational graph; the JIT then treats that graph the way a traditional compiler treats source code - analyzing, fusing, and scheduling it for the hardware. This is how modern ML frameworks deliver compiled-language performance from code that looks like a Python script, and it is a direct continuation of the long arc of compiler technology applied to a new domain.

Sources

Last verified June 8, 2026