Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

An input to an ADD node keeps getting casted to float32 despite being float16 in the onnx file causing issues with the ADD op #681

Closed
AD-lite24 opened this issue Aug 26, 2024 · 17 comments
Labels
Dynamic batch / Dynamic shape Dynamic batch / Dynamic shape third party Third-party tool issues

Comments

@AD-lite24
Copy link

Issue Type

Documentation Feature Request

OS

Linux

onnx2tf version number

1.25.7

onnx version number

1.16.2

onnxruntime version number

1.18.1

onnxsim (onnx_simplifier) version number

0.4.36

tensorflow version number

2.17.0

Download URL for ONNX

https://huggingface.co/onnx-community/metric3d-vit-small/resolve/main/onnx/model_fp16.onnx

Parameter Replacement JSON

{
        "format_version": 1,
        "operations": [
                {
                        "op_name": "wa/depth_model/encoder/Add",
                        "param_target": "inputs",
                        "param_name": "wa/depth_model/encoder/Constant_14_output_0",
                        "values": 10
                }
        ]
}

Description

  1. Trying to deploy a monocular depth estimation model to an autonomous drone flight controller using Snapdragon SOCs. It is for both research, an academic curiosity to try to run a large model on the edge as well as for product development. I have spent so long trying to make these models run on snapdragon SOCs and TFlite is the only framework which can interact well with the snapdragon neural engines, so it is crucial I get this converted. Solving this problem will finally culminate into hours of studies and research work into deploying large models on the edge and I can finally move on to the next stage of development.

  2. When I run
    onnx2tf -i model_fp16.onnx it runs up till 69/1875 layers upon which it gives an exception

INFO: 20 / 1806
INFO: onnx_op_type: Add onnx_op_name: wa/depth_model/encoder/Add
INFO:  input_name.1: wa/depth_model/encoder/Cast_4_output_0 shape: [] dtype: float16
INFO:  input_name.2: wa/depth_model/encoder/Constant_14_output_0 shape: [] dtype: float16
INFO:  output_name.1: wa/depth_model/encoder/Add_output_0 shape: [] dtype: float16
ERROR: The trace log is below.

TypeError: Exception encountered when calling layer "tf.math.add_3" (type TFOpLambda).

Input 'y' of 'AddV2' Op has type float32 that does not match type float16 of argument 'x'.

Call arguments received by layer "tf.math.add_3" (type TFOpLambda):
  • x=tf.Tensor(shape=(), dtype=float16)
  • y=tf.Tensor(shape=(), dtype=float32)
  • name='wa/depth_model/encoder/Add'

The entire log output has been attached in a txt file. I went through the model graph on netron and both the inputs seem to be float16 to the ADD op, so I am not sure why exactly it got converted to float32.

  1. I tried to debug by going over the netron graph and finding the node where it was breaking. It seemed fine to me. I tried tried going through the parameter replacement documentation but I couldn't figure out how I could typecast the input back to float16. I did some experiments (which I add in the parameter replacement field) but it didnt go anywhere. I then went through the source code to try to understand the issue, but it just seems to be an implicit type cast somewhere.

  2. These SOCs do not really work with any framework other than TFLite hence it is crucial I get this converted to TF. Moreover the model itself is only available in the onnx format. I have spent many hours trying to make this work and this project is the closest I have gotten from all the dependency hell I have been going through for so long. Having this problem solved will be a massive relief for me and would allow me to move on to further stages of my project.

  3. The other converter scripts which are now deprecated or not compatible with the latest onnx models especially with the constantly updating API. This project seems to be the best in terms of dependency management and it would be really great to have this resolved.

log.txt
Screenshot 2024-08-27 at 1 39 33 AM

Repository owner deleted a comment Aug 27, 2024
@PINTO0309
Copy link
Owner

PINTO0309 commented Aug 27, 2024

Since the beginning of the creation of onnx2tf, it has been working according to the specifications that have been determined from the beginning. The tool does not successfully convert models other than those with Float32 inputs.

If you simply want to generate a Float16 tflite file, it will work just fine if you first generate a Float32 ONNX file for you. If you have a strong reason for not being able to generate a Float32 ONNX file, please explain that reason in detail before beginning any correspondence on this issue.

I know that your R&D is very interesting, but I am concerned that you may have made a mistake in the arrangements for generating the tflite file in the first place.

@AD-lite24
Copy link
Author

AD-lite24 commented Aug 27, 2024

Ah that makes sense. I will try it out on a float32 file then. While I would have preferred to have a float16 give that it is for embedded devices, it does give out float16 tflite files as you said.

@AD-lite24 AD-lite24 reopened this Aug 27, 2024
@AD-lite24
Copy link
Author

AD-lite24 commented Aug 27, 2024

Ok hi I am back. So I switched it to float32 and the progress was significantly more. However I ran into an issue with the ConvTranspose op.

2024-08-27 14:19:07.721322515 [E:onnxruntime:, sequential_executor.cc:516 ExecuteKernel] Non-zero status code returned while running Conv node. Name:'wa/depth_model/encoder/patch_embed/proj/Conv_output_0_nchwc' Status Message: Invalid input shape: {1,1}

This is actually an error that it gives before the entire conversion process starts. In fact it is an exception thrown by OnnxRuntime. But the onnx file runs fine and gives the desired inference, and on netron too the input and output shapes are fine. So I am simply not able to understand where the input shape {1, 1} is coming from.

Screenshot 2024-08-27 at 2 22 43 PM

The input X to the conv node is meant to be of shape (N, 384, H/14, W/14) and it seems to be that, but it seems to be getting an input of shape {1, 1 } which is super odd. I again went through the source code as well as the parameter replacement docs you gave but I am unable to identify the source of this {1, 1} shaped tensor.

Here is the snippet of the log where the complete errors occurs while the verbose log file is attached


INFO: 763 / 1763
INFO: onnx_op_type: ConvTranspose onnx_op_name: wa/depth_model/decoder/token2feature/read_1/sample/ConvTranspose
INFO:  input_name.1: wa/depth_model/decoder/token2feature/read_1/Transpose_output_0 shape: ['batch_size', 384, 'floor(height/14)', 'floor(width/14)'] dtype: float32
INFO:  input_name.2: meta_arch.depth_model.decoder.token2feature.read_1.sample.weight shape: [384, 192, 2, 2] dtype: float32
INFO:  input_name.3: meta_arch.depth_model.decoder.token2feature.read_1.sample.bias shape: [192] dtype: float32
INFO:  output_name.1: wa/depth_model/decoder/token2feature/read_1/sample/ConvTranspose_output_0 shape: ['batch_size', 192, 'ConvTranspose_761_o0__d2', 'ConvTranspose_761_o0__d3'] dtype: float32
2024-08-27 14:19:07.721322515 [E:onnxruntime:, sequential_executor.cc:516 ExecuteKernel] Non-zero status code returned while running Conv node. Name:'wa/depth_model/encoder/patch_embed/proj/Conv_output_0_nchwc' Status Message: Invalid input shape: {1,1}
ERROR: The trace log is below.
Traceback (most recent call last):
  File "/home/robocon2/miniconda3/envs/onnx-tflite/lib/python3.10/site-packages/onnx2tf/utils/common_functions.py", line 312, in print_wrapper_func
    result = func(*args, **kwargs)
  File "/home/robocon2/miniconda3/envs/onnx-tflite/lib/python3.10/site-packages/onnx2tf/utils/common_functions.py", line 385, in inverted_operation_enable_disable_wrapper_func
    result = func(*args, **kwargs)
  File "/home/robocon2/miniconda3/envs/onnx-tflite/lib/python3.10/site-packages/onnx2tf/utils/common_functions.py", line 55, in get_replacement_parameter_wrapper_func
    func(*args, **kwargs)
  File "/home/robocon2/miniconda3/envs/onnx-tflite/lib/python3.10/site-packages/onnx2tf/ops/ConvTranspose.py", line 178, in make_node
    dummy_onnx_inference(
  File "/home/robocon2/miniconda3/envs/onnx-tflite/lib/python3.10/site-packages/onnx2tf/utils/common_functions.py", line 3864, in dummy_onnx_inference
    outputs = onnx_session.run(None, input_datas)
  File "/home/robocon2/miniconda3/envs/onnx-tflite/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 220, in run
    return self._sess.run(output_names, input_feed, run_options)
onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Non-zero status code returned while running Conv node. Name:'wa/depth_model/encoder/patch_embed/proj/Conv_output_0_nchwc' Status Message: Invalid input shape: {1,1}

log.txt

@PINTO0309
Copy link
Owner

PINTO0309 commented Aug 27, 2024

Please share the Float32 model. As you surmise, it's an ONNX issue, but without the model files I have no idea.

Looking only at the partial image you shared, it is only natural that such an error would occur if the input resolution of the model is dynamic.

@PINTO0309 PINTO0309 added the third party Third-party tool issues label Aug 27, 2024
@AD-lite24
Copy link
Author

AD-lite24 commented Aug 27, 2024

Right my bad. Please find the file on huggingface
https://huggingface.co/onnx-community/metric3d-vit-small/blob/main/onnx/model.onnx

The input resolution is indeed dynamic, there is no height or width it needs to be resized to in preprocessing, but I am not sure how this comes into play in the conversion process. I have a lot to learn from your codebase which I am doing as we speak

@PINTO0309
Copy link
Owner

PINTO0309 commented Aug 27, 2024

  • Tests
    • Customized ONNX

      1. format="ONNX v10" -> format="ONNX v9"
      2. Fixed shape 1,3,480,640
      3. onnxsim
        image
    • Results
      image

    • TFLite Float16
      image

@PINTO0309 PINTO0309 added the TODO TODO label Aug 27, 2024
@AD-lite24
Copy link
Author

Can you tell me how you customized the onnx file? The onnx modifier does not allow changes to format version and static shape does does not get reflected downstream

@PINTO0309
Copy link
Owner

Can you tell me how you customized the onnx file?

Please wait a moment. I am in the process of examining what is wrong with the design of this model.

I found the process very redundant by design.

NCHW -> Transpose -> NHWC -> ReduceMean axis=-1 -> Transpose -> NCHW: At all points in the block of this pattern, onnx2tf loses the correct channel position. In this model, the ReduceMean should have been processed with axis=1, but we went to the trouble of Transpose and transposed the channel twice.

  • e.g.
    image

@PINTO0309
Copy link
Owner

PINTO0309 commented Aug 27, 2024

If changing the input resolution of the model to a fixed resolution and optimizing the overall model structure is not a problem, the easiest way to generate tflite is described below. I don't know how you generated your onnx file, but I don't recommend ir_version=10.

  • ONNX to JSON
onnx2json \
--input_onnx_file_path metric3d-vit-small.onnx \
--output_json_path metric3d-vit-small.json \
--json_indent 2
  • Edit JSON (nano, vi, vscode, etc...)
nano metric3d-vit-small.json
  • From
 {
   "irVersion": "10",
   "producerName": "pytorch",
   "producerVersion": "2.0.1",
   "domain": "",
   "graph": {
     "node": [
  • To
 {
   "irVersion": "9",
   "producerName": "pytorch",
   "producerVersion": "2.0.1",
   "domain": "",
   "graph": {
     "node": [
  • JSON to ONNX
json2onnx \
--input_json_path metric3d-vit-small.json \
--output_onnx_file_path metric3d-vit-small.onnx
  • Optimization + Static shape
onnxsim metric3d-vit-small.onnx metric3d-vit-small.onnx --overwrite-input-shape "pixel_values:1,3,480,640"
  • onnx2tf
onnx2tf -i metric3d-vit-small.onnx

@PINTO0309 PINTO0309 removed the TODO TODO label Aug 27, 2024
@AD-lite24
Copy link
Author

Wow this is brilliant! Thanks so much for the help. I still have to test it with the tflite api and validate the model results but still you have amazing skills

@AD-lite24
Copy link
Author

Just curious though, why does it not support variable sizes? In every node the input size are written relative to the original height and width so it should be simple enough to plug those values in right? Of course it might very well be very complicated since I am not aware of the entire scope of the project, but if you could tell me why I could learn something more from you again

@PINTO0309
Copy link
Owner

You are diligent. It is now 2:00 AM in Japan, so I will explain tomorrow morning. Good night.

@PINTO0309
Copy link
Owner

PINTO0309 commented Aug 27, 2024

I will explain step by step.

The reasons for the very high conversion difficulty in converting the model are simply as follows.

  1. There is no hint information written in the ONNX file that would indicate the order of the channels. Therefore, it is impossible to determine whether the input data in the form [a, 3, b, c] is image data, voice data, or some other sensor data from its appearance.
  2. There is no way to tell if it is a ViT model, or a CNN, or an LSTM, or anything else.
  3. Since the type of data cannot be identified from the outside, it is not at all clear whether the input data is NCHW, ABCD, or EFGH. In other words, there is no way to find out what the channel order of the input data means.
  4. As a work-around to identify the correct channel order, it is possible to use the size of each dimension as a cue and automatically correct as much as possible using the size values of each rank during the conversion. However, it becomes extremely difficult to correct to [N, H, W, 3] when the three undefined ranks [N, 3, H, W] are not known to actually be H and W, as in this ViT model.
  5. onnx2tf infers that if an onnx with the definition [a, 3, b, c] is input, it uses dummy np.ones([1, 3, 1, 1], dtype=np.float32) sample data only once to infer the correct channel order as much as possible internally to compute provisional rank sizes on each OP in the middle of the model, and then perform a correction operation to the correct Transpose order using those provisional rank sizes.
  6. If we are confident that the input data is an image, or if there is some external hint that such data type is an image, it is possible to specify that b and c effectively mean H and W. However, since there is no such hint information in the model definition information, onnx2tf has no choice but to perform the inference only once using the dummy sizes b=1 and c=1.
  7. The same problem you posted has been discussed many times in closed issues in the past, and everyone understands the difficulty of this automatic conversion.
  8. There are three reasons why you may have encountered your first conversion error.
    • ir_version=10 has a problem that interferes with the optimization operation of ONNX, so the operation of optimizing the model with onnx2tf fails, resulting in a model structure that is more redundant than necessary.
    • Since b and c, i.e., H and W, are undefined, we had no choice but to perform shape estimation on impossible dummy data assuming b=1 and c=1, resulting in an impossibly small size of the input tensor to ConvTranspose.
    • If a model tried to transform a model with a large number of unnecessary Transpose, the input tensor just before ConvTranspose was [N, H, C, W] because the correct order of the channels was not known. This mistake can be made in various ways depending on the complexity of the model.

A fairly complex conversion operation is implemented to accommodate automatic conversion of any model that assumes input data other than images. You are the only one who knows that the input data for this model is an image.

@AD-lite24
Copy link
Author

AD-lite24 commented Aug 28, 2024

Thanks for the detailed explanation. I get what you are saying. However in each node'

      {
        "name": "/depth_model/encoder/blocks.0/blocks.0.0/norm1/Sub_output_0",
        "type": {
          "tensorType": {
            "elemType": 1,
            "shape": {
              "dim": [
                {
                  "dimParam": "batch_size"
                },
                {
                  "dimParam": "floor(height/14)*floor(width/14) + 5"
                },
                {
                  "dimValue": "384"
                }
              ]
            }
          }
        }
      },

^ This is from the json file

The dim param is always in an arithmetic of the original variables height and width. The way I understand from your comment, the onnx2tf builds a numpy array from it. Now since the values of height and width are not know we perhaps cannot create the appropriate array but then how does onnx manage to have its model graph with shapes defined by variables that will only be defined at runtime?

Referring to point 6, if onnx2tf is provided information that the input variables height and width are dynamic, and every node downstream is always relative to these variables fixed at runtime, could there not be a work around to define the nodes in tflite with the a similar variable in the entry node?

Since b and c, i.e., H and W, are undefined, we had no choice but to perform shape estimation on impossible dummy data assuming b=1 and c=1, resulting in an impossibly small size of the input tensor to ConvTranspose.

Right so in the conversion steps H and W are recognized as None and onnx2tf sets them as 1, which then causes an issue with the ConvTranspose. But if values are not provided, we could perhaps a create a map between the variable names as strings and an integer value, and expect the map to be defined at runtime? And till then use the strings as placeholders and create other structures and define the Ops to be done on the string at runtime? It might not work with the how tflite works but onnx supports (though likely not in the way I have described it) it so I have a feeling it could be done with tflite as well.

@PINTO0309
Copy link
Owner

PINTO0309 commented Aug 28, 2024

Look closely at the names of the attributes. The way dimParam and dimValue are handled inside ONNX is completely different.

  • The dimParam is treated as a simple string. That is, your JSON is just set to a string that looks like an equation. The string "floor(height/14)*floor(width/14) + 5" makes little sense. The system dynamically calculates the size during inference in onnxruntime. Referring to the column named "floor(height/14)*floor(width/14) + 5" from onnx2tf will get None. In other words, no value. Your model is special in that it has a formula set in dimParam, but the models generated by people other than you do not have a formula but just a meaningless string like abcdefgh. In addition, this string is usually given an unintelligible name, such as a sequential number generated by PyTorch and ONNX with certain internal rules.
  • The dimValue is treated as an integer value of type INT64. That is, it is a constant.

Right so in the conversion steps H and W are recognized as None and onnx2tf sets them as 1, which then causes an issue with the ConvTranspose. But if values are not provided, we could perhaps a create a map between the variable names as strings and an integer value, and expect the map to be defined at runtime? And till then use the strings as placeholders and create other structures and define the Ops to be done on the string at runtime? It might not work with the how tflite works but onnx supports (though likely not in the way I have described it) it so I have a feeling it could be done with tflite as well.

No. That's not possible. (I'm thinking I probably can't do it.) It is already done precisely, but the TFLite runtime is not flexible. To add a little more, the APIs for Keras, TensorFlow, and TensorFlowLite are not flexible. I did not say that ONNX with dynamic tensors cannot be converted to TFLite. In fact, I am able to convert models with some dynamic inputs other than yours.

The tutorial below explains how to generate a TFLite with a dynamic tensor and how to perform inference.

https://github.com/PINTO0309/onnx2tf?tab=readme-ov-file#14-inference-with-dynamic-tensors-in-tflite

However, TensorFlow is very annoying because it cannot define a flexible model structure like ONNX and has already been troubling me for more than two years.

TensorFlow, TFLite runtime and Keras completely ignore the name given to the OP. LOL.

For your reference, I'll post a JSON image of TFLite's model structure. It is not possible to name each element. All undefined elements can only be declared as -1.
image

Just to be clear, we have a function to customize the behavior of onnx2tf using the file param_replacement.json, so I know from the start that your model will support dynamic tensor transformation by disabling all Transpose. However, it is very annoying to have to specify custom behavior for every Transpose that should be disabled.

https://github.com/PINTO0309/onnx2tf?tab=readme-ov-file#parameter-replacement

Rather than spending the effort on such things, the following very useful tools can be used to generate tflite models in a more flexible, straightforward, and efficient manner.

https://github.com/AlexanderLutsenko/nobuco

Alternatively, some tools, such as the one below, preserve channel order by extrapolating a large number of Transposes immediately before and after most of the OPs. Although maintenance has already been abandoned, any model pattern could be converted. However, the inference speed of the generated tflite is 25% slower than that generated by onnx2tf.

https://github.com/onnx/onnx-tensorflow

Since onnx2tf is not a tool designed simply to convert models accurately, but rather to convert them while greatly optimizing the redundant model structure of the conversion source, the behavior of thoroughly removing or fusing unnecessary OPs is implemented. Therefore, defining a simple OP name mapping or mapping by element name is itself quite difficult.

Statically and accurately tying all the elements together during model transformation and dynamically estimating the shape during inference are completely different behaviors.

@AD-lite24
Copy link
Author

AD-lite24 commented Aug 28, 2024

The system dynamically calculates the size during inference in onnxruntime. Referring to the column named "floor(height/14)*floor(width/14) + 5" from onnx2tf will get None. In other words, no value. Your model is special in that it has a formula set in dimParam, but the models generated by people other than you do not have a formula but just a meaningless string like abcdefgh.

Oh so any fix could not be generalized to all models. But in my case (ie. this particular) if I write an adapter which could parse the string into an arithmetic operation and create some data structure to store it should theoretically be possible?

Just to be clear, we have a function to customize the behavior of onnx2tf using the file param_replacement.json, so I know from the start that your model will support dynamic tensor transformation by disabling all Transpose. However, it is very annoying to have to specify custom behavior for every Transpose that should be disabled.

Ah so the transpose ops mess with the rank ordering of the dimensions. Turning them off (if they are redundant) or simplifying them (using onnxsim) resolves the issue. Got it

Rather than spending the effort on such things, the following very useful tools can be used to generate tflite models in a more flexible, straightforward, and efficient manner.
https://github.com/AlexanderLutsenko/nobuco

Was not aware of this, will try this out soon! The onnx-tensorflow module is not maintained though and is not compatible with anything really or the compatibility is poorly defined. I spent a few hours dealing with that before I landed upon your project (thank God for that). It is odd they stopped maintaining an official converter but maybe the update to tensorflow 2.x made them give up I suppose lol.

@PINTO0309
Copy link
Owner

Oh so any fix could not be generalized to all models. But in my case (ie. this particular) if I write an adapter which could parse the string into an arithmetic operation and create some data structure to store it should theoretically be possible?

I think you're right, if you work super hard you can do it. 😸

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Dynamic batch / Dynamic shape Dynamic batch / Dynamic shape third party Third-party tool issues
Projects
None yet
Development

No branches or pull requests

9 participants
@PINTO0309 @AD-lite24 and others