Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions src/AMSlib/wf/action.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#pragma once

#include "AMSError.hpp"

namespace ams
{

struct EvalContext; // forward declaration

/// Base class for a single step in an AMS evaluation pipeline.
///
/// Actions mutate the shared EvalContext and may fail; failures are reported
/// via AMSStatus so pipelines can short-circuit cleanly.
class Action
{
public:
virtual ~Action() = default;

/// Execute this action on the evaluation context.
virtual AMSStatus run(EvalContext& ctx) = 0;

/// Human-readable name for debugging, logging, and tracing.
virtual const char* name() const noexcept = 0;
};

} // namespace ams
55 changes: 55 additions & 0 deletions src/AMSlib/wf/pipeline.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#pragma once

#include <memory>
#include <vector>

#include "AMSError.hpp" // AMSStatus
#include "wf/action.hpp" // Action

namespace ams
{

struct EvalContext;

/// A linear sequence of Actions executed in order.
///
/// If any Action fails, execution stops and the error is returned.
class Pipeline
{
public:
using ActionPtr = std::unique_ptr<Action>;

Pipeline() = default;

/// Append an Action to the pipeline.
Pipeline& add(ActionPtr Act)
{
Actions.emplace_back(std::move(Act));
return *this;
}

/// Execute all actions in order; stops on first error.
AMSStatus run(EvalContext& Ctx) const
{
for (const auto& Act : Actions) {
if (auto St = Act->run(Ctx); !St) {
return St;
}
}
return {};
}

/// Number of actions in the pipeline.
size_t size() const noexcept { return Actions.size(); }

/// True if there are no actions.
bool empty() const noexcept { return Actions.empty(); }

/// Remove all actions.
void clear() noexcept { Actions.clear(); }

private:
std::vector<ActionPtr> Actions;
};

} // namespace ams
6 changes: 6 additions & 0 deletions tests/AMSlib/wf/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,9 @@ ADD_WORKFLOW_UNIT_TEST(WORKFLOW::EVAL_CONTEXT eval_context)

BUILD_UNIT_TEST(pointwise pointwise_layout_transform.cpp)
ADD_WORKFLOW_UNIT_TEST(WORKFLOW::POINTWISE pointwise)

BUILD_UNIT_TEST(action action.cpp)
ADD_WORKFLOW_UNIT_TEST(WORKFLOW::ACTION action)

BUILD_UNIT_TEST(pipeline pipeline.cpp)
ADD_WORKFLOW_UNIT_TEST(WORKFLOW::PIPELINE pipeline)
50 changes: 50 additions & 0 deletions tests/AMSlib/wf/action.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#include "wf/action.hpp"

#include <catch2/catch_test_macros.hpp>
#include <memory>
#include <type_traits>

// Prefer the real EvalContext if available.
// If your project uses a different header name, adjust accordingly.
#include "wf/eval_context.hpp"

namespace ams
{

namespace
{
class TestAction final : public Action
{
public:
const char* name() const noexcept override { return "TestAction"; }

AMSStatus run(EvalContext& ctx) override
{
ctx.Threshold = ctx.Threshold.value_or(0.0f) + 1.0f;
return {};
}
};
} // namespace

CATCH_TEST_CASE("Action: abstract base class + virtual interface",
"[wf][action]")
{
CATCH_STATIC_REQUIRE(std::is_abstract_v<Action>);
CATCH_STATIC_REQUIRE(std::has_virtual_destructor_v<Action>);

EvalContext ctx{};
ctx.Threshold = 0.0f;
std::unique_ptr<Action> act = std::make_unique<TestAction>();

CATCH_REQUIRE(act->name() == std::string("TestAction"));

auto Err = act->run(ctx);
CATCH_REQUIRE(Err);
CATCH_REQUIRE(ctx.Threshold == 1.0f);

auto Err1 = act->run(ctx);
CATCH_REQUIRE(Err1);
CATCH_REQUIRE(ctx.Threshold == 2.0f);
}

} // namespace ams
75 changes: 75 additions & 0 deletions tests/AMSlib/wf/pipeline.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
#include "wf/pipeline.hpp"

#include <catch2/catch_test_macros.hpp>
#include <memory>
#include <string>

#include "wf/eval_context.hpp"

namespace ams
{

namespace
{

class IncAction final : public Action
{
public:
const char* name() const noexcept override { return "IncAction"; }

AMSStatus run(EvalContext& ctx) override
{
ctx.Threshold = ctx.Threshold.value_or(0.0f) + 1.0f;
return {};
}
};

class FailAction final : public Action
{
public:
const char* name() const noexcept override { return "FailAction"; }

AMSStatus run(EvalContext&) override
{
return AMS_MAKE_ERROR(AMSErrorType::Generic, "FailAction triggered");
}
};

} // namespace

CATCH_TEST_CASE("Pipeline runs actions in order and short-circuits on error",
"[wf][pipeline]")
{
EvalContext Ctx{};
Pipeline P;

// Two increments -> Threshold becomes 2, then FailAction stops the pipeline.
P.add(std::make_unique<IncAction>())
.add(std::make_unique<IncAction>())
.add(std::make_unique<FailAction>())
.add(std::make_unique<IncAction>()); // must NOT execute

Ctx.Threshold = 0.0f;

auto St = P.run(Ctx);
CATCH_REQUIRE_FALSE(St);
CATCH_REQUIRE(St.error().getType() == AMSErrorType::Generic);

// Only the first two IncAction should have run.
CATCH_REQUIRE(Ctx.Threshold.value() == 2.0f);
}

CATCH_TEST_CASE("Pipeline succeeds when all actions succeed", "[wf][pipeline]")
{
EvalContext Ctx{};
Pipeline P;

P.add(std::make_unique<IncAction>()).add(std::make_unique<IncAction>());

Ctx.Threshold = 0.0f;
auto St = P.run(Ctx);
CATCH_REQUIRE(St);
CATCH_REQUIRE(Ctx.Threshold.value() == 2.0f);
}

} // namespace ams
Loading