diff --git a/tmva/sofie/inc/TMVA/RModel.hxx b/tmva/sofie/inc/TMVA/RModel.hxx index 80dcc9a9c45d5..8b521f2e6c03f 100644 --- a/tmva/sofie/inc/TMVA/RModel.hxx +++ b/tmva/sofie/inc/TMVA/RModel.hxx @@ -34,6 +34,7 @@ private: std::vector> fOperators; + std::vector> fConstantOperators; std::vector> fSubGraphs; /// shape); void AddInputTensorInfo(std::string input_name, ETensorType type, std::vector shape); - void AddOperator(std::unique_ptr op, int order_execution = -1); - void AddOperatorReference(ROperator *op, int order_execution = -1) + void AddOperator(std::unique_ptr op, size_t order_execution = -1); + void AddOperatorReference(ROperator *op, size_t order_execution = -1) { std::unique_ptr tmp(op); AddOperator(std::move(tmp), order_execution); } + void AddConstantOperator(std::unique_ptr op); void AddInitializedTensor(std::string tensor_name, ETensorType type, std::vector shape, std::shared_ptr data); void AddConstantTensor(std::string tensor_name, ETensorType type, std::vector shape, @@ -159,11 +161,15 @@ public: std::string GenerateInferSignature(bool isdecl = true); // calculate total intermediate memory and position intermediate tensor addresses - std::string AllocateIntermediateMemory(std::span op_output_tensors); - void CheckAndFlushIntermediateMemory(std::span op_output_tensors, const size_t& op_idx); + std::string AllocateIntermediateMemory(std::span op_output_tensors, std::set& allocated_tensors); + void CheckAndFlushIntermediateMemory(std::span op_output_tensors, const size_t& op_idx); void SetOptimizationLevel(const OptimizationLevel &optim_level) { fOptimizationLevel = optim_level; } + void RemoveIntermediateTensor(const std::string& tensor_name){ + fIntermediateTensorInfos.erase(tensor_name); + } + protected: // internal functions // generate code for the initialized tensors @@ -180,6 +186,7 @@ protected: void GenerateIntermediateMemoryPool(); // Generate all session code void GenerateSessionCode(); + void CheckAndFuseOperators(); public: const std::vector & GetInputTensorNames() const { return fInputTensorNames; } diff --git a/tmva/sofie/inc/TMVA/ROperator.hxx b/tmva/sofie/inc/TMVA/ROperator.hxx index f0afd9c4374c1..1c166491bf8b4 100644 --- a/tmva/sofie/inc/TMVA/ROperator.hxx +++ b/tmva/sofie/inc/TMVA/ROperator.hxx @@ -2,6 +2,7 @@ #define TMVA_SOFIE_ROPERATOR #include +#include #include #include "TMVA/SOFIE_common.hxx" @@ -15,6 +16,32 @@ namespace SOFIE{ class RModel; +enum class OperatorKind { + GEMM = 0, + LAYERNORM = 1, + RELU = 2, + CONSTANT = 3, + CONSTANTOFSHAPE = 4, + UNDEFINED = 5, + CONV=6, + BATCHNORM=7 +}; + +inline const char* toString(OperatorKind kind) { + switch (kind) { + case OperatorKind::GEMM: return "GEMM"; + case OperatorKind::LAYERNORM: return "LAYERNORM"; + case OperatorKind::RELU: return "RELU"; + case OperatorKind::CONSTANT: return "CONSTANT"; + case OperatorKind::CONSTANTOFSHAPE: return "CONSTANTOFSHAPE"; + case OperatorKind::BATCHNORM: return "batchnorm"; + case OperatorKind::CONV: return "conv"; + case OperatorKind::UNDEFINED: return "UNDEFINED"; + default: return "UNKNOWN"; + } +} +inline std::set FusableKinds = { OperatorKind::RELU, OperatorKind::LAYERNORM, OperatorKind::BATCHNORM}; + class ROperator{ @@ -32,30 +59,44 @@ public: // generate session data members specific to operator virtual std::string GenerateSessionMembersCode(std::string /*opName*/) { return ""; } virtual std::string Header() { return "";} + virtual std::string GetFusableOutputTensorName() { return "";} + virtual void UpdateFusableTensorName(std::string, const std::function& removal_func){ return;}; + //virtual void Forward_reference() = 0; //virtual void Forward_blas() = 0; virtual ~ROperator(){} protected: - + OperatorKind fKind = OperatorKind::UNDEFINED; + size_t fOpOrder = 0; const std::string SP = " "; ///< space used to correctly indent the generated C++ code bool fUseSession = false; ///< flag to identify if using the session class bool fIsOutputConstant = false; ///< flag to identify if operator has a constant output (no need to generate code) bool fIsOutputParamShape = false; ///< flag to identify of the output represents a parametric shape (can be knwon at compile time) - mutable std::vector fInputTensorNames; - mutable std::vector fOutputTensorNames; + mutable std::vector fInputTensorNames; + mutable std::vector fOutputTensorNames; public: - std::span GetOpInputTensors() const { + std::span GetOpInputTensors() const { return fInputTensorNames; } - std::span GetOpOutputTensors() const { + std::span GetOpOutputTensors() const { return fOutputTensorNames; } + OperatorKind GetOpKind(){ + return fKind; + } + void RegisterOperatorOrder(const size_t ord){ + fOpOrder = ord; + } + size_t GetOpOrder(){ + return fOpOrder; + } + }; diff --git a/tmva/sofie/inc/TMVA/ROperator_BasicNary.hxx b/tmva/sofie/inc/TMVA/ROperator_BasicNary.hxx index bcc0e52a40ca3..22cd0edcf75e4 100644 --- a/tmva/sofie/inc/TMVA/ROperator_BasicNary.hxx +++ b/tmva/sofie/inc/TMVA/ROperator_BasicNary.hxx @@ -103,7 +103,7 @@ public: fInputTensorNames.resize(fNInputs.size()); std::transform(fNInputs.begin(), fNInputs.end(), fInputTensorNames.begin(), - [](const std::string& s) -> std::string_view { return s; }); + [](const std::string& s) -> std::string { return s; }); fOutputTensorNames = { fNY }; } diff --git a/tmva/sofie/inc/TMVA/ROperator_BatchNormalization.hxx b/tmva/sofie/inc/TMVA/ROperator_BatchNormalization.hxx index 16fc3d6c07ba5..a565ea5d314c3 100644 --- a/tmva/sofie/inc/TMVA/ROperator_BatchNormalization.hxx +++ b/tmva/sofie/inc/TMVA/ROperator_BatchNormalization.hxx @@ -53,6 +53,7 @@ public: fNB(UTILITY::Clean_name(nameB)), fNMean(UTILITY::Clean_name(nameMean)), fNVar(UTILITY::Clean_name(nameVar)), fNY(UTILITY::Clean_name(nameY)), fActivation(activation) { + fKind = OperatorKind::BATCHNORM; fInputTensorNames = { fNX }; fOutputTensorNames = { fNY }; @@ -233,6 +234,19 @@ public: } std::vector GetBlasRoutines() override { return { std::string("Copy"), std::string("Axpy") }; } + std::string GetFusableOutputTensorName() override { + return fNY; + } + + void UpdateFusableTensorName(std::string fusable_tensor_name, const std::function& removal_func){ + removal_func(fNX); + removal_func(fNY); + fNX = fusable_tensor_name; + fNY = fusable_tensor_name; + fInputTensorNames[0] = fNX; + fOutputTensorNames[0] = fNY; + } + }; }//SOFIE diff --git a/tmva/sofie/inc/TMVA/ROperator_Concat.hxx b/tmva/sofie/inc/TMVA/ROperator_Concat.hxx index ad855341dfc17..1161eaf5e8e0f 100644 --- a/tmva/sofie/inc/TMVA/ROperator_Concat.hxx +++ b/tmva/sofie/inc/TMVA/ROperator_Concat.hxx @@ -37,7 +37,7 @@ fInputTensorNames.resize(fInputs.size()); std::transform(fInputs.begin(), fInputs.end(), fInputTensorNames.begin(), - [](const std::string& s) -> std::string_view { return s; }); + [](const std::string& s) -> std::string { return s; }); fOutputTensorNames = { fOutput }; } diff --git a/tmva/sofie/inc/TMVA/ROperator_Constant.hxx b/tmva/sofie/inc/TMVA/ROperator_Constant.hxx index 1cf5d13f5cd6f..736adac1c3526 100644 --- a/tmva/sofie/inc/TMVA/ROperator_Constant.hxx +++ b/tmva/sofie/inc/TMVA/ROperator_Constant.hxx @@ -36,7 +36,13 @@ public: fShape(shape), fValues(values), fAttrType(type) - { + { + fKind = OperatorKind::CONSTANT; + if (!fNX.empty()) { + // case of ConstantOfShape (since no inputs in case of Constant operator) + fIsConstantOfShape = true; + fKind = OperatorKind::CONSTANTOFSHAPE; + } fInputTensorNames = { }; fOutputTensorNames = { }; } @@ -53,10 +59,9 @@ public: void Initialize(RModel& model) override { //input must be a graph input, or already initialized intermediate tensor size_t length = 1; - /// ConstantOfShape------------- + + /// ------- ConstantOfShape --------- if (!fNX.empty()) { - // case of ConstantOfShape (since no inputs in case of Constant operator) - fIsConstantOfShape = true; if (model.CheckIfTensorAlreadyExist(fNX) == false){ throw std::runtime_error("TMVA SOFIE ConstantOfShape Op Input Tensor is not found in model"); } diff --git a/tmva/sofie/inc/TMVA/ROperator_Conv.hxx b/tmva/sofie/inc/TMVA/ROperator_Conv.hxx index 6d5d54262036f..ee9a24c159ea2 100644 --- a/tmva/sofie/inc/TMVA/ROperator_Conv.hxx +++ b/tmva/sofie/inc/TMVA/ROperator_Conv.hxx @@ -58,7 +58,8 @@ public: fAttrPads(pads), fAttrStrides(strides), fNX(UTILITY::Clean_name(nameX)), fNW(UTILITY::Clean_name(nameW)), fNB(UTILITY::Clean_name(nameB)), fNY(UTILITY::Clean_name(nameY)) - { + { + fKind = OperatorKind::CONV; if(std::is_same::value) { fType = "float"; } else { @@ -77,6 +78,7 @@ public: fAttrPads(pads), fAttrStrides(strides), fNX(UTILITY::Clean_name(nameX)), fNW(UTILITY::Clean_name(nameW)), fNY(UTILITY::Clean_name(nameY)) { + fKind = OperatorKind::CONV; if(std::is_same::value) { fType = "float"; } else { @@ -569,6 +571,14 @@ public: /*! \brief Returns the blas routines needed to compile the generated code */ std::vector GetBlasRoutines() override { return { std::string("Gemm"), std::string("Axpy") }; } + std::string GetFusableOutputTensorName() override { + return fNY; + } + void UpdateFusableTensorName(std::string fusable_tensor_name, const std::function& removal_func) override { + removal_func(fNY); + fNY = fusable_tensor_name; + fOutputTensorNames[0] = fNY; + } }; } // namespace SOFIE diff --git a/tmva/sofie/inc/TMVA/ROperator_Einsum.hxx b/tmva/sofie/inc/TMVA/ROperator_Einsum.hxx index fbf6659058d36..78e138d6b5dee 100644 --- a/tmva/sofie/inc/TMVA/ROperator_Einsum.hxx +++ b/tmva/sofie/inc/TMVA/ROperator_Einsum.hxx @@ -51,7 +51,7 @@ public: fInputTensorNames.resize(fNInputs.size()); std::transform(fNInputs.begin(), fNInputs.end(), fInputTensorNames.begin(), - [](const std::string& s) -> std::string_view { return s; }); + [](const std::string& s) -> std::string { return s; }); fOutputTensorNames = { fNY }; } diff --git a/tmva/sofie/inc/TMVA/ROperator_Gemm.hxx b/tmva/sofie/inc/TMVA/ROperator_Gemm.hxx index d954720396151..00e3a9158d6d8 100644 --- a/tmva/sofie/inc/TMVA/ROperator_Gemm.hxx +++ b/tmva/sofie/inc/TMVA/ROperator_Gemm.hxx @@ -50,6 +50,7 @@ namespace SOFIE{ fAttrAlpha(alpha), fAttrBeta(beta), fAttrTransA(transA), fAttrTransB(transB), fNA(UTILITY::Clean_name(nameA)), fNB(UTILITY::Clean_name(nameB)), fNY(UTILITY::Clean_name(nameY)) { + fKind = OperatorKind::GEMM; fActivation = activation; fType = "float"; static_assert(std::is_same_v, @@ -62,10 +63,11 @@ namespace SOFIE{ fAttrAlpha(alpha), fAttrBeta(beta), fAttrTransA(transA), fAttrTransB(transB), fNA(UTILITY::Clean_name(nameA)), fNB(UTILITY::Clean_name(nameB)), fNC(UTILITY::Clean_name(nameC)), fNY(UTILITY::Clean_name(nameY)), fActivation(activation) { + fKind = OperatorKind::GEMM; fActivation = activation; fType = "float"; - fInputTensorNames = {fNA, fNB, fNC}; + fInputTensorNames = { fNA, fNB, fNC }; fOutputTensorNames = { fNY }; } @@ -402,7 +404,16 @@ namespace SOFIE{ } std::vector GetBlasRoutines() override { return { std::string("Gemm"), std::string("Gemv") }; } - + std::string GetFusableOutputTensorName() override { + return fNY; + } + + void UpdateFusableTensorName(std::string fusable_tensor_name, const std::function& removal_func){ + removal_func(fNY); + fNY = fusable_tensor_name; + fOutputTensorNames[0] = fNY; + } + }; diff --git a/tmva/sofie/inc/TMVA/ROperator_LayerNormalization.hxx b/tmva/sofie/inc/TMVA/ROperator_LayerNormalization.hxx index 239c5332172b0..d09db1a19979e 100644 --- a/tmva/sofie/inc/TMVA/ROperator_LayerNormalization.hxx +++ b/tmva/sofie/inc/TMVA/ROperator_LayerNormalization.hxx @@ -58,7 +58,8 @@ public: : fAttrAxis(axis), fAttrEpsilon(epsilon), fAttrStashType(stashType), fNX(UTILITY::Clean_name(nameX)), fNScale(UTILITY::Clean_name(nameScale)), fNB(UTILITY::Clean_name(nameB)), fNY(UTILITY::Clean_name(nameY)), fNMean(UTILITY::Clean_name(nameMean)), fNInvStdDev(UTILITY::Clean_name(nameInvStdDev)) - { + { + fKind = OperatorKind::LAYERNORM; fInputTensorNames = { fNX, fNScale }; if (!fNB.empty()){ fInputTensorNames.emplace_back(fNB); @@ -336,6 +337,20 @@ public: std::vector GetBlasRoutines() override { return { std::string("Axpy") }; } std::vector GetStdLibs() override { return { std::string("cmath") }; } + + std::string GetFusableOutputTensorName() override { + return fNY; + } + + void UpdateFusableTensorName(std::string fusable_tensor_name, const std::function& removal_func){ + removal_func(fNX); + removal_func(fNY); + fNX = fusable_tensor_name; + fNY = fusable_tensor_name; + fInputTensorNames[0] = fNX; + fOutputTensorNames[0] = fNY; + } + }; } // namespace SOFIE diff --git a/tmva/sofie/inc/TMVA/ROperator_Range.hxx b/tmva/sofie/inc/TMVA/ROperator_Range.hxx index 9cac15a14fc52..cc475ec6fcb03 100644 --- a/tmva/sofie/inc/TMVA/ROperator_Range.hxx +++ b/tmva/sofie/inc/TMVA/ROperator_Range.hxx @@ -28,7 +28,7 @@ public: ROperator_Range(){} ROperator_Range(std::string start, std::string limit, std::string delta, std::string nameOutput): - fNStart(start), fNLimit(limit), fNDelta(delta), + fNStart(UTILITY::Clean_name(start)), fNLimit(UTILITY::Clean_name(limit)), fNDelta(UTILITY::Clean_name(delta)), fNOutput(UTILITY::Clean_name(nameOutput)) { if (std::is_same::value) { fType = "float"; @@ -37,6 +37,8 @@ public: } static_assert( (std::is_same_v || std::is_same_v), "TMVA::SOFIE - Unsupported type by Range operator"); + fInputTensorNames = {fNStart, fNLimit, fNDelta}; + fOutputTensorNames = {fNOutput}; } std::vector TypeInference(std::vector input) override { diff --git a/tmva/sofie/inc/TMVA/ROperator_Relu.hxx b/tmva/sofie/inc/TMVA/ROperator_Relu.hxx index a3b1df8ee0abf..dea69818e978e 100644 --- a/tmva/sofie/inc/TMVA/ROperator_Relu.hxx +++ b/tmva/sofie/inc/TMVA/ROperator_Relu.hxx @@ -25,6 +25,7 @@ public: ROperator_Relu(){} ROperator_Relu(std::string nameX, std::string nameY): fNX(UTILITY::Clean_name(nameX)), fNY(UTILITY::Clean_name(nameY)){ + fKind = OperatorKind::RELU; fInputTensorNames = { fNX }; fOutputTensorNames = { fNY }; } @@ -66,6 +67,20 @@ public: return out.str(); } + + std::string GetFusableOutputTensorName() override { + return fNY; + } + + void UpdateFusableTensorName(std::string fusable_tensor_name, const std::function& removal_func){ + removal_func(fNX); + removal_func(fNY); + fNX = fusable_tensor_name; + fNY = fusable_tensor_name; + fInputTensorNames[0] = fNX; + fOutputTensorNames[0] = fNY; + } + }; }//SOFIE diff --git a/tmva/sofie/inc/TMVA/ROperator_Split.hxx b/tmva/sofie/inc/TMVA/ROperator_Split.hxx index f191f9d014238..0936a13415313 100644 --- a/tmva/sofie/inc/TMVA/ROperator_Split.hxx +++ b/tmva/sofie/inc/TMVA/ROperator_Split.hxx @@ -38,7 +38,7 @@ public: fInputTensorNames = { fNX }; fOutputTensorNames.resize(fNYs.size()); std::transform(fNYs.begin(), fNYs.end(), fOutputTensorNames.begin(), - [](const std::string& s) -> std::string_view { return s; }); + [](const std::string& s) -> std::string { return s; }); } std::vector TypeInference(std::vector input) override { diff --git a/tmva/sofie/inc/TMVA/ROperator_SubGraph.hxx b/tmva/sofie/inc/TMVA/ROperator_SubGraph.hxx index 683d40de835d7..4eea16e059210 100644 --- a/tmva/sofie/inc/TMVA/ROperator_SubGraph.hxx +++ b/tmva/sofie/inc/TMVA/ROperator_SubGraph.hxx @@ -36,7 +36,7 @@ public: fInputTensorNames = { fNX }; std::transform(fNYs.begin(), fNYs.end(), fOutputTensorNames.begin(), - [](const std::string& s) -> std::string_view { return s; }); + [](const std::string& s) -> std::string { return s; }); } std::vector TypeInference(std::vector input) override { diff --git a/tmva/sofie/inc/TMVA/SOFIE_common.hxx b/tmva/sofie/inc/TMVA/SOFIE_common.hxx index 57d829ac1acff..ffc79e9be94c6 100644 --- a/tmva/sofie/inc/TMVA/SOFIE_common.hxx +++ b/tmva/sofie/inc/TMVA/SOFIE_common.hxx @@ -171,7 +171,7 @@ struct TensorMemoryInfo { TensorMemoryInfo split(const std::string_view new_name, size_t new_size) { if (new_size > tensor_size) { - throw std::invalid_argument("New size exceeds available tensor size."); + throw std::invalid_argument("New size "+ std::to_string(new_size) + " exceeds available tensor size of " + std::to_string(tensor_size)+"."); } tensor_size -= new_size; return TensorMemoryInfo{new_name, new_size}; diff --git a/tmva/sofie/src/RModel.cxx b/tmva/sofie/src/RModel.cxx index 57d1630f8c619..56dd6b336e37d 100644 --- a/tmva/sofie/src/RModel.cxx +++ b/tmva/sofie/src/RModel.cxx @@ -153,14 +153,20 @@ void RModel::AddInputTensorName(std::string input_name) { fInputTensorNames.emplace_back(UTILITY::Clean_name(input_name)); } -void RModel::AddOperator(std::unique_ptr op, int order_execution) { +void RModel::AddOperator(std::unique_ptr op, size_t order_execution) { AddBlasRoutines(op->GetBlasRoutines()); auto libs = op->GetStdLibs(); auto op_input_tensors = op->GetOpInputTensors(); for (auto& stdlib : libs) { AddNeededStdLib(stdlib); } - if (order_execution >= 0) { + if (op->GetOpKind()==OperatorKind::CONSTANT){ + AddConstantOperator(std::move(op)); + return; + } + + op->RegisterOperatorOrder(order_execution); + if (order_execution >= 0 && order_execution <= fOperators.size()) { fOperators.insert(fOperators.begin() + order_execution, std::move(op)); } else { fOperators.push_back(std::move(op)); @@ -168,16 +174,21 @@ void RModel::AddOperator(std::unique_ptr op, int order_execution) { // storing the last usage of tensors which are input to // operators (but are not inputs to the model, i.e. they are intermediate - // tensors). This information is needed to keep a check on when a - // particular intermediate tensor can be flushed to free up memory for reuse. - for(size_t index = 0; index op){ + fConstantOperators.push_back(std::move(op)); } void RModel::AddInitializedTensor(std::string tensor_name, ETensorType type, std::vector shape, std::shared_ptr data) { @@ -255,7 +266,8 @@ void RModel::AddIntermediateTensor(std::string tensor_name, ETensorType type, st void RModel::AddIntermediateTensor(std::string tensor_name, ETensorType type, std::vector shape) { tensor_name = UTILITY::Clean_name(tensor_name); if (CheckIfTensorAlreadyExist(tensor_name)) { - throw std::runtime_error("TMVA-SOFIE: intermediate tensor with name " + tensor_name + " already exists \n"); + // throw std::runtime_error("TMVA-SOFIE: intermediate tensor with name " + tensor_name + " already exists \n"); + return; } TensorInfo new_tensor {type, shape}; fIntermediateTensorInfos[tensor_name] = new_tensor; @@ -326,7 +338,7 @@ void RModel::SetNotWritableInitializedTensor(const std::string & tensor_name) { t->second.SetNotWritable(); } -std::string RModel::AllocateIntermediateMemory(std::span op_output_tensors) +std::string RModel::AllocateIntermediateMemory(std::span op_output_tensors, std::set& allocated_tensors) { std::stringstream code; @@ -343,8 +355,9 @@ std::string RModel::AllocateIntermediateMemory(std::span bool allocated = false; if (GetTensorType(name) == ETensorType::BOOL || fInitializedTensors.find(name) != fInitializedTensors.end() || - fDynamicTensorInfos.find(name) != fDynamicTensorInfos.end()) continue; - + fDynamicTensorInfos.find(name) != fDynamicTensorInfos.end() || + allocated_tensors.count(it)) continue; + auto tensor_size = GetTypeSize(GetTensorType(name)) * ConvertShapeToLength(GetTensorShape(name)); for (auto chunk = fIntermediateMemoryInfo.available_stack.begin(); chunk != fIntermediateMemoryInfo.available_stack.end(); ) { @@ -369,6 +382,7 @@ std::string RModel::AllocateIntermediateMemory(std::span ++chunk; } + if (!allocated) { size_t chunk_idx = fIntermediateMemoryInfo.total_stack.empty() ? 0 @@ -378,18 +392,18 @@ std::string RModel::AllocateIntermediateMemory(std::span declareIntermediateTensor(name, tensor_size, chunk_idx); } + allocated_tensors.insert(it); } return code.str(); } -void RModel::CheckAndFlushIntermediateMemory(std::span op_input_tensors, const size_t& op_idx){ +void RModel::CheckAndFlushIntermediateMemory(std::span op_input_tensors, const size_t& op_idx){ for (auto &it : op_input_tensors){ // last occurence of the tensor is reached => flush it from memory - if (fIntermediateTensorFrequencyLookup[it] == op_idx) { + if (fIntermediateTensorFrequencyLookup[it] == fOperators[op_idx]->GetOpOrder()){ for (auto chunk = fIntermediateMemoryInfo.total_stack.begin(); chunk != fIntermediateMemoryInfo.total_stack.end(); ++chunk ) { if (chunk->second.tensor_name == it) { - // check if nearby chunks in available memory can coalesce auto first_greater = fIntermediateMemoryInfo.available_stack.upper_bound(chunk->first); // smallest element greater than the flushed chunk idx auto last_smaller = (first_greater == fIntermediateMemoryInfo.available_stack.begin()) ? fIntermediateMemoryInfo.available_stack.end() : std::prev(first_greater); // largest element smaller than the flushed chunk idx @@ -419,7 +433,54 @@ void RModel::CheckAndFlushIntermediateMemory(std::span o } } - +void RModel::CheckAndFuseOperators() { + size_t idx = 0; + std::vector fusable_indices; + std::string fusable_propagate_tensor_name; + while (idx < fOperators.size()) { + if (fOperators[idx]->GetOpKind() != OperatorKind::GEMM && fOperators[idx]->GetOpKind() != OperatorKind::CONV) { + ++idx; + continue; + } + + fusable_indices.clear(); + fusable_propagate_tensor_name.clear(); + + fusable_indices.push_back(idx); + size_t j = idx + 1; + for (; j < fOperators.size()-1; ++j) { + auto opKind = fOperators[j]->GetOpKind(); + // Only consider operators with fusable kinds + if (!FusableKinds.count(opKind)) { + break; + } + + const auto& tensorName = fOperators[j]->GetFusableOutputTensorName(); + auto freqIt = fIntermediateTensorFrequencyLookup.find(tensorName); + + // Propagate tensor name only if it's not used multiple times + fusable_indices.push_back(j); + if (freqIt != fIntermediateTensorFrequencyLookup.end() && + (freqIt->second != fOperators[j + 1]->GetOpOrder() || + FusableKinds.count(fOperators[j + 1]->GetOpKind()) == 0)) { + fusable_propagate_tensor_name = tensorName; + break; + } + } + if (!fusable_propagate_tensor_name.empty()) { + auto fusable_tensor_type = GetTensorType(fusable_propagate_tensor_name); + auto fusable_tensor_shape = GetDynamicTensorShape(fusable_propagate_tensor_name); + for (auto& index : fusable_indices) { + fOperators[index]->UpdateFusableTensorName(fusable_propagate_tensor_name, [this](const std::string& name) { + this->RemoveIntermediateTensor(name); + }); + } + AddIntermediateTensor(fusable_propagate_tensor_name, fusable_tensor_type, fusable_tensor_shape); + } + + idx = std::max(idx + 1, j); + } +} void RModel::Initialize(int batchSize, bool verbose) { std::map inputParams; @@ -434,7 +495,7 @@ void RModel::Initialize(int batchSize, bool verbose) { void RModel::Initialize(const std::map & inputParams, bool verbose) { fVerbose = int(verbose); - + fVerbose = 0; if (fIsInitialized) { if (verbose) std::cout << "Model is already initialized - skip initialization " << std::endl; @@ -510,28 +571,35 @@ void RModel::Initialize(const std::map & inputParams, bool if (!modelHasWeights) fUseWeightFile = false; } - // Go through model and initialize each operator - int i = 0; - std::vector temp_available_stack; // vector stores individual chunks of available memory that maybe reused - - for(size_t op_idx = 0; op_idx < fOperators.size(); ++op_idx){ + for (size_t op_const_idx = 0; op_const_idx < fConstantOperators.size(); ++op_const_idx) { if (verbose) { - auto& r = *fOperators[op_idx].get(); - std::cout << "Initializing operator " << i << " " << typeid(r).name() << std::endl; - } - fOperators[op_idx]->Initialize(*this); - for(auto &it:fOperators[op_idx]->GetOpOutputTensors()){ - std::string name = std::string{it}; - if (fIntermediateTensorFrequencyLookup.find(it) == fIntermediateTensorFrequencyLookup.end() && - std::find(fOutputTensorNames.begin(), fOutputTensorNames.end(), name) == fOutputTensorNames.end() && - fInitializedTensors.find(name) == fInitializedTensors.end() && - fDynamicTensorInfos.find(name) == fDynamicTensorInfos.end()){ - fIntermediateTensorFrequencyLookup[it] = op_idx; - } + auto& r = *fConstantOperators[op_const_idx].get(); + std::cout << "Initializing constant operator " << op_const_idx << " " << typeid(r).name() << std::endl; } - i++; + + fConstantOperators[op_const_idx]->Initialize(*this); + } + + // Go through model and initialize each operator + for (size_t op_idx = 0; op_idx < fOperators.size(); ++op_idx ) { + if (verbose) { + auto& r = *fOperators[op_idx].get(); + std::cout << "Initializing operator " << op_idx << " " << typeid(r).name() << std::endl; + } + + fOperators[op_idx]->Initialize(*this); + for (auto &it : fOperators[op_idx]->GetOpOutputTensors()) { + std::string name{it}; + if (fIntermediateTensorFrequencyLookup.find(it) == fIntermediateTensorFrequencyLookup.end() && + fInputTensorInfos.find(name) == fInputTensorInfos.end() && + fInitializedTensors.find(name) == fInitializedTensors.end() && + fDynamicTensorInfos.find(name) == fDynamicTensorInfos.end()) { + fIntermediateTensorFrequencyLookup[it] = fOperators[op_idx]->GetOpOrder(); + } + } } + CheckAndFuseOperators(); fIsInitialized = true; } @@ -843,10 +911,11 @@ void RModel::GenerateSessionCode() if (fOptimizationLevel == OptimizationLevel::kExtended) { // evaluate total intermediate memory and position intermediate tensor addresses + std::set allocated_tensors; std::string intermediate_memory_alloc_string = ""; intermediate_memory_alloc_string += "\n// --- Positioning intermediate tensor memory --"; for (size_t op_idx = 0; op_idx < fOperators.size(); ++op_idx) { - intermediate_memory_alloc_string += AllocateIntermediateMemory(fOperators[op_idx]->GetOpOutputTensors()); + intermediate_memory_alloc_string += AllocateIntermediateMemory(fOperators[op_idx]->GetOpOutputTensors(), allocated_tensors); CheckAndFlushIntermediateMemory(fOperators[op_idx]->GetOpInputTensors(), op_idx); } @@ -929,7 +998,7 @@ void RModel::GenerateSessionCode() fGC += "}\n\n"; } - + fGC += doInferSignature + "{\n"; fGC += "\n"; @@ -1339,7 +1408,6 @@ void RModel::HeadInitializedTensors(std::string name, int n_print) { void RModel::OutputGenerated(std::string filename, bool append) { RModel_Base::OutputGenerated(filename, append); - // write weights in a text file if (fUseWeightFile) { if (!filename.empty()) { diff --git a/tmva/sofie_parsers/src/RModelParser_ONNX.cxx b/tmva/sofie_parsers/src/RModelParser_ONNX.cxx index 7b4ade2b6bc09..1303d37150c28 100644 --- a/tmva/sofie_parsers/src/RModelParser_ONNX.cxx +++ b/tmva/sofie_parsers/src/RModelParser_ONNX.cxx @@ -90,7 +90,7 @@ extern ParserFuncSignature ParseScatterElements; // Declaration of fused operators extern ParserFuseFuncSignature ParseFuseConvAdd; extern ParserFuseFuncSignature ParseFuseGemmRelu; -extern ParserFuseFuncSignature ParseFuseBatchnormRelu; +// extern ParserFuseFuncSignature ParseFuseBatchnormRelu; extern ParserFuseFuncSignature ParseFuseConvTransposeAdd; extern ParserFuseFuncSignature ParseFuseMatMulAdd; @@ -320,12 +320,13 @@ RModelParser_ONNX::ParseOperator(const size_t i, const onnx::GraphProto &graphpr fFusedOperators[idx2] = true; return ParseFuseGemmRelu(*this, graphproto.node(idx), graphproto.node(idx2)); } - } else if (nodeproto.op_type() == "BatchNormalization") { - if (idx2 < graphproto.node_size() && graphproto.node(idx2).op_type() == "Relu") { - fFusedOperators[idx2] = true; - return ParseFuseBatchnormRelu(*this, graphproto.node(idx), graphproto.node(idx2)); - } - } + } + // else if (nodeproto.op_type() == "BatchNormalization") { + // if (idx2 < graphproto.node_size() && graphproto.node(idx2).op_type() == "Relu") { + // fFusedOperators[idx2] = true; + // return ParseFuseBatchnormRelu(*this, graphproto.node(idx), graphproto.node(idx2)); + // } + // } } @@ -676,7 +677,7 @@ void RModelParser_ONNX::ParseONNXGraph(RModel & rmodel, const onnx::GraphProto & } while ((int)nodesOrder.size() < graph.node_size()); - // find list of children for each operator (used for fusing oiperators) + // find list of children for each operator (used for fusing operators) std::vector> nodesChildren(graph.node_size()); for (int k = 0; k < graph.node_size(); k++) {