Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 03512da

Browse files
authored
Merge pull request #558 from facebookresearch/schedule-tree-evolution
ScheduleTree*: hide constructors
2 parents 1499725 + c555112 commit 03512da

File tree

5 files changed

+400
-110
lines changed

5 files changed

+400
-110
lines changed

tc/core/polyhedral/schedule_isl_conversion.cc

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -236,12 +236,14 @@ namespace {
236236

237237
std::unique_ptr<ScheduleTreeBand> fromIslScheduleNodeBand(
238238
isl::schedule_node_band b) {
239-
auto res = ScheduleTreeBand::fromMultiUnionPwAff(b.get_partial_schedule());
240-
res->permutable_ = b.get_permutable();
241-
for (size_t i = 0; i < b.n_member(); ++i) {
242-
res->coincident_[i] = b.member_get_coincident(i);
239+
auto n = b.n_member();
240+
std::vector<bool> coincident(n, false);
241+
std::vector<bool> unroll(n, false);
242+
for (size_t i = 0; i < n; ++i) {
243+
coincident[i] = b.member_get_coincident(i);
243244
}
244-
return res;
245+
return ScheduleTreeBand::make(
246+
b.get_partial_schedule(), b.get_permutable(), coincident, unroll);
245247
}
246248

247249
std::unique_ptr<ScheduleTree> elemFromIslScheduleNode(isl::schedule_node node) {
@@ -250,19 +252,19 @@ std::unique_ptr<ScheduleTree> elemFromIslScheduleNode(isl::schedule_node node) {
250252
return fromIslScheduleNodeBand(band);
251253
} else if (auto context = node.as<isl::schedule_node_context>()) {
252254
auto c = context.get_context();
253-
return std::unique_ptr<ScheduleTreeContext>(new ScheduleTreeContext(c));
255+
return ScheduleTreeContext::make(c);
254256
} else if (auto domain = node.as<isl::schedule_node_domain>()) {
255257
auto c = domain.get_domain();
256-
return std::unique_ptr<ScheduleTreeDomain>(new ScheduleTreeDomain(c));
258+
return ScheduleTreeDomain::make(c);
257259
} else if (auto expansion = node.as<isl::schedule_node_expansion>()) {
258260
LOG(FATAL) << "expansion nodes not supported";
259261
return nullptr;
260262
} else if (auto extension = node.as<isl::schedule_node_extension>()) {
261263
auto e = extension.get_extension();
262-
return std::unique_ptr<ScheduleTreeExtension>(new ScheduleTreeExtension(e));
264+
return ScheduleTreeExtension::make(e);
263265
} else if (auto filter = node.as<isl::schedule_node_filter>()) {
264266
auto f = filter.get_filter();
265-
return std::unique_ptr<ScheduleTreeFilter>(new ScheduleTreeFilter(f));
267+
return ScheduleTreeFilter::make(f);
266268
} else if (auto guard = node.as<isl::schedule_node_guard>()) {
267269
LOG(FATAL) << "guard nodes not supported";
268270
return nullptr;
@@ -273,9 +275,9 @@ std::unique_ptr<ScheduleTree> elemFromIslScheduleNode(isl::schedule_node node) {
273275
LOG(FATAL) << "ScheduleTree::make called on explicit leaf";
274276
return nullptr;
275277
} else if (node.isa<isl::schedule_node_sequence>()) {
276-
return std::unique_ptr<ScheduleTreeSequence>(new ScheduleTreeSequence(ctx));
278+
return ScheduleTreeSequence::make(ctx);
277279
} else if (node.isa<isl::schedule_node_set>()) {
278-
return std::unique_ptr<ScheduleTreeSet>(new ScheduleTreeSet(ctx));
280+
return ScheduleTreeSet::make(ctx);
279281
}
280282
LOG(FATAL) << "NYI: ScheduleTree from type: "
281283
<< isl_schedule_node_get_type(node.get());

tc/core/polyhedral/schedule_tree.cc

Lines changed: 11 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -122,31 +122,6 @@ vector<ScheduleTree*> ancestorsInSubTree(
122122
}
123123
return res;
124124
}
125-
126-
static std::unique_ptr<ScheduleTree> makeElem(const ScheduleTree& st) {
127-
#define ELEM_MAKE_CASE(CLASS) \
128-
else if (st.type_ == CLASS::NodeType) { \
129-
return std::unique_ptr<CLASS>(new CLASS(static_cast<const CLASS&>(st))); \
130-
}
131-
132-
if (st.type_ == detail::ScheduleTreeType::None) {
133-
LOG(FATAL) << "Hit Error node!";
134-
}
135-
ELEM_MAKE_CASE(ScheduleTreeBand)
136-
ELEM_MAKE_CASE(ScheduleTreeContext)
137-
ELEM_MAKE_CASE(ScheduleTreeDomain)
138-
ELEM_MAKE_CASE(ScheduleTreeExtension)
139-
ELEM_MAKE_CASE(ScheduleTreeFilter)
140-
ELEM_MAKE_CASE(ScheduleTreeMapping)
141-
ELEM_MAKE_CASE(ScheduleTreeSequence)
142-
ELEM_MAKE_CASE(ScheduleTreeSet)
143-
ELEM_MAKE_CASE(ScheduleTreeThreadSpecificMarker)
144-
145-
#undef ELEM_MAKE_CASE
146-
147-
LOG(FATAL) << "NYI: ScheduleTree from type: " << static_cast<int>(st.type_);
148-
return nullptr;
149-
}
150125
} // namespace
151126

152127
////////////////////////////////////////////////////////////////////////////////
@@ -163,7 +138,7 @@ ScheduleTree::ScheduleTree(const ScheduleTree& st)
163138
}
164139

165140
ScheduleTreeUPtr ScheduleTree::makeScheduleTree(const ScheduleTree& tree) {
166-
return makeElem(tree);
141+
return tree.clone();
167142
}
168143

169144
ScheduleTree* ScheduleTree::child(const vector<size_t>& positions) {
@@ -226,8 +201,10 @@ size_t ScheduleTree::scheduleDepth(const ScheduleTree* relativeRoot) const {
226201
std::unique_ptr<ScheduleTree> ScheduleTree::makeBand(
227202
isl::multi_union_pw_aff mupa,
228203
std::vector<ScheduleTreeUPtr>&& children) {
229-
auto res = ScheduleTreeBand::fromMultiUnionPwAff(mupa);
230-
res->appendChildren(std::move(children));
204+
std::vector<bool> coincident(mupa.size(), false);
205+
std::vector<bool> unroll(mupa.size(), false);
206+
auto res = ScheduleTreeBand::make(
207+
mupa, false, coincident, unroll, std::move(children));
231208
return res;
232209
}
233210

@@ -243,25 +220,19 @@ ScheduleTreeUPtr ScheduleTree::makeEmptyBand(const ScheduleTree* root) {
243220
std::unique_ptr<ScheduleTree> ScheduleTree::makeDomain(
244221
isl::union_set domain,
245222
std::vector<ScheduleTreeUPtr>&& children) {
246-
auto res = std::unique_ptr<ScheduleTree>(new ScheduleTreeDomain(domain));
247-
res->appendChildren(std::move(children));
248-
return res;
223+
return ScheduleTreeDomain::make(domain, std::move(children));
249224
}
250225

251226
std::unique_ptr<ScheduleTree> ScheduleTree::makeContext(
252227
isl::set context,
253228
std::vector<ScheduleTreeUPtr>&& children) {
254-
auto res = std::unique_ptr<ScheduleTree>(new ScheduleTreeContext(context));
255-
res->appendChildren(std::move(children));
256-
return res;
229+
return ScheduleTreeContext::make(context, std::move(children));
257230
}
258231

259232
std::unique_ptr<ScheduleTree> ScheduleTree::makeFilter(
260233
isl::union_set filter,
261234
std::vector<ScheduleTreeUPtr>&& children) {
262-
auto res = std::unique_ptr<ScheduleTree>(new ScheduleTreeFilter(filter));
263-
res->appendChildren(std::move(children));
264-
return res;
235+
return ScheduleTreeFilter::make(filter, std::move(children));
265236
}
266237

267238
std::unique_ptr<ScheduleTree> ScheduleTree::makeMappingUnsafe(
@@ -278,28 +249,19 @@ std::unique_ptr<ScheduleTree> ScheduleTree::makeMappingUnsafe(
278249
TC_CHECK_EQ(mappedIds.size(), mapping.size())
279250
<< "some id is used more than once in the mapping";
280251
auto ctx = mappedIds[0].get_ctx();
281-
auto res =
282-
std::unique_ptr<ScheduleTree>(new ScheduleTreeMapping(ctx, mapping));
283-
res->appendChildren(std::move(children));
284-
return res;
252+
return ScheduleTreeMapping::make(ctx, mapping, std::move(children));
285253
}
286254

287255
std::unique_ptr<ScheduleTree> ScheduleTree::makeExtension(
288256
isl::union_map extension,
289257
std::vector<ScheduleTreeUPtr>&& children) {
290-
auto res =
291-
std::unique_ptr<ScheduleTree>(new ScheduleTreeExtension(extension));
292-
res->appendChildren(std::move(children));
293-
return res;
258+
return ScheduleTreeExtension::make(extension, std::move(children));
294259
}
295260

296261
std::unique_ptr<ScheduleTree> ScheduleTree::makeThreadSpecificMarker(
297262
isl::ctx ctx,
298263
std::vector<ScheduleTreeUPtr>&& children) {
299-
auto res =
300-
std::unique_ptr<ScheduleTree>(new ScheduleTreeThreadSpecificMarker(ctx));
301-
res->appendChildren(std::move(children));
302-
return res;
264+
return ScheduleTreeThreadSpecificMarker::make(ctx, std::move(children));
303265
}
304266

305267
////////////////////////////////////////////////////////////////////////////////

tc/core/polyhedral/schedule_tree.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -401,11 +401,11 @@ struct ScheduleTree {
401401
"Arguments must be rvalue references to ScheduleTreeUPtr");
402402

403403
auto ctx = arg->ctx_;
404-
auto res = new T(ctx);
405-
flattenSequenceOrSet(res);
404+
auto res = T::make(ctx);
405+
flattenSequenceOrSet(res.get());
406406
res->appendChildren(
407407
vectorFromArgs(std::forward<Arg>(arg), std::forward<Args>(args)...));
408-
return ScheduleTreeUPtr(res);
408+
return res;
409409
}
410410

411411
// Make a (deep) copy of "tree".
@@ -465,6 +465,10 @@ struct ScheduleTree {
465465
// Note that this function does _not_ output the child trees.
466466
virtual std::ostream& write(std::ostream& os) const = 0;
467467

468+
// Clone the current node.
469+
// Note that this function does _not_ clone the child trees.
470+
virtual ScheduleTreeUPtr clone() const = 0;
471+
468472
//
469473
// Data members
470474
//

0 commit comments

Comments
 (0)