Skip to content

[tmva][sofie] Implement multi-operator fusion for efficient memory usage #18729

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions tmva/sofie/inc/TMVA/RModel.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ private:
std::vector<std::string> fInputTensorNames; // input tensor names using ONNX order

std::vector<std::unique_ptr<ROperator>> fOperators;
std::vector<std::unique_ptr<ROperator>> fConstantOperators;

std::vector<std::shared_ptr<RModel>> fSubGraphs; ///<! sub-graph models (transient)
RModel * fParentGraph = nullptr;
Expand Down Expand Up @@ -59,12 +60,13 @@ public:
bool CheckIfTensorAlreadyExist(std::string tensor_name);
void AddInputTensorInfo(std::string input_name, ETensorType type, std::vector<Dim> shape);
void AddInputTensorInfo(std::string input_name, ETensorType type, std::vector<size_t> shape);
void AddOperator(std::unique_ptr<ROperator> op, int order_execution = -1);
void AddOperatorReference(ROperator *op, int order_execution = -1)
void AddOperator(std::unique_ptr<ROperator> op, size_t order_execution = -1);
void AddOperatorReference(ROperator *op, size_t order_execution = -1)
{
std::unique_ptr<ROperator> tmp(op);
AddOperator(std::move(tmp), order_execution);
}
void AddConstantOperator(std::unique_ptr<ROperator> op);
void AddInitializedTensor(std::string tensor_name, ETensorType type, std::vector<std::size_t> shape,
std::shared_ptr<void> data);
void AddConstantTensor(std::string tensor_name, ETensorType type, std::vector<std::size_t> shape,
Expand Down Expand Up @@ -161,6 +163,7 @@ protected:
void GenerateIntermediateMemoryPool();
// Generate all session code
void GenerateSessionCode();
void CheckAndFuseOperators();

public:
const std::vector<std::string> &GetInputTensorNames() const { return fInputTensorNames; }
Expand Down
41 changes: 39 additions & 2 deletions tmva/sofie/inc/TMVA/ROperator.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define TMVA_SOFIE_ROPERATOR

#include <vector>
#include <set>
#include <memory>

#include "TMVA/SOFIE_common.hxx"
Expand All @@ -15,6 +16,28 @@ namespace SOFIE{

class RModel;

enum class OperatorKind {
GEMM = 0,
LAYERNORM = 1,
RELU = 2,
CONSTANT = 3,
CONSTANTOFSHAPE = 4,
UNDEFINED = 5
};

inline const char* toString(OperatorKind kind) {
switch (kind) {
case OperatorKind::GEMM: return "GEMM";
case OperatorKind::LAYERNORM: return "LAYERNORM";
case OperatorKind::RELU: return "RELU";
case OperatorKind::CONSTANT: return "CONSTANT";
case OperatorKind::CONSTANTOFSHAPE: return "CONSTANTOFSHAPE";
case OperatorKind::UNDEFINED: return "UNDEFINED";
default: return "UNKNOWN";
}
}
inline std::set<OperatorKind> FusableKinds = { OperatorKind::RELU, OperatorKind::LAYERNORM };

class ROperator{


Expand All @@ -32,13 +55,17 @@ public:
// generate session data members specific to operator
virtual std::string GenerateSessionMembersCode(std::string /*opName*/) { return ""; }
virtual std::string Header() { return "";}
virtual std::string GetFusableOutputTensorName() { return "";}
virtual void UpdateFusableTensorName(std::string){ return;};


//virtual void Forward_reference() = 0;
//virtual void Forward_blas() = 0;
virtual ~ROperator(){}

protected:

OperatorKind fKind = OperatorKind::UNDEFINED;
size_t fOpOrder = 0;
const std::string SP = " "; ///< space used to correctly indent the generated C++ code
bool fUseSession = false; ///< flag to identify if using the session class
bool fIsOutputConstant = false; ///< flag to identify if operator has a constant output (no need to generate code)
Expand All @@ -54,7 +81,17 @@ public:
std::span<const std::string_view> GetOpOutputTensors() const {
return fOutputTensorNames;
}


OperatorKind GetOpKind(){
return fKind;
}
void RegisterOperatorOrder(const size_t ord){
fOpOrder = ord;
}
size_t GetOpOrder(){
return fOpOrder;
}

};


Expand Down
12 changes: 9 additions & 3 deletions tmva/sofie/inc/TMVA/ROperator_Constant.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,13 @@ public:
fShape(shape),
fValues(values),
fAttrType(type)
{
{
fKind = OperatorKind::CONSTANT;
if (!fNX.empty()) {
// case of ConstantOfShape (since no inputs in case of Constant operator)
fIsConstantOfShape = true;
fKind = OperatorKind::CONSTANTOFSHAPE;
}
fInputTensorNames = { };
fOutputTensorNames = { };
}
Expand All @@ -50,9 +56,9 @@ public:
void Initialize(RModel& model) override {
//input must be a graph input, or already initialized intermediate tensor
size_t length = 1;

// constant of shape case
if (!fNX.empty()) {
// case of ConstantOfShape (since no inputs in case of Constant operator)
fIsConstantOfShape = true;
if (model.CheckIfTensorAlreadyExist(fNX) == false){
throw std::runtime_error("TMVA SOFIE ConstantOfShape Op Input Tensor is not found in model");
}
Expand Down
11 changes: 10 additions & 1 deletion tmva/sofie/inc/TMVA/ROperator_Gemm.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
fAttrAlpha(alpha), fAttrBeta(beta), fAttrTransA(transA), fAttrTransB(transB), fNA(UTILITY::Clean_name(nameA)),
fNB(UTILITY::Clean_name(nameB)), fNY(UTILITY::Clean_name(nameY))
{
fKind = OperatorKind::GEMM;
fActivation = activation;
fType = "float";
static_assert(std::is_same_v<T, float>,
Expand All @@ -61,9 +62,11 @@
fAttrAlpha(alpha), fAttrBeta(beta), fAttrTransA(transA), fAttrTransB(transB), fNA(UTILITY::Clean_name(nameA)),
fNB(UTILITY::Clean_name(nameB)), fNC(UTILITY::Clean_name(nameC)), fNY(UTILITY::Clean_name(nameY)), fActivation(activation)
{
fKind = OperatorKind::GEMM;
fActivation = activation;
fType = "float";

fInputTensorNames = { fNA, fNB, fNC };
fOutputTensorNames = { fNY };
}

Expand Down Expand Up @@ -383,7 +386,13 @@
}

std::vector<std::string> GetBlasRoutines() override { return { std::string("Gemm"), std::string("Gemv") }; }

std::string GetFusableOutputTensorName() override {
return fNY;
}

void UpdateFusableTensorName(std::string fusable_tensor_name){

Check warning on line 393 in tmva/sofie/inc/TMVA/ROperator_Gemm.hxx

View workflow job for this annotation

GitHub Actions / mac15 ARM64 CMAKE_CXX_STANDARD=23

'UpdateFusableTensorName' overrides a member function but is not marked 'override' [-Winconsistent-missing-override]

Check warning on line 393 in tmva/sofie/inc/TMVA/ROperator_Gemm.hxx

View workflow job for this annotation

GitHub Actions / mac13 ARM64 builtin_zlib=ON

'UpdateFusableTensorName' overrides a member function but is not marked 'override' [-Winconsistent-missing-override]

Check warning on line 393 in tmva/sofie/inc/TMVA/ROperator_Gemm.hxx

View workflow job for this annotation

GitHub Actions / alma9-clang clang CMAKE_C_COMPILER=clang, CMAKE_CXX_COMPILER=clang++

'UpdateFusableTensorName' overrides a member function but is not marked 'override' [-Winconsistent-missing-override]

Check warning on line 393 in tmva/sofie/inc/TMVA/ROperator_Gemm.hxx

View workflow job for this annotation

GitHub Actions / alma9-clang clang CMAKE_C_COMPILER=clang, CMAKE_CXX_COMPILER=clang++

'UpdateFusableTensorName' overrides a member function but is not marked 'override' [-Winconsistent-missing-override]

Check warning on line 393 in tmva/sofie/inc/TMVA/ROperator_Gemm.hxx

View workflow job for this annotation

GitHub Actions / alma9-clang clang CMAKE_C_COMPILER=clang, CMAKE_CXX_COMPILER=clang++

'UpdateFusableTensorName' overrides a member function but is not marked 'override' [-Winconsistent-missing-override]

Check warning on line 393 in tmva/sofie/inc/TMVA/ROperator_Gemm.hxx

View workflow job for this annotation

GitHub Actions / alma9-clang clang CMAKE_C_COMPILER=clang, CMAKE_CXX_COMPILER=clang++

'UpdateFusableTensorName' overrides a member function but is not marked 'override' [-Winconsistent-missing-override]

Check warning on line 393 in tmva/sofie/inc/TMVA/ROperator_Gemm.hxx

View workflow job for this annotation

GitHub Actions / mac-beta ARM64 CMAKE_CXX_STANDARD=23

'UpdateFusableTensorName' overrides a member function but is not marked 'override' [-Winconsistent-missing-override]

Check warning on line 393 in tmva/sofie/inc/TMVA/ROperator_Gemm.hxx

View workflow job for this annotation

GitHub Actions / mac14 X64 CMAKE_CXX_STANDARD=20

'UpdateFusableTensorName' overrides a member function but is not marked 'override' [-Winconsistent-missing-override]
fNY = UTILITY::Clean_name(fusable_tensor_name);
}
};


Expand Down
12 changes: 11 additions & 1 deletion tmva/sofie/inc/TMVA/ROperator_LayerNormalization.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@
: fAttrAxis(axis), fAttrEpsilon(epsilon), fAttrStashType(stashType), fNX(UTILITY::Clean_name(nameX)),
fNScale(UTILITY::Clean_name(nameScale)), fNB(UTILITY::Clean_name(nameB)),
fNY(UTILITY::Clean_name(nameY)), fNMean(UTILITY::Clean_name(nameMean)), fNInvStdDev(UTILITY::Clean_name(nameInvStdDev))
{
{
fKind = OperatorKind::LAYERNORM;
fInputTensorNames = { fNX, fNScale };
if (!fNB.empty()){
fInputTensorNames.emplace_back(fNB);
Expand Down Expand Up @@ -336,6 +337,15 @@
std::vector<std::string> GetBlasRoutines() override { return { std::string("Axpy") }; }

std::vector<std::string> GetStdLibs() override { return { std::string("cmath") }; }

std::string GetFusableOutputTensorName() override {
return fNY;
}

void UpdateFusableTensorName(std::string fusable_tensor_name){

Check warning on line 345 in tmva/sofie/inc/TMVA/ROperator_LayerNormalization.hxx

View workflow job for this annotation

GitHub Actions / mac15 ARM64 CMAKE_CXX_STANDARD=23

'UpdateFusableTensorName' overrides a member function but is not marked 'override' [-Winconsistent-missing-override]

Check warning on line 345 in tmva/sofie/inc/TMVA/ROperator_LayerNormalization.hxx

View workflow job for this annotation

GitHub Actions / mac13 ARM64 builtin_zlib=ON

'UpdateFusableTensorName' overrides a member function but is not marked 'override' [-Winconsistent-missing-override]

Check warning on line 345 in tmva/sofie/inc/TMVA/ROperator_LayerNormalization.hxx

View workflow job for this annotation

GitHub Actions / alma9-clang clang CMAKE_C_COMPILER=clang, CMAKE_CXX_COMPILER=clang++

'UpdateFusableTensorName' overrides a member function but is not marked 'override' [-Winconsistent-missing-override]

Check warning on line 345 in tmva/sofie/inc/TMVA/ROperator_LayerNormalization.hxx

View workflow job for this annotation

GitHub Actions / alma9-clang clang CMAKE_C_COMPILER=clang, CMAKE_CXX_COMPILER=clang++

'UpdateFusableTensorName' overrides a member function but is not marked 'override' [-Winconsistent-missing-override]

Check warning on line 345 in tmva/sofie/inc/TMVA/ROperator_LayerNormalization.hxx

View workflow job for this annotation

GitHub Actions / mac-beta ARM64 CMAKE_CXX_STANDARD=23

'UpdateFusableTensorName' overrides a member function but is not marked 'override' [-Winconsistent-missing-override]

Check warning on line 345 in tmva/sofie/inc/TMVA/ROperator_LayerNormalization.hxx

View workflow job for this annotation

GitHub Actions / mac14 X64 CMAKE_CXX_STANDARD=20

'UpdateFusableTensorName' overrides a member function but is not marked 'override' [-Winconsistent-missing-override]
fNX = UTILITY::Clean_name(fusable_tensor_name);
fNY = UTILITY::Clean_name(fusable_tensor_name);
}
};

} // namespace SOFIE
Expand Down
11 changes: 11 additions & 0 deletions tmva/sofie/inc/TMVA/ROperator_Relu.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
ROperator_Relu(){}
ROperator_Relu(std::string nameX, std::string nameY):
fNX(UTILITY::Clean_name(nameX)), fNY(UTILITY::Clean_name(nameY)){
fKind = OperatorKind::RELU;
fInputTensorNames = { fNX };
fOutputTensorNames = { fNY };
}
Expand Down Expand Up @@ -66,6 +67,16 @@
return out.str();
}


std::string GetFusableOutputTensorName() override {
return fNY;
}

void UpdateFusableTensorName(std::string fusable_tensor_name){

Check warning on line 75 in tmva/sofie/inc/TMVA/ROperator_Relu.hxx

View workflow job for this annotation

GitHub Actions / mac15 ARM64 CMAKE_CXX_STANDARD=23

'UpdateFusableTensorName' overrides a member function but is not marked 'override' [-Winconsistent-missing-override]

Check warning on line 75 in tmva/sofie/inc/TMVA/ROperator_Relu.hxx

View workflow job for this annotation

GitHub Actions / mac13 ARM64 builtin_zlib=ON

'UpdateFusableTensorName' overrides a member function but is not marked 'override' [-Winconsistent-missing-override]

Check warning on line 75 in tmva/sofie/inc/TMVA/ROperator_Relu.hxx

View workflow job for this annotation

GitHub Actions / alma9-clang clang CMAKE_C_COMPILER=clang, CMAKE_CXX_COMPILER=clang++

'UpdateFusableTensorName' overrides a member function but is not marked 'override' [-Winconsistent-missing-override]

Check warning on line 75 in tmva/sofie/inc/TMVA/ROperator_Relu.hxx

View workflow job for this annotation

GitHub Actions / alma9-clang clang CMAKE_C_COMPILER=clang, CMAKE_CXX_COMPILER=clang++

'UpdateFusableTensorName' overrides a member function but is not marked 'override' [-Winconsistent-missing-override]

Check warning on line 75 in tmva/sofie/inc/TMVA/ROperator_Relu.hxx

View workflow job for this annotation

GitHub Actions / alma9-clang clang CMAKE_C_COMPILER=clang, CMAKE_CXX_COMPILER=clang++

'UpdateFusableTensorName' overrides a member function but is not marked 'override' [-Winconsistent-missing-override]

Check warning on line 75 in tmva/sofie/inc/TMVA/ROperator_Relu.hxx

View workflow job for this annotation

GitHub Actions / alma9-clang clang CMAKE_C_COMPILER=clang, CMAKE_CXX_COMPILER=clang++

'UpdateFusableTensorName' overrides a member function but is not marked 'override' [-Winconsistent-missing-override]

Check warning on line 75 in tmva/sofie/inc/TMVA/ROperator_Relu.hxx

View workflow job for this annotation

GitHub Actions / mac-beta ARM64 CMAKE_CXX_STANDARD=23

'UpdateFusableTensorName' overrides a member function but is not marked 'override' [-Winconsistent-missing-override]

Check warning on line 75 in tmva/sofie/inc/TMVA/ROperator_Relu.hxx

View workflow job for this annotation

GitHub Actions / mac14 X64 CMAKE_CXX_STANDARD=20

'UpdateFusableTensorName' overrides a member function but is not marked 'override' [-Winconsistent-missing-override]
fNX = UTILITY::Clean_name(fusable_tensor_name);
fNY = UTILITY::Clean_name(fusable_tensor_name);
}

};

}//SOFIE
Expand Down
Loading
Loading