diff --git a/src/AMSlib/wf/action.hpp b/src/AMSlib/wf/action.hpp new file mode 100644 index 00000000..6f875119 --- /dev/null +++ b/src/AMSlib/wf/action.hpp @@ -0,0 +1,26 @@ +#pragma once + +#include "AMSError.hpp" + +namespace ams +{ + +struct EvalContext; // forward declaration + +/// Base class for a single step in an AMS evaluation pipeline. +/// +/// Actions mutate the shared EvalContext and may fail; failures are reported +/// via AMSStatus so pipelines can short-circuit cleanly. +class Action +{ +public: + virtual ~Action() = default; + + /// Execute this action on the evaluation context. + virtual AMSStatus run(EvalContext& ctx) = 0; + + /// Human-readable name for debugging, logging, and tracing. + virtual const char* name() const noexcept = 0; +}; + +} // namespace ams diff --git a/src/AMSlib/wf/pipeline.hpp b/src/AMSlib/wf/pipeline.hpp new file mode 100644 index 00000000..e412f0f0 --- /dev/null +++ b/src/AMSlib/wf/pipeline.hpp @@ -0,0 +1,55 @@ +#pragma once + +#include +#include + +#include "AMSError.hpp" // AMSStatus +#include "wf/action.hpp" // Action + +namespace ams +{ + +struct EvalContext; + +/// A linear sequence of Actions executed in order. +/// +/// If any Action fails, execution stops and the error is returned. +class Pipeline +{ +public: + using ActionPtr = std::unique_ptr; + + Pipeline() = default; + + /// Append an Action to the pipeline. + Pipeline& add(ActionPtr Act) + { + Actions.emplace_back(std::move(Act)); + return *this; + } + + /// Execute all actions in order; stops on first error. + AMSStatus run(EvalContext& Ctx) const + { + for (const auto& Act : Actions) { + if (auto St = Act->run(Ctx); !St) { + return St; + } + } + return {}; + } + + /// Number of actions in the pipeline. + size_t size() const noexcept { return Actions.size(); } + + /// True if there are no actions. + bool empty() const noexcept { return Actions.empty(); } + + /// Remove all actions. + void clear() noexcept { Actions.clear(); } + +private: + std::vector Actions; +}; + +} // namespace ams diff --git a/tests/AMSlib/wf/CMakeLists.txt b/tests/AMSlib/wf/CMakeLists.txt index 018b04b2..2c825b3c 100644 --- a/tests/AMSlib/wf/CMakeLists.txt +++ b/tests/AMSlib/wf/CMakeLists.txt @@ -56,3 +56,9 @@ ADD_WORKFLOW_UNIT_TEST(WORKFLOW::EVAL_CONTEXT eval_context) BUILD_UNIT_TEST(pointwise pointwise_layout_transform.cpp) ADD_WORKFLOW_UNIT_TEST(WORKFLOW::POINTWISE pointwise) + +BUILD_UNIT_TEST(action action.cpp) +ADD_WORKFLOW_UNIT_TEST(WORKFLOW::ACTION action) + +BUILD_UNIT_TEST(pipeline pipeline.cpp) +ADD_WORKFLOW_UNIT_TEST(WORKFLOW::PIPELINE pipeline) diff --git a/tests/AMSlib/wf/action.cpp b/tests/AMSlib/wf/action.cpp new file mode 100644 index 00000000..746cc18a --- /dev/null +++ b/tests/AMSlib/wf/action.cpp @@ -0,0 +1,50 @@ +#include "wf/action.hpp" + +#include +#include +#include + +// Prefer the real EvalContext if available. +// If your project uses a different header name, adjust accordingly. +#include "wf/eval_context.hpp" + +namespace ams +{ + +namespace +{ +class TestAction final : public Action +{ +public: + const char* name() const noexcept override { return "TestAction"; } + + AMSStatus run(EvalContext& ctx) override + { + ctx.Threshold = ctx.Threshold.value_or(0.0f) + 1.0f; + return {}; + } +}; +} // namespace + +CATCH_TEST_CASE("Action: abstract base class + virtual interface", + "[wf][action]") +{ + CATCH_STATIC_REQUIRE(std::is_abstract_v); + CATCH_STATIC_REQUIRE(std::has_virtual_destructor_v); + + EvalContext ctx{}; + ctx.Threshold = 0.0f; + std::unique_ptr act = std::make_unique(); + + CATCH_REQUIRE(act->name() == std::string("TestAction")); + + auto Err = act->run(ctx); + CATCH_REQUIRE(Err); + CATCH_REQUIRE(ctx.Threshold == 1.0f); + + auto Err1 = act->run(ctx); + CATCH_REQUIRE(Err1); + CATCH_REQUIRE(ctx.Threshold == 2.0f); +} + +} // namespace ams diff --git a/tests/AMSlib/wf/pipeline.cpp b/tests/AMSlib/wf/pipeline.cpp new file mode 100644 index 00000000..c02bf26c --- /dev/null +++ b/tests/AMSlib/wf/pipeline.cpp @@ -0,0 +1,75 @@ +#include "wf/pipeline.hpp" + +#include +#include +#include + +#include "wf/eval_context.hpp" + +namespace ams +{ + +namespace +{ + +class IncAction final : public Action +{ +public: + const char* name() const noexcept override { return "IncAction"; } + + AMSStatus run(EvalContext& ctx) override + { + ctx.Threshold = ctx.Threshold.value_or(0.0f) + 1.0f; + return {}; + } +}; + +class FailAction final : public Action +{ +public: + const char* name() const noexcept override { return "FailAction"; } + + AMSStatus run(EvalContext&) override + { + return AMS_MAKE_ERROR(AMSErrorType::Generic, "FailAction triggered"); + } +}; + +} // namespace + +CATCH_TEST_CASE("Pipeline runs actions in order and short-circuits on error", + "[wf][pipeline]") +{ + EvalContext Ctx{}; + Pipeline P; + + // Two increments -> Threshold becomes 2, then FailAction stops the pipeline. + P.add(std::make_unique()) + .add(std::make_unique()) + .add(std::make_unique()) + .add(std::make_unique()); // must NOT execute + + Ctx.Threshold = 0.0f; + + auto St = P.run(Ctx); + CATCH_REQUIRE_FALSE(St); + CATCH_REQUIRE(St.error().getType() == AMSErrorType::Generic); + + // Only the first two IncAction should have run. + CATCH_REQUIRE(Ctx.Threshold.value() == 2.0f); +} + +CATCH_TEST_CASE("Pipeline succeeds when all actions succeed", "[wf][pipeline]") +{ + EvalContext Ctx{}; + Pipeline P; + + P.add(std::make_unique()).add(std::make_unique()); + + Ctx.Threshold = 0.0f; + auto St = P.run(Ctx); + CATCH_REQUIRE(St); + CATCH_REQUIRE(Ctx.Threshold.value() == 2.0f); +} + +} // namespace ams