-
Notifications
You must be signed in to change notification settings - Fork 2
Open
Labels
Description
nanoDFT computes forces on the CPU using def grad(..) on line 230. To run def grad(..) on the IPU it is sufficient to port lines 269-273 and line 283.
Different strategies for porting lines 269-273:
- Compile libcint to poplar and replace all
mol.intor(..)with corresponding poplar calls (ERI is only problematic part). - Use Jax implementation from D4FT for the forward pass of the
mol.intor(..)and match up thejax.grad(..)of the forward passes with lines 269-273 (pyscfad matched up libcint withjax.gradfor CPU => their code may be helpful). - Reimplement all integrals from first principles in Jax/tesselate.
Note: Line 230 uses this theorem to compute gradients. We could use jax.grad(_nanoDFT) instead of the theorem. That would require us to fix all calls in _nanoDFT(..) which don't support derivatives. We currently believe the work involved is the same as fixing def grad(..) (see the above different strategies). In other words: the non-autograd stuff _nanoDFT calls are calls which have derivatives as computed on line 269-273 and 283.