Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 65 additions & 3 deletions src/frontends/onnx/frontend/src/op/lstm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
#include "openvino/op/gather.hpp"
#include "openvino/op/lstm_sequence.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/slice.hpp"
#include "openvino/op/squeeze.hpp"
#include "openvino/op/tile.hpp"
#include "openvino/util/common_util.hpp"
#include "utils/reshape.hpp"
#include "utils/split.hpp"
Expand All @@ -37,6 +41,47 @@ enum class LSTMInput {
LSTM_INPUT_P
};

// Helper function to reduce tensor rank to target_rank by squeezing or reshaping
std::shared_ptr<ov::Node> reduce_tensor_rank(const ov::Output<ov::Node>& input, int64_t target_rank) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
std::shared_ptr<ov::Node> reduce_tensor_rank(const ov::Output<ov::Node>& input, int64_t target_rank) {
ov::Output<ov::Node> reduce_tensor_rank(const ov::Output<ov::Node>& input, int64_t target_rank) {

Avoid working with nodes as this might be a node with many outouts and you will use first output that way

const auto& input_shape = input.get_partial_shape();

if (!input_shape.rank().is_static()) {
return input.get_node_shared_ptr();
}
Comment on lines +48 to +50
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (!input_shape.rank().is_static()) {
return input.get_node_shared_ptr();
}
if (input_shape.rank().is_dynamic()) {
return input;
}


const auto input_rank = input_shape.rank().get_length();

if (input_rank <= target_rank) {
return input.get_node_shared_ptr();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return input.get_node_shared_ptr();
return input;

}

// Strategy: Try to squeeze all leading dimensions that are equal to 1
std::vector<int64_t> axes_to_squeeze;
for (int64_t i = 0; i < input_rank - target_rank; ++i) {
if (input_shape[i].is_static() && input_shape[i].get_length() == 1) {
axes_to_squeeze.push_back(i);
}
}

if (axes_to_squeeze.size() == static_cast<size_t>(input_rank - target_rank)) {
// All extra dimensions are 1, we can squeeze to get target rank
auto axes_const = v0::Constant::create(ov::element::i64, Shape{axes_to_squeeze.size()}, axes_to_squeeze);
return std::make_shared<v0::Squeeze>(input, axes_const);
} else {
// Some dimensions are not 1 or dynamic, need to reshape
auto shape_of_input = std::make_shared<v3::ShapeOf>(input);
auto start_idx = v0::Constant::create(ov::element::i64, Shape{1}, {input_rank - target_rank});
auto stop_idx = v0::Constant::create(ov::element::i64, Shape{1}, {input_rank});
auto step = v0::Constant::create(ov::element::i64, Shape{1}, {1});

// Get last target_rank dimensions: shape[-target_rank:]
auto last_dims = std::make_shared<v8::Slice>(shape_of_input, start_idx, stop_idx, step);

// Reshape to extract last target_rank dimensions
return std::make_shared<v1::Reshape>(input, last_dims, false);
Comment on lines +72 to +81
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can extra input dimensions be not 1? In such case the reshape will fail

}
}

struct LSTMNgInputMap {
explicit LSTMNgInputMap(const Node& node) {
const auto& ng_inputs = node.get_ov_inputs();
Expand All @@ -48,7 +93,14 @@ struct LSTMNgInputMap {
// Packed input sequences.
// ONNX Shape: [seq_length, batch_size, input_size]
// OpenVino Shape: [batch_size, seq_length, input_size]
m_input_map[LSTMInput::LSTM_INPUT_X] = ov::op::util::reorder_axes(ng_inputs.at(0), {1, 0, 2});

// First reduce rank if needed, THEN reorder axes
// This is important because Squeeze changes dimension indices
auto input_x = ng_inputs.at(0);
input_x = reduce_tensor_rank(input_x, 3);
input_x = ov::op::util::reorder_axes(input_x, {1, 0, 2});

m_input_map[LSTMInput::LSTM_INPUT_X] = input_x;

// Weight tensor for the gates.
// Shape: [num_directions, 4*hidden_size, input_size]
Expand Down Expand Up @@ -124,7 +176,12 @@ struct LSTMNgInputMap {
// ONNX Shape: [num_directions, batch_size, hidden_size]
// OpenVino Shape: [batch_size, num_directions, hidden_size]
if (ng_inputs.size() > 5 && !ov::op::util::is_null(ng_inputs.at(5))) {
m_input_map[LSTMInput::LSTM_INPUT_INIT_H] = ov::op::util::reorder_axes(ng_inputs.at(5), {1, 0, 2});
auto init_h = ng_inputs.at(5);
// First reduce rank, THEN reorder axes
init_h = reduce_tensor_rank(init_h, 3);
init_h = ov::op::util::reorder_axes(init_h, {1, 0, 2});

m_input_map[LSTMInput::LSTM_INPUT_INIT_H] = init_h;
} else {
auto init_h_shape =
std::make_shared<v0::Concat>(ov::OutputVector{batch_size_node, num_directions_node, hidden_size_node},
Expand All @@ -137,7 +194,12 @@ struct LSTMNgInputMap {
// ONNX Shape: [num_directions, batch_size, hidden_size]
// OpenVino Shape: [batch_size, num_directions, hidden_size]
if (ng_inputs.size() > 6 && !ov::op::util::is_null(ng_inputs.at(6))) {
m_input_map[LSTMInput::LSTM_INPUT_INIT_C] = ov::op::util::reorder_axes(ng_inputs.at(6), {1, 0, 2});
auto init_c = ng_inputs.at(6);
// First reduce rank, THEN reorder axes
init_c = reduce_tensor_rank(init_c, 3);
init_c = ov::op::util::reorder_axes(init_c, {1, 0, 2});

m_input_map[LSTMInput::LSTM_INPUT_INIT_C] = init_c;
} else {
auto init_c_shape =
std::make_shared<v0::Concat>(ov::OutputVector{batch_size_node, num_directions_node, hidden_size_node},
Expand Down
167 changes: 167 additions & 0 deletions src/frontends/onnx/tests/models/lstm_high_rank_input.prototxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
ir_version: 4
producer_name: "OpenVINO ONNX Frontend"
graph {
node {
input: "X"
input: "W"
input: "R"
input: "B"
output: "Y"
output: "Y_h"
output: "Y_c"
op_type: "LSTM"
attribute {
name: "direction"
s: "forward"
type: STRING
}
attribute {
name: "hidden_size"
i: 2
type: INT
}
}
name: "compute_graph"
input {
name: "X"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 3
}
dim {
dim_value: 2
}
dim {
dim_value: 4
}
}
}
}
}
input {
name: "W"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 8
}
dim {
dim_value: 4
}
}
}
}
}
input {
name: "R"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 8
}
dim {
dim_value: 2
}
}
}
}
}
input {
name: "B"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 16
}
}
}
}
}
output {
name: "Y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 1
}
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
output {
name: "Y_h"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
output {
name: "Y_c"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
}
opset_import {
version: 7
}
Loading
Loading