Skip to content

Commit 2cf37b3

Browse files
preetha-intelsfatimar
authored andcommitted
Fix the model copies and redefinitions for CPU fallback (#728)
* Fix the model copies and redefinitions for CPU fallback * OV compatibility is not needed --------- Co-authored-by: sfatimar <[email protected]>
1 parent 08de9ce commit 2cf37b3

File tree

1 file changed

+23
-25
lines changed

1 file changed

+23
-25
lines changed

onnxruntime/core/providers/openvino/backends/basic_backend.cc

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,6 @@ BasicBackend::BasicBackend(std::unique_ptr<ONNX_NAMESPACE::ModelProto>& model_pr
3636
if (ValidateSubgraph(const_outputs_map_))
3737
return;
3838

39-
// Pre-requisite is provider_option "context" must be set
40-
auto auto_unified_compile = ((hw_target.find("AUTO") == std::string::npos) ||
41-
(session_context_.OpenVINO_Version.at(0) >= 2024 &&
42-
session_context_.OpenVINO_Version.at(1) > 2));
4339
ov::AnyMap device_config;
4440
SetOVDeviceConfiguration(device_config);
4541
if (subgraph_context_.is_ep_ctx_graph) {
@@ -81,42 +77,46 @@ BasicBackend::BasicBackend(std::unique_ptr<ONNX_NAMESPACE::ModelProto>& model_pr
8177
ORT_THROW(msg);
8278
} // Delete stream after it is no longer needed
8379
} else {
80+
std::shared_ptr<const onnxruntime::openvino_ep::OVNetwork> ov_model;
8481
std::string model = model_proto->SerializeAsString();
8582
if (!subgraph_context.has_dynamic_input_shape) {
8683
model_proto.reset();
8784
}
85+
bool eligible_for_cpu_fallback = session_context_.device_type.find("NPU") != std::string::npos &&
86+
!session_context_.so_disable_cpu_ep_fallback &&
87+
!subgraph_context_.is_ep_ctx_graph;
88+
#if defined(OPENVINO_DISABLE_NPU_FALLBACK)
89+
eligible_for_cpu_fallback = false;
90+
#endif
91+
auto auto_unified_compile = (hw_target.find("AUTO") == std::string::npos);
92+
93+
// Unified compile is efficient with cahce_dir cached model loading that bypass Read Model
94+
// Does not support model with exteral weights, dynamic input shape, Epctx onnx cached model,
95+
// reshape, enable_causallm, and for NPU CPU fallback
96+
97+
auto is_unified_compile = (!session_context_.has_external_weights &&
98+
!subgraph_context_.has_dynamic_input_shape &&
99+
!session_context_.so_context_enable &&
100+
session_context_.reshape.empty() &&
101+
!enable_causallm &&
102+
!eligible_for_cpu_fallback &&
103+
auto_unified_compile);
88104
try {
89-
// SetOVDeviceConfiguration(device_config);
90-
if (!session_context_.has_external_weights &&
91-
!subgraph_context_.has_dynamic_input_shape &&
92-
!session_context_.so_context_enable &&
93-
session_context_.reshape.empty() &&
94-
!enable_causallm &&
95-
auto_unified_compile) {
96-
// Unified OV compile_model is efficient when ov model caching is enabled
97-
// Unified OV compile_model API is supported with AUTO from version 2024.3 and above
98-
// Inputs with static dimensions
99-
// Not enabled for models with external weights and when ep context is set.
100-
105+
if (is_unified_compile) {
101106
exe_network_ = OVCore::Get()->CompileModel(model,
102107
hw_target,
103108
device_config,
104109
subgraph_context_.subgraph_name);
105110
} else { // For all other types use ov::ov_core read_model() to generate OV IR
106111
// followed by ov::ov_core compile_model()
107-
auto ov_model = CreateOVModel(std::move(model), session_context_, const_outputs_map_);
112+
ov_model = CreateOVModel(std::move(model), session_context_, const_outputs_map_);
108113
exe_network_ = OVCore::Get()->CompileModel(
109114
ov_model, hw_target, device_config, enable_causallm, subgraph_context_.subgraph_name);
110115
}
111116
LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin";
112117
} catch (const OnnxRuntimeException& ex) {
113118
std::string exception_str = ex.what();
114-
bool eligible_for_cpu_fallback = session_context_.device_type.find("NPU") != std::string::npos &&
115-
!session_context_.so_disable_cpu_ep_fallback &&
116-
!subgraph_context_.is_ep_ctx_graph;
117-
#if defined(OPENVINO_DISABLE_NPU_FALLBACK)
118-
eligible_for_cpu_fallback = false;
119-
#endif
119+
120120
if (eligible_for_cpu_fallback) {
121121
LOGS_DEFAULT(WARNING) << "Model compilation failed at OV NPU."
122122
<< "Falling back to OV CPU for execution";
@@ -125,8 +125,6 @@ BasicBackend::BasicBackend(std::unique_ptr<ONNX_NAMESPACE::ModelProto>& model_pr
125125
device_config.clear();
126126
SetOVDeviceConfiguration(device_config);
127127
try {
128-
// Recreate the model with CPU device type
129-
auto ov_model = CreateOVModel(std::move(model), session_context_, const_outputs_map_);
130128
exe_network_ = OVCore::Get()->CompileModel(
131129
ov_model, hw_target, device_config, enable_causallm, subgraph_context_.subgraph_name);
132130
} catch (std::string const& msg) {

0 commit comments

Comments
 (0)