@@ -36,10 +36,6 @@ BasicBackend::BasicBackend(std::unique_ptr<ONNX_NAMESPACE::ModelProto>& model_pr
36
36
if (ValidateSubgraph (const_outputs_map_))
37
37
return ;
38
38
39
- // Pre-requisite is provider_option "context" must be set
40
- auto auto_unified_compile = ((hw_target.find (" AUTO" ) == std::string::npos) ||
41
- (session_context_.OpenVINO_Version .at (0 ) >= 2024 &&
42
- session_context_.OpenVINO_Version .at (1 ) > 2 ));
43
39
ov::AnyMap device_config;
44
40
SetOVDeviceConfiguration (device_config);
45
41
if (subgraph_context_.is_ep_ctx_graph ) {
@@ -81,42 +77,46 @@ BasicBackend::BasicBackend(std::unique_ptr<ONNX_NAMESPACE::ModelProto>& model_pr
81
77
ORT_THROW (msg);
82
78
} // Delete stream after it is no longer needed
83
79
} else {
80
+ std::shared_ptr<const onnxruntime::openvino_ep::OVNetwork> ov_model;
84
81
std::string model = model_proto->SerializeAsString ();
85
82
if (!subgraph_context.has_dynamic_input_shape ) {
86
83
model_proto.reset ();
87
84
}
85
+ bool eligible_for_cpu_fallback = session_context_.device_type .find (" NPU" ) != std::string::npos &&
86
+ !session_context_.so_disable_cpu_ep_fallback &&
87
+ !subgraph_context_.is_ep_ctx_graph ;
88
+ #if defined(OPENVINO_DISABLE_NPU_FALLBACK)
89
+ eligible_for_cpu_fallback = false ;
90
+ #endif
91
+ auto auto_unified_compile = (hw_target.find (" AUTO" ) == std::string::npos);
92
+
93
+ // Unified compile is efficient with cahce_dir cached model loading that bypass Read Model
94
+ // Does not support model with exteral weights, dynamic input shape, Epctx onnx cached model,
95
+ // reshape, enable_causallm, and for NPU CPU fallback
96
+
97
+ auto is_unified_compile = (!session_context_.has_external_weights &&
98
+ !subgraph_context_.has_dynamic_input_shape &&
99
+ !session_context_.so_context_enable &&
100
+ session_context_.reshape .empty () &&
101
+ !enable_causallm &&
102
+ !eligible_for_cpu_fallback &&
103
+ auto_unified_compile);
88
104
try {
89
- // SetOVDeviceConfiguration(device_config);
90
- if (!session_context_.has_external_weights &&
91
- !subgraph_context_.has_dynamic_input_shape &&
92
- !session_context_.so_context_enable &&
93
- session_context_.reshape .empty () &&
94
- !enable_causallm &&
95
- auto_unified_compile) {
96
- // Unified OV compile_model is efficient when ov model caching is enabled
97
- // Unified OV compile_model API is supported with AUTO from version 2024.3 and above
98
- // Inputs with static dimensions
99
- // Not enabled for models with external weights and when ep context is set.
100
-
105
+ if (is_unified_compile) {
101
106
exe_network_ = OVCore::Get ()->CompileModel (model,
102
107
hw_target,
103
108
device_config,
104
109
subgraph_context_.subgraph_name );
105
110
} else { // For all other types use ov::ov_core read_model() to generate OV IR
106
111
// followed by ov::ov_core compile_model()
107
- auto ov_model = CreateOVModel (std::move (model), session_context_, const_outputs_map_);
112
+ ov_model = CreateOVModel (std::move (model), session_context_, const_outputs_map_);
108
113
exe_network_ = OVCore::Get ()->CompileModel (
109
114
ov_model, hw_target, device_config, enable_causallm, subgraph_context_.subgraph_name );
110
115
}
111
116
LOGS_DEFAULT (INFO) << log_tag << " Loaded model to the plugin" ;
112
117
} catch (const OnnxRuntimeException& ex) {
113
118
std::string exception_str = ex.what ();
114
- bool eligible_for_cpu_fallback = session_context_.device_type .find (" NPU" ) != std::string::npos &&
115
- !session_context_.so_disable_cpu_ep_fallback &&
116
- !subgraph_context_.is_ep_ctx_graph ;
117
- #if defined(OPENVINO_DISABLE_NPU_FALLBACK)
118
- eligible_for_cpu_fallback = false ;
119
- #endif
119
+
120
120
if (eligible_for_cpu_fallback) {
121
121
LOGS_DEFAULT (WARNING) << " Model compilation failed at OV NPU."
122
122
<< " Falling back to OV CPU for execution" ;
@@ -125,8 +125,6 @@ BasicBackend::BasicBackend(std::unique_ptr<ONNX_NAMESPACE::ModelProto>& model_pr
125
125
device_config.clear ();
126
126
SetOVDeviceConfiguration (device_config);
127
127
try {
128
- // Recreate the model with CPU device type
129
- auto ov_model = CreateOVModel (std::move (model), session_context_, const_outputs_map_);
130
128
exe_network_ = OVCore::Get ()->CompileModel (
131
129
ov_model, hw_target, device_config, enable_causallm, subgraph_context_.subgraph_name );
132
130
} catch (std::string const & msg) {
0 commit comments