Building Neural Networks with Flax NNX
Feb 23 · 11 min read · Over the past two weeks, we've learned that JAX is fast (jit), that it eliminates loops (vmap), and that it computes gradients automatically (grad). These are powerful primitives.
But if you've been f