-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Fix LSTM conversion for models with rank > 3 inputs from Unsqueeze operations #33023
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
d1689ce
de3353d
b06e7f2
c58c5bc
5ba2f09
08c7dda
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -12,7 +12,11 @@ | |||||||||||||
| #include "openvino/op/gather.hpp" | ||||||||||||||
| #include "openvino/op/lstm_sequence.hpp" | ||||||||||||||
| #include "openvino/op/multiply.hpp" | ||||||||||||||
| #include "openvino/op/reshape.hpp" | ||||||||||||||
| #include "openvino/op/shape_of.hpp" | ||||||||||||||
| #include "openvino/op/slice.hpp" | ||||||||||||||
| #include "openvino/op/squeeze.hpp" | ||||||||||||||
| #include "openvino/op/tile.hpp" | ||||||||||||||
| #include "openvino/util/common_util.hpp" | ||||||||||||||
| #include "utils/reshape.hpp" | ||||||||||||||
| #include "utils/split.hpp" | ||||||||||||||
|
|
@@ -37,6 +41,47 @@ enum class LSTMInput { | |||||||||||||
| LSTM_INPUT_P | ||||||||||||||
| }; | ||||||||||||||
|
|
||||||||||||||
| // Helper function to reduce tensor rank to target_rank by squeezing or reshaping | ||||||||||||||
| std::shared_ptr<ov::Node> reduce_tensor_rank(const ov::Output<ov::Node>& input, int64_t target_rank) { | ||||||||||||||
| const auto& input_shape = input.get_partial_shape(); | ||||||||||||||
|
|
||||||||||||||
| if (!input_shape.rank().is_static()) { | ||||||||||||||
| return input.get_node_shared_ptr(); | ||||||||||||||
| } | ||||||||||||||
|
Comment on lines
+48
to
+50
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||
|
|
||||||||||||||
| const auto input_rank = input_shape.rank().get_length(); | ||||||||||||||
|
|
||||||||||||||
| if (input_rank <= target_rank) { | ||||||||||||||
| return input.get_node_shared_ptr(); | ||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| // Strategy: Try to squeeze all leading dimensions that are equal to 1 | ||||||||||||||
| std::vector<int64_t> axes_to_squeeze; | ||||||||||||||
| for (int64_t i = 0; i < input_rank - target_rank; ++i) { | ||||||||||||||
| if (input_shape[i].is_static() && input_shape[i].get_length() == 1) { | ||||||||||||||
| axes_to_squeeze.push_back(i); | ||||||||||||||
| } | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| if (axes_to_squeeze.size() == static_cast<size_t>(input_rank - target_rank)) { | ||||||||||||||
| // All extra dimensions are 1, we can squeeze to get target rank | ||||||||||||||
| auto axes_const = v0::Constant::create(ov::element::i64, Shape{axes_to_squeeze.size()}, axes_to_squeeze); | ||||||||||||||
| return std::make_shared<v0::Squeeze>(input, axes_const); | ||||||||||||||
| } else { | ||||||||||||||
| // Some dimensions are not 1 or dynamic, need to reshape | ||||||||||||||
| auto shape_of_input = std::make_shared<v3::ShapeOf>(input); | ||||||||||||||
| auto start_idx = v0::Constant::create(ov::element::i64, Shape{1}, {input_rank - target_rank}); | ||||||||||||||
| auto stop_idx = v0::Constant::create(ov::element::i64, Shape{1}, {input_rank}); | ||||||||||||||
| auto step = v0::Constant::create(ov::element::i64, Shape{1}, {1}); | ||||||||||||||
|
|
||||||||||||||
| // Get last target_rank dimensions: shape[-target_rank:] | ||||||||||||||
| auto last_dims = std::make_shared<v8::Slice>(shape_of_input, start_idx, stop_idx, step); | ||||||||||||||
|
|
||||||||||||||
| // Reshape to extract last target_rank dimensions | ||||||||||||||
| return std::make_shared<v1::Reshape>(input, last_dims, false); | ||||||||||||||
|
Comment on lines
+72
to
+81
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can extra input dimensions be not 1? In such case the reshape will fail |
||||||||||||||
| } | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| struct LSTMNgInputMap { | ||||||||||||||
| explicit LSTMNgInputMap(const Node& node) { | ||||||||||||||
| const auto& ng_inputs = node.get_ov_inputs(); | ||||||||||||||
|
|
@@ -48,7 +93,14 @@ struct LSTMNgInputMap { | |||||||||||||
| // Packed input sequences. | ||||||||||||||
| // ONNX Shape: [seq_length, batch_size, input_size] | ||||||||||||||
| // OpenVino Shape: [batch_size, seq_length, input_size] | ||||||||||||||
| m_input_map[LSTMInput::LSTM_INPUT_X] = ov::op::util::reorder_axes(ng_inputs.at(0), {1, 0, 2}); | ||||||||||||||
|
|
||||||||||||||
| // First reduce rank if needed, THEN reorder axes | ||||||||||||||
| // This is important because Squeeze changes dimension indices | ||||||||||||||
| auto input_x = ng_inputs.at(0); | ||||||||||||||
| input_x = reduce_tensor_rank(input_x, 3); | ||||||||||||||
| input_x = ov::op::util::reorder_axes(input_x, {1, 0, 2}); | ||||||||||||||
|
|
||||||||||||||
| m_input_map[LSTMInput::LSTM_INPUT_X] = input_x; | ||||||||||||||
|
|
||||||||||||||
| // Weight tensor for the gates. | ||||||||||||||
| // Shape: [num_directions, 4*hidden_size, input_size] | ||||||||||||||
|
|
@@ -124,7 +176,12 @@ struct LSTMNgInputMap { | |||||||||||||
| // ONNX Shape: [num_directions, batch_size, hidden_size] | ||||||||||||||
| // OpenVino Shape: [batch_size, num_directions, hidden_size] | ||||||||||||||
| if (ng_inputs.size() > 5 && !ov::op::util::is_null(ng_inputs.at(5))) { | ||||||||||||||
| m_input_map[LSTMInput::LSTM_INPUT_INIT_H] = ov::op::util::reorder_axes(ng_inputs.at(5), {1, 0, 2}); | ||||||||||||||
| auto init_h = ng_inputs.at(5); | ||||||||||||||
| // First reduce rank, THEN reorder axes | ||||||||||||||
| init_h = reduce_tensor_rank(init_h, 3); | ||||||||||||||
| init_h = ov::op::util::reorder_axes(init_h, {1, 0, 2}); | ||||||||||||||
|
|
||||||||||||||
| m_input_map[LSTMInput::LSTM_INPUT_INIT_H] = init_h; | ||||||||||||||
| } else { | ||||||||||||||
| auto init_h_shape = | ||||||||||||||
| std::make_shared<v0::Concat>(ov::OutputVector{batch_size_node, num_directions_node, hidden_size_node}, | ||||||||||||||
|
|
@@ -137,7 +194,12 @@ struct LSTMNgInputMap { | |||||||||||||
| // ONNX Shape: [num_directions, batch_size, hidden_size] | ||||||||||||||
| // OpenVino Shape: [batch_size, num_directions, hidden_size] | ||||||||||||||
| if (ng_inputs.size() > 6 && !ov::op::util::is_null(ng_inputs.at(6))) { | ||||||||||||||
| m_input_map[LSTMInput::LSTM_INPUT_INIT_C] = ov::op::util::reorder_axes(ng_inputs.at(6), {1, 0, 2}); | ||||||||||||||
| auto init_c = ng_inputs.at(6); | ||||||||||||||
| // First reduce rank, THEN reorder axes | ||||||||||||||
| init_c = reduce_tensor_rank(init_c, 3); | ||||||||||||||
| init_c = ov::op::util::reorder_axes(init_c, {1, 0, 2}); | ||||||||||||||
|
|
||||||||||||||
| m_input_map[LSTMInput::LSTM_INPUT_INIT_C] = init_c; | ||||||||||||||
| } else { | ||||||||||||||
| auto init_c_shape = | ||||||||||||||
| std::make_shared<v0::Concat>(ov::OutputVector{batch_size_node, num_directions_node, hidden_size_node}, | ||||||||||||||
|
|
||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,167 @@ | ||
| ir_version: 4 | ||
| producer_name: "OpenVINO ONNX Frontend" | ||
| graph { | ||
| node { | ||
| input: "X" | ||
| input: "W" | ||
| input: "R" | ||
| input: "B" | ||
| output: "Y" | ||
| output: "Y_h" | ||
| output: "Y_c" | ||
| op_type: "LSTM" | ||
| attribute { | ||
| name: "direction" | ||
| s: "forward" | ||
| type: STRING | ||
| } | ||
| attribute { | ||
| name: "hidden_size" | ||
| i: 2 | ||
| type: INT | ||
| } | ||
| } | ||
| name: "compute_graph" | ||
| input { | ||
| name: "X" | ||
| type { | ||
| tensor_type { | ||
| elem_type: 1 | ||
| shape { | ||
| dim { | ||
| dim_value: 1 | ||
| } | ||
| dim { | ||
| dim_value: 1 | ||
| } | ||
| dim { | ||
| dim_value: 3 | ||
| } | ||
| dim { | ||
| dim_value: 2 | ||
| } | ||
| dim { | ||
| dim_value: 4 | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| input { | ||
| name: "W" | ||
| type { | ||
| tensor_type { | ||
| elem_type: 1 | ||
| shape { | ||
| dim { | ||
| dim_value: 1 | ||
| } | ||
| dim { | ||
| dim_value: 8 | ||
| } | ||
| dim { | ||
| dim_value: 4 | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| input { | ||
| name: "R" | ||
| type { | ||
| tensor_type { | ||
| elem_type: 1 | ||
| shape { | ||
| dim { | ||
| dim_value: 1 | ||
| } | ||
| dim { | ||
| dim_value: 8 | ||
| } | ||
| dim { | ||
| dim_value: 2 | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| input { | ||
| name: "B" | ||
| type { | ||
| tensor_type { | ||
| elem_type: 1 | ||
| shape { | ||
| dim { | ||
| dim_value: 1 | ||
| } | ||
| dim { | ||
| dim_value: 16 | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| output { | ||
| name: "Y" | ||
| type { | ||
| tensor_type { | ||
| elem_type: 1 | ||
| shape { | ||
| dim { | ||
| dim_value: 3 | ||
| } | ||
| dim { | ||
| dim_value: 1 | ||
| } | ||
| dim { | ||
| dim_value: 2 | ||
| } | ||
| dim { | ||
| dim_value: 2 | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| output { | ||
| name: "Y_h" | ||
| type { | ||
| tensor_type { | ||
| elem_type: 1 | ||
| shape { | ||
| dim { | ||
| dim_value: 1 | ||
| } | ||
| dim { | ||
| dim_value: 2 | ||
| } | ||
| dim { | ||
| dim_value: 2 | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| output { | ||
| name: "Y_c" | ||
| type { | ||
| tensor_type { | ||
| elem_type: 1 | ||
| shape { | ||
| dim { | ||
| dim_value: 1 | ||
| } | ||
| dim { | ||
| dim_value: 2 | ||
| } | ||
| dim { | ||
| dim_value: 2 | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| opset_import { | ||
| version: 7 | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid working with nodes as this might be a node with many outouts and you will use first output that way