diff --git a/runtime/executor/method.cpp b/runtime/executor/method.cpp index fe44f49e7e8..4f1b6dd4b26 100644 --- a/runtime/executor/method.cpp +++ b/runtime/executor/method.cpp @@ -14,6 +14,7 @@ #include #include +#include #include #include #include @@ -1512,6 +1513,26 @@ Error Method::experimental_step() { return step(); } +Error Method::update(executorch::runtime::ArrayRef backend_option) { + for (const auto& entry : backend_option) { + const char* backend_name = entry.backend_name; + auto backend_options = entry.options; + + auto backend_class = get_backend_class(backend_name); + if (!backend_class) { + return Error::NotFound; + } + + BackendUpdateContext backend_update_context; + auto update_result = + backend_class->update(backend_update_context, backend_options); + if (update_result != Error::Ok) { + return update_result; + } + } + return Error::Ok; +} + Error Method::execute() { internal::event_tracer_create_event_block(event_tracer_, "Execute"); EventTracerEntry event_tracer_entry = diff --git a/runtime/executor/method.h b/runtime/executor/method.h index 99a6aea439f..4564615be11 100644 --- a/runtime/executor/method.h +++ b/runtime/executor/method.h @@ -14,6 +14,7 @@ #pragma GCC diagnostic ignored "-Wdeprecated-declarations" #endif +#include #include #include #include @@ -240,6 +241,14 @@ class Method final { /// DEPRECATED: Use `reset_execution()` instead. ET_DEPRECATED ET_NODISCARD Error experimental_reset_execution(); + /** + * EXPERIMENTAL: Update backend options, which will be dispatched to different backends. + * + * @retval Error::Ok step succeeded + * @retval non-Ok Method update fails + */ + ET_EXPERIMENTAL ET_NODISCARD Error update(executorch::runtime::ArrayRef backend_option); + /** * Returns the MethodMeta that corresponds to the calling Method. */ diff --git a/runtime/executor/targets.bzl b/runtime/executor/targets.bzl index 649b2c13cc1..e7a87dcf661 100644 --- a/runtime/executor/targets.bzl +++ b/runtime/executor/targets.bzl @@ -108,6 +108,7 @@ def define_common_targets(): ":memory_manager", ":pte_data_map" + aten_suffix, "//executorch/runtime/backend:interface" + aten_suffix, + "//executorch/runtime/backend:backend_options_map" + aten_suffix, "//executorch/runtime/core:core", "//executorch/runtime/core:named_data_map" + aten_suffix, "//executorch/runtime/core:evalue" + aten_suffix, diff --git a/runtime/executor/test/method_update_test.cpp b/runtime/executor/test/method_update_test.cpp new file mode 100644 index 00000000000..11c6f281953 --- /dev/null +++ b/runtime/executor/test/method_update_test.cpp @@ -0,0 +1,187 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + + #include + #include + + #include + #include + #include + #include + #include + #include + #include + #include + #include + #include + #include + #include + #include + #include + + + using namespace ::testing; + using executorch::aten::ArrayRef; + using executorch::runtime::Error; + using executorch::runtime::EValue; + using executorch::runtime::Method; + using executorch::runtime::Program; + using executorch::runtime::Result; + using executorch::runtime::testing::ManagedMemoryManager; + using torch::executor::util::FileDataLoader; + using executorch::runtime::BackendExecutionContext; +using executorch::runtime::BackendInitContext; +using executorch::runtime::BackendInterface; +using executorch::runtime::BackendUpdateContext; +using executorch::runtime::BackendOption; +using executorch::runtime::BackendOptions; +using executorch::runtime::BackendOptionsMap; +using executorch::runtime::BoolKey; +using executorch::runtime::IntKey; +using executorch::runtime::Entry; +using executorch::runtime::CompileSpec; +using executorch::runtime::DataLoader; +using executorch::runtime::DelegateHandle; +using executorch::runtime::FreeableBuffer; + + constexpr size_t kDefaultNonConstMemBytes = 32 * 1024U; + constexpr size_t kDefaultRuntimeMemBytes = 32 * 1024U; + +/** + * A backend class whose methods can be overridden individually. + */ +class StubBackend final : public BackendInterface { + public: + + // Default name that this backend is registered as. + static constexpr char kName[] = "StubBackend"; + + bool is_available() const override { + return true; + } + + Result init( + BackendInitContext& context, + FreeableBuffer* processed, + ArrayRef compile_specs) const override { + return nullptr; + } + + Error execute( + BackendExecutionContext& context, + DelegateHandle* handle, + EValue** args) const override { + return Error::Ok; + } + + int num_threads() const { + return num_threads_; + } + + Error update( + BackendUpdateContext& context, + const executorch::runtime::ArrayRef& backend_options) const override { + int success_update = 0; + for (const auto& backend_option : backend_options) { + if (strcmp(backend_option.key, "NumberOfThreads") == 0) { + if (std::holds_alternative(backend_option.value)) { + num_threads_ = std::get(backend_option.value); + success_update++; + } + } + } + if (success_update == backend_options.size()) { + return Error::Ok; + } + return Error::InvalidArgument; + } + + /** + * Registers the singleton instance if not already registered. + * + * Note that this can be used to install the stub as the implementation for + * any export-time backend by passing in the right name, as long as no other + * backend with that name has been registered yet. + */ + static Error register_singleton(const char* name = kName) { + if (!registered_) { + registered_ = true; + return executorch::runtime::register_backend({name, &singleton_}); + } + return Error::Ok; + } + + /** + * Returns the instance that was added to the backend registry. + */ + static StubBackend& singleton() { + return singleton_; + } + + private: + static bool registered_; + static StubBackend singleton_; + mutable int num_threads_ = 1; + }; + + bool StubBackend::registered_ = false; + StubBackend StubBackend::singleton_; + + class MethodUpdateTest : public ::testing::Test { + protected: + void load_program() { + // Since these tests cause ET_LOG to be called, the PAL must be initialized + // first. + executorch::runtime::runtime_init(); + + // Create a loader for the serialized program. + ASSERT_EQ(StubBackend::register_singleton(), Error::Ok); + + auto loader_res = FileDataLoader::from(std::getenv("ET_MODULE_ADD_MUL_DELEGATED_PATH")); + ASSERT_EQ(loader_res.error(), Error::Ok); + loader_ = std::make_unique(std::move(loader_res.get())); + + // Use it to load the program. + auto program_res = Program::load(loader_.get()); + ASSERT_EQ(program_res.error(), Error::Ok); + program_ = std::make_unique(std::move(program_res.get())); + } + + void SetUp() override { + executorch::runtime::runtime_init(); + + load_program(); + } + + private: + std::unique_ptr loader_; + + protected: + std::unique_ptr program_; + }; + + TEST_F(MethodUpdateTest, MoveTest) { + BackendInterface* backend = + executorch::runtime::get_backend_class(StubBackend::kName); + ASSERT_EQ(backend, &StubBackend::singleton()); + + ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes); + Result method = program_->load_method("forward", &mmm.get()); + // Check that the default number of threads is 1. + ASSERT_EQ(StubBackend::singleton().num_threads(), 1); + ASSERT_EQ(method.error(), Error::Ok); + + BackendOptionsMap<3> map; + BackendOptions<1> backend_options; + int new_num_threads = 4; + backend_options.set_option(IntKey("NumberOfThreads"), new_num_threads); + map.add("StubBackend", backend_options.view()); + Error update_result = method->update(map.entries()); + ASSERT_EQ(update_result, Error::Ok); + ASSERT_EQ(StubBackend::singleton().num_threads(), new_num_threads); +} diff --git a/runtime/executor/test/targets.bzl b/runtime/executor/test/targets.bzl index 39ff0668d5d..b075e5b6b62 100644 --- a/runtime/executor/test/targets.bzl +++ b/runtime/executor/test/targets.bzl @@ -170,6 +170,23 @@ def define_common_targets(is_fbcode = False): env = modules_env, ) + + runtime.cxx_test( + name = "method_update_test", + srcs = [ + "method_update_test.cpp", + ], + deps = [ + ":managed_memory_manager", + "//executorch/runtime/backend:interface", + "//executorch/runtime/executor:program", + "//executorch/extension/data_loader:buffer_data_loader", + "//executorch/extension/data_loader:file_data_loader", + ], + env = { + "ET_MODULE_ADD_MUL_DELEGATED_PATH": "$(location fbcode//executorch/test/models:exported_delegated_add_mul[ModuleAddMul.pte])", + }, ) + runtime.cxx_test( name = "program_test", srcs = [ diff --git a/test/models/targets.bzl b/test/models/targets.bzl index 391ce230ab8..d3fa3230468 100644 --- a/test/models/targets.bzl +++ b/test/models/targets.bzl @@ -248,3 +248,21 @@ def define_common_targets(): "//executorch/test/...", ], ) + + runtime.genrule( + name = "exported_executor_backend_program_linear", + cmd = "$(exe :export_delegated_program)" + + " --modules ModuleLinear" + + " --backend_id ExecutorBackend" + + " --outdir $OUT", + + outs = { + "ExcuTorchBackendLinear.pte": ["ExcuTorchBackendLinear.pte"], + }, + default_outs = ["."], + visibility = [ + "//executorch/runtime/executor/test/...", + "//executorch/extension/flat_tensor/test/...", + "//executorch/test/...", + ], + )