Skip to content

A new blog post about Pytrees for Scientific Python #250

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 21 commits into
base: main
Choose a base branch
from

Conversation

pfackeldey
Copy link
Member

@pfackeldey pfackeldey commented May 14, 2025

  • The main subject relates to at least one project affiliated to the Scientific Python Ecosystem.
  • I have the right to publish the content under BSD 3-Clause License for the code and Creative Common CC-BY-4.0 License for the text.
  • Images have been compressed using a tool like pngquant.

This is a new blog post about "PyTrees for Scientific Python" by @mihaimaruseac @matthewfeickert and me. It explains what PyTrees are, where they come from, and how they can be useful for scientific Python.

It's currently in draft mode because we're missing a last section about "final thoughts" (pinging @matthewfeickert for this 🙏).

@pfackeldey
Copy link
Member Author

I've added a short "final thought" section myself - so I'm marking this blogpost as ready for review.
Let me know what you think!

@pfackeldey pfackeldey marked this pull request as ready for review May 18, 2025 20:07
@matthewfeickert
Copy link
Member

Thanks @pfackeldey. 🚀 Currently in transit to Houston so I will review tonight once I land, but as we discussed together I think this is all fine at the high level and any comments I have will be nitpicks.

Copy link
Member

@matthewfeickert matthewfeickert left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pfackeldey looks really good to me. As I said, the only thing I have to suggest is nitpicky typo fixes and possible rephrasing, but I think that the examples are very strong. Thank you for writing this!

@pfackeldey
Copy link
Member Author

Thanks @matthewfeickert , you're comments are great, I accepted all of them :)

Copy link
Member

@rossbar rossbar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice writeup @pfackeldey ! The post flows beautifully and IMO does a great job both motivating and demonstrating what PyTrees has to offer. A few grammatical nits that are by no means blockers; thanks for writing this up!

@stefanv
Copy link
Member

stefanv commented May 20, 2025

I'll review also, and then we can hopefully get this merged ASAP.

Copy link
Member

@stefanv stefanv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @pfackeldey! Feel free to do with the comments what you want; let me know when you are ready to merge and I'll push the button.

## Manipulating Tree-like Data using Functional Programming Paradigms

A "PyTree" is a nested collection of Python containers (e.g. dicts, (named) tuples, lists, ...), where the leafs are of interest.
As you can imagine (or even experienced in the past), such arbitrary nested collections can be cumbersome to manipulate _efficiently_.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This motivation could be fleshed out a bit. I've never had to manipulate trees of dicts, but if I had to do so I'd probably pop them into a NetworkX graph. So, perhaps helpful to tell me what the types of issues are I might have run into, rather than expect me to know it already?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I've not thought about using NetworkX for such cases. This could probably be a solution(?), but the difference here is that here we're not interested in the graph properties. We just want to access the leaves - no matter the structure of the tree - and manipulate / work with them. I would not know how to do this with NetworkX (but I'm also not very familiar with the library).
PyTrees are mainly about separating the leaves from a generic structure. The generic structure itself is not of interest.

So, perhaps helpful to tell me what the types of issues are I might have run into, rather than expect me to know it already?

That's a good point! I'll try to rephrase this.

From what I've seen in code snippets written by scientists, it's common to collect/track arrays/lists of numerical data in python containers (dicts, tuples, lists), e.g. if they measure multiple quantities in multiple experiment runs (this may become nested and complex fast).

Copy link
Member

@stefanv stefanv May 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I've not thought about using NetworkX for such cases. This could probably be a solution(?), but the difference here is that here we're not interested in the graph properties. We just want to access the leaves - no matter the structure of the tree - and manipulate / work with them. I would not know how to do this with NetworkX (but I'm also not very familiar with the library).

@dschult Can NetworkX efficiently grab leaf nodes and do operations on them? And if not, why not ;)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To grab the nodes of a directed network in NetworkX use:
[node for node, deg in G.out_degree if deg == 0]

I think the tricky part for NetworkX would be handling unhashable objects. It might be possible to identify each container with an integer index to a list... It essentially flattens the PyTree while holding the relationships between containers in the networkx graph. From the list of tree containers along with a NetworkX graph, you could form a list of the leaf node indices as above, and then pass through the leaf level containers to manipulate the data inside each. If there are both values to be manipulated and containers in the same container it might be good to use two networkx graphs -- one with the container-container relationships and the other with container-value relationships. But I haven't thought much about how to identify containers vs values, etc.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation! I think that helps me understand even better why PyTrees can be useful.

result_pytree = pt.tree_unflatten(treedef, new_leafs)
```

Here, we use [`optree`](https://github.com/metaopt/optree/tree/main/optree) — a standalone PyTree library — that enables all these manipulations. It focuses on performance, is feature rich, has minimal dependencies, and has been adopted by [PyTorch](https://pytorch.org), [Keras](https://keras.io), and [TensorFlow](https://github.com/tensorflow/tensorflow) (through Keras) as a core dependency.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should optree be mentioned earlier, since it has already been used in the code?

I usually get stuck when I read something I don't know or understand, since I'm trying to figure out what I missed. import optree as pt 🤔

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, if optree gives you all this, what is pytree all about?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll see if there's a good place to mention optree earlier 👍

optree is the library whereas PyTree ist the concept. A similar comparison would be NumPy <-> Array.

# )
```

The new `minimize` still works, because a `tuple` of `Params` is just _another_ PyTree!
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm looking at this example, trying to imagine where a person would use such a thing. Not because I doubt it's useful, but because my experience is limited. A few more examples (not inline code examples, just a sentence or two listing them) may be helpful here.

Like, the generalization above required significant changes to the rosenbrock function itself, so it is unclear what the benefit is to not having to modify the minimizer at the same time. But it hints that there may be families of problems that can more easily be solved.

And that ties into the question:

Wouldn't it be nice to make workflows in the scientific Python ecosystem just work with any PyTree?

What would we need to do to make that happen, and under which circumstances would that be useful?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think hierarchical models are something where this comes in very handy. It's not uncommon to fit a superposition of multiple functions (or pdfs, e.g. multiple gaussians) to data. Then it's nice to group those parameters semantically, instead of passing a flat list of floats / array of floats, where you have to remember what each value stands for.

What would we need to do to make that happen, and under which circumstances would that be useful?

I'm not suggesting we should change APIs in scipy and co, it's rather that PyTrees (with tools like optree) enable to make APIs like scipy.optimize.minimize work generically with any data structure (not just a flat list/array of floats) without too much code changes. It's something that users can opt-in to make their code more expressive as they're the ones that define the "model", here the rosenbrock_modified function.

Does that make sense? (Do you have a proposal to write this more clearly here?)

# >> sqrt_pytree=([[{'foo': array([2.])}], array([3.])],)

# reductions
all_positive = pt.tree_all(pt.tree_map(lambda x: x > 0.0, pytree))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
all_positive = pt.tree_all(pt.tree_map(lambda x: x > 0.0, pytree))
all_positive = all(x > 0.0 for x in pt.tree_iter(pytree))

import numpy as np

# tuple of a list of a dict with an array as value, and an array
pytree = ([[{"foo": np.array([4.0])}], np.array([9.0])],)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I suggest using tree over pytree as variable names.

This allows for the application of generic transformations. For example, on a PyTree with NumPy arrays as leaves, taking the square root of each leaf with `tree_map(np.sqrt, pytree)`:

```python
import optree as pt

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, optree 0.16.0 is released. I wonder if it is worth mentioning optree.pytree.reexeport(...).

>>> import optree
>>> pytree = optree.pytree.reexport(namespace='my-pkg', module='my_pkg.pytree')
>>> pytree.flatten({'a': 1, 'b': 2})
([1, 2], PyTreeSpec({'a': *, 'b': *}))
# foo/__init__.py
import optree
pytree = optree.pytree.reexport(namespace='foo')

# foo/bar.py
from foo import pytree

@pytree.dataclasses.dataclass
class Bar:
    a: int
    b: float

print(pytree.flatten({'a': 1, 'b': 2, 'c': Bar(3, 4.0)}))
# Output:
#   ([1, 2, 3, 4.0], PyTreeSpec({'a': *, 'b': *, 'c': CustomTreeNode(Bar[()], [*, *])}, namespace='foo'))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants