Skip to content

Commit d007255

Browse files
committed
[df] Fix interaction between RVecDS and Snapshot
This commit updates the implementation of the column reader retrieval in RVecDS to use GetColumnReaders with type_info instead of the older version. This in turn allows Snapshot to request the correct column reader when reading a Numpy-based dataset.
1 parent ee08f36 commit d007255

File tree

2 files changed

+65
-33
lines changed

2 files changed

+65
-33
lines changed

bindings/pyroot/pythonizations/test/rdataframe_misc.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
import ROOT
77

8+
import numpy
9+
810

911
class DatasetContext:
1012
"""A helper class to create the dataset for the tutorial below."""
@@ -129,7 +131,30 @@ def test_ttree_ownership(self):
129131

130132
self.assertEqual(rdf.Count().GetValue(), 9)
131133

132-
134+
def test_regression_gh_20291(self):
135+
"""
136+
Regression test for https://github.com/root-project/root/issues/20291
137+
"""
138+
# Issues on Windows with contention on file deletion
139+
if platform.system() == "Windows":
140+
return
141+
out_path = "dataframe_misc_regression_gh20291.root"
142+
try:
143+
x, y = numpy.array([1, 2, 3]), numpy.array([4, 5, 6])
144+
df = ROOT.RDF.FromNumpy({"x": x, "y": y})
145+
146+
df.Snapshot("tree", out_path)
147+
148+
df_out = ROOT.RDataFrame("tree", out_path)
149+
count = df_out.Count()
150+
take_x = df_out.Take["Long64_t"]("x")
151+
take_y = df_out.Take["Long64_t"]("y")
152+
153+
self.assertEqual(count.GetValue(), 3)
154+
self.assertSequenceEqual(take_x.GetValue(), [1, 2, 3])
155+
self.assertSequenceEqual(take_y.GetValue(), [4, 5, 6])
156+
finally:
157+
os.remove(out_path)
133158

134159

135160
if __name__ == '__main__':

tree/dataframe/inc/ROOT/RVecDS.hxx

Lines changed: 39 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,14 @@ namespace Internal {
3232

3333
namespace RDF {
3434

35+
class R__CLING_PTRCHECK(off) RVecDSColumnReader final : public ROOT::Detail::RDF::RColumnReaderBase {
36+
TPointerHolder *fPtrHolder;
37+
void *GetImpl(Long64_t) final { return fPtrHolder->GetPointer(); }
38+
39+
public:
40+
RVecDSColumnReader(TPointerHolder *ptrHolder) : fPtrHolder(ptrHolder) {}
41+
};
42+
3543
////////////////////////////////////////////////////////////////////////////////////////////////
3644
/// \brief A RDataSource implementation which takes a collection of RVecs, which
3745
/// are able to adopt data from Numpy arrays
@@ -46,46 +54,18 @@ class RVecDS final : public ROOT::RDF::RDataSource {
4654
using PointerHolderPtrs_t = std::vector<ROOT::Internal::RDF::TPointerHolder *>;
4755

4856
std::tuple<ROOT::RVec<ColumnTypes>...> fColumns;
49-
const std::vector<std::string> fColNames;
50-
const std::map<std::string, std::string> fColTypesMap;
57+
std::vector<std::string> fColNames;
58+
std::unordered_map<std::string, std::string> fColTypesMap;
5159
// The role of the fPointerHoldersModels is to be initialised with the pack
5260
// of arguments in the constrcutor signature at construction time
5361
// Once the number of slots is known, the fPointerHolders are initialised
5462
// according to the models.
55-
const PointerHolderPtrs_t fPointerHoldersModels;
63+
PointerHolderPtrs_t fPointerHoldersModels;
5664
std::vector<PointerHolderPtrs_t> fPointerHolders;
5765
std::vector<std::pair<ULong64_t, ULong64_t>> fEntryRanges{};
5866
std::function<void()> fDeleteRVecs;
5967

60-
Record_t GetColumnReadersImpl(std::string_view colName, const std::type_info &id)
61-
{
62-
auto colNameStr = std::string(colName);
63-
// This could be optimised and done statically
64-
const auto idName = ROOT::Internal::RDF::TypeID2TypeName(id);
65-
auto it = fColTypesMap.find(colNameStr);
66-
if (fColTypesMap.end() == it) {
67-
std::string err = "The specified column name, \"" + colNameStr + "\" is not known to the data source.";
68-
throw std::runtime_error(err);
69-
}
70-
71-
const auto colIdName = it->second;
72-
if (colIdName != idName) {
73-
std::string err = "Column " + colNameStr + " has type " + colIdName +
74-
" while the id specified is associated to type " + idName;
75-
throw std::runtime_error(err);
76-
}
77-
78-
const auto colBegin = fColNames.begin();
79-
const auto colEnd = fColNames.end();
80-
const auto namesIt = std::find(colBegin, colEnd, colName);
81-
const auto index = std::distance(colBegin, namesIt);
82-
83-
Record_t ret(fNSlots);
84-
for (auto slot : ROOT::TSeqU(fNSlots)) {
85-
ret[slot] = fPointerHolders[index][slot]->GetPointerAddr();
86-
}
87-
return ret;
88-
}
68+
Record_t GetColumnReadersImpl(std::string_view, const std::type_info &) { return {}; }
8969

9070
size_t GetEntriesNumber() { return std::get<0>(fColumns).size(); }
9171
template <std::size_t... S>
@@ -146,6 +126,33 @@ public:
146126
fDeleteRVecs();
147127
}
148128

129+
std::unique_ptr<ROOT::Detail::RDF::RColumnReaderBase>
130+
GetColumnReaders(unsigned int slot, std::string_view colName, const std::type_info &id) final
131+
{
132+
auto colNameStr = std::string(colName);
133+
134+
auto it = fColTypesMap.find(colNameStr);
135+
if (fColTypesMap.end() == it) {
136+
std::string err = "The specified column name, \"" + colNameStr + "\" is not known to the data source.";
137+
throw std::runtime_error(err);
138+
}
139+
140+
const auto &colIdName = it->second;
141+
const auto idName = ROOT::Internal::RDF::TypeID2TypeName(id);
142+
if (colIdName != idName) {
143+
std::string err = "Column " + colNameStr + " has type " + colIdName +
144+
" while the id specified is associated to type " + idName;
145+
throw std::runtime_error(err);
146+
}
147+
148+
if (auto colNameIt = std::find(fColNames.begin(), fColNames.end(), colNameStr); colNameIt != fColNames.end()) {
149+
const auto index = std::distance(fColNames.begin(), colNameIt);
150+
return std::make_unique<ROOT::Internal::RDF::RVecDSColumnReader>(fPointerHolders[index][slot]);
151+
}
152+
153+
throw std::runtime_error("Could not find column name \"" + colNameStr + "\" in available column names.");
154+
}
155+
149156
const std::vector<std::string> &GetColumnNames() const { return fColNames; }
150157

151158
std::vector<std::pair<ULong64_t, ULong64_t>> GetEntryRanges()

0 commit comments

Comments
 (0)