I've written a lot of numerical JAX and PyTorch, now used in diverse applications across science (simulation of black holes, soil moisture, ...) and ML (large language models, large protein models, ...). I would particularly highlight:
-
Equinox: elegant neural networks.
-
Diffrax: numerical ODE/SDE solvers.
-
jaxtyping: shape/dtype annotations for arrays. (Also supports PyTorch etc, despite the name!)
A full list of other libraries
-
Lineax: linear/least-squares solvers.
-
Optimistix: root finding, least squares, etc.
-
sympy2jax: optimise your symbolic expressions via gradient descent!
-
Quax: multiple dispatch in JAX!
-
ESM2quinox: ESM2 implemented in JAX.
new!
- Wadler-Lindig: A better Python pretty-printer, based upon the theory of Wadler and Lindig.
-
MkPosters: Write academic posters in Markdown, style them with CSS, save them to PDF. No wrestling with LaTeX.
-
typst_pyimage: A Typst extension adding support for generating figures using inline Python code.
I am currently a tech lead on ML for protein engineering (lead optimization) at Cradle Bio, and founded much of the open-source scientific JAX ecosystem. I also hold an honorary lectureship at Imperial College London. I previously worked at Google X, and received my PhD from Oxford on neural differential equations.
My current interests include pretty much anything related to scientific machine learning and scientific computing! I've now worked across diverse parts of the field, from modern deep learning (protein language models) to classical methods (numerics), to everything in between (neural differential equations).
I am also known for having strong opinions on the importance of good software development! :)
Other links:
- Bluesky:
- Twitter:
- Google scholar: here
- Personal website: kidger.site
- Neural ODE/SDE textbook: arXiv/2202.02435