diff --git a/runtime/backend/options_map.h b/runtime/backend/options_map.h index c761704e69..b833f46120 100644 --- a/runtime/backend/options_map.h +++ b/runtime/backend/options_map.h @@ -11,6 +11,8 @@ #include #include #include +#include +#include #include namespace executorch { namespace runtime { @@ -82,5 +84,106 @@ class BackendOptionsMap { size_t size_ = 0; // Current number of entries }; + +/** + * Retrieves backend options for a specific backend. + * + * @param backend_name The name of the backend to get options from + * @param backend_options The backend option objects that will be filled with + * the populated values from the backend + * @return Error::Ok on success, Error::NotFound if backend is not found, or + * other error codes on failure + */ + Error get_option( + const char* backend_name, + executorch::runtime::Span + backend_options) { +auto backend_class = get_backend_class(backend_name); +if (!backend_class) { + return Error::NotFound; +} +executorch::runtime::BackendOptionContext backend_option_context; +executorch::runtime::Span backend_options_ref( + backend_options.data(), backend_options.size()); +auto result = + backend_class->get_option(backend_option_context, backend_options_ref); +if (result != Error::Ok) { + return result; +} +return Error::Ok; +} + +/** +* Retrieves backend options for multiple backends using a backend options map. +* +* @param backend_options_map The backend option map containing backend names +* and their associated options, which will be filled with the populated values +* from the backend +* @return Error::Ok on success, or the first error encountered when processing +* the entries +*/ +Error get_option( + executorch::runtime::Span backend_options_map) { +Error result = Error::Ok; +for (auto& entry : backend_options_map) { + const char* backend_name = entry.backend_name; + auto backend_options = entry.options; + auto result = get_option(backend_name, backend_options); + if (result != Error::Ok) { + return result; + } +} +return Error::Ok; +} + +/** +* Sets backend options for a specific backend. +* +* @param backend_name The name of the backend to set options for +* @param backend_options The backend option list containing the options +* to set +* @return Error::Ok on success, Error::NotFound if backend is not found, or +* other error codes on failure +*/ +Error set_option( + const char* backend_name, + const executorch::runtime::Span + backend_options) { +auto backend_class = get_backend_class(backend_name); +if (!backend_class) { + return Error::NotFound; +} + +executorch::runtime::BackendOptionContext backend_option_context; +Error result = + backend_class->set_option(backend_option_context, backend_options); +if (result != Error::Ok) { + return result; +} +return Error::Ok; +} + +/** +* Sets backend options for multiple backends using a backend options map. +* +* @param backend_options_map The backend option map containing backend names +* and their associated backend options to set +* @return Error::Ok on success, or the first error encountered when processing +*/ +Error set_option(const executorch::runtime::Span + backend_options_map) { +Error result = Error::Ok; +for (const auto& entry : backend_options_map) { + const char* backend_name = entry.backend_name; + auto backend_options = entry.options; + result = set_option(backend_name, backend_options); + + if (result != Error::Ok) { + return result; + } +} +return Error::Ok; +} + } // namespace runtime } // namespace executorch diff --git a/runtime/backend/targets.bzl b/runtime/backend/targets.bzl index 4c6c7c5c48..03a38900d6 100644 --- a/runtime/backend/targets.bzl +++ b/runtime/backend/targets.bzl @@ -68,5 +68,6 @@ def define_common_targets(): exported_deps = [ "//executorch/runtime/core:core", ":options" + aten_suffix, + ":interface" + aten_suffix, ], ) diff --git a/runtime/backend/test/backend_options_map_test.cpp b/runtime/backend/test/backend_options_map_test.cpp index 1457d06632..bb6c9b24e4 100644 --- a/runtime/backend/test/backend_options_map_test.cpp +++ b/runtime/backend/test/backend_options_map_test.cpp @@ -6,6 +6,7 @@ * LICENSE file in the root directory of this source tree. */ +#include #include #include #include @@ -13,11 +14,23 @@ #include using namespace ::testing; +using executorch::runtime::ArrayRef; +using executorch::runtime::Backend; +using executorch::runtime::BackendExecutionContext; +using executorch::runtime::BackendInitContext; +using executorch::runtime::BackendInterface; using executorch::runtime::BackendOption; +using executorch::runtime::BackendOptionContext; using executorch::runtime::BackendOptions; using executorch::runtime::BackendOptionsMap; +using executorch::runtime::CompileSpec; +using executorch::runtime::DelegateHandle; using executorch::runtime::Error; +using executorch::runtime::EValue; +using executorch::runtime::FreeableBuffer; using executorch::runtime::OptionKey; +using executorch::runtime::register_backend; +using executorch::runtime::Result; namespace executorch { namespace runtime { @@ -144,5 +157,155 @@ TEST_F(BackendOptionsMapTest, OptionIsolation) { EXPECT_STREQ(gpu_opts[2].key, "Hardware"); EXPECT_STREQ(std::get(gpu_opts[2].value), "H100"); } + +// Mock backend for testing +class StubBackend : public BackendInterface { + public: + ~StubBackend() override = default; + + bool is_available() const override { + return true; + } + + Result init( + BackendInitContext& context, + FreeableBuffer* processed, + ArrayRef compile_specs) const override { + return nullptr; + } + + Error execute( + BackendExecutionContext& context, + DelegateHandle* handle, + EValue** args) const override { + return Error::Ok; + } + + Error get_option( + BackendOptionContext& context, + executorch::runtime::Span& + backend_options) override { + // For testing purposes, just record that get_option was called + // and verify the input parameters + get_option_called = true; + get_option_call_count++; + last_get_option_size = backend_options.size(); + + // Verify that the expected option key is present and modify the value + for (size_t i = 0; i < backend_options.size(); ++i) { + if (strcmp(backend_options[i].key, "NumberOfThreads") == 0) { + // Set the value to what was stored by set_option + backend_options[i].value = last_num_threads; + found_expected_key = true; + break; + } + } + + return Error::Ok; + } + + Error set_option( + BackendOptionContext& context, + const Span& backend_options) + override { + // Store the options for verification + last_options_size = backend_options.size(); + if (backend_options.size() > 0) { + for (const auto& option : backend_options) { + if (strcmp(option.key, "NumberOfThreads") == 0) { + if (auto* val = std::get_if(&option.value)) { + last_num_threads = *val; + } + } + } + } + return Error::Ok; + } + + // Mutable for testing verification + size_t last_options_size = 0; + int last_num_threads = 0; + bool get_option_called = false; + int get_option_call_count = 0; + size_t last_get_option_size = 0; + bool found_expected_key = false; +}; + +class BackendUpdateTest : public ::testing::Test { + protected: + void SetUp() override { + // Since these tests cause ET_LOG to be called, the PAL must be initialized + // first. + executorch::runtime::runtime_init(); + + // Register the stub backend + stub_backend = std::make_unique(); + Backend backend_config{"StubBackend", stub_backend.get()}; + auto register_result = register_backend(backend_config); + ASSERT_EQ(register_result, Error::Ok); + } + + std::unique_ptr stub_backend; +}; + +// Test basic string functionality +TEST_F(BackendUpdateTest, TestSetOption) { + BackendOptionsMap<3> map; + BackendOptions<1> backend_options; + int new_num_threads = 4; + backend_options.set_option("NumberOfThreads", new_num_threads); + map.add("StubBackend", backend_options.view()); + + auto status = set_option(map.entries()); + ASSERT_EQ(status, Error::Ok); + + // Verify the map contains the expected data + ASSERT_EQ(map.size(), 1); + auto options = map.get("StubBackend"); + ASSERT_EQ(options.size(), 1); + + // Verify that the backend actually received the options + ASSERT_EQ(stub_backend->last_options_size, 1); + ASSERT_EQ(stub_backend->last_num_threads, new_num_threads); +} + +// Test get_option functionality +TEST_F(BackendUpdateTest, TestGetOption) { + // First, set some options in the backend + BackendOptionsMap<3> set_map; + BackendOptions<1> set_backend_options; + int expected_num_threads = 8; + set_backend_options.set_option("NumberOfThreads", expected_num_threads); + set_map.add("StubBackend", set_backend_options.view()); + + auto set_status = set_option(set_map.entries()); + ASSERT_EQ(set_status, Error::Ok); + ASSERT_EQ(stub_backend->last_num_threads, expected_num_threads); + + // Reset get_option tracking variables + stub_backend->get_option_called = false; + stub_backend->get_option_call_count = 0; + stub_backend->found_expected_key = false; + + // Now create a map with options for get_option to process + BackendOptionsMap<3> get_map; + BackendOptions<1> get_backend_options; + get_backend_options.set_option("NumberOfThreads", 0); + get_map.add("StubBackend", get_backend_options.view()); + + // Call get_option to test the API + auto get_status = get_option(get_map.entries()); + ASSERT_EQ(get_status, Error::Ok); + + ASSERT_TRUE( + std::get(get_map.entries()[0].options[0].value) == + expected_num_threads); + + // Verify that the backend's get_option method was called correctly + ASSERT_TRUE(stub_backend->get_option_called); + ASSERT_EQ(stub_backend->get_option_call_count, 1); + ASSERT_EQ(stub_backend->last_get_option_size, 1); + ASSERT_TRUE(stub_backend->found_expected_key); +} } // namespace runtime } // namespace executorch