1515#ifndef TMVA_RBATCHGENERATOR
1616#define TMVA_RBATCHGENERATOR
1717
18- #include " TMVA/RTensor .hxx"
18+ #include " TMVA/BatchGenerator/RFlat2DMatrix .hxx"
1919#include " ROOT/RDF/RDatasetSpec.hxx"
2020#include " TMVA/BatchGenerator/RChunkLoader.hxx"
2121#include " TMVA/BatchGenerator/RBatchLoader.hxx"
@@ -100,11 +100,12 @@ private:
100100 std::size_t fNumTrainingBatches ;
101101 std::size_t fNumValidationBatches ;
102102
103- TMVA::Experimental::RTensor<float > fTrainTensor ;
104- TMVA::Experimental::RTensor<float > fTrainChunkTensor ;
103+ // flattened buffers for chunks and temporary tensors (rows * cols)
104+ RFlat2DMatrix fTrainTensor ;
105+ RFlat2DMatrix fTrainChunkTensor ;
105106
106- TMVA::Experimental::RTensor< float > fValidationTensor ;
107- TMVA::Experimental::RTensor< float > fValidationChunkTensor ;
107+ RFlat2DMatrix fValidationTensor ;
108+ RFlat2DMatrix fValidationChunkTensor ;
108109
109110public:
110111 RBatchGenerator (ROOT::RDF::RNode &rdf, const std::size_t chunkSize, const std::size_t blockSize,
@@ -125,11 +126,7 @@ public:
125126 fShuffle(shuffle),
126127 fNotFiltered(f_rdf.GetFilterNames().empty()),
127128 fUseWholeFile(maxChunks == 0 ),
128- fNumColumns(cols.size()),
129- fTrainTensor({0 , 0 }),
130- fTrainChunkTensor({0 , 0 }),
131- fValidationTensor({0 , 0 }),
132- fValidationChunkTensor({0 , 0 })
129+ fNumColumns(cols.size())
133130 {
134131
135132 fNumEntries = f_rdf.Count ().GetValue ();
@@ -255,7 +252,7 @@ public:
255252 }
256253
257254 // / \brief Loads a training batch from the queue
258- TMVA::Experimental::RTensor< float > GetTrainBatch ()
255+ RFlat2DMatrix GetTrainBatch ()
259256 {
260257 auto batchQueue = fBatchLoader ->GetNumTrainingBatchQueue ();
261258
@@ -276,8 +273,8 @@ public:
276273 return fBatchLoader ->GetTrainBatch ();
277274 }
278275
279- // / \brief Loads a validation batch from the queue
280- TMVA::Experimental::RTensor< float > GetValidationBatch ()
276+ // / \brief Loads a validation batch from the queue
277+ RFlat2DMatrix GetValidationBatch ()
281278 {
282279 auto batchQueue = fBatchLoader ->GetNumValidationBatchQueue ();
283280
0 commit comments