Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Shape, Reshape, and invalid rust variable names. #2939

Closed
wants to merge 3 commits into from

Conversation

Knight-Ops
Copy link
Contributor

Pull Request Template

Checklist

  • Confirmed that run-checks all script has been executed.
  • Made sure the book is up to date with changes in this PR.

Related Issues/PRs

#2930

Changes

Shape

  • shape_update_outputs : Changed the output type to a Tensor as per the Onnx specification. It always outputs a Tensor, not a Shape.

Reshape

  • Added improved error reporting. Its really hard to tell where errors are coming from when doing the build, so I made sure everything had Reshape: for some extra information
  • reshape_config : Removed check on input length and input value. The input.value is None, because the input is a Tensor, you can't use value you have to check the Tensor.shape()
  • shape_update_outputs : Fixed output Tensor elem_type and ensured that shape was a valid Vec, this is again done by not trying to get the value, but instead looking at the shape on the Tensor.

format_name

  • Simple change, just replace characters that are invalid Rust variable names with underscores. This could possibly be done with a regex, but I don't know that it is entirely necessary right now. I just replaced the characters I ran into issues with while testing.

Testing

Run cargo test on burn-import crate with all passing. Also using file referenced in #2930 as a test file.

@Knight-Ops
Copy link
Contributor Author

Knight-Ops commented Mar 21, 2025

Not sure what really to add here, I added back in some of the old code as a first approach specific on Reshape the model referenced in #2930 didn't like how Reshape was implemented (Value never existed). But I have a TF exported model that had a regression on it due to these changes, I believe it may be due to constant folding for the Reshape, but I don't have 100% confirmation on that.

My best guess here is that the original code worked for Constant shapes, where-as in theory you can have a Tensor calculated at runtime based on my understanding. We should check first for constants and then fall back to dynamic tensors.

Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a comment regarding the ArgType::Shape.

This is part of the intermediate representation between ONNX <> Burn. In burn shapes are not actually tensors, but arrays. I believe this was the initial motivation behind the variants for ArgType.

ONNX uses tensor types for pretty much all values, but the reality is a bit more complex. Even scalars are simply represented as 0-dim tensors.

So the output of shape should remain ArgType::Shape to maintain compatibility with the current codegen.

I haven't looked at the underlying cause in the linked issue, but from the error it looks like there is a node that should handle the ArgType::Shape.

@laggui laggui closed this Mar 31, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants