Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kernel error for T5-style beam search with FP-16 subgraphs #23728

Open
KarelZe opened this issue Feb 17, 2025 · 0 comments
Open

Kernel error for T5-style beam search with FP-16 subgraphs #23728

KarelZe opened this issue Feb 17, 2025 · 0 comments
Assignees
Labels
model:transformer issues related to a transformer model: BERT, GPT2, Hugging Face, Longformer, T5, etc.

Comments

@KarelZe
Copy link

KarelZe commented Feb 17, 2025

Describe the issue

Thanks for your work on onnx. 💯

I'm currently running into issues with the BeamSearch-Node (https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.BeamSearch) where the subgraph for the encoder and decoder (T5-style) have been converted to FP16.

According to the internal documentations inputs could be both float/float16 tensors (

Data type of input or output is float or float16 if not specified.
).

For float16 inputs, however, I get the error:

2025-02-15 20:15:02.999160115 [E:onnxruntime:, sequential_executor.cc:516 ExecuteKernel] Non-zero status code returned while running BeamSearch node. Name:'beam_search' Status Message: bad_function_call

To reproduce

I crafted a minimal reproducer with CPU execution provider (example adapted from here:

Minimal reproducer

pip freeze:

numpy==1.26.4
onnx==1.17.0
onnxruntime==1.20.1

Example:

import onnx
import onnxruntime
import numpy as np

VOCAB_SIZE = 20
NUM_HEADS = 2
HEAD_SIZE = 4
HIDDEN_SIZE = NUM_HEADS * HEAD_SIZE

DTYPE = onnx.TensorProto.FLOAT16

def create_encoder_graph(encoder_embedding_weight, decoder_embedding_weight, decoder_linear_weight):

    input_name = "encoder_input_ids"

    # Create input tensor
    input_tensor = onnx.helper.make_tensor_value_info(input_name, onnx.TensorProto.INT32, ["batch_size", "encode_sequence_length"])
    mask_tensor = onnx.helper.make_tensor_value_info("encoder_attention_mask", onnx.TensorProto.INT32, ["batch_size", "encode_sequence_length"])

    # Create embedding layer
    embedding_node = onnx.helper.make_node("Gather", ["embedding_weight", input_name], ["encoder_hidden_states"], name="embedding")
    embedding_weight_initializer = onnx.helper.make_tensor("embedding_weight", DTYPE, [VOCAB_SIZE, HIDDEN_SIZE], encoder_embedding_weight.flatten())

    # Create encoder layer
    encoder_output = onnx.helper.make_tensor_value_info("encoder_hidden_states", DTYPE, ["batch_size", "encode_sequence_length", HIDDEN_SIZE])

    # Create decoder input
    decoder_input = onnx.helper.make_tensor_value_info("decoder_input_ids", onnx.TensorProto.INT32, ["batch_size", 1])

    # Create decoder embedding layer
    decoder_embedding_node = onnx.helper.make_node("Gather", ["decoder_embedding_weight", "decoder_input_ids"], ["decoder_embedding_output"], name="decoder_embedding")
    decoder_embedding_weight_initializer = onnx.helper.make_tensor("decoder_embedding_weight", DTYPE, [VOCAB_SIZE, HIDDEN_SIZE], decoder_embedding_weight.flatten())

    # Create decoder output
    decoder_output = onnx.helper.make_tensor_value_info("logits", DTYPE, ["batch_size", 1, VOCAB_SIZE])

    # Reduce mean of encoder output
    encoder_output_mean = onnx.helper.make_node("ReduceMean", ["encoder_hidden_states"], ["encoder_hidden_states_mean"], axes=[1])

    # Create sum node
    sum_node = onnx.helper.make_node("Add", ["decoder_embedding_output", "encoder_hidden_states_mean"], ["sum_output"], name="sum")

    # Create linear projection
    linear_node = onnx.helper.make_node("MatMul", ["sum_output", "W"], ["linear_output"], name="linear")
    linear_weight_initializer = onnx.helper.make_tensor("W", DTYPE, [HIDDEN_SIZE, VOCAB_SIZE], decoder_linear_weight.flatten())

    # Create softmax node
    softmax_node = onnx.helper.make_node("Softmax", ["linear_output"], ["logits"], name="softmax")

    # Create output key and value states
    present_self_key = onnx.helper.make_tensor_value_info("present_key_self_0", DTYPE, ["batch_size", NUM_HEADS, 1, HEAD_SIZE])
    present_self_value = onnx.helper.make_tensor_value_info("present_value_self_0", DTYPE, ["batch_size", NUM_HEADS, 1, HEAD_SIZE])

    # Obtain key and value states from reshaping the encoder/decoder sum
    final_shape_as_constant = onnx.helper.make_node("Constant", [], ["final_shape"], value=onnx.helper.make_tensor("value", onnx.TensorProto.INT64, [4], [-1, 1, NUM_HEADS, HEAD_SIZE]))
    key_node = onnx.helper.make_node("Reshape", ["sum_output", "final_shape"], ["to_transpose_self_key"], name="key_reshape")
    value_node = onnx.helper.make_node("Reshape", ["sum_output", "final_shape"], ["to_transpose_self_value"], name="value_reshape")
    transposed_key_node = onnx.helper.make_node("Transpose", ["to_transpose_self_key"], ["present_key_self_0"], perm=[0, 2, 1, 3])
    transposed_value_node = onnx.helper.make_node("Transpose", ["to_transpose_self_value"], ["present_value_self_0"], perm=[0, 2, 1, 3])

    # Create output key and value states from the encoder
    present_cross_key = onnx.helper.make_tensor_value_info("present_cross_key", DTYPE, ["batch_size", NUM_HEADS, "encode_sequence_length", HEAD_SIZE])
    present_cross_value = onnx.helper.make_tensor_value_info("present_cross_value", DTYPE, ["batch_size", NUM_HEADS, "encode_sequence_length", HEAD_SIZE])

    # Obtain key and value states from reshaping the encoder output
    encoder_batch_seq_len = onnx.helper.make_node("Shape", ["encoder_hidden_states"], ["encoder_batch_seq_len"], end=2)
    num_heads_and_size = onnx.helper.make_node("Constant", [], ["num_heads_and_size"], value=onnx.helper.make_tensor("value", onnx.TensorProto.INT64, [2], [NUM_HEADS, HEAD_SIZE]))
    encoder_final_shape = onnx.helper.make_node("Concat", ["encoder_batch_seq_len", "num_heads_and_size"], ["encoder_final_shape"], axis=0)
    encoder_key_node = onnx.helper.make_node("Reshape", ["encoder_hidden_states", "encoder_final_shape"], ["to_transpose_cross_key"], name="encoder_key_reshape")
    encoder_value_node = onnx.helper.make_node("Reshape", ["encoder_hidden_states", "encoder_final_shape"], ["to_transpose_cross_value"], name="encoder_value_reshape")
    encoder_transposed_key_node = onnx.helper.make_node("Transpose", ["to_transpose_cross_key"], ["present_cross_key"], perm=[0, 2, 1, 3])
    encoder_transposed_value_node = onnx.helper.make_node("Transpose", ["to_transpose_cross_value"], ["present_cross_value"], perm=[0, 2, 1, 3])


    # Create graph
    graph = onnx.helper.make_graph(
        nodes=[final_shape_as_constant, embedding_node, decoder_embedding_node, encoder_output_mean, sum_node, linear_node, softmax_node, key_node, value_node, encoder_batch_seq_len, num_heads_and_size, encoder_final_shape, encoder_key_node, encoder_value_node, transposed_key_node, transposed_value_node, encoder_transposed_key_node, encoder_transposed_value_node],
        name="encoder_decoder_init",
        inputs=[input_tensor, mask_tensor, decoder_input],
        outputs=[decoder_output, encoder_output, present_self_key, present_self_value, present_cross_key, present_cross_value],
        initializer=[embedding_weight_initializer, decoder_embedding_weight_initializer, linear_weight_initializer]
    )

    return graph

def create_decoder_graph(decoder_embedding_weight, decoder_linear_weight):

    input_name = "input_ids"
    input_tensor = onnx.helper.make_tensor_value_info(input_name, onnx.TensorProto.INT32, ["batch_size", "decode_sequence_length"])
    encoder_attention_mask = onnx.helper.make_tensor_value_info("encoder_attention_mask", onnx.TensorProto.INT32, ["batch_size", "encode_sequence_length"])
    encoder_hidden_states = onnx.helper.make_tensor_value_info("encoder_hidden_states", DTYPE, ["batch_size", "encode_sequence_length", HIDDEN_SIZE])
    past_self_key = onnx.helper.make_tensor_value_info("past_self_key", DTYPE, ["batch_size", NUM_HEADS, "decode_sequence_length", HEAD_SIZE])
    past_self_value = onnx.helper.make_tensor_value_info("past_self_value", DTYPE, ["batch_size", NUM_HEADS, "decode_sequence_length", HEAD_SIZE])
    past_cross_key = onnx.helper.make_tensor_value_info("past_cross_key", DTYPE, ["batch_size", NUM_HEADS, "encode_sequence_length", HEAD_SIZE])
    past_cross_value = onnx.helper.make_tensor_value_info("past_cross_value", DTYPE, ["batch_size", NUM_HEADS, "encode_sequence_length", HEAD_SIZE])

    decoder_embedding_node = onnx.helper.make_node("Gather", ["decoder_embedding_weight", input_name], ["decoder_embedding_output"], name="decoder_embedding")
    decoder_embedding_weight_initializer = onnx.helper.make_tensor("decoder_embedding_weight", DTYPE, [VOCAB_SIZE, HIDDEN_SIZE], decoder_embedding_weight.flatten())

    logits = onnx.helper.make_tensor_value_info("logits", DTYPE, ["batch_size", "decode_sequence_length", VOCAB_SIZE])

    encoder_output_mean = onnx.helper.make_node("ReduceMean", ["encoder_hidden_states"], ["encoder_hidden_states_mean"], axes=[1])

    sum_node = onnx.helper.make_node("Add", ["decoder_embedding_output", "encoder_hidden_states_mean"], ["sum_output"], name="sum")

    linear_node = onnx.helper.make_node("MatMul", ["sum_output", "W"], ["linear_output"], name="linear")
    linear_weight_initializer = onnx.helper.make_tensor("W", DTYPE, [HIDDEN_SIZE, VOCAB_SIZE], decoder_linear_weight.flatten())

    softmax_node = onnx.helper.make_node("Softmax", ["linear_output"], ["logits"], name="softmax")

    output_key = onnx.helper.make_tensor_value_info("present_key", DTYPE, ["batch_size", NUM_HEADS, "present_sequence_length", HEAD_SIZE])
    output_value = onnx.helper.make_tensor_value_info("present_value", DTYPE, ["batch_size", NUM_HEADS, "present_sequence_length", HEAD_SIZE])

    batch_size = onnx.helper.make_node("Shape", ["sum_output"], ["batch_size"], end=1)
    final_shape_without_batch = onnx.helper.make_node("Constant", [], ["final_shape_without_batch"], value=onnx.helper.make_tensor("value", onnx.TensorProto.INT64, [3], [-1, NUM_HEADS, HEAD_SIZE]))
    final_shape = onnx.helper.make_node("Concat", ["batch_size", "final_shape_without_batch"], ["final_shape"], axis=0)

    key_node = onnx.helper.make_node("Reshape", ["sum_output", "final_shape"], ["output_key"], name="key_reshape")
    value_node = onnx.helper.make_node("Reshape", ["sum_output", "final_shape"], ["output_value"], name="value_reshape")
    transposed_key_node = onnx.helper.make_node("Transpose", ["output_key"], ["output_key_transposed"], perm=[0, 2, 1, 3])
    transposed_value_node = onnx.helper.make_node("Transpose", ["output_value"], ["output_value_transposed"], perm=[0, 2, 1, 3])
    key_concat_node = onnx.helper.make_node("Concat", ["past_self_key", "output_key_transposed"], ["present_key"], axis=2)
    value_concat_node = onnx.helper.make_node("Concat", ["past_self_value", "output_value_transposed"], ["present_value"], axis=2)

    graph = onnx.helper.make_graph(
        nodes=[decoder_embedding_node, encoder_output_mean, sum_node, batch_size, final_shape_without_batch, final_shape, linear_node, softmax_node, key_node, value_node, transposed_key_node, transposed_value_node, key_concat_node, value_concat_node],
        name="decoder_step",
        inputs=[input_tensor, encoder_attention_mask, encoder_hidden_states, past_self_key, past_self_value, past_cross_key, past_cross_value],
        outputs=[logits, output_key, output_value],
        initializer=[decoder_embedding_weight_initializer, linear_weight_initializer]
    )

    return graph

def create_model_with_beam_search():

    encoder_embedding_weight = np.random.rand(VOCAB_SIZE, HIDDEN_SIZE)
    decoder_embedding_weight = np.random.rand(VOCAB_SIZE, HIDDEN_SIZE)
    decoder_linear_weight = np.random.rand(HIDDEN_SIZE, VOCAB_SIZE)
    encoder_graph = create_encoder_graph(encoder_embedding_weight, decoder_embedding_weight, decoder_linear_weight)
    decoder_graph = create_decoder_graph(decoder_embedding_weight, decoder_linear_weight)

    # Create input tensor
    encoder_input = onnx.helper.make_tensor_value_info("encoder_input_ids", onnx.TensorProto.INT32, ["batch_size", "encode_sequence_length"])

    # Create output tensor
    sequences_output = onnx.helper.make_tensor_value_info("sequences", onnx.TensorProto.INT32, ["batch_size", 3, None])

    num_beams_tensor = onnx.helper.make_tensor("num_beams", onnx.TensorProto.INT32, [], [3])
    num_beams_as_constant = onnx.helper.make_node("Constant", [], ["num_beams"], value=num_beams_tensor)
    min_length_tensor = onnx.helper.make_tensor("min_length", onnx.TensorProto.INT32, [], [1])
    min_length_as_constant = onnx.helper.make_node("Constant", [], ["min_length"], value=min_length_tensor)
    max_length_tensor = onnx.helper.make_tensor("max_length", onnx.TensorProto.INT32, [], [10])
    max_length_as_constant = onnx.helper.make_node("Constant", [], ["max_length"], value=max_length_tensor)
    length_penalty_tensor = onnx.helper.make_tensor("length_penalty", onnx.TensorProto.FLOAT, [], [0.6])
    length_penalty_as_constant = onnx.helper.make_node("Constant", [], ["length_penalty"], value=length_penalty_tensor)

    # Create beam search node
    beam_search_node = onnx.helper.make_node(
        "BeamSearch",
        ["encoder_input_ids", "max_length", "min_length", "num_beams", "num_beams", "length_penalty"],
        ["sequences"],
        decoder=decoder_graph,
        encoder=encoder_graph,
        decoder_start_token_id=2,
        early_stopping=0,
        eos_token_id=2,
        model_type=1,
        pad_token_id=1,
        name="beam_search",
        domain="com.microsoft"
    )

    # Create main graph
    graph = onnx.helper.make_graph(
        nodes=[num_beams_as_constant, min_length_as_constant, max_length_as_constant, length_penalty_as_constant, beam_search_node],
        name="model",
        inputs=[encoder_input],
        outputs=[sequences_output],
    )

    # Create model
    model = onnx.helper.make_model(graph, opset_imports=[onnx.helper.make_opsetid("", 17), onnx.helper.make_opsetid(domain="com.microsoft", version=1)])

    onnx.checker.check_model(model, full_check=True)
    onnx.shape_inference.infer_shapes(model, strict_mode=True)

    return model

def run_model_with_dtype():
    cpu_session = onnxruntime.InferenceSession("model.onnx", providers=["CPUExecutionProvider"])
    input = np.random.randint(0, VOCAB_SIZE, (1, 5)).astype(np.int32)
    print("Input: ", input)
    cpu_output = cpu_session.run([], {"encoder_input_ids": input})
    print("Output: ", cpu_output)

# Create model
model = create_model_with_beam_search()

# Save model
onnx.save(model, "model.onnx")

run_model_with_dtype()

print("Model saved and tested successfully")

Output for DTYPE=TensorProto.FLOAT16:

Input:  [[15  3  9 17  8]]
2025-02-17 10:18:56.734928937 [E:onnxruntime:, sequential_executor.cc:516 ExecuteKernel] Non-zero status code returned while running BeamSearch node. Name:'beam_search' Status Message: bad_function_call
---------------------------------------------------------------------------
RuntimeException                          Traceback (most recent call last)
Cell In[2], line 200
    197 # Save model
    198 onnx.save(model, "model.onnx")
--> 200 run_model_with_dtype()
    202 print("Model saved and tested successfully")

Cell In[2], line 191
    189 input = np.random.randint(0, VOCAB_SIZE, (1, 5)).astype(np.int32)
    190 print("Input: ", input)
--> 191 cpu_output = cpu_session.run([], {"encoder_input_ids": input})
    192 print("Output: ", cpu_output)

File ~/git/smad-legi-extraction-lib/venv/lib/python3.11/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py:266, in Session.run(self, output_names, input_feed, run_options)
    264     output_names = [output.name for output in self._outputs_meta]
    265 try:
--> 266     return self._sess.run(output_names, input_feed, run_options)
    267 except C.EPFail as err:
    268     if self._enable_fallback:

RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running BeamSearch node. Name:'beam_search' Status Message: bad_function_call

Output for DTYPE=TensorProto.FLOAT:

Input:  [[12 17 10  5  8]]
Output:  [array([[[ 2, 11, 11, 11, 11, 11, 11, 11, 11, 11],
        [ 2, 11, 14, 11, 11, 11, 11, 11, 11, 11],
        [ 2, 11,  9, 11, 11, 11, 11, 11, 11, 11]]], dtype=int32)]
Model saved and tested successfully

I get the same errors when converting larger encoder/decoders to fp16 using the transformer optimizer fp16 conversion functionality m.convert_float_to_float16(keep_io_types=False), when keep_io_types is set to False. If I retain the input types (keep_io_types=True), execution is possible.

Could you please look into this issue? Maybe @tianleiwu ? Please let me know, if you need more Info. 👍

Urgency

Yes, mid-spring would be great.

Platform

Linux

OS Version

Ubuntu in WSL.

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.20.1

ONNX Runtime API

Python

Architecture

X64

Execution Provider

Default CPU

Execution Provider Library Version

No response

@github-actions github-actions bot added the model:transformer issues related to a transformer model: BERT, GPT2, Hugging Face, Longformer, T5, etc. label Feb 17, 2025
@tianleiwu tianleiwu self-assigned this Feb 22, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
model:transformer issues related to a transformer model: BERT, GPT2, Hugging Face, Longformer, T5, etc.
Projects
None yet
Development

No branches or pull requests

2 participants