Skip to content

[5/N] Add update in method #11463

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: gh/cccclai/25/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions runtime/executor/method.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <cstdint>
#include <cstdio>

#include <executorch/runtime/backend/backend_options_map.h>
#include <executorch/runtime/backend/interface.h>
#include <executorch/runtime/core/event_tracer_hooks.h>
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
Expand Down Expand Up @@ -1512,6 +1513,26 @@ Error Method::experimental_step() {
return step();
}

Error Method::update(executorch::runtime::ArrayRef<executorch::runtime::Entry> backend_option) {
for (const auto& entry : backend_option) {
const char* backend_name = entry.backend_name;
auto backend_options = entry.options;

auto backend_class = get_backend_class(backend_name);
if (!backend_class) {
return Error::NotFound;
}

BackendUpdateContext backend_update_context;
auto update_result =
backend_class->update(backend_update_context, backend_options);
if (update_result != Error::Ok) {
return update_result;
}
}
return Error::Ok;
}

Error Method::execute() {
internal::event_tracer_create_event_block(event_tracer_, "Execute");
EventTracerEntry event_tracer_entry =
Expand Down
10 changes: 10 additions & 0 deletions runtime/executor/method.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
#endif

#include <executorch/runtime/backend/backend_options_map.h>
#include <executorch/runtime/core/evalue.h>
#include <executorch/runtime/core/event_tracer.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
Expand Down Expand Up @@ -240,6 +241,15 @@ class Method final {
/// DEPRECATED: Use `reset_execution()` instead.
ET_DEPRECATED ET_NODISCARD Error experimental_reset_execution();

/**
* EXPERIMENTAL: Advances/executes a single instruction in the method.
*
* @retval Error::Ok step succeeded
* @retval non-Ok step failed
* @retval Error::EndOfMethod method finished executing successfully
*/
ET_EXPERIMENTAL ET_NODISCARD Error update(executorch::runtime::ArrayRef<executorch::runtime::Entry> backend_option);

/**
* Returns the MethodMeta that corresponds to the calling Method.
*/
Expand Down
1 change: 1 addition & 0 deletions runtime/executor/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def define_common_targets():
":memory_manager",
":pte_data_map" + aten_suffix,
"//executorch/runtime/backend:interface" + aten_suffix,
"//executorch/runtime/backend:backend_options_map" + aten_suffix,
"//executorch/runtime/core:core",
"//executorch/runtime/core:named_data_map" + aten_suffix,
"//executorch/runtime/core:evalue" + aten_suffix,
Expand Down
188 changes: 188 additions & 0 deletions runtime/executor/test/method_update_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <cstdlib>
#include <filesystem>

#include <executorch/extension/data_loader/file_data_loader.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/executor/method.h>
#include <executorch/runtime/executor/program.h>
#include <executorch/runtime/executor/test/managed_memory_manager.h>
#include <executorch/runtime/platform/runtime.h>
#include <executorch/test/utils/DeathTest.h>
#include <gtest/gtest.h>
#include <executorch/runtime/backend/interface.h>
#include <executorch/runtime/backend/backend_update_context.h>
#include <executorch/runtime/backend/backend_options.h>
#include <executorch/runtime/backend/backend_options_map.h>
#include <executorch/runtime/core/error.h>
#include <executorch/runtime/core/result.h>


using namespace ::testing;
using executorch::aten::ArrayRef;
using executorch::runtime::Error;
using executorch::runtime::EValue;
using executorch::runtime::Method;
using executorch::runtime::Program;
using executorch::runtime::Result;
using executorch::runtime::testing::ManagedMemoryManager;
using torch::executor::util::FileDataLoader;
using executorch::runtime::BackendExecutionContext;
using executorch::runtime::BackendInitContext;
using executorch::runtime::BackendInterface;
using executorch::runtime::BackendUpdateContext;
using executorch::runtime::BackendOption;
using executorch::runtime::BackendOptions;
using executorch::runtime::BackendOptionsMap;
using executorch::runtime::BoolKey;
using executorch::runtime::IntKey;
using executorch::runtime::Entry;
using executorch::runtime::OptionType;
using executorch::runtime::CompileSpec;
using executorch::runtime::DataLoader;
using executorch::runtime::DelegateHandle;
using executorch::runtime::FreeableBuffer;

constexpr size_t kDefaultNonConstMemBytes = 32 * 1024U;
constexpr size_t kDefaultRuntimeMemBytes = 32 * 1024U;

/**
* A backend class whose methods can be overridden individually.
*/
class StubBackend final : public BackendInterface {
public:

// Default name that this backend is registered as.
static constexpr char kName[] = "StubBackend";

bool is_available() const override {
return true;
}

Result<DelegateHandle*> init(
BackendInitContext& context,
FreeableBuffer* processed,
ArrayRef<CompileSpec> compile_specs) const override {
return nullptr;
}

Error execute(
BackendExecutionContext& context,
DelegateHandle* handle,
EValue** args) const override {
return Error::Ok;
}

int num_threads() const {
return num_threads_;
}

Error update(
BackendUpdateContext& context,
const executorch::runtime::ArrayRef<BackendOption>& backend_options) const override {
int success_update = 0;
for (const auto& backend_option : backend_options) {
if (strcmp(backend_option.key, "NumberOfThreads") == 0) {
if (backend_option.type == OptionType::INT) {
num_threads_ = backend_option.value.int_value;
success_update++;
}
}
}
if (success_update == backend_options.size()) {
return Error::Ok;
}
return Error::InvalidArgument;
}

/**
* Registers the singleton instance if not already registered.
*
* Note that this can be used to install the stub as the implementation for
* any export-time backend by passing in the right name, as long as no other
* backend with that name has been registered yet.
*/
static Error register_singleton(const char* name = kName) {
if (!registered_) {
registered_ = true;
return executorch::runtime::register_backend({name, &singleton_});
}
return Error::Ok;
}

/**
* Returns the instance that was added to the backend registry.
*/
static StubBackend& singleton() {
return singleton_;
}

private:
static bool registered_;
static StubBackend singleton_;
mutable int num_threads_ = 1;
};

bool StubBackend::registered_ = false;
StubBackend StubBackend::singleton_;

class MethodUpdateTest : public ::testing::Test {
protected:
void load_program() {
// Since these tests cause ET_LOG to be called, the PAL must be initialized
// first.
executorch::runtime::runtime_init();

// Create a loader for the serialized program.
ASSERT_EQ(StubBackend::register_singleton(), Error::Ok);

auto loader_res = FileDataLoader::from(std::getenv("ET_MODULE_ADD_MUL_DELEGATED_PATH"));
ASSERT_EQ(loader_res.error(), Error::Ok);
loader_ = std::make_unique<FileDataLoader>(std::move(loader_res.get()));

// Use it to load the program.
auto program_res = Program::load(loader_.get());
ASSERT_EQ(program_res.error(), Error::Ok);
program_ = std::make_unique<Program>(std::move(program_res.get()));
}

void SetUp() override {
executorch::runtime::runtime_init();

load_program();
}

private:
std::unique_ptr<FileDataLoader> loader_;

protected:
std::unique_ptr<Program> program_;
};

TEST_F(MethodUpdateTest, MoveTest) {
BackendInterface* backend =
executorch::runtime::get_backend_class(StubBackend::kName);
ASSERT_EQ(backend, &StubBackend::singleton());

ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
Result<Method> method = program_->load_method("forward", &mmm.get());
// Check that the default number of threads is 1.
ASSERT_EQ(StubBackend::singleton().num_threads(), 1);
ASSERT_EQ(method.error(), Error::Ok);

BackendOptionsMap<3> map;
BackendOptions<1> backend_options;
int new_num_threads = 4;
backend_options.set_option(IntKey("NumberOfThreads"), new_num_threads);
map.add("StubBackend", backend_options.view());
Error update_result = method->update(map.entries());
ASSERT_EQ(update_result, Error::Ok);
ASSERT_EQ(StubBackend::singleton().num_threads(), new_num_threads);
}
17 changes: 17 additions & 0 deletions runtime/executor/test/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,23 @@ def define_common_targets(is_fbcode = False):
env = modules_env,
)


runtime.cxx_test(
name = "method_update_test",
srcs = [
"method_update_test.cpp",
],
deps = [
":managed_memory_manager",
"//executorch/runtime/backend:interface",
"//executorch/runtime/executor:program",
"//executorch/extension/data_loader:buffer_data_loader",
"//executorch/extension/data_loader:file_data_loader",
],
env = {
"ET_MODULE_ADD_MUL_DELEGATED_PATH": "$(location fbcode//executorch/test/models:exported_delegated_add_mul[ModuleAddMul.pte])",
}, )

runtime.cxx_test(
name = "program_test",
srcs = [
Expand Down
18 changes: 18 additions & 0 deletions test/models/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -248,3 +248,21 @@ def define_common_targets():
"//executorch/test/...",
],
)

runtime.genrule(
name = "exported_executor_backend_program_linear",
cmd = "$(exe :export_delegated_program)" +
" --modules ModuleLinear" +
" --backend_id ExecutorBackend" +
" --outdir $OUT",

outs = {
"ExcuTorchBackendLinear.pte": ["ExcuTorchBackendLinear.pte"],
},
default_outs = ["."],
visibility = [
"//executorch/runtime/executor/test/...",
"//executorch/extension/flat_tensor/test/...",
"//executorch/test/...",
],
)
Loading