Skip to content

A new blog post about Pytrees for Scientific Python #250

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 32 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
4d71fbe
wip
pfackeldey May 14, 2025
dd6c7ca
first complete version
pfackeldey May 14, 2025
e4804e1
first round of review improvements
pfackeldey May 14, 2025
d71c8e8
add final thought section
pfackeldey May 18, 2025
653296b
improve code snippet for modified rosenbrock
pfackeldey May 18, 2025
628d645
remove obsolete comment
pfackeldey May 18, 2025
7d6b287
Update content/posts/optree/pytrees/index.md
pfackeldey May 19, 2025
99fe18c
Update content/posts/optree/pytrees/index.md
pfackeldey May 19, 2025
cd443a3
Update content/posts/optree/pytrees/index.md
pfackeldey May 19, 2025
13d1700
Update content/posts/optree/pytrees/index.md
pfackeldey May 19, 2025
17f2f95
Update content/posts/optree/pytrees/index.md
pfackeldey May 19, 2025
b08dcb4
Update content/posts/optree/pytrees/index.md
pfackeldey May 19, 2025
76bbff4
Update content/posts/optree/pytrees/index.md
pfackeldey May 19, 2025
7a124b1
Update content/posts/optree/pytrees/index.md
pfackeldey May 19, 2025
19262dd
Update content/posts/optree/pytrees/index.md
pfackeldey May 20, 2025
bc98572
add summary
pfackeldey May 20, 2025
11fde89
Revert "add summary"
pfackeldey May 20, 2025
0555d93
add summary
pfackeldey May 20, 2025
4db8fba
Update content/posts/optree/pytrees/index.md
pfackeldey May 20, 2025
50a3b2e
Update content/posts/optree/pytrees/index.md
pfackeldey May 20, 2025
4dbf7fd
Update content/posts/optree/pytrees/index.md
pfackeldey May 21, 2025
ae7764e
Update content/posts/optree/pytrees/index.md
pfackeldey Jun 17, 2025
8c33d24
pytree -> tree
pfackeldey Jun 17, 2025
9bde52e
be more specific about the motivation for scientific data
pfackeldey Jun 17, 2025
5e59e37
mention optree earlier
pfackeldey Jun 17, 2025
e07a2ac
be a bit more specific about the use case with hierarchical models
pfackeldey Jun 17, 2025
89d464f
Merge branch 'main' into pytree
pfackeldey Jun 17, 2025
4c548bc
leafs -> leaves
pfackeldey Jun 18, 2025
254ad90
Update content/posts/optree/pytrees/index.md
pfackeldey Jul 7, 2025
b5acc95
be more explicit about reducing arrays with more than 1 element
pfackeldey Jul 7, 2025
4e14a11
add a sentence of how pytrees and jax.jit work together
pfackeldey Jul 7, 2025
aa851f0
clarify what 'compiler' is meant with
pfackeldey Jul 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 212 additions & 0 deletions content/posts/optree/pytrees/index.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
---
title: "Pytrees for Scientific Python"
date: 2025-05-14T10:27:59-07:00
draft: false
description: "
Introducing PyTrees for Scientific Python. We discuss what PyTrees are, how they're useful in the realm of scientific Python, and how to work _efficiently_ with them.
"
tags: ["PyTrees", "Functional Programming", "Tree-like data manipulation"]
displayInList: true
author: ["Peter Fackeldey", "Mihai Maruseac", "Matthew Feickert"]
summary: |
This blog introduces PyTrees — nested Python data structures (such as lists, dicts, and tuples) with numerical leaf values — designed to simplify working with complex, hierarchically organized data.
While such structures are often cumbersome to manipulate, PyTrees make them more manageable by allowing them to be flattened into a list of leaves along with a reusable structure blueprint in a _generic_ way.
This enables flexible, generic operations like mapping and reducing from functional programming.
By bringing those functional paradigms to structured data, PyTrees let you focus on what transformations to apply, not how to traverse the structure — no matter how deeply nested or complex it is.
---

## Manipulating Tree-like Data using Functional Programming Paradigms

A "PyTree" is a nested collection of Python containers (e.g. dicts, (named) tuples, lists, ...), where the leaves are of interest.
In the scientific world, such a PyTree could consist of experimental measurements of different properties at different timestamps and measurement settings resulting in a highly complex, nested and not necessarily rectangular data structure.
Such collections can be cumbersome to manipulate _efficiently_, especially if they are nested any depth.
It often requires complex recursive logic which usually does not generalize to other nested Python containers (PyTrees), e.g. for new measurements.

The core concept of PyTrees is being able to flatten them into a flat collection of leaves and a "blueprint" of the tree structure, and then being able to unflatten them back into the original PyTree.
This allows for the application of generic transformations.
In this blog post, we use [`optree`](https://github.com/metaopt/optree/tree/main/optree) — a standalone PyTree library — that enables these transformations. It focuses on performance, is feature rich, has minimal dependencies, and has been adopted by [PyTorch](https://pytorch.org), [Keras](https://keras.io), and [TensorFlow](https://github.com/tensorflow/tensorflow) (through Keras) as a core dependency.
For example, on a PyTree with NumPy arrays as leaves, taking the square root of each leaf with `optree.tree_map(np.sqrt, tree)`:

```python
import optree as pt

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, optree 0.16.0 is released. I wonder if it is worth mentioning optree.pytree.reexeport(...).

>>> import optree
>>> pytree = optree.pytree.reexport(namespace='my-pkg', module='my_pkg.pytree')
>>> pytree.flatten({'a': 1, 'b': 2})
([1, 2], PyTreeSpec({'a': *, 'b': *}))
# foo/__init__.py
import optree
pytree = optree.pytree.reexport(namespace='foo')

# foo/bar.py
from foo import pytree

@pytree.dataclasses.dataclass
class Bar:
    a: int
    b: float

print(pytree.flatten({'a': 1, 'b': 2, 'c': Bar(3, 4.0)}))
# Output:
#   ([1, 2, 3, 4.0], PyTreeSpec({'a': *, 'b': *, 'c': CustomTreeNode(Bar[()], [*, *])}, namespace='foo'))

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's an awesome feature (which I'll certainly use for my libraries!).
I think this blog post is more targeted towards users that may want to try optree directly, and less towards library developers that may want to reexport with additional functionality. That's why I'm leaning towards not adding it to the blog posts. Does that make sense to you @XuehaiPan ?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It makes sense.

import numpy as np

# tuple of a list of a dict with an array as value, and an array
tree = ([[{"foo": np.array([4.0])}], np.array([9.0])],)

# sqrt of each leaf array
sqrt_tree = pt.tree_map(np.sqrt, tree)
print(f"{sqrt_tree=}")
# >> sqrt_tree=([[{'foo': array([2.])}], array([3.])],)

# reductions
all_positive = all(np.all(x > 0.0) for x in pt.tree_iter(tree))
print(f"{all_positive=}")
# >> all_positive=True

summed = np.sum(pt.tree_reduce(sum, tree))
print(f"{summed=}")
# >> summed=np.float64(13.0)
```

The trick here is that these operations can be implemented in three steps, e.g. `tree_map`:

```python
# step 1:
leaves, treedef = pt.tree_flatten(tree)

# step 2:
new_leaves = tuple(map(fun, leaves))

# step 3:
result_tree = pt.tree_unflatten(treedef, new_leaves)
```

### PyTree Origins

Originally, the concept of PyTrees was developed by the [JAX](https://docs.jax.dev/en/latest/) project to make nested collections of JAX arrays work transparently at the "JIT-boundary" (the JAX JIT toolchain does not know about Python containers, only about JAX Arrays).
However, PyTrees were quickly adopted by AI researchers for broader use-cases: semantically grouping layers of weights and biases in a list of named tuples (or dictionaries) is a common pattern in the JAX-AI-world, as shown in the following (pseudo) Python snippet:

```python
from typing import NamedTuple, Callable
import jax
import jax.numpy as jnp


class Layer(NamedTuple):
W: jax.Array
b: jax.Array


layers = [
Layer(W=jnp.array(...), b=jnp.array(...)), # first layer
Layer(W=jnp.array(...), b=jnp.array(...)), # second layer
...,
]


@jax.jit
def neural_network(layers: list[Layer], x: jax.Array) -> jax.Array:
for layer in layers:
x = jnp.tanh(layer.W @ x + layer.b)
return x


prediction = neural_network(layers=layers, x=jnp.array(...))
```

Here, `layers` is a PyTree — a `list` of multiple `Layer` — and the JIT compiled `neural_network` function _just works_ with this data structure as input.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not seem to be an example of a nested structure, just the final layer of a tree. Wouldn't it be pretty straightforward to handle a list using regular methods?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a ('3-level') nested datastructure in the sense: layers is a list (1st level) of Layers (2nd level) of W (3rd level [leaf]) and b (3rd level [leaf]). Here, one could also create a list for weights and one for biases and make a zip-loop over both, but then weights and biases are not logically grouped anymore in a Layer object.

Although you cannot see what happens inside of `jax.jit`, `layers` is automatically flattened by the `jax.jit` decorator to a flat iterable of arrays, which are understood by the JAX JIT toolchain in contrast to a Python `list` of `NamedTuples`.

### PyTrees in Scientific Python

Wouldn't it be nice to make workflows in the scientific Python ecosystem _just work_ with any PyTree?

Giving semantic meaning to numeric data through PyTrees can be useful for applications outside of AI as well.
Consider the following minimization of the [Rosenbrock](https://en.wikipedia.org/wiki/Rosenbrock_function) function:

```python
from scipy.optimize import minimize


def rosenbrock(params: tuple[float]) -> float:
"""
Rosenbrock function. Minimum: f(1, 1) = 0.

https://en.wikipedia.org/wiki/Rosenbrock_function
"""
x, y = params
return (1 - x) ** 2 + 100 * (y - x**2) ** 2


x0 = (0.9, 1.2)
res = minimize(rosenbrock, x0)
print(res.x)
# >> [0.99999569 0.99999137]
```

Now, let's consider a minimization that uses a more complex type for the parameters — a NamedTuple that describes our fit parameters:

```python
import optree as pt
from typing import NamedTuple, Callable
from scipy.optimize import minimize as sp_minimize


class Params(NamedTuple):
x: float
y: float


def rosenbrock(params: Params) -> float:
"""
Rosenbrock function. Minimum: f(1, 1) = 0.

https://en.wikipedia.org/wiki/Rosenbrock_function
"""
return (1 - params.x) ** 2 + 100 * (params.y - params.x**2) ** 2


def minimize(fun: Callable, params: Params) -> Params:
# flatten and store PyTree definition
flat_params, treedef = pt.tree_flatten(params)

# wrap fun to work with flat_params
def wrapped_fun(flat_params):
params = pt.tree_unflatten(treedef, flat_params)
return fun(params)

# actual minimization
res = sp_minimize(wrapped_fun, flat_params)

# re-wrap the bestfit values into Params with stored PyTree definition
return pt.tree_unflatten(treedef, res.x)


# scipy minimize that works with any PyTree
x0 = Params(x=0.9, y=1.2)
bestfit_params = minimize(rosenbrock, x0)
print(bestfit_params)
# >> Params(x=np.float64(0.999995688776513), y=np.float64(0.9999913673387226))
```

This new `minimize` function works with _any_ PyTree!

Let's now consider a modified and more complex version of the Rosenbrock function that relies on two sets of `Params` as input — a common pattern for hierarchical models (e.g. a superposition of various probability density functions):

```python
import numpy as np


def rosenbrock_modified(two_params: tuple[Params, Params]) -> float:
"""
Modified Rosenbrock where the x and y parameters are determined by
a non-linear transformations of two versions of each, i.e.:
x = arcsin(min(x1, x2) / max(x1, x2))
y = sigmoid(x1 - x2)
"""
p1, p2 = two_params

# calculate `x` and `y` from two sources:
x = np.asin(min(p1.x, p2.x) / max(p1.x, p2.x))
y = 1 / (1 + np.exp(-(p1.y / p2.y)))

return (1 - x) ** 2 + 100 * (y - x**2) ** 2


x0 = (Params(x=0.9, y=1.2), Params(x=0.8, y=1.3))
bestfit_params = minimize(rosenbrock_modified, x0)
print(bestfit_params)
# >> (
# Params(x=np.float64(4.686181110201706), y=np.float64(0.05129869722505759)),
# Params(x=np.float64(3.9432263101976073), y=np.float64(0.005146110126174016)),
# )
```

The new `minimize` still works, because a `tuple` of `Params` is just _another_ PyTree!
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm looking at this example, trying to imagine where a person would use such a thing. Not because I doubt it's useful, but because my experience is limited. A few more examples (not inline code examples, just a sentence or two listing them) may be helpful here.

Like, the generalization above required significant changes to the rosenbrock function itself, so it is unclear what the benefit is to not having to modify the minimizer at the same time. But it hints that there may be families of problems that can more easily be solved.

And that ties into the question:

Wouldn't it be nice to make workflows in the scientific Python ecosystem just work with any PyTree?

What would we need to do to make that happen, and under which circumstances would that be useful?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think hierarchical models are something where this comes in very handy. It's not uncommon to fit a superposition of multiple functions (or pdfs, e.g. multiple gaussians) to data. Then it's nice to group those parameters semantically, instead of passing a flat list of floats / array of floats, where you have to remember what each value stands for.

What would we need to do to make that happen, and under which circumstances would that be useful?

I'm not suggesting we should change APIs in scipy and co, it's rather that PyTrees (with tools like optree) enable to make APIs like scipy.optimize.minimize work generically with any data structure (not just a flat list/array of floats) without too much code changes. It's something that users can opt-in to make their code more expressive as they're the ones that define the "model", here the rosenbrock_modified function.

Does that make sense? (Do you have a proposal to write this more clearly here?)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't sure how I could be more precise about hierarchical models, so I've added an example in the text that I think is a common in scientific research: e07a2ac


### Final Thought

Working with nested data structures doesn’t have to be messy.
PyTrees let you focus on the data and the transformations you want to apply, in a generic manner.
Whether you're building neural networks, optimizing scientific models, or just dealing with complex nested Python containers, PyTrees can make your code cleaner, more flexible, and just nicer to work with.