-
Notifications
You must be signed in to change notification settings - Fork 36
A new blog post about Pytrees for Scientific Python #250
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 18 commits
4d71fbe
dd6c7ca
e4804e1
d71c8e8
653296b
628d645
7d6b287
99fe18c
cd443a3
13d1700
17f2f95
b08dcb4
76bbff4
7a124b1
19262dd
bc98572
11fde89
0555d93
4db8fba
50a3b2e
4dbf7fd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,210 @@ | ||||||
--- | ||||||
title: "Pytrees for Scientific Python" | ||||||
date: 2025-05-14T10:27:59-07:00 | ||||||
draft: false | ||||||
description: " | ||||||
Introducing PyTrees for Scientific Python. We discuss what PyTrees are, how they're useful in the realm of scientific Python, and how to work _efficiently_ with them. | ||||||
" | ||||||
tags: ["PyTrees", "Functional Programming", "Tree-like data manipulation"] | ||||||
displayInList: true | ||||||
author: ["Peter Fackeldey", "Mihai Maruseac", "Matthew Feickert"] | ||||||
summary: | | ||||||
This blog introduces PyTrees — nested Python data structures (such as lists, dicts, and tuples) with numerical leaf values — designed to simplify working with complex, hierarchically organized data. | ||||||
While such structures are often cumbersome to manipulate, PyTrees make them more manageable by allowing them to be flattened into a list of leaves along with a reusable structure blueprint in a _generic_ way. | ||||||
This enables flexible, generic operations like mapping and reducing from functional programming. | ||||||
By bringing those functional paradigms to structured data, PyTrees let you focus on what transformations to apply, not how to traverse the structure — no matter how deeply nested or complex it is. | ||||||
--- | ||||||
|
||||||
## Manipulating Tree-like Data using Functional Programming Paradigms | ||||||
|
||||||
A "PyTree" is a nested collection of Python containers (e.g. dicts, (named) tuples, lists, ...), where the leafs are of interest. | ||||||
pfackeldey marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
As you can imagine (or even experienced in the past), such arbitrary nested collections can be cumbersome to manipulate _efficiently_. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This motivation could be fleshed out a bit. I've never had to manipulate trees of dicts, but if I had to do so I'd probably pop them into a NetworkX graph. So, perhaps helpful to tell me what the types of issues are I might have run into, rather than expect me to know it already? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, I've not thought about using NetworkX for such cases. This could probably be a solution(?), but the difference here is that here we're not interested in the graph properties. We just want to access the leaves - no matter the structure of the tree - and manipulate / work with them. I would not know how to do this with NetworkX (but I'm also not very familiar with the library).
That's a good point! I'll try to rephrase this. From what I've seen in code snippets written by scientists, it's common to collect/track arrays/lists of numerical data in python containers (dicts, tuples, lists), e.g. if they measure multiple quantities in multiple experiment runs (this may become nested and complex fast). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
@dschult Can NetworkX efficiently grab leaf nodes and do operations on them? And if not, why not ;) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To grab the nodes of a directed network in NetworkX use: I think the tricky part for NetworkX would be handling unhashable objects. It might be possible to identify each container with an integer index to a list... It essentially flattens the PyTree while holding the relationships between containers in the networkx graph. From the list of tree containers along with a NetworkX graph, you could form a list of the leaf node indices as above, and then pass through the leaf level containers to manipulate the data inside each. If there are both values to be manipulated and containers in the same container it might be good to use two networkx graphs -- one with the container-container relationships and the other with container-value relationships. But I haven't thought much about how to identify containers vs values, etc. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the explanation! I think that helps me understand even better why PyTrees can be useful. |
||||||
It often requires complex recursive logic which usually does not generalize to other nested Python containers (PyTrees). | ||||||
|
||||||
The core concept of PyTrees is being able to flatten them into a flat collection of leafs and a "blueprint" of the tree structure, and then being able to unflatten them back into the original PyTree. | ||||||
pfackeldey marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
This allows to apply generic transformations, e.g. taking the square root of each leaf of a PyTree with a `tree_map(np.sqrt, pytree)` operation: | ||||||
pfackeldey marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
```python | ||||||
import optree as pt | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FYI, optree 0.16.0 is released. I wonder if it is worth mentioning >>> import optree
>>> pytree = optree.pytree.reexport(namespace='my-pkg', module='my_pkg.pytree')
>>> pytree.flatten({'a': 1, 'b': 2})
([1, 2], PyTreeSpec({'a': *, 'b': *})) # foo/__init__.py
import optree
pytree = optree.pytree.reexport(namespace='foo')
# foo/bar.py
from foo import pytree
@pytree.dataclasses.dataclass
class Bar:
a: int
b: float
print(pytree.flatten({'a': 1, 'b': 2, 'c': Bar(3, 4.0)}))
# Output:
# ([1, 2, 3, 4.0], PyTreeSpec({'a': *, 'b': *, 'c': CustomTreeNode(Bar[()], [*, *])}, namespace='foo')) |
||||||
import numpy as np | ||||||
|
||||||
# tuple of a list of a dict with an array as value, and an array | ||||||
pytree = ([[{"foo": np.array([4.0])}], np.array([9.0])],) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: I suggest using |
||||||
|
||||||
# sqrt of each leaf array | ||||||
sqrt_pytree = pt.tree_map(np.sqrt, pytree) | ||||||
print(f"{sqrt_pytree=}") | ||||||
# >> sqrt_pytree=([[{'foo': array([2.])}], array([3.])],) | ||||||
|
||||||
# reductions | ||||||
all_positive = pt.tree_all(pt.tree_map(lambda x: x > 0.0, pytree)) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
print(f"{all_positive=}") | ||||||
# >> all_positive=True | ||||||
|
||||||
summed = pt.tree_reduce(sum, pytree) | ||||||
print(f"{summed=}") | ||||||
# >> summed=array([13.]) | ||||||
``` | ||||||
|
||||||
The trick here is that these operations can be implemented in three steps, e.g. `tree_map`: | ||||||
|
||||||
```python | ||||||
# step 1: | ||||||
leafs, treedef = pt.tree_flatten(pytree) | ||||||
|
||||||
# step 2: | ||||||
new_leafs = tuple(map(fun, leafs)) | ||||||
|
||||||
# step 3: | ||||||
result_pytree = pt.tree_unflatten(treedef, new_leafs) | ||||||
``` | ||||||
|
||||||
Here, we use [`optree`](https://github.com/metaopt/optree/tree/main/optree) — a standalone PyTree library — that enables all these manipulations. It focuses on performance, is feature rich, has minimal dependencies, and has been adopted by [PyTorch](https://pytorch.org), [Keras](https://keras.io), and [TensorFlow](https://github.com/tensorflow/tensorflow) (through Keras) as a core dependency. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should I usually get stuck when I read something I don't know or understand, since I'm trying to figure out what I missed. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, if optree gives you all this, what is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll see if there's a good place to mention
|
||||||
|
||||||
### PyTree Origins | ||||||
|
||||||
Originally, the concept of PyTrees was developed by the [JAX](https://docs.jax.dev/en/latest/) project to make nested collections of JAX arrays work transparently at the "JIT-boundary" (the JAX JIT toolchain does not know about Python containers, only about JAX Arrays). | ||||||
However, PyTrees were quickly adopted by AI researchers for broader use-cases: semantically grouping layers of weights and biases in a list of named tuples (or dictionaries) is a common pattern in the JAX-AI-world, see the following (pseudo) Python snippet: | ||||||
|
||||||
```python | ||||||
from typing import NamedTuple, Callable | ||||||
import jax | ||||||
import jax.numpy as jnp | ||||||
|
||||||
|
||||||
class Layer(NamedTuple): | ||||||
W: jax.Array | ||||||
b: jax.Array | ||||||
|
||||||
|
||||||
layers = [ | ||||||
Layer(W=jnp.array(...), b=jnp.array(...)), # first layer | ||||||
Layer(W=jnp.array(...), b=jnp.array(...)), # second layer | ||||||
..., | ||||||
] | ||||||
|
||||||
|
||||||
@jax.jit | ||||||
def neural_network(layers: list[Layer], x: jax.Array) -> jax.Array: | ||||||
for layer in layers: | ||||||
x = jnp.tanh(layer.W @ x + layer.b) | ||||||
return x | ||||||
|
||||||
|
||||||
prediction = neural_network(layers=layers, x=jnp.array(...)) | ||||||
``` | ||||||
|
||||||
Here, `layers` is a PyTree — a `list` of multiple `Layer` — and the JIT compiled `neural_network` function _just works_ with this data structure as input. | ||||||
|
||||||
### PyTrees in Scientific Python | ||||||
|
||||||
Wouldn't it be nice to make workflows in the scientific Python ecosystem _just work_ with any PyTree? | ||||||
|
||||||
Giving semantic meaning to numeric data through PyTrees can be useful for applications outside of AI as well. | ||||||
Consider the following minimization of the [Rosenbrock](https://en.wikipedia.org/wiki/Rosenbrock_function) function: | ||||||
|
||||||
```python | ||||||
from scipy.optimize import minimize | ||||||
|
||||||
|
||||||
def rosenbrock(params: tuple[float]) -> float: | ||||||
""" | ||||||
Rosenbrock function. Minimum: f(1, 1) = 0. | ||||||
|
||||||
https://en.wikipedia.org/wiki/Rosenbrock_function | ||||||
""" | ||||||
x, y = params | ||||||
return (1 - x) ** 2 + 100 * (y - x**2) ** 2 | ||||||
|
||||||
|
||||||
x0 = (0.9, 1.2) | ||||||
res = minimize(rosenbrock, x0) | ||||||
print(res.x) | ||||||
# >> [0.99999569 0.99999137] | ||||||
``` | ||||||
|
||||||
Now, let's consider a minimization that uses a more complex type for the parameters — a NamedTuple that describes our fit parameters: | ||||||
|
||||||
```python | ||||||
import optree as pt | ||||||
from typing import NamedTuple, Callable | ||||||
from scipy.optimize import minimize as sp_minimize | ||||||
|
||||||
|
||||||
class Params(NamedTuple): | ||||||
x: float | ||||||
y: float | ||||||
|
||||||
|
||||||
def rosenbrock(params: Params) -> float: | ||||||
""" | ||||||
Rosenbrock function. Minimum: f(1, 1) = 0. | ||||||
|
||||||
https://en.wikipedia.org/wiki/Rosenbrock_function | ||||||
""" | ||||||
return (1 - params.x) ** 2 + 100 * (params.y - params.x**2) ** 2 | ||||||
|
||||||
|
||||||
def minimize(fun: Callable, params: Params) -> Params: | ||||||
# flatten and store PyTree definition | ||||||
flat_params, treedef = pt.tree_flatten(params) | ||||||
|
||||||
# wrap fun to work with flat_params | ||||||
def wrapped_fun(flat_params): | ||||||
params = pt.tree_unflatten(treedef, flat_params) | ||||||
return fun(params) | ||||||
|
||||||
# actual minimization | ||||||
res = sp_minimize(wrapped_fun, flat_params) | ||||||
|
||||||
# re-wrap the bestfit values into Params with stored PyTree definition | ||||||
return pt.tree_unflatten(treedef, res.x) | ||||||
|
||||||
|
||||||
# scipy minimize that works with any PyTree | ||||||
x0 = Params(x=0.9, y=1.2) | ||||||
bestfit_params = minimize(rosenbrock, x0) | ||||||
print(bestfit_params) | ||||||
# >> Params(x=np.float64(0.999995688776513), y=np.float64(0.9999913673387226)) | ||||||
``` | ||||||
|
||||||
This new `minimize` function works with _any_ PyTree! | ||||||
|
||||||
Let's now consider a modified and more complex version of the Rosenbrock function that relies on two sets of `Params` as input — a common pattern for hierarchical models: | ||||||
|
||||||
```python | ||||||
import numpy as np | ||||||
|
||||||
|
||||||
def rosenbrock_modified(two_params: tuple[Params, Params]) -> float: | ||||||
""" | ||||||
Modified Rosenbrock where the x and y parameters are determined by | ||||||
a non-linear transformations of two versions of each, i.e.: | ||||||
x = arcsin(min(x1, x2) / max(x1, x2)) | ||||||
y = sigmoid(x1 - x2) | ||||||
""" | ||||||
p1, p2 = two_params | ||||||
|
||||||
# calculate `x` and `y` from two sources: | ||||||
x = np.asin(min(p1.x, p2.x) / max(p1.x, p2.x)) | ||||||
y = 1 / (1 + np.exp(-(p1.y / p2.y))) | ||||||
|
||||||
return (1 - x) ** 2 + 100 * (y - x**2) ** 2 | ||||||
|
||||||
|
||||||
x0 = (Params(x=0.9, y=1.2), Params(x=0.8, y=1.3)) | ||||||
bestfit_params = minimize(rosenbrock_modified, x0) | ||||||
print(bestfit_params) | ||||||
# >> ( | ||||||
# Params(x=np.float64(4.686181110201706), y=np.float64(0.05129869722505759)), | ||||||
# Params(x=np.float64(3.9432263101976073), y=np.float64(0.005146110126174016)), | ||||||
# ) | ||||||
``` | ||||||
|
||||||
The new `minimize` still works, because a `tuple` of `Params` is just _another_ PyTree! | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm looking at this example, trying to imagine where a person would use such a thing. Not because I doubt it's useful, but because my experience is limited. A few more examples (not inline code examples, just a sentence or two listing them) may be helpful here. Like, the generalization above required significant changes to the rosenbrock function itself, so it is unclear what the benefit is to not having to modify the minimizer at the same time. But it hints that there may be families of problems that can more easily be solved. And that ties into the question:
What would we need to do to make that happen, and under which circumstances would that be useful? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think hierarchical models are something where this comes in very handy. It's not uncommon to fit a superposition of multiple functions (or pdfs, e.g. multiple gaussians) to data. Then it's nice to group those parameters semantically, instead of passing a flat list of floats / array of floats, where you have to remember what each value stands for.
I'm not suggesting we should change APIs in scipy and co, it's rather that PyTrees (with tools like Does that make sense? (Do you have a proposal to write this more clearly here?) |
||||||
|
||||||
### Final Thought | ||||||
|
||||||
Working with nested data structures doesn’t have to be messy. | ||||||
PyTrees let you focus on the data and the transformations you want to apply, in a generic manner. | ||||||
Whether you're building neural networks, optimizing scientific models, or just dealing with complex nested Python containers, PyTrees can make your code cleaner, more flexible, and just nicer to work with. |
Uh oh!
There was an error while loading. Please reload this page.