Building Neural Networks with Flax NNX
Over the past two weeks, we've learned that JAX is fast (jit), that it eliminates loops (vmap), and that it computes gradients automatically (grad). These are powerful primitives.
But if you've been f
kambale.dev11 min read