Tracing a 13x PyTorch Slowdown to a Hidden NumPy Synchronization
TL;DR
A .cpu().numpy() call buried inside a forward pass was forcing a full CPU-GPU synchronization on every batch, every loop iteration. The GPU would finish its work in milliseconds, then sit idle
ingero.hashnode.dev6 min read