-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Apply new OOP pattern compatible with jax transformations #48
Conversation
1b101d8
to
eeec516
Compare
I got back working on this PR. The new OOP decorator works already pretty good for r/o methods, however the r/w logic miserably fails. Our OOP strategy for the high-level resources (
The main problem of this first version of the design is that a jit-compiled -decorated- r/w method of links and joints (like When this logic runs, the I've been experimenting with few solutions similar to those proposed in jax-ml/jax#7919 and jax-ml/jax#17341, but the logic seems complicated to maintain and quite error prone. In a context of data sharing (that could be state like in this case, but also more generic r/w parameters), probably the most simple solution is not to allow any r/w methods on child classes like Note that we have a similar relationship also between |
872d4e7
to
74bd6fa
Compare
Otherwise the pytree structure changes after the first applied jax transformation
Static arguments must be hashable, therefore lists cannot be passed
The link force can either override or be summed with previously set forces. The default behavior is to sum it.
Happy to help adapting my class as shown in jax-ml/jax#7919 (comment) |
74bd6fa
to
7384d41
Compare
Thank you @samskiter for chiming in! I'm definitely interested in your approach using Luckily, we only have two of such methods, and I moved them for now to the parent class (5a975d4). I keep in mind to try playing around with your solution, it could be an excellent companion to our |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Afraid not open source, but I'm happy for you to use ReffableTreeNode - the toy usage in my other comment still stands. Essentially this stores a UUID of every item in the tree structure. There are 4 methods each class must implement:
You don't have to implement any recursion yourself - the super class will call The only change from the toy usage code is to store the UUIDs in the aux_data:
|
Yeah I didn't find anything in your repos and I assumed still being closed source. I opened #52 for collecting resources and attempts we might perform in the future. For now, thanks a lot for the details! |
This PR applies the resources of #44 on the jaxsim classes, solving #43 that is the original problem that triggered the development of this new approach.
Beyond fixing trace leaks, this PR:
jax.jit
andjax.vmap
.jax.numpy
arrays, even when they are scalar quantities (in this case they are 0-dimensional arrays).Mutability
from the methods, now the mutable/frozen context is enforced in the new decorators.The key resources of this pattern are new method decorators and a
Vmappable
inheritance (introduced in #44). There are, however, some caveats to consider:str
) cannot be jit compiled nor vectorized.None
and the default value has to be configured inside the method.