Great article! Found a small bug, adjusted_probs = torch.clamp( target_probs - torch.softmax(all_logits[position], dim=0), min=0.0 ) it should be target_probs - draft_probs You're essentially doing target_probs - target_probs, which means we're sampling from your else case, which means we're sampling from the target_probs but also in the rejected overlap space with draft model. If we do it correctly, we will be sampling from p'(t) which will be in the target_space not overlapping with draft model's space.