Skip to content

A new blog post about Pytrees for Scientific Python #250

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 31 commits into
base: main
Choose a base branch
from

Conversation

pfackeldey
Copy link
Member

@pfackeldey pfackeldey commented May 14, 2025

  • The main subject relates to at least one project affiliated to the Scientific Python Ecosystem.
  • I have the right to publish the content under BSD 3-Clause License for the code and Creative Common CC-BY-4.0 License for the text.
  • Images have been compressed using a tool like pngquant.

This is a new blog post about "PyTrees for Scientific Python" by @mihaimaruseac @matthewfeickert and me. It explains what PyTrees are, where they come from, and how they can be useful for scientific Python.

It's currently in draft mode because we're missing a last section about "final thoughts" (pinging @matthewfeickert for this 🙏).

@pfackeldey
Copy link
Member Author

I've added a short "final thought" section myself - so I'm marking this blogpost as ready for review.
Let me know what you think!

@pfackeldey pfackeldey marked this pull request as ready for review May 18, 2025 20:07
@matthewfeickert
Copy link
Member

Thanks @pfackeldey. 🚀 Currently in transit to Houston so I will review tonight once I land, but as we discussed together I think this is all fine at the high level and any comments I have will be nitpicks.

Copy link
Member

@matthewfeickert matthewfeickert left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pfackeldey looks really good to me. As I said, the only thing I have to suggest is nitpicky typo fixes and possible rephrasing, but I think that the examples are very strong. Thank you for writing this!

@pfackeldey
Copy link
Member Author

Thanks @matthewfeickert , you're comments are great, I accepted all of them :)

Copy link
Member

@rossbar rossbar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice writeup @pfackeldey ! The post flows beautifully and IMO does a great job both motivating and demonstrating what PyTrees has to offer. A few grammatical nits that are by no means blockers; thanks for writing this up!

@stefanv
Copy link
Member

stefanv commented May 20, 2025

I'll review also, and then we can hopefully get this merged ASAP.

Copy link
Member

@stefanv stefanv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @pfackeldey! Feel free to do with the comments what you want; let me know when you are ready to merge and I'll push the button.

## Manipulating Tree-like Data using Functional Programming Paradigms

A "PyTree" is a nested collection of Python containers (e.g. dicts, (named) tuples, lists, ...), where the leafs are of interest.
As you can imagine (or even experienced in the past), such arbitrary nested collections can be cumbersome to manipulate _efficiently_.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This motivation could be fleshed out a bit. I've never had to manipulate trees of dicts, but if I had to do so I'd probably pop them into a NetworkX graph. So, perhaps helpful to tell me what the types of issues are I might have run into, rather than expect me to know it already?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I've not thought about using NetworkX for such cases. This could probably be a solution(?), but the difference here is that here we're not interested in the graph properties. We just want to access the leaves - no matter the structure of the tree - and manipulate / work with them. I would not know how to do this with NetworkX (but I'm also not very familiar with the library).
PyTrees are mainly about separating the leaves from a generic structure. The generic structure itself is not of interest.

So, perhaps helpful to tell me what the types of issues are I might have run into, rather than expect me to know it already?

That's a good point! I'll try to rephrase this.

From what I've seen in code snippets written by scientists, it's common to collect/track arrays/lists of numerical data in python containers (dicts, tuples, lists), e.g. if they measure multiple quantities in multiple experiment runs (this may become nested and complex fast).

Copy link
Member

@stefanv stefanv May 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I've not thought about using NetworkX for such cases. This could probably be a solution(?), but the difference here is that here we're not interested in the graph properties. We just want to access the leaves - no matter the structure of the tree - and manipulate / work with them. I would not know how to do this with NetworkX (but I'm also not very familiar with the library).

@dschult Can NetworkX efficiently grab leaf nodes and do operations on them? And if not, why not ;)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To grab the nodes of a directed network in NetworkX use:
[node for node, deg in G.out_degree if deg == 0]

I think the tricky part for NetworkX would be handling unhashable objects. It might be possible to identify each container with an integer index to a list... It essentially flattens the PyTree while holding the relationships between containers in the networkx graph. From the list of tree containers along with a NetworkX graph, you could form a list of the leaf node indices as above, and then pass through the leaf level containers to manipulate the data inside each. If there are both values to be manipulated and containers in the same container it might be good to use two networkx graphs -- one with the container-container relationships and the other with container-value relationships. But I haven't thought much about how to identify containers vs values, etc.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation! I think that helps me understand even better why PyTrees can be useful.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @stefanv,
I've improved the motivation in 9bde52e. I hope this motivates the use-cases for pytrees more :)

# )
```

The new `minimize` still works, because a `tuple` of `Params` is just _another_ PyTree!
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm looking at this example, trying to imagine where a person would use such a thing. Not because I doubt it's useful, but because my experience is limited. A few more examples (not inline code examples, just a sentence or two listing them) may be helpful here.

Like, the generalization above required significant changes to the rosenbrock function itself, so it is unclear what the benefit is to not having to modify the minimizer at the same time. But it hints that there may be families of problems that can more easily be solved.

And that ties into the question:

Wouldn't it be nice to make workflows in the scientific Python ecosystem just work with any PyTree?

What would we need to do to make that happen, and under which circumstances would that be useful?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think hierarchical models are something where this comes in very handy. It's not uncommon to fit a superposition of multiple functions (or pdfs, e.g. multiple gaussians) to data. Then it's nice to group those parameters semantically, instead of passing a flat list of floats / array of floats, where you have to remember what each value stands for.

What would we need to do to make that happen, and under which circumstances would that be useful?

I'm not suggesting we should change APIs in scipy and co, it's rather that PyTrees (with tools like optree) enable to make APIs like scipy.optimize.minimize work generically with any data structure (not just a flat list/array of floats) without too much code changes. It's something that users can opt-in to make their code more expressive as they're the ones that define the "model", here the rosenbrock_modified function.

Does that make sense? (Do you have a proposal to write this more clearly here?)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't sure how I could be more precise about hierarchical models, so I've added an example in the text that I think is a common in scientific research: e07a2ac

This allows for the application of generic transformations. For example, on a PyTree with NumPy arrays as leaves, taking the square root of each leaf with `tree_map(np.sqrt, pytree)`:

```python
import optree as pt

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, optree 0.16.0 is released. I wonder if it is worth mentioning optree.pytree.reexeport(...).

>>> import optree
>>> pytree = optree.pytree.reexport(namespace='my-pkg', module='my_pkg.pytree')
>>> pytree.flatten({'a': 1, 'b': 2})
([1, 2], PyTreeSpec({'a': *, 'b': *}))
# foo/__init__.py
import optree
pytree = optree.pytree.reexport(namespace='foo')

# foo/bar.py
from foo import pytree

@pytree.dataclasses.dataclass
class Bar:
    a: int
    b: float

print(pytree.flatten({'a': 1, 'b': 2, 'c': Bar(3, 4.0)}))
# Output:
#   ([1, 2, 3, 4.0], PyTreeSpec({'a': *, 'b': *, 'c': CustomTreeNode(Bar[()], [*, *])}, namespace='foo'))

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's an awesome feature (which I'll certainly use for my libraries!).
I think this blog post is more targeted towards users that may want to try optree directly, and less towards library developers that may want to reexport with additional functionality. That's why I'm leaning towards not adding it to the blog posts. Does that make sense to you @XuehaiPan ?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It makes sense.

@pfackeldey
Copy link
Member Author

Dear @stefanv, @rossbar, @XuehaiPan, @matthewfeickert, @mihaimaruseac,

thank you very much again for reviewing this blog post!

I've addressed the remaining comments now, could you give this blog post another look?
From my side this PR is now ready :)

Thanks, Peter

Copy link
Member

@stefanv stefanv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the edits, Peter. I'm happy to merge this whenever you are ready. I highlighted two points of confusion for me, but happy to let you decide whether you want to handle those or not.

prediction = neural_network(layers=layers, x=jnp.array(...))
```

Here, `layers` is a PyTree — a `list` of multiple `Layer` — and the JIT compiled `neural_network` function _just works_ with this data structure as input.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not seem to be an example of a nested structure, just the final layer of a tree. Wouldn't it be pretty straightforward to handle a list using regular methods?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a ('3-level') nested datastructure in the sense: layers is a list (1st level) of Layers (2nd level) of W (3rd level [leaf]) and b (3rd level [leaf]). Here, one could also create a list for weights and one for biases and make a zip-loop over both, but then weights and biases are not logically grouped anymore in a Layer object.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants