Skip to content

Commit d919191

Browse files
committed
Remove RTensor dependency from RBatchGenerator classes
1 parent ee08f36 commit d919191

File tree

6 files changed

+165
-161
lines changed

6 files changed

+165
-161
lines changed

tmva/tmva/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,7 @@ ROOT_STANDARD_LIBRARY_PACKAGE(TMVAUtils
447447
TMVA/BatchGenerator/RBatchLoader.hxx
448448
TMVA/BatchGenerator/RChunkLoader.hxx
449449
TMVA/BatchGenerator/RChunkConstructor.hxx
450+
TMVA/BatchGenerator/RFlat2DMatrix.hxx
450451

451452
SOURCES
452453

tmva/tmva/inc/TMVA/BatchGenerator/RBatchGenerator.hxx

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
#ifndef TMVA_RBATCHGENERATOR
1616
#define TMVA_RBATCHGENERATOR
1717

18-
#include "TMVA/RTensor.hxx"
18+
#include "TMVA/BatchGenerator/RFlat2DMatrix.hxx"
1919
#include "ROOT/RDF/RDatasetSpec.hxx"
2020
#include "TMVA/BatchGenerator/RChunkLoader.hxx"
2121
#include "TMVA/BatchGenerator/RBatchLoader.hxx"
@@ -100,11 +100,12 @@ private:
100100
std::size_t fNumTrainingBatches;
101101
std::size_t fNumValidationBatches;
102102

103-
TMVA::Experimental::RTensor<float> fTrainTensor;
104-
TMVA::Experimental::RTensor<float> fTrainChunkTensor;
103+
// flattened buffers for chunks and temporary tensors (rows * cols)
104+
RFlat2DMatrix fTrainTensor;
105+
RFlat2DMatrix fTrainChunkTensor;
105106

106-
TMVA::Experimental::RTensor<float> fValidationTensor;
107-
TMVA::Experimental::RTensor<float> fValidationChunkTensor;
107+
RFlat2DMatrix fValidationTensor;
108+
RFlat2DMatrix fValidationChunkTensor;
108109

109110
public:
110111
RBatchGenerator(ROOT::RDF::RNode &rdf, const std::size_t chunkSize, const std::size_t blockSize,
@@ -125,11 +126,7 @@ public:
125126
fShuffle(shuffle),
126127
fNotFiltered(f_rdf.GetFilterNames().empty()),
127128
fUseWholeFile(maxChunks == 0),
128-
fNumColumns(cols.size()),
129-
fTrainTensor({0, 0}),
130-
fTrainChunkTensor({0, 0}),
131-
fValidationTensor({0, 0}),
132-
fValidationChunkTensor({0, 0})
129+
fNumColumns(cols.size())
133130
{
134131

135132
fNumEntries = f_rdf.Count().GetValue();
@@ -255,7 +252,7 @@ public:
255252
}
256253

257254
/// \brief Loads a training batch from the queue
258-
TMVA::Experimental::RTensor<float> GetTrainBatch()
255+
RFlat2DMatrix GetTrainBatch()
259256
{
260257
auto batchQueue = fBatchLoader->GetNumTrainingBatchQueue();
261258

@@ -276,8 +273,8 @@ public:
276273
return fBatchLoader->GetTrainBatch();
277274
}
278275

279-
/// \brief Loads a validation batch from the queue
280-
TMVA::Experimental::RTensor<float> GetValidationBatch()
276+
/// \brief Loads a validation batch from the queue
277+
RFlat2DMatrix GetValidationBatch()
281278
{
282279
auto batchQueue = fBatchLoader->GetNumValidationBatchQueue();
283280

0 commit comments

Comments
 (0)