diff --git a/src/AMSlib/wf/policy.hpp b/src/AMSlib/wf/policy.hpp new file mode 100644 index 00000000..7d020404 --- /dev/null +++ b/src/AMSlib/wf/policy.hpp @@ -0,0 +1,33 @@ +#pragma once + +#include "wf/pipeline.hpp" + +namespace ams +{ + +namespace ml +{ +class InferenceModel; +} + +class LayoutTransform; + +/// Policies are factories that construct Pipelines. +/// +/// A Policy encodes *what* should happen (control flow, fallback strategy), +/// while the Pipeline and Actions encode *how* it happens. +class Policy +{ +public: + virtual ~Policy() = default; + + /// Construct a pipeline for the given model and layout. The, potentially + /// nullable, Model is a non-owning pointer. + /// + /// The returned Pipeline is ready to run. + virtual Pipeline makePipeline(const ml::InferenceModel* Model, + LayoutTransform& Layout) const = 0; + virtual const char* name() const noexcept = 0; +}; + +} // namespace ams diff --git a/tests/AMSlib/wf/CMakeLists.txt b/tests/AMSlib/wf/CMakeLists.txt index 2c825b3c..4e3f98fb 100644 --- a/tests/AMSlib/wf/CMakeLists.txt +++ b/tests/AMSlib/wf/CMakeLists.txt @@ -62,3 +62,6 @@ ADD_WORKFLOW_UNIT_TEST(WORKFLOW::ACTION action) BUILD_UNIT_TEST(pipeline pipeline.cpp) ADD_WORKFLOW_UNIT_TEST(WORKFLOW::PIPELINE pipeline) + +BUILD_UNIT_TEST(policy policy.cpp) +ADD_WORKFLOW_UNIT_TEST(WORKFLOW::POLICY policy) diff --git a/tests/AMSlib/wf/policy.cpp b/tests/AMSlib/wf/policy.cpp new file mode 100644 index 00000000..853cba8d --- /dev/null +++ b/tests/AMSlib/wf/policy.cpp @@ -0,0 +1,129 @@ +#include "wf/policy.hpp" + +#include +#include +#include +#include + +#include "ml/Model.hpp" +#include "wf/action.hpp" +#include "wf/eval_context.hpp" +#include "wf/layout_transform.hpp" +#include "wf/pipeline.hpp" + +namespace ams +{ + +namespace +{ + +class IncAction final : public Action +{ +public: + const char* name() const noexcept override { return "IncAction"; } + AMSStatus run(EvalContext& Ctx) override + { + Ctx.Threshold = Ctx.Threshold.value_or(0.0f) + 1.0f; + return {}; + } +}; + +class FailAction final : public Action +{ +public: + const char* name() const noexcept override { return "FailAction"; } + AMSStatus run(EvalContext&) override + { + return AMS_MAKE_ERROR(AMSErrorType::Generic, "FailAction triggered"); + } +}; + +class DummyLayout final : public LayoutTransform +{ +public: + const char* name() const noexcept override { return "DummyLayout"; } + + AMSExpected pack(const TensorBundle&, + const TensorBundle&, + at::Tensor&) override + { + return IndexMap{}; + } + AMSStatus unpack(const torch::jit::IValue&, + TensorBundle&, + TensorBundle&, + std::optional&) override + { + return {}; + } +}; + +class DirectLikePolicy final : public Policy +{ +public: + const char* name() const noexcept override { return "DirectLikePolicy"; } + + Pipeline makePipeline(const ml::InferenceModel* /*Model*/, + LayoutTransform& /*Layout*/) const override + { + Pipeline P; + P.add(std::make_unique()).add(std::make_unique()); + return P; + } +}; + +class FailingPolicy final : public Policy +{ +public: + const char* name() const noexcept override { return "FailingPolicy"; } + + Pipeline makePipeline(const ml::InferenceModel* /*Model*/, + LayoutTransform& /*Layout*/) const override + { + Pipeline P; + P.add(std::make_unique()) + .add(std::make_unique()) + .add(std::make_unique()); // must not run + return P; + } +}; + +} // namespace + +CATCH_TEST_CASE("Policy is an abstract factory for Pipelines", "[wf][policy]") +{ + CATCH_STATIC_REQUIRE(std::is_abstract_v); + CATCH_STATIC_REQUIRE(std::has_virtual_destructor_v); + + DummyLayout L; + ml::InferenceModel* Model = nullptr; + + DirectLikePolicy Pol; + CATCH_REQUIRE(std::string(Pol.name()) == "DirectLikePolicy"); + + EvalContext Ctx{}; + auto P = Pol.makePipeline(Model, L); + + auto St = P.run(Ctx); + CATCH_REQUIRE(St); + CATCH_REQUIRE(Ctx.Threshold == 2.0f); +} + +CATCH_TEST_CASE("Policy-built pipeline short-circuits on Action failure", + "[wf][policy]") +{ + DummyLayout L; + ml::InferenceModel* Model = nullptr; + + FailingPolicy Pol; + EvalContext Ctx{}; + + auto P = Pol.makePipeline(Model, L); + auto St = P.run(Ctx); + + CATCH_REQUIRE_FALSE(St); + CATCH_REQUIRE(St.error().getType() == AMSErrorType::Generic); + CATCH_REQUIRE(Ctx.Threshold == 1.0f); +} + +} // namespace ams