Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions src/AMSlib/wf/policy.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#pragma once

#include "wf/pipeline.hpp"

namespace ams
{

namespace ml
{
class InferenceModel;
}

class LayoutTransform;

/// Policies are factories that construct Pipelines.
///
/// A Policy encodes *what* should happen (control flow, fallback strategy),
/// while the Pipeline and Actions encode *how* it happens.
class Policy
{
public:
virtual ~Policy() = default;

/// Construct a pipeline for the given model and layout. The, potentially
/// nullable, Model is a non-owning pointer.
///
/// The returned Pipeline is ready to run.
virtual Pipeline makePipeline(const ml::InferenceModel* Model,
LayoutTransform& Layout) const = 0;
virtual const char* name() const noexcept = 0;
};

} // namespace ams
3 changes: 3 additions & 0 deletions tests/AMSlib/wf/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,6 @@ ADD_WORKFLOW_UNIT_TEST(WORKFLOW::ACTION action)

BUILD_UNIT_TEST(pipeline pipeline.cpp)
ADD_WORKFLOW_UNIT_TEST(WORKFLOW::PIPELINE pipeline)

BUILD_UNIT_TEST(policy policy.cpp)
ADD_WORKFLOW_UNIT_TEST(WORKFLOW::POLICY policy)
129 changes: 129 additions & 0 deletions tests/AMSlib/wf/policy.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
#include "wf/policy.hpp"

#include <catch2/catch_test_macros.hpp>
#include <memory>
#include <string>
#include <type_traits>

#include "ml/Model.hpp"
#include "wf/action.hpp"
#include "wf/eval_context.hpp"
#include "wf/layout_transform.hpp"
#include "wf/pipeline.hpp"

namespace ams
{

namespace
{

class IncAction final : public Action
{
public:
const char* name() const noexcept override { return "IncAction"; }
AMSStatus run(EvalContext& Ctx) override
{
Ctx.Threshold = Ctx.Threshold.value_or(0.0f) + 1.0f;
return {};
}
};

class FailAction final : public Action
{
public:
const char* name() const noexcept override { return "FailAction"; }
AMSStatus run(EvalContext&) override
{
return AMS_MAKE_ERROR(AMSErrorType::Generic, "FailAction triggered");
}
};

class DummyLayout final : public LayoutTransform
{
public:
const char* name() const noexcept override { return "DummyLayout"; }

AMSExpected<IndexMap> pack(const TensorBundle&,
const TensorBundle&,
at::Tensor&) override
{
return IndexMap{};
}
AMSStatus unpack(const torch::jit::IValue&,
TensorBundle&,
TensorBundle&,
std::optional<at::Tensor>&) override
{
return {};
}
};

class DirectLikePolicy final : public Policy
{
public:
const char* name() const noexcept override { return "DirectLikePolicy"; }

Pipeline makePipeline(const ml::InferenceModel* /*Model*/,
LayoutTransform& /*Layout*/) const override
{
Pipeline P;
P.add(std::make_unique<IncAction>()).add(std::make_unique<IncAction>());
return P;
}
};

class FailingPolicy final : public Policy
{
public:
const char* name() const noexcept override { return "FailingPolicy"; }

Pipeline makePipeline(const ml::InferenceModel* /*Model*/,
LayoutTransform& /*Layout*/) const override
{
Pipeline P;
P.add(std::make_unique<IncAction>())
.add(std::make_unique<FailAction>())
.add(std::make_unique<IncAction>()); // must not run
return P;
}
};

} // namespace

CATCH_TEST_CASE("Policy is an abstract factory for Pipelines", "[wf][policy]")
{
CATCH_STATIC_REQUIRE(std::is_abstract_v<Policy>);
CATCH_STATIC_REQUIRE(std::has_virtual_destructor_v<Policy>);

DummyLayout L;
ml::InferenceModel* Model = nullptr;

DirectLikePolicy Pol;
CATCH_REQUIRE(std::string(Pol.name()) == "DirectLikePolicy");

EvalContext Ctx{};
auto P = Pol.makePipeline(Model, L);

auto St = P.run(Ctx);
CATCH_REQUIRE(St);
CATCH_REQUIRE(Ctx.Threshold == 2.0f);
}

CATCH_TEST_CASE("Policy-built pipeline short-circuits on Action failure",
"[wf][policy]")
{
DummyLayout L;
ml::InferenceModel* Model = nullptr;

FailingPolicy Pol;
EvalContext Ctx{};

auto P = Pol.makePipeline(Model, L);
auto St = P.run(Ctx);

CATCH_REQUIRE_FALSE(St);
CATCH_REQUIRE(St.error().getType() == AMSErrorType::Generic);
CATCH_REQUIRE(Ctx.Threshold == 1.0f);
}

} // namespace ams
Loading