Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/AMSlib/include/AMSError.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ enum class AMSErrorType {
FileDoesNotExist, ///< Path to file or directory does not exist
TorchInternal, ///< An internal error that happens to the torch library
InvalidModel, ///< A torchscripted model that has not been serialized through AMS
InvalidShapes, ///< Some Data shape is not the proper|expected shape
};

/// \brief Strongly-typed error object used across AMS.
Expand Down
25 changes: 25 additions & 0 deletions src/AMSlib/wf/index_map.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#pragma once

#include <cstdint>
#include <string>
#include <vector>

namespace ams
{

/// Field-to-column mapping for layout transformations.
struct IndexMap {
struct FieldInfo {
std::string Name;

enum class Kind { Input, InOut, Output };
Kind EKind;

int64_t Offset; ///< Starting column in the concatenated tensor
int64_t Cols; ///< Number of columns this field covers
};

std::vector<FieldInfo> Fields;
};

} // namespace ams
31 changes: 11 additions & 20 deletions src/AMSlib/wf/layout_transform.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

#include <optional>

#include "AMSError.hpp"
#include "wf/index_map.hpp"
#include "wf/tensor_bundle.hpp"

namespace ams
Expand All @@ -24,26 +26,15 @@ class LayoutTransform
public:
virtual ~LayoutTransform() = default;

/// Pack the application-level Inputs and Inouts into a single tensor suitable
/// for feeding into the ML model.
virtual at::Tensor pack(const TensorBundle& Inputs,
const TensorBundle& Inouts) = 0;

/// Unpack the model's output (an IValue that may be a tensor or a tuple of
/// tensors) into:
/// - Outputs
/// - Inouts
/// - Uncertainties (optional)
///
/// Concrete layouts determine how the returned IValue maps back to domain
/// tensors. Only LayoutTransform knows the correct indexing and shapes.
virtual void unpack(const torch::jit::IValue& ModelOutput,
TensorBundle& Outputs,
TensorBundle& Inouts,
std::optional<at::Tensor>& Uncertainties) = 0;

/// Descriptive name used for debugging, logging, and introspection.
/// Must be implemented by all subclasses.
virtual AMSExpected<IndexMap> pack(const TensorBundle& Inputs,
const TensorBundle& InOuts,
at::Tensor& ModelInput) = 0;

virtual AMSStatus unpack(const torch::jit::IValue& ModelOutput,
TensorBundle& Outs,
TensorBundle& InOuts,
std::optional<at::Tensor>& Uncertainties) = 0;

virtual const char* name() const noexcept = 0;
};

Expand Down
167 changes: 167 additions & 0 deletions src/AMSlib/wf/pointwise_layout_transform.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
#pragma once

#include <ATen/ATen.h>
#include <torch/script.h>

#include <cstdint>
#include <optional>
#include <vector>

#include "wf/index_map.hpp"
#include "wf/layout_transform.hpp"
#include "wf/tensor_bundle.hpp"

namespace ams
{

/// PointwiseConcatTransform:
///
/// Converts Inputs + InOuts into a single matrix [N, SUM(K_i)] where:
/// - N = batch size (outer dim)
/// - K_i = flattened size of each tensor field except the batch dimension
///
/// Supports:
/// ✔ Scalar fields (shape [N])
/// ✔ Multi-channel fields (shape [N, K])
/// ✔ Arbitrary shapes [N, ...] → flattened to [N, M]
/// ✔ Prediction-only models
/// ✔ Uncertainty-aware models returning (pred, uncertainty)
///
/// Produces IndexMap for both pack() and unpack().
class PointwiseConcatTransform : public LayoutTransform
{
public:
const char* name() const noexcept override
{
return "PointwiseConcatTransform";
}

// ------------------------------------------------------------------
// PACK
// ------------------------------------------------------------------
AMSExpected<IndexMap> pack(const TensorBundle& Inputs,
const TensorBundle& InOuts,
at::Tensor& ModelInput) override
{
IndexMap map;
std::vector<at::Tensor> cols;
int total_cols{0};

if (auto st = process(
Inputs, IndexMap::FieldInfo::Kind::Input, map, cols, total_cols);
!st)
return tl::unexpected(st.error());
if (auto st = process(
InOuts, IndexMap::FieldInfo::Kind::InOut, map, cols, total_cols);
!st)
return tl::unexpected(st.error());

if (total_cols <= 0) {
return AMS_MAKE_ERROR(AMSErrorType::InvalidShapes,
fmt::format("PointwiseConcatTransform expected at "
"least a single dimension in pack"));
}
// Concatenate horizontally
ModelInput = at::cat(cols, /*dim=*/1);
return map;
}

// ------------------------------------------------------------------
// UNPACK
// ------------------------------------------------------------------
AMSStatus unpack(const torch::jit::IValue& ModelOutput,
TensorBundle& Outs,
TensorBundle& InOuts,
std::optional<at::Tensor>& Uncertainties) override
{
at::Tensor ModelOut;
at::Tensor Uncertainty;
bool has_uncertainty = false;

// --------------------------------------------
// Case 1: Single tensor prediction
// --------------------------------------------
if (ModelOutput.isTensor()) {
ModelOut = ModelOutput.toTensor();
}
// --------------------------------------------
// Case 2: Tuple(pred, uncertainty)
// --------------------------------------------
else if (ModelOutput.isTuple()) {
auto tup = ModelOutput.toTuple();
if (tup->elements().size() != 2)
return AMS_MAKE_ERROR(AMSErrorType::InvalidShapes,
"PointwiseConcatTransform: expected "
"tuple(pred,uncertainty).");

ModelOut = tup->elements()[0].toTensor();
Uncertainty = tup->elements()[1].toTensor();
has_uncertainty = true;
} else {
return AMS_MAKE_ERROR(AMSErrorType::InvalidShapes,
"PointwiseConcatTransform: ModelOutput must be "
"tensor or "
"tuple.");
}

// Uncertainties
if (has_uncertainty) {
Uncertainties = Uncertainty;
} else {
Uncertainties.reset();
}

if (ModelOut.size(1) != Outs.size() + InOuts.size())
return AMS_MAKE_ERROR(AMSErrorType::InvalidShapes,
"Expected the output size to match the Application "
"output dimensions");

int k = 0;
for (; k < Outs.size(); ++k) {
Outs[k].tensor =
ModelOut.narrow(/*dim=*/1, /*start=*/k, /*length=*/1).squeeze();
}

for (int i = 0; i < InOuts.size(); ++k, ++i) {
InOuts[i].tensor =
ModelOut.narrow(/*dim=*/1, /*start=*/k, /*length=*/1).squeeze();
}

return {};
}

private:
AMSStatus process(const TensorBundle& tb,
IndexMap::FieldInfo::Kind kind,
IndexMap& map,
std::vector<at::Tensor>& cols,
int& total_cols)
{
for (size_t i = 0; i < tb.size(); i++) {
const auto& item = tb.items[i];
at::Tensor t = item.tensor;

if (t.dim() < 1)
return AMS_MAKE_ERROR(AMSErrorType::InvalidShapes,
fmt::format("PointwiseConcatTransform for "
"field {} must have at least 1 "
"dimension",
item.name));
int64_t N = t.size(0);

// Flatten everything except outer dimension.
at::Tensor flat = t.reshape({N, -1});
int64_t M = flat.size(1);

int64_t offset = total_cols;
total_cols += M;

map.Fields.push_back({item.name, kind, offset, M});

cols.push_back(flat);
}
return {};
}
};

} // namespace ams
5 changes: 3 additions & 2 deletions tests/AMSlib/wf/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ ADD_WORKFLOW_UNIT_TEST(WORKFLOW::EVALUATE_IN_OUTS evaluate_in_and_outs)
BUILD_UNIT_TEST(tensor_bundle tensor_bundle.cpp)
ADD_WORKFLOW_UNIT_TEST(WORKFLOW::TENSOR_BUNDLE tensor_bundle)

BUILD_UNIT_TEST(layout_transform layout_transform.cpp)
ADD_WORKFLOW_UNIT_TEST(WORKFLOW::LAYOUT_TRANSFORM layout_transform)
BUILD_UNIT_TEST(eval_context eval_context.cpp)
ADD_WORKFLOW_UNIT_TEST(WORKFLOW::EVAL_CONTEXT eval_context)

BUILD_UNIT_TEST(pointwise pointwise_layout_transform.cpp)
ADD_WORKFLOW_UNIT_TEST(WORKFLOW::POINTWISE pointwise)
108 changes: 0 additions & 108 deletions tests/AMSlib/wf/layout_transform.cpp

This file was deleted.

Loading
Loading