`torch.inference_mode` is supposed to provide a small speedup. https://pytorch.org/docs/stable/generated/torch.inference_mode.html