Skip to content

Commit 3d09e79

Browse files
authored
[NPU] Refactoring getGraphDescriptor by setting the flags inside it (#32449)
### Details: - *Small refactoring by setting the compilation flags in the ze_graph_ext_wrappers* - *Skip tests that aren't available on different platforms* --------- Signed-off-by: Bogdan Pereanu <[email protected]>
1 parent 7c616dc commit 3d09e79

File tree

8 files changed

+197
-166
lines changed

8 files changed

+197
-166
lines changed

src/plugins/intel_npu/src/compiler_adapter/include/ze_graph_ext_wrappers.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class ZeGraphExtWrappers {
3535

3636
GraphDescriptor getGraphDescriptor(SerializedIR serializedIR,
3737
const std::string& buildFlags,
38-
const uint32_t& flags) const;
38+
const bool bypassUmdCache = false) const;
3939

4040
GraphDescriptor getGraphDescriptor(void* data, size_t size) const;
4141

src/plugins/intel_npu/src/compiler_adapter/src/driver_compiler_adapter.cpp

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -101,15 +101,10 @@ std::shared_ptr<IGraph> DriverCompilerAdapter::compile(const std::shared_ptr<con
101101

102102
_logger.debug("compileIR Build flags : %s", buildFlags.c_str());
103103

104-
// If UMD Caching is requested to be bypassed or if OV cache is enabled, disable driver caching
105-
uint32_t flags = ZE_GRAPH_FLAG_NONE;
106-
const auto set_cache_dir = config.get<CACHE_DIR>();
107-
if (!set_cache_dir.empty() || config.get<BYPASS_UMD_CACHING>()) {
108-
flags = flags | ZE_GRAPH_FLAG_DISABLE_CACHING;
109-
}
110-
111104
_logger.debug("compile start");
112-
auto graphDesc = _zeGraphExt->getGraphDescriptor(std::move(serializedIR), buildFlags, flags);
105+
// If UMD Caching is requested to be bypassed or if OV cache is enabled, disable driver caching
106+
const bool bypassCache = !config.get<CACHE_DIR>().empty() || config.get<BYPASS_UMD_CACHING>();
107+
auto graphDesc = _zeGraphExt->getGraphDescriptor(std::move(serializedIR), buildFlags, bypassCache);
113108
_logger.debug("compile end");
114109

115110
OV_ITT_TASK_NEXT(COMPILE_BLOB, "getNetworkMeta");
@@ -161,13 +156,6 @@ std::shared_ptr<IGraph> DriverCompilerAdapter::compileWS(const std::shared_ptr<o
161156
}
162157
FilteredConfig updatedConfig = *plgConfig;
163158

164-
// If UMD Caching is requested to be bypassed or if OV cache is enabled, disable driver caching
165-
uint32_t flags = ZE_GRAPH_FLAG_NONE;
166-
const auto set_cache_dir = config.get<CACHE_DIR>();
167-
if (!set_cache_dir.empty() || config.get<BYPASS_UMD_CACHING>()) {
168-
flags = flags | ZE_GRAPH_FLAG_DISABLE_CACHING;
169-
}
170-
171159
// WS v3 is based on a stateless compiler. We'll use a separate config entry for informing the compiler the index of
172160
// the current call iteration.
173161
std::vector<NetworkMetadata> initNetworkMetadata;
@@ -191,7 +179,9 @@ std::shared_ptr<IGraph> DriverCompilerAdapter::compileWS(const std::shared_ptr<o
191179
buildFlags += irSerializer.serializeConfig(updatedConfig, compilerVersion);
192180

193181
_logger.debug("compile start");
194-
auto graphDesc = _zeGraphExt->getGraphDescriptor(serializedIR, buildFlags, flags);
182+
// If UMD Caching is requested to be bypassed or if OV cache is enabled, disable driver caching
183+
const bool bypassCache = !config.get<CACHE_DIR>().empty() || config.get<BYPASS_UMD_CACHING>();
184+
auto graphDesc = _zeGraphExt->getGraphDescriptor(serializedIR, buildFlags, bypassCache);
195185
_logger.debug("compile end");
196186

197187
OV_ITT_TASK_NEXT(COMPILE_BLOB, "getNetworkMeta");

src/plugins/intel_npu/src/compiler_adapter/src/ze_graph_ext_wrappers.cpp

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -131,10 +131,6 @@ ZeGraphExtWrappers::ZeGraphExtWrappers(const std::shared_ptr<ZeroInitStructsHold
131131
ZE_MAJOR_VERSION(_graphExtVersion),
132132
ZE_MINOR_VERSION(_graphExtVersion));
133133
_logger.debug("capabilities:");
134-
_logger.debug("-SupportQuery: %d", true);
135-
_logger.debug("-SupportAPIGraphQueryNetworkV1: %d", true);
136-
_logger.debug("-SupportAPIGraphQueryNetworkV2 :%d", true);
137-
_logger.debug("-SupportpfnCreate2 :%d", true);
138134
_logger.debug("-SupportArgumentMetadata :%d", !NotSupportArgumentMetadata(_graphExtVersion));
139135
_logger.debug("-UseCopyForNativeBinary :%d", UseCopyForNativeBinary(_graphExtVersion));
140136
}
@@ -274,10 +270,8 @@ static std::unordered_set<std::string> parseQueryResult(std::vector<char>& data)
274270

275271
std::unordered_set<std::string> ZeGraphExtWrappers::queryGraph(SerializedIR serializedIR,
276272
const std::string& buildFlags) const {
277-
// For ext version >= 1.5
278273
ze_graph_query_network_handle_t hGraphQueryNetwork = nullptr;
279274

280-
// For ext version >= 1.5
281275
ze_graph_desc_2_t desc = {ZE_STRUCTURE_TYPE_GRAPH_DESC_PROPERTIES,
282276
nullptr,
283277
ZE_GRAPH_FORMAT_NGRAPH_LITE,
@@ -286,14 +280,14 @@ std::unordered_set<std::string> ZeGraphExtWrappers::queryGraph(SerializedIR seri
286280
buildFlags.c_str(),
287281
ZE_GRAPH_FLAG_NONE};
288282

289-
// Create querynetwork handle
290-
_logger.debug("For ext larger than 1.4 - perform pfnQueryNetworkCreate2");
291-
ze_result_t result = _zeroInitStruct->getGraphDdiTable().pfnQueryNetworkCreate2(_zeroInitStruct->getContext(),
292-
_zeroInitStruct->getDevice(),
293-
&desc,
294-
&hGraphQueryNetwork);
283+
_logger.debug("queryGraph - perform pfnQueryNetworkCreate2");
284+
auto result = _zeroInitStruct->getGraphDdiTable().pfnQueryNetworkCreate2(_zeroInitStruct->getContext(),
285+
_zeroInitStruct->getDevice(),
286+
&desc,
287+
&hGraphQueryNetwork);
295288
THROW_ON_FAIL_FOR_LEVELZERO_EXT("pfnQueryNetworkCreate2", result, _zeroInitStruct->getGraphDdiTable());
296289

290+
// Get the size of query result
297291
_logger.debug("queryGraph - perform pfnQueryNetworkGetSupportedLayers to get size");
298292
size_t size = 0;
299293
result = _zeroInitStruct->getGraphDdiTable().pfnQueryNetworkGetSupportedLayers(hGraphQueryNetwork, &size, nullptr);
@@ -341,8 +335,15 @@ bool ZeGraphExtWrappers::canCpuVaBeImported(void* data, size_t size) const {
341335

342336
GraphDescriptor ZeGraphExtWrappers::getGraphDescriptor(SerializedIR serializedIR,
343337
const std::string& buildFlags,
344-
const uint32_t& flags) const {
345-
// For ext version >= 1.5, calling pfnCreate2 api in _zeroInitStruct->getGraphDdiTable()
338+
const bool bypassUmdCache) const {
339+
ze_graph_handle_t graphHandle = nullptr;
340+
341+
uint32_t flags = ZE_GRAPH_FLAG_NONE;
342+
if (bypassUmdCache) {
343+
_logger.debug("getGraphDescriptor - set ZE_GRAPH_FLAG_DISABLE_CACHING");
344+
flags |= ZE_GRAPH_FLAG_DISABLE_CACHING;
345+
}
346+
346347
ze_graph_desc_2_t desc = {ZE_STRUCTURE_TYPE_GRAPH_DESC_PROPERTIES,
347348
nullptr,
348349
ZE_GRAPH_FORMAT_NGRAPH_LITE,
@@ -352,8 +353,6 @@ GraphDescriptor ZeGraphExtWrappers::getGraphDescriptor(SerializedIR serializedIR
352353
flags};
353354

354355
_logger.debug("getGraphDescriptor - perform pfnCreate2");
355-
// Create querynetwork handle
356-
ze_graph_handle_t graphHandle = nullptr;
357356
auto result = _zeroInitStruct->getGraphDdiTable().pfnCreate2(_zeroInitStruct->getContext(),
358357
_zeroInitStruct->getDevice(),
359358
&desc,
@@ -370,12 +369,11 @@ GraphDescriptor ZeGraphExtWrappers::getGraphDescriptor(void* blobData, size_t bl
370369
OPENVINO_THROW("Empty blob");
371370
}
372371

373-
uint32_t flags = 0;
372+
uint32_t flags = ZE_GRAPH_FLAG_NONE;
374373
bool setPersistentFlag = canCpuVaBeImported(blobData, blobSize);
375-
376374
if (setPersistentFlag) {
377375
_logger.debug("getGraphDescriptor - set ZE_GRAPH_FLAG_INPUT_GRAPH_PERSISTENT");
378-
flags = ZE_GRAPH_FLAG_INPUT_GRAPH_PERSISTENT;
376+
flags |= ZE_GRAPH_FLAG_INPUT_GRAPH_PERSISTENT;
379377
}
380378

381379
ze_graph_desc_2_t desc = {ZE_STRUCTURE_TYPE_GRAPH_DESC_PROPERTIES,
@@ -387,7 +385,6 @@ GraphDescriptor ZeGraphExtWrappers::getGraphDescriptor(void* blobData, size_t bl
387385
flags};
388386

389387
_logger.debug("getGraphDescriptor - perform pfnCreate2");
390-
391388
auto result = _zeroInitStruct->getGraphDdiTable().pfnCreate2(_zeroInitStruct->getContext(),
392389
_zeroInitStruct->getDevice(),
393390
&desc,

src/plugins/intel_npu/tests/functional/internal/backend/zero_tensor_tests.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ class ZeroTensorTests : public ov::test::behavior::OVPluginTestBase,
6161
ov::element::Type type;
6262
std::tie(targetDevice, configuration, type) = obj.param;
6363
std::replace(targetDevice.begin(), targetDevice.end(), ':', '_');
64-
targetDevice = ov::test::utils::getTestsPlatformFromEnvironmentOr(ov::test::utils::DEVICE_NPU);
6564

6665
std::ostringstream result;
6766
result << "targetDevice=" << targetDevice << "_";

src/plugins/intel_npu/tests/functional/internal/backend/zero_variable_state_tests.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ class ZeroVariableStateTests : public ov::test::behavior::OVPluginTestBase,
5959
ov::AnyMap configuration;
6060
std::tie(targetDevice, configuration) = obj.param;
6161
std::replace(targetDevice.begin(), targetDevice.end(), ':', '_');
62-
targetDevice = ov::test::utils::getTestsPlatformFromEnvironmentOr(ov::test::utils::DEVICE_NPU);
6362

6463
std::ostringstream result;
6564
result << "targetDevice=" << targetDevice << "_";

src/plugins/intel_npu/tests/functional/internal/compiler_adapter/zero_graph.cpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,27 @@
44

55
#include "zero_graph.hpp"
66

7-
#include <common_test_utils/test_assertions.hpp>
8-
97
namespace {
10-
std::vector<int> graphDescflags = {ZE_GRAPH_FLAG_NONE, ZE_GRAPH_FLAG_DISABLE_CACHING, ZE_GRAPH_FLAG_ENABLE_PROFILING};
8+
const std::vector<ov::AnyMap> configsGraphCompilationTests = {{},
9+
{ov::cache_dir("test")},
10+
{ov::intel_npu::bypass_umd_caching(true)}};
1111

1212
// tested versions interval is [1.5, CURRENT + 1)
13-
auto extVersions = ::testing::Range(ZE_MAKE_VERSION(1, 5), ZE_GRAPH_EXT_VERSION_CURRENT + 1);
13+
auto graphExtVersions = ::testing::Range(ZE_MAKE_VERSION(1, 5), ZE_GRAPH_EXT_VERSION_CURRENT + 1);
1414

1515
INSTANTIATE_TEST_SUITE_P(smoke_BehaviorTest,
1616
ZeroGraphCompilationTests,
17-
::testing::Combine(::testing::ValuesIn(graphDescflags), extVersions),
17+
::testing::Combine(::testing::Values(ov::test::utils::DEVICE_NPU),
18+
::testing::ValuesIn(configsGraphCompilationTests),
19+
graphExtVersions),
1820
ZeroGraphTest::getTestCaseName);
1921

20-
std::vector<int> noneGraphDescflags = {ZE_GRAPH_FLAG_NONE};
22+
const std::vector<ov::AnyMap> emptyConfigsTests = {{}};
2123

2224
INSTANTIATE_TEST_SUITE_P(smoke_BehaviorTest,
2325
ZeroGraphTest,
24-
::testing::Combine(::testing::ValuesIn(noneGraphDescflags), extVersions),
26+
::testing::Combine(::testing::Values(ov::test::utils::DEVICE_NPU),
27+
::testing::ValuesIn(emptyConfigsTests),
28+
graphExtVersions),
2529
ZeroGraphTest::getTestCaseName);
2630
} // namespace

0 commit comments

Comments
 (0)