Custom training loop from scratch in JAX
For the past three weeks, we've been building up to this moment.
Week 1 taught us that JAX is fast. Week 2 showed us how to eliminate loops with vmap and compute gradients with grad. Week 3 gave us Fl
kambale.dev10 min read