While torch.autocast(dtype=torch.float16) can almost halve the VRAM usage, during experimentation we found that this sometimes trigger overflow in RK4 integrator and caused the output to be ``nan``` in one or more channels.
Fixing this problem is likely to be beyond the scope of this project, as it is known in the numerical analysis community that loss of precision and accumulation of rounding error might trigger numerical instability in higher order integrator.
We therefore advice testing on the problem you are working on before using torch.autocast optimization.