Skip to content

A new blog post about Pytrees for Scientific Python #250

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 210 additions & 0 deletions content/posts/optree/pytrees/index.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
---
title: "Pytrees for Scientific Python"
date: 2025-05-14T10:27:59-07:00
draft: false
description: "
Introducing PyTrees for Scientific Python. We discuss what PyTrees are, how they're useful in the realm of scientific Python, and how to work _efficiently_ with them.
"
tags: ["PyTrees", "Functional Programming", "Tree-like data manipulation"]
displayInList: true
author: ["Peter Fackeldey", "Mihai Maruseac", "Matthew Feickert"]
summary: |
This blog introduces PyTrees — nested Python data structures (such as lists, dicts, and tuples) with numerical leaf values — designed to simplify working with complex, hierarchically organized data.
While such structures are often cumbersome to manipulate, PyTrees make them more manageable by allowing them to be flattened into a list of leaves along with a reusable structure blueprint in a _generic_ way.
This enables flexible, generic operations like mapping and reducing from functional programming.
By bringing those functional paradigms to structured data, PyTrees let you focus on what transformations to apply, not how to traverse the structure — no matter how deeply nested or complex it is.
---

## Manipulating Tree-like Data using Functional Programming Paradigms

A "PyTree" is a nested collection of Python containers (e.g. dicts, (named) tuples, lists, ...), where the leafs are of interest.
As you can imagine (or even experienced in the past), such arbitrary nested collections can be cumbersome to manipulate _efficiently_.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This motivation could be fleshed out a bit. I've never had to manipulate trees of dicts, but if I had to do so I'd probably pop them into a NetworkX graph. So, perhaps helpful to tell me what the types of issues are I might have run into, rather than expect me to know it already?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I've not thought about using NetworkX for such cases. This could probably be a solution(?), but the difference here is that here we're not interested in the graph properties. We just want to access the leaves - no matter the structure of the tree - and manipulate / work with them. I would not know how to do this with NetworkX (but I'm also not very familiar with the library).
PyTrees are mainly about separating the leaves from a generic structure. The generic structure itself is not of interest.

So, perhaps helpful to tell me what the types of issues are I might have run into, rather than expect me to know it already?

That's a good point! I'll try to rephrase this.

From what I've seen in code snippets written by scientists, it's common to collect/track arrays/lists of numerical data in python containers (dicts, tuples, lists), e.g. if they measure multiple quantities in multiple experiment runs (this may become nested and complex fast).

Copy link
Member

@stefanv stefanv May 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I've not thought about using NetworkX for such cases. This could probably be a solution(?), but the difference here is that here we're not interested in the graph properties. We just want to access the leaves - no matter the structure of the tree - and manipulate / work with them. I would not know how to do this with NetworkX (but I'm also not very familiar with the library).

@dschult Can NetworkX efficiently grab leaf nodes and do operations on them? And if not, why not ;)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To grab the nodes of a directed network in NetworkX use:
[node for node, deg in G.out_degree if deg == 0]

I think the tricky part for NetworkX would be handling unhashable objects. It might be possible to identify each container with an integer index to a list... It essentially flattens the PyTree while holding the relationships between containers in the networkx graph. From the list of tree containers along with a NetworkX graph, you could form a list of the leaf node indices as above, and then pass through the leaf level containers to manipulate the data inside each. If there are both values to be manipulated and containers in the same container it might be good to use two networkx graphs -- one with the container-container relationships and the other with container-value relationships. But I haven't thought much about how to identify containers vs values, etc.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation! I think that helps me understand even better why PyTrees can be useful.

It often requires complex recursive logic which usually does not generalize to other nested Python containers (PyTrees).

The core concept of PyTrees is being able to flatten them into a flat collection of leafs and a "blueprint" of the tree structure, and then being able to unflatten them back into the original PyTree.
This allows to apply generic transformations, e.g. taking the square root of each leaf of a PyTree with a `tree_map(np.sqrt, pytree)` operation:

```python
import optree as pt

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, optree 0.16.0 is released. I wonder if it is worth mentioning optree.pytree.reexeport(...).

>>> import optree
>>> pytree = optree.pytree.reexport(namespace='my-pkg', module='my_pkg.pytree')
>>> pytree.flatten({'a': 1, 'b': 2})
([1, 2], PyTreeSpec({'a': *, 'b': *}))
# foo/__init__.py
import optree
pytree = optree.pytree.reexport(namespace='foo')

# foo/bar.py
from foo import pytree

@pytree.dataclasses.dataclass
class Bar:
    a: int
    b: float

print(pytree.flatten({'a': 1, 'b': 2, 'c': Bar(3, 4.0)}))
# Output:
#   ([1, 2, 3, 4.0], PyTreeSpec({'a': *, 'b': *, 'c': CustomTreeNode(Bar[()], [*, *])}, namespace='foo'))

import numpy as np

# tuple of a list of a dict with an array as value, and an array
pytree = ([[{"foo": np.array([4.0])}], np.array([9.0])],)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I suggest using tree over pytree as variable names.


# sqrt of each leaf array
sqrt_pytree = pt.tree_map(np.sqrt, pytree)
print(f"{sqrt_pytree=}")
# >> sqrt_pytree=([[{'foo': array([2.])}], array([3.])],)

# reductions
all_positive = pt.tree_all(pt.tree_map(lambda x: x > 0.0, pytree))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
all_positive = pt.tree_all(pt.tree_map(lambda x: x > 0.0, pytree))
all_positive = all(x > 0.0 for x in pt.tree_iter(pytree))

print(f"{all_positive=}")
# >> all_positive=True

summed = pt.tree_reduce(sum, pytree)
print(f"{summed=}")
# >> summed=array([13.])
```

The trick here is that these operations can be implemented in three steps, e.g. `tree_map`:

```python
# step 1:
leafs, treedef = pt.tree_flatten(pytree)

# step 2:
new_leafs = tuple(map(fun, leafs))

# step 3:
result_pytree = pt.tree_unflatten(treedef, new_leafs)
```

Here, we use [`optree`](https://github.com/metaopt/optree/tree/main/optree) — a standalone PyTree library — that enables all these manipulations. It focuses on performance, is feature rich, has minimal dependencies, and has been adopted by [PyTorch](https://pytorch.org), [Keras](https://keras.io), and [TensorFlow](https://github.com/tensorflow/tensorflow) (through Keras) as a core dependency.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should optree be mentioned earlier, since it has already been used in the code?

I usually get stuck when I read something I don't know or understand, since I'm trying to figure out what I missed. import optree as pt 🤔

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, if optree gives you all this, what is pytree all about?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll see if there's a good place to mention optree earlier 👍

optree is the library whereas PyTree ist the concept. A similar comparison would be NumPy <-> Array.


### PyTree Origins

Originally, the concept of PyTrees was developed by the [JAX](https://docs.jax.dev/en/latest/) project to make nested collections of JAX arrays work transparently at the "JIT-boundary" (the JAX JIT toolchain does not know about Python containers, only about JAX Arrays).
However, PyTrees were quickly adopted by AI researchers for broader use-cases: semantically grouping layers of weights and biases in a list of named tuples (or dictionaries) is a common pattern in the JAX-AI-world, see the following (pseudo) Python snippet:

```python
from typing import NamedTuple, Callable
import jax
import jax.numpy as jnp


class Layer(NamedTuple):
W: jax.Array
b: jax.Array


layers = [
Layer(W=jnp.array(...), b=jnp.array(...)), # first layer
Layer(W=jnp.array(...), b=jnp.array(...)), # second layer
...,
]


@jax.jit
def neural_network(layers: list[Layer], x: jax.Array) -> jax.Array:
for layer in layers:
x = jnp.tanh(layer.W @ x + layer.b)
return x


prediction = neural_network(layers=layers, x=jnp.array(...))
```

Here, `layers` is a PyTree &mdash; a `list` of multiple `Layer` &mdash; and the JIT compiled `neural_network` function _just works_ with this data structure as input.

### PyTrees in Scientific Python

Wouldn't it be nice to make workflows in the scientific Python ecosystem _just work_ with any PyTree?

Giving semantic meaning to numeric data through PyTrees can be useful for applications outside of AI as well.
Consider the following minimization of the [Rosenbrock](https://en.wikipedia.org/wiki/Rosenbrock_function) function:

```python
from scipy.optimize import minimize


def rosenbrock(params: tuple[float]) -> float:
"""
Rosenbrock function. Minimum: f(1, 1) = 0.

https://en.wikipedia.org/wiki/Rosenbrock_function
"""
x, y = params
return (1 - x) ** 2 + 100 * (y - x**2) ** 2


x0 = (0.9, 1.2)
res = minimize(rosenbrock, x0)
print(res.x)
# >> [0.99999569 0.99999137]
```

Now, let's consider a minimization that uses a more complex type for the parameters &mdash; a NamedTuple that describes our fit parameters:

```python
import optree as pt
from typing import NamedTuple, Callable
from scipy.optimize import minimize as sp_minimize


class Params(NamedTuple):
x: float
y: float


def rosenbrock(params: Params) -> float:
"""
Rosenbrock function. Minimum: f(1, 1) = 0.

https://en.wikipedia.org/wiki/Rosenbrock_function
"""
return (1 - params.x) ** 2 + 100 * (params.y - params.x**2) ** 2


def minimize(fun: Callable, params: Params) -> Params:
# flatten and store PyTree definition
flat_params, treedef = pt.tree_flatten(params)

# wrap fun to work with flat_params
def wrapped_fun(flat_params):
params = pt.tree_unflatten(treedef, flat_params)
return fun(params)

# actual minimization
res = sp_minimize(wrapped_fun, flat_params)

# re-wrap the bestfit values into Params with stored PyTree definition
return pt.tree_unflatten(treedef, res.x)


# scipy minimize that works with any PyTree
x0 = Params(x=0.9, y=1.2)
bestfit_params = minimize(rosenbrock, x0)
print(bestfit_params)
# >> Params(x=np.float64(0.999995688776513), y=np.float64(0.9999913673387226))
```

This new `minimize` function works with _any_ PyTree!

Let's now consider a modified and more complex version of the Rosenbrock function that relies on two sets of `Params` as input &mdash; a common pattern for hierarchical models:

```python
import numpy as np


def rosenbrock_modified(two_params: tuple[Params, Params]) -> float:
"""
Modified Rosenbrock where the x and y parameters are determined by
a non-linear transformations of two versions of each, i.e.:
x = arcsin(min(x1, x2) / max(x1, x2))
y = sigmoid(x1 - x2)
"""
p1, p2 = two_params

# calculate `x` and `y` from two sources:
x = np.asin(min(p1.x, p2.x) / max(p1.x, p2.x))
y = 1 / (1 + np.exp(-(p1.y / p2.y)))

return (1 - x) ** 2 + 100 * (y - x**2) ** 2


x0 = (Params(x=0.9, y=1.2), Params(x=0.8, y=1.3))
bestfit_params = minimize(rosenbrock_modified, x0)
print(bestfit_params)
# >> (
# Params(x=np.float64(4.686181110201706), y=np.float64(0.05129869722505759)),
# Params(x=np.float64(3.9432263101976073), y=np.float64(0.005146110126174016)),
# )
```

The new `minimize` still works, because a `tuple` of `Params` is just _another_ PyTree!
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm looking at this example, trying to imagine where a person would use such a thing. Not because I doubt it's useful, but because my experience is limited. A few more examples (not inline code examples, just a sentence or two listing them) may be helpful here.

Like, the generalization above required significant changes to the rosenbrock function itself, so it is unclear what the benefit is to not having to modify the minimizer at the same time. But it hints that there may be families of problems that can more easily be solved.

And that ties into the question:

Wouldn't it be nice to make workflows in the scientific Python ecosystem just work with any PyTree?

What would we need to do to make that happen, and under which circumstances would that be useful?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think hierarchical models are something where this comes in very handy. It's not uncommon to fit a superposition of multiple functions (or pdfs, e.g. multiple gaussians) to data. Then it's nice to group those parameters semantically, instead of passing a flat list of floats / array of floats, where you have to remember what each value stands for.

What would we need to do to make that happen, and under which circumstances would that be useful?

I'm not suggesting we should change APIs in scipy and co, it's rather that PyTrees (with tools like optree) enable to make APIs like scipy.optimize.minimize work generically with any data structure (not just a flat list/array of floats) without too much code changes. It's something that users can opt-in to make their code more expressive as they're the ones that define the "model", here the rosenbrock_modified function.

Does that make sense? (Do you have a proposal to write this more clearly here?)


### Final Thought

Working with nested data structures doesn’t have to be messy.
PyTrees let you focus on the data and the transformations you want to apply, in a generic manner.
Whether you're building neural networks, optimizing scientific models, or just dealing with complex nested Python containers, PyTrees can make your code cleaner, more flexible, and just nicer to work with.