-
Notifications
You must be signed in to change notification settings - Fork 36
A new blog post about Pytrees for Scientific Python #250
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
4d71fbe
dd6c7ca
e4804e1
d71c8e8
653296b
628d645
7d6b287
99fe18c
cd443a3
13d1700
17f2f95
b08dcb4
76bbff4
7a124b1
19262dd
bc98572
11fde89
0555d93
4db8fba
50a3b2e
4dbf7fd
ae7764e
8c33d24
9bde52e
5e59e37
e07a2ac
89d464f
4c548bc
254ad90
b5acc95
4e14a11
aa851f0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,212 @@ | ||
--- | ||
title: "Pytrees for Scientific Python" | ||
date: 2025-05-14T10:27:59-07:00 | ||
draft: false | ||
description: " | ||
Introducing PyTrees for Scientific Python. We discuss what PyTrees are, how they're useful in the realm of scientific Python, and how to work _efficiently_ with them. | ||
" | ||
tags: ["PyTrees", "Functional Programming", "Tree-like data manipulation"] | ||
displayInList: true | ||
author: ["Peter Fackeldey", "Mihai Maruseac", "Matthew Feickert"] | ||
summary: | | ||
This blog introduces PyTrees — nested Python data structures (such as lists, dicts, and tuples) with numerical leaf values — designed to simplify working with complex, hierarchically organized data. | ||
While such structures are often cumbersome to manipulate, PyTrees make them more manageable by allowing them to be flattened into a list of leaves along with a reusable structure blueprint in a _generic_ way. | ||
This enables flexible, generic operations like mapping and reducing from functional programming. | ||
By bringing those functional paradigms to structured data, PyTrees let you focus on what transformations to apply, not how to traverse the structure — no matter how deeply nested or complex it is. | ||
--- | ||
|
||
## Manipulating Tree-like Data using Functional Programming Paradigms | ||
|
||
A "PyTree" is a nested collection of Python containers (e.g. dicts, (named) tuples, lists, ...), where the leaves are of interest. | ||
In the scientific world, such a PyTree could consist of experimental measurements of different properties at different timestamps and measurement settings resulting in a highly complex, nested and not necessarily rectangular data structure. | ||
Such collections can be cumbersome to manipulate _efficiently_, especially if they are nested any depth. | ||
It often requires complex recursive logic which usually does not generalize to other nested Python containers (PyTrees), e.g. for new measurements. | ||
|
||
The core concept of PyTrees is being able to flatten them into a flat collection of leaves and a "blueprint" of the tree structure, and then being able to unflatten them back into the original PyTree. | ||
This allows for the application of generic transformations. | ||
In this blog post, we use [`optree`](https://github.com/metaopt/optree/tree/main/optree) — a standalone PyTree library — that enables these transformations. It focuses on performance, is feature rich, has minimal dependencies, and has been adopted by [PyTorch](https://pytorch.org), [Keras](https://keras.io), and [TensorFlow](https://github.com/tensorflow/tensorflow) (through Keras) as a core dependency. | ||
For example, on a PyTree with NumPy arrays as leaves, taking the square root of each leaf with `optree.tree_map(np.sqrt, tree)`: | ||
|
||
```python | ||
import optree as pt | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FYI, optree 0.16.0 is released. I wonder if it is worth mentioning >>> import optree
>>> pytree = optree.pytree.reexport(namespace='my-pkg', module='my_pkg.pytree')
>>> pytree.flatten({'a': 1, 'b': 2})
([1, 2], PyTreeSpec({'a': *, 'b': *})) # foo/__init__.py
import optree
pytree = optree.pytree.reexport(namespace='foo')
# foo/bar.py
from foo import pytree
@pytree.dataclasses.dataclass
class Bar:
a: int
b: float
print(pytree.flatten({'a': 1, 'b': 2, 'c': Bar(3, 4.0)}))
# Output:
# ([1, 2, 3, 4.0], PyTreeSpec({'a': *, 'b': *, 'c': CustomTreeNode(Bar[()], [*, *])}, namespace='foo')) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's an awesome feature (which I'll certainly use for my libraries!). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It makes sense. |
||
import numpy as np | ||
|
||
# tuple of a list of a dict with an array as value, and an array | ||
tree = ([[{"foo": np.array([4.0])}], np.array([9.0])],) | ||
|
||
# sqrt of each leaf array | ||
sqrt_tree = pt.tree_map(np.sqrt, tree) | ||
print(f"{sqrt_tree=}") | ||
# >> sqrt_tree=([[{'foo': array([2.])}], array([3.])],) | ||
|
||
# reductions | ||
all_positive = all(np.all(x > 0.0) for x in pt.tree_iter(tree)) | ||
print(f"{all_positive=}") | ||
# >> all_positive=True | ||
|
||
summed = np.sum(pt.tree_reduce(sum, tree)) | ||
print(f"{summed=}") | ||
# >> summed=np.float64(13.0) | ||
``` | ||
|
||
The trick here is that these operations can be implemented in three steps, e.g. `tree_map`: | ||
|
||
```python | ||
# step 1: | ||
leaves, treedef = pt.tree_flatten(tree) | ||
|
||
# step 2: | ||
new_leaves = tuple(map(fun, leaves)) | ||
|
||
# step 3: | ||
result_tree = pt.tree_unflatten(treedef, new_leaves) | ||
``` | ||
|
||
### PyTree Origins | ||
|
||
Originally, the concept of PyTrees was developed by the [JAX](https://docs.jax.dev/en/latest/) project to make nested collections of JAX arrays work transparently at the "JIT-boundary" (the JAX JIT toolchain does not know about Python containers, only about JAX Arrays). | ||
However, PyTrees were quickly adopted by AI researchers for broader use-cases: semantically grouping layers of weights and biases in a list of named tuples (or dictionaries) is a common pattern in the JAX-AI-world, as shown in the following (pseudo) Python snippet: | ||
|
||
```python | ||
from typing import NamedTuple, Callable | ||
import jax | ||
import jax.numpy as jnp | ||
|
||
|
||
class Layer(NamedTuple): | ||
W: jax.Array | ||
b: jax.Array | ||
|
||
|
||
layers = [ | ||
Layer(W=jnp.array(...), b=jnp.array(...)), # first layer | ||
Layer(W=jnp.array(...), b=jnp.array(...)), # second layer | ||
..., | ||
] | ||
|
||
|
||
@jax.jit | ||
def neural_network(layers: list[Layer], x: jax.Array) -> jax.Array: | ||
for layer in layers: | ||
x = jnp.tanh(layer.W @ x + layer.b) | ||
return x | ||
|
||
|
||
prediction = neural_network(layers=layers, x=jnp.array(...)) | ||
``` | ||
|
||
Here, `layers` is a PyTree — a `list` of multiple `Layer` — and the JIT compiled `neural_network` function _just works_ with this data structure as input. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This does not seem to be an example of a nested structure, just the final layer of a tree. Wouldn't it be pretty straightforward to handle a list using regular methods? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's a ('3-level') nested datastructure in the sense: |
||
Although you cannot see what happens inside of `jax.jit`, `layers` is automatically flattened by the `jax.jit` decorator to a flat iterable of arrays, which are understood by the JAX JIT toolchain in contrast to a Python `list` of `NamedTuples`. | ||
|
||
### PyTrees in Scientific Python | ||
|
||
Wouldn't it be nice to make workflows in the scientific Python ecosystem _just work_ with any PyTree? | ||
|
||
Giving semantic meaning to numeric data through PyTrees can be useful for applications outside of AI as well. | ||
Consider the following minimization of the [Rosenbrock](https://en.wikipedia.org/wiki/Rosenbrock_function) function: | ||
|
||
```python | ||
from scipy.optimize import minimize | ||
|
||
|
||
def rosenbrock(params: tuple[float]) -> float: | ||
""" | ||
Rosenbrock function. Minimum: f(1, 1) = 0. | ||
|
||
https://en.wikipedia.org/wiki/Rosenbrock_function | ||
""" | ||
x, y = params | ||
return (1 - x) ** 2 + 100 * (y - x**2) ** 2 | ||
|
||
|
||
x0 = (0.9, 1.2) | ||
res = minimize(rosenbrock, x0) | ||
print(res.x) | ||
# >> [0.99999569 0.99999137] | ||
``` | ||
|
||
Now, let's consider a minimization that uses a more complex type for the parameters — a NamedTuple that describes our fit parameters: | ||
|
||
```python | ||
import optree as pt | ||
from typing import NamedTuple, Callable | ||
from scipy.optimize import minimize as sp_minimize | ||
|
||
|
||
class Params(NamedTuple): | ||
x: float | ||
y: float | ||
|
||
|
||
def rosenbrock(params: Params) -> float: | ||
""" | ||
Rosenbrock function. Minimum: f(1, 1) = 0. | ||
|
||
https://en.wikipedia.org/wiki/Rosenbrock_function | ||
""" | ||
return (1 - params.x) ** 2 + 100 * (params.y - params.x**2) ** 2 | ||
|
||
|
||
def minimize(fun: Callable, params: Params) -> Params: | ||
# flatten and store PyTree definition | ||
flat_params, treedef = pt.tree_flatten(params) | ||
|
||
# wrap fun to work with flat_params | ||
def wrapped_fun(flat_params): | ||
params = pt.tree_unflatten(treedef, flat_params) | ||
return fun(params) | ||
|
||
# actual minimization | ||
res = sp_minimize(wrapped_fun, flat_params) | ||
|
||
# re-wrap the bestfit values into Params with stored PyTree definition | ||
return pt.tree_unflatten(treedef, res.x) | ||
|
||
|
||
# scipy minimize that works with any PyTree | ||
x0 = Params(x=0.9, y=1.2) | ||
bestfit_params = minimize(rosenbrock, x0) | ||
print(bestfit_params) | ||
# >> Params(x=np.float64(0.999995688776513), y=np.float64(0.9999913673387226)) | ||
``` | ||
|
||
This new `minimize` function works with _any_ PyTree! | ||
|
||
Let's now consider a modified and more complex version of the Rosenbrock function that relies on two sets of `Params` as input — a common pattern for hierarchical models (e.g. a superposition of various probability density functions): | ||
|
||
```python | ||
import numpy as np | ||
|
||
|
||
def rosenbrock_modified(two_params: tuple[Params, Params]) -> float: | ||
""" | ||
Modified Rosenbrock where the x and y parameters are determined by | ||
a non-linear transformations of two versions of each, i.e.: | ||
x = arcsin(min(x1, x2) / max(x1, x2)) | ||
y = sigmoid(x1 - x2) | ||
""" | ||
p1, p2 = two_params | ||
|
||
# calculate `x` and `y` from two sources: | ||
x = np.asin(min(p1.x, p2.x) / max(p1.x, p2.x)) | ||
y = 1 / (1 + np.exp(-(p1.y / p2.y))) | ||
|
||
return (1 - x) ** 2 + 100 * (y - x**2) ** 2 | ||
|
||
|
||
x0 = (Params(x=0.9, y=1.2), Params(x=0.8, y=1.3)) | ||
bestfit_params = minimize(rosenbrock_modified, x0) | ||
print(bestfit_params) | ||
# >> ( | ||
# Params(x=np.float64(4.686181110201706), y=np.float64(0.05129869722505759)), | ||
# Params(x=np.float64(3.9432263101976073), y=np.float64(0.005146110126174016)), | ||
# ) | ||
``` | ||
|
||
The new `minimize` still works, because a `tuple` of `Params` is just _another_ PyTree! | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm looking at this example, trying to imagine where a person would use such a thing. Not because I doubt it's useful, but because my experience is limited. A few more examples (not inline code examples, just a sentence or two listing them) may be helpful here. Like, the generalization above required significant changes to the rosenbrock function itself, so it is unclear what the benefit is to not having to modify the minimizer at the same time. But it hints that there may be families of problems that can more easily be solved. And that ties into the question:
What would we need to do to make that happen, and under which circumstances would that be useful? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think hierarchical models are something where this comes in very handy. It's not uncommon to fit a superposition of multiple functions (or pdfs, e.g. multiple gaussians) to data. Then it's nice to group those parameters semantically, instead of passing a flat list of floats / array of floats, where you have to remember what each value stands for.
I'm not suggesting we should change APIs in scipy and co, it's rather that PyTrees (with tools like Does that make sense? (Do you have a proposal to write this more clearly here?) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wasn't sure how I could be more precise about hierarchical models, so I've added an example in the text that I think is a common in scientific research: e07a2ac |
||
|
||
### Final Thought | ||
|
||
Working with nested data structures doesn’t have to be messy. | ||
PyTrees let you focus on the data and the transformations you want to apply, in a generic manner. | ||
Whether you're building neural networks, optimizing scientific models, or just dealing with complex nested Python containers, PyTrees can make your code cleaner, more flexible, and just nicer to work with. |
Uh oh!
There was an error while loading. Please reload this page.