From 469714d823e1382a9cdac7ba8067c056bd35fe5c Mon Sep 17 00:00:00 2001 From: Oleksandr Zinenko Date: Wed, 25 Jul 2018 11:47:06 +0200 Subject: [PATCH 1/5] ScheduleTreeMapping::operator==: compare affine expression maps The mapping between identifiers and affine functions was introduced to the ScheduleTreeMapping in a630d0456 (ScheduleTreeElemMappingFilter: store mapping between identifiers and functions, Wed May 16 17:44:08 2018 +0200), however the comparison operator was only comparing the union_set filter_. While the two representations are generally expected to be equivalent, it is possible to modify them individually. So we need to also compare the mappings in ScheduleTreeMapping::operator==. --- tc/core/polyhedral/schedule_tree_elem.cc | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/tc/core/polyhedral/schedule_tree_elem.cc b/tc/core/polyhedral/schedule_tree_elem.cc index cc7ce86ff..1a7102a45 100644 --- a/tc/core/polyhedral/schedule_tree_elem.cc +++ b/tc/core/polyhedral/schedule_tree_elem.cc @@ -364,8 +364,18 @@ bool ScheduleTreeFilter::operator==(const ScheduleTreeFilter& other) const { } bool ScheduleTreeMapping::operator==(const ScheduleTreeMapping& other) const { - auto res = filter_.is_equal(other.filter_); - return res; + if (mapping.size() != other.mapping.size()) { + return false; + } + for (const auto& kvp : mapping) { + if (other.mapping.count(kvp.first) == 0) { + return false; + } + if (!other.mapping.at(kvp.first).plain_is_equal(kvp.second)) { + return false; + } + } + return filter_.is_equal(other.filter_); } bool ScheduleTreeSequence::operator==(const ScheduleTreeSequence& other) const { From e58e53a5693ac05956fb661b8456f4c98495dabe Mon Sep 17 00:00:00 2001 From: Oleksandr Zinenko Date: Wed, 25 Jul 2018 11:51:36 +0200 Subject: [PATCH 2/5] ScheduleTree*: introduce virtual nodeEquals method This method is abstract in the base class and is defined in derived classes to perform the type comparison and, if types match, call the class-specific comparison operator (or return true immediately for tree nodes that are always considered equal to each other, e.g., sequences). Its name indicates that that it applies to the current node only, ignoring the subtree that may be rooted at this node. It will be used in the upcoming commits to use virtual methods instead of macro-based dispatch on schedule node types. --- tc/core/polyhedral/schedule_tree.h | 5 ++++ tc/core/polyhedral/schedule_tree_elem.h | 36 +++++++++++++++++++++++++ 2 files changed, 41 insertions(+) diff --git a/tc/core/polyhedral/schedule_tree.h b/tc/core/polyhedral/schedule_tree.h index 403dd0210..f0ec82cff 100644 --- a/tc/core/polyhedral/schedule_tree.h +++ b/tc/core/polyhedral/schedule_tree.h @@ -469,6 +469,11 @@ struct ScheduleTree { // Note that this function does _not_ clone the child trees. virtual ScheduleTreeUPtr clone() const = 0; + // Compare the current node to the "other" node. + // Note that this function does _not_ compare the child trees, + // use treeEquals() instead to compare entire trees. + virtual bool nodeEquals(const ScheduleTree* other) const = 0; + // // Data members // diff --git a/tc/core/polyhedral/schedule_tree_elem.h b/tc/core/polyhedral/schedule_tree_elem.h index 5d782049d..86d7010d2 100644 --- a/tc/core/polyhedral/schedule_tree_elem.h +++ b/tc/core/polyhedral/schedule_tree_elem.h @@ -61,6 +61,10 @@ struct ScheduleTreeContext : public ScheduleTree { virtual ScheduleTreeUPtr clone() const override { return make(this); } + virtual bool nodeEquals(const ScheduleTree* other) const override { + auto otherContext = other->as(); + return otherContext && *this == *otherContext; + } public: isl::set context_; @@ -97,6 +101,10 @@ struct ScheduleTreeDomain : public ScheduleTree { virtual ScheduleTreeUPtr clone() const override { return make(this); } + virtual bool nodeEquals(const ScheduleTree* other) const override { + auto otherDomain = other->as(); + return otherDomain && *this == *otherDomain; + } public: isl::union_set domain_; @@ -133,6 +141,10 @@ struct ScheduleTreeExtension : public ScheduleTree { virtual ScheduleTreeUPtr clone() const override { return make(this); } + virtual bool nodeEquals(const ScheduleTree* other) const override { + auto otherExtension = other->as(); + return otherExtension && *this == *otherExtension; + } public: isl::union_map extension_; @@ -169,6 +181,10 @@ struct ScheduleTreeFilter : public ScheduleTree { virtual ScheduleTreeUPtr clone() const override { return make(this); } + virtual bool nodeEquals(const ScheduleTree* other) const override { + auto otherFilter = other->as(); + return otherFilter && *this == *otherFilter; + } public: isl::union_set filter_; @@ -210,6 +226,10 @@ struct ScheduleTreeMapping : public ScheduleTree { virtual ScheduleTreeUPtr clone() const override { return make(this); } + virtual bool nodeEquals(const ScheduleTree* other) const override { + auto otherMapping = other->as(); + return otherMapping && *this == *otherMapping; + } public: // Mapping from identifiers to affine functions on domain elements. @@ -247,6 +267,10 @@ struct ScheduleTreeSequence : public ScheduleTree { virtual ScheduleTreeUPtr clone() const override { return make(this); } + virtual bool nodeEquals(const ScheduleTree* other) const override { + auto otherSequence = other->as(); + return otherSequence && *this == *otherSequence; + } }; struct ScheduleTreeSet : public ScheduleTree { @@ -277,6 +301,10 @@ struct ScheduleTreeSet : public ScheduleTree { virtual ScheduleTreeUPtr clone() const override { return make(this); } + virtual bool nodeEquals(const ScheduleTree* other) const override { + auto otherSet = other->as(); + return otherSet && *this == *otherSet; + } }; struct ScheduleTreeBand : public ScheduleTree { @@ -304,6 +332,10 @@ struct ScheduleTreeBand : public ScheduleTree { virtual ScheduleTreeUPtr clone() const override { return make(this); } + virtual bool nodeEquals(const ScheduleTree* other) const override { + auto otherBand = other->as(); + return otherBand && *this == *otherBand; + } // Make a schedule node band from partial schedule. // Replace "mupa" by its greatest integer part to ensure that the @@ -380,6 +412,10 @@ struct ScheduleTreeThreadSpecificMarker : public ScheduleTree { virtual ScheduleTreeUPtr clone() const override { return make(this); } + virtual bool nodeEquals(const ScheduleTree* other) const override { + auto otherMarker = other->as(); + return otherMarker && *this == *otherMarker; + } }; bool elemEquals( From 4becd3d1181b3c7a4b7b65d83e46cd95b09c61d1 Mon Sep 17 00:00:00 2001 From: Oleksandr Zinenko Date: Wed, 25 Jul 2018 11:55:58 +0200 Subject: [PATCH 3/5] ScheduleTree: drop external elemEquals in favor of nodeEquals The elemEquals function, as its name suggests, was introduced when ScheduleTreeElement was a separate concept. It was kept mostly intact to minimize the changes when individual tree node classes were made inheriting ScheduleTree instead. Replace elemEquals with a call to the virtual function nodeEquals that leverages polymorphism to perform dynmaic dispatch instead of manual macro-based one. --- tc/core/polyhedral/schedule_tree.cc | 2 +- tc/core/polyhedral/schedule_tree_elem.cc | 30 ------------------------ tc/core/polyhedral/schedule_tree_elem.h | 5 ---- 3 files changed, 1 insertion(+), 36 deletions(-) diff --git a/tc/core/polyhedral/schedule_tree.cc b/tc/core/polyhedral/schedule_tree.cc index c17c2542a..1fbe57f68 100644 --- a/tc/core/polyhedral/schedule_tree.cc +++ b/tc/core/polyhedral/schedule_tree.cc @@ -344,7 +344,7 @@ bool ScheduleTree::operator==(const ScheduleTree& other) const { if (children_.size() != other.children_.size()) { return false; } - if (!elemEquals(this, &other, type_)) { + if (!this->nodeEquals(&other)) { return false; } TC_CHECK(!other.as()) diff --git a/tc/core/polyhedral/schedule_tree_elem.cc b/tc/core/polyhedral/schedule_tree_elem.cc index 1a7102a45..af3c8afbb 100644 --- a/tc/core/polyhedral/schedule_tree_elem.cc +++ b/tc/core/polyhedral/schedule_tree_elem.cc @@ -385,36 +385,6 @@ bool ScheduleTreeSequence::operator==(const ScheduleTreeSequence& other) const { bool ScheduleTreeSet::operator==(const ScheduleTreeSet& other) const { return true; } - -bool elemEquals( - const ScheduleTree* e1, - const ScheduleTree* e2, - detail::ScheduleTreeType type) { -#define ELEM_EQUALS_CASE(CLASS) \ - else if (type == CLASS::NodeType) { \ - return *static_cast(e1) == *static_cast(e2); \ - } - - if (type == detail::ScheduleTreeType::None) { - LOG(FATAL) << "Hit Error node!"; - } - ELEM_EQUALS_CASE(ScheduleTreeBand) - ELEM_EQUALS_CASE(ScheduleTreeContext) - ELEM_EQUALS_CASE(ScheduleTreeDomain) - ELEM_EQUALS_CASE(ScheduleTreeExtension) - ELEM_EQUALS_CASE(ScheduleTreeFilter) - ELEM_EQUALS_CASE(ScheduleTreeMapping) - ELEM_EQUALS_CASE(ScheduleTreeSequence) - ELEM_EQUALS_CASE(ScheduleTreeSet) - else { - LOG(FATAL) << "NYI: ScheduleTree::operator== for type: " - << static_cast(type); - } - -#undef ELEM_EQUALS_CASE - - return false; -} } // namespace detail } // namespace polyhedral } // namespace tc diff --git a/tc/core/polyhedral/schedule_tree_elem.h b/tc/core/polyhedral/schedule_tree_elem.h index 86d7010d2..7ee732f77 100644 --- a/tc/core/polyhedral/schedule_tree_elem.h +++ b/tc/core/polyhedral/schedule_tree_elem.h @@ -418,11 +418,6 @@ struct ScheduleTreeThreadSpecificMarker : public ScheduleTree { } }; -bool elemEquals( - const ScheduleTree* e1, - const ScheduleTree* e2, - detail::ScheduleTreeType type); - std::ostream& operator<<(std::ostream& os, detail::ScheduleTreeType nt); std::ostream& operator<<( std::ostream& os, From 82656e91d78b8b7031253a66a0e8371b34e95d29 Mon Sep 17 00:00:00 2001 From: Oleksandr Zinenko Date: Wed, 25 Jul 2018 12:28:23 +0200 Subject: [PATCH 4/5] ScheduleTree*: replace comparison operators with nodeEquals In tree-like structures, the behavior of the default comparison operators (operator== and operator!=) are not intuitive. They may compare individual tree nodes or the subtrees rooted at those nodes. The current implementation made this confusion even worse where ScheduleTree::operator== compared subtrees but ScheduleTree*::operator== did not. Replace comparison operators by type-safe non-virtual overloads of nodeEquals in specific tree node classes. Subtree comparison can be performed by a dedicated method that will be introduced next. --- tc/core/polyhedral/schedule_tree_elem.cc | 69 ++++++++++++---------- tc/core/polyhedral/schedule_tree_elem.h | 74 ++++++------------------ 2 files changed, 58 insertions(+), 85 deletions(-) diff --git a/tc/core/polyhedral/schedule_tree_elem.cc b/tc/core/polyhedral/schedule_tree_elem.cc index af3c8afbb..77973894a 100644 --- a/tc/core/polyhedral/schedule_tree_elem.cc +++ b/tc/core/polyhedral/schedule_tree_elem.cc @@ -281,21 +281,26 @@ ScheduleTreeThreadSpecificMarker::make( return res; } -bool ScheduleTreeBand::operator==(const ScheduleTreeBand& other) const { - if (permutable_ != other.permutable_) { +bool ScheduleTreeBand::nodeEquals(const ScheduleTreeBand* otherBand) const { + if (!otherBand) { return false; } - if (coincident_.size() != other.coincident_.size()) { + if (permutable_ != otherBand->permutable_) { return false; } - if (unroll_.size() != other.unroll_.size()) { + if (coincident_.size() != otherBand->coincident_.size()) { + return false; + } + if (unroll_.size() != otherBand->unroll_.size()) { return false; } if (!std::equal( - coincident_.begin(), coincident_.end(), other.coincident_.begin())) { + coincident_.begin(), + coincident_.end(), + otherBand->coincident_.begin())) { return false; } - if (!std::equal(unroll_.begin(), unroll_.end(), other.unroll_.begin())) { + if (!std::equal(unroll_.begin(), unroll_.end(), otherBand->unroll_.begin())) { return false; } @@ -305,13 +310,13 @@ bool ScheduleTreeBand::operator==(const ScheduleTreeBand& other) const { // .domain() returns a zero-dimensional union set (in purely parameter space) // if there is no explicit domain. bool mupaIs0D = nMember() == 0; - bool otherMupaIs0D = other.nMember() == 0; + bool otherMupaIs0D = otherBand->nMember() == 0; if (mupaIs0D ^ otherMupaIs0D) { return false; } if (mupaIs0D && otherMupaIs0D) { auto d1 = mupa_.domain(); - auto d2 = other.mupa_.domain(); + auto d2 = otherBand->mupa_.domain(); auto res = d1.is_equal(d2); if (!res) { LOG_IF(INFO, FLAGS_debug_tc_mapper) @@ -322,7 +327,7 @@ bool ScheduleTreeBand::operator==(const ScheduleTreeBand& other) const { } } else { auto m1 = isl::union_map::from(mupa_); - auto m2 = isl::union_map::from(other.mupa_); + auto m2 = isl::union_map::from(otherBand->mupa_); { auto res = m1.is_equal(m2); if (!res) { @@ -337,54 +342,60 @@ bool ScheduleTreeBand::operator==(const ScheduleTreeBand& other) const { return true; } -bool ScheduleTreeContext::operator==(const ScheduleTreeContext& other) const { - auto res = context_.is_equal(other.context_); - return res; +bool ScheduleTreeContext::nodeEquals(const ScheduleTreeContext* other) const { + return other && context_.is_equal(other->context_); } -bool ScheduleTreeDomain::operator==(const ScheduleTreeDomain& other) const { - auto res = domain_.is_equal(other.domain_); +bool ScheduleTreeDomain::nodeEquals(const ScheduleTreeDomain* other) const { + if (!other) { + return false; + } + auto res = domain_.is_equal(other->domain_); if (!res) { LOG_IF(INFO, FLAGS_debug_tc_mapper) << "ScheduleTreeDomain difference: " << domain_ << " VS " - << other.domain_ << "\n"; + << other->domain_ << "\n"; } return res; } -bool ScheduleTreeExtension::operator==( - const ScheduleTreeExtension& other) const { - auto res = extension_.is_equal(other.extension_); - return res; +bool ScheduleTreeExtension::nodeEquals( + const ScheduleTreeExtension* other) const { + return other && extension_.is_equal(other->extension_); } -bool ScheduleTreeFilter::operator==(const ScheduleTreeFilter& other) const { - auto res = filter_.is_equal(other.filter_); - return res; +bool ScheduleTreeFilter::nodeEquals(const ScheduleTreeFilter* other) const { + return other && filter_.is_equal(other->filter_); } -bool ScheduleTreeMapping::operator==(const ScheduleTreeMapping& other) const { - if (mapping.size() != other.mapping.size()) { +bool ScheduleTreeMapping::nodeEquals(const ScheduleTreeMapping* other) const { + if (mapping.size() != other->mapping.size()) { return false; } for (const auto& kvp : mapping) { - if (other.mapping.count(kvp.first) == 0) { + if (other->mapping.count(kvp.first) == 0) { return false; } - if (!other.mapping.at(kvp.first).plain_is_equal(kvp.second)) { + if (!other->mapping.at(kvp.first).plain_is_equal(kvp.second)) { return false; } } - return filter_.is_equal(other.filter_); + return filter_.is_equal(other->filter_); } -bool ScheduleTreeSequence::operator==(const ScheduleTreeSequence& other) const { +bool ScheduleTreeSequence::nodeEquals(const ScheduleTreeSequence* other) const { return true; } -bool ScheduleTreeSet::operator==(const ScheduleTreeSet& other) const { +bool ScheduleTreeSet::nodeEquals(const ScheduleTreeSet* other) const { return true; } + +bool ScheduleTreeThreadSpecificMarker::nodeEquals( + const ScheduleTreeThreadSpecificMarker* other) const { + return true; +} + } // namespace detail } // namespace polyhedral } // namespace tc diff --git a/tc/core/polyhedral/schedule_tree_elem.h b/tc/core/polyhedral/schedule_tree_elem.h index 7ee732f77..bc6e47ddf 100644 --- a/tc/core/polyhedral/schedule_tree_elem.h +++ b/tc/core/polyhedral/schedule_tree_elem.h @@ -52,19 +52,15 @@ struct ScheduleTreeContext : public ScheduleTree { const ScheduleTreeContext* tree, std::vector&& children = {}); - bool operator==(const ScheduleTreeContext& other) const; - bool operator!=(const ScheduleTreeContext& other) const { - return !(*this == other); - } - virtual std::ostream& write(std::ostream& os) const override; virtual ScheduleTreeUPtr clone() const override { return make(this); } virtual bool nodeEquals(const ScheduleTree* other) const override { auto otherContext = other->as(); - return otherContext && *this == *otherContext; + return otherContext && nodeEquals(otherContext); } + bool nodeEquals(const ScheduleTreeContext* otherContext) const; public: isl::set context_; @@ -92,19 +88,15 @@ struct ScheduleTreeDomain : public ScheduleTree { const ScheduleTreeDomain* tree, std::vector&& children = {}); - bool operator==(const ScheduleTreeDomain& other) const; - bool operator!=(const ScheduleTreeDomain& other) const { - return !(*this == other); - } - virtual std::ostream& write(std::ostream& os) const override; virtual ScheduleTreeUPtr clone() const override { return make(this); } virtual bool nodeEquals(const ScheduleTree* other) const override { auto otherDomain = other->as(); - return otherDomain && *this == *otherDomain; + return otherDomain && nodeEquals(otherDomain); } + bool nodeEquals(const ScheduleTreeDomain* otherDomain) const; public: isl::union_set domain_; @@ -132,19 +124,15 @@ struct ScheduleTreeExtension : public ScheduleTree { const ScheduleTreeExtension* tree, std::vector&& children = {}); - bool operator==(const ScheduleTreeExtension& other) const; - bool operator!=(const ScheduleTreeExtension& other) const { - return !(*this == other); - } - virtual std::ostream& write(std::ostream& os) const override; virtual ScheduleTreeUPtr clone() const override { return make(this); } virtual bool nodeEquals(const ScheduleTree* other) const override { auto otherExtension = other->as(); - return otherExtension && *this == *otherExtension; + return otherExtension && nodeEquals(otherExtension); } + bool nodeEquals(const ScheduleTreeExtension* otherExtension) const; public: isl::union_map extension_; @@ -165,11 +153,6 @@ struct ScheduleTreeFilter : public ScheduleTree { public: virtual ~ScheduleTreeFilter() override {} - bool operator==(const ScheduleTreeFilter& other) const; - bool operator!=(const ScheduleTreeFilter& other) const { - return !(*this == other); - } - static std::unique_ptr make( isl::union_set filter, std::vector&& children = {}); @@ -183,8 +166,9 @@ struct ScheduleTreeFilter : public ScheduleTree { } virtual bool nodeEquals(const ScheduleTree* other) const override { auto otherFilter = other->as(); - return otherFilter && *this == *otherFilter; + return otherFilter && nodeEquals(otherFilter); } + bool nodeEquals(const ScheduleTreeFilter* otherFilter) const; public: isl::union_set filter_; @@ -209,11 +193,6 @@ struct ScheduleTreeMapping : public ScheduleTree { public: virtual ~ScheduleTreeMapping() override {} - bool operator==(const ScheduleTreeMapping& other) const; - bool operator!=(const ScheduleTreeMapping& other) const { - return !(*this == other); - } - static std::unique_ptr make( isl::ctx ctx, const Mapping& mapping, @@ -228,8 +207,9 @@ struct ScheduleTreeMapping : public ScheduleTree { } virtual bool nodeEquals(const ScheduleTree* other) const override { auto otherMapping = other->as(); - return otherMapping && *this == *otherMapping; + return otherMapping && nodeEquals(otherMapping); } + bool nodeEquals(const ScheduleTreeMapping* otherMapping) const; public: // Mapping from identifiers to affine functions on domain elements. @@ -251,11 +231,6 @@ struct ScheduleTreeSequence : public ScheduleTree { public: virtual ~ScheduleTreeSequence() override {} - bool operator==(const ScheduleTreeSequence& other) const; - bool operator!=(const ScheduleTreeSequence& other) const { - return !(*this == other); - } - static std::unique_ptr make( isl::ctx ctx, std::vector&& children = {}); @@ -269,8 +244,9 @@ struct ScheduleTreeSequence : public ScheduleTree { } virtual bool nodeEquals(const ScheduleTree* other) const override { auto otherSequence = other->as(); - return otherSequence && *this == *otherSequence; + return otherSequence && nodeEquals(otherSequence); } + bool nodeEquals(const ScheduleTreeSequence* otherSequence) const; }; struct ScheduleTreeSet : public ScheduleTree { @@ -285,11 +261,6 @@ struct ScheduleTreeSet : public ScheduleTree { public: virtual ~ScheduleTreeSet() override {} - bool operator==(const ScheduleTreeSet& other) const; - bool operator!=(const ScheduleTreeSet& other) const { - return !(*this == other); - } - static std::unique_ptr make( isl::ctx ctx, std::vector&& children = {}); @@ -303,8 +274,9 @@ struct ScheduleTreeSet : public ScheduleTree { } virtual bool nodeEquals(const ScheduleTree* other) const override { auto otherSet = other->as(); - return otherSet && *this == *otherSet; + return otherSet && nodeEquals(otherSet); } + bool nodeEquals(const ScheduleTreeSet* otherSet) const; }; struct ScheduleTreeBand : public ScheduleTree { @@ -323,19 +295,15 @@ struct ScheduleTreeBand : public ScheduleTree { virtual ~ScheduleTreeBand() override {} - bool operator==(const ScheduleTreeBand& other) const; - bool operator!=(const ScheduleTreeBand& other) const { - return !(*this == other); - } - virtual std::ostream& write(std::ostream& os) const override; virtual ScheduleTreeUPtr clone() const override { return make(this); } virtual bool nodeEquals(const ScheduleTree* other) const override { auto otherBand = other->as(); - return otherBand && *this == *otherBand; + return otherBand && nodeEquals(otherBand); } + bool nodeEquals(const ScheduleTreeBand* other) const; // Make a schedule node band from partial schedule. // Replace "mupa" by its greatest integer part to ensure that the @@ -394,13 +362,6 @@ struct ScheduleTreeThreadSpecificMarker : public ScheduleTree { public: virtual ~ScheduleTreeThreadSpecificMarker() override {} - bool operator==(const ScheduleTreeThreadSpecificMarker& other) const { - return true; - } - bool operator!=(const ScheduleTreeThreadSpecificMarker& other) const { - return !(*this == other); - } - static std::unique_ptr make( isl::ctx ctx, std::vector&& children = {}); @@ -414,8 +375,9 @@ struct ScheduleTreeThreadSpecificMarker : public ScheduleTree { } virtual bool nodeEquals(const ScheduleTree* other) const override { auto otherMarker = other->as(); - return otherMarker && *this == *otherMarker; + return otherMarker && nodeEquals(otherMarker); } + bool nodeEquals(const ScheduleTreeThreadSpecificMarker* other) const; }; std::ostream& operator<<(std::ostream& os, detail::ScheduleTreeType nt); From 597e1da671859218d39598f66a65e1642dbebba6 Mon Sep 17 00:00:00 2001 From: Oleksandr Zinenko Date: Wed, 25 Jul 2018 13:29:57 +0200 Subject: [PATCH 5/5] ScheduleTree: replace operator== with treeEquals In tree-like structures, the behavior of the default comparison operators (operator== and operator!=) is not intuitive. They may compare individual tree nodes or the subtrees rooted at those nodes. Replace the equality comparison operator on ScheduleTree with the treeEquals method, which makes it clear that subtrees are compared (as opposed to nodeEquals introduced previously). Removing the overloaded comparison operators may make it harder to use standard containers and algorithms on ScheduleTrees. However, the caller is never supposed to operate on ScheduleTrees by-value, and pointers are trivially comparable. Internal functions may define and use a comparator class with clear intended behavior when necessary. For external uses, explicitly-named functions offer a better alternative. --- tc/core/polyhedral/schedule_isl_conversion.cc | 2 +- tc/core/polyhedral/schedule_tree.cc | 16 ++++++---------- tc/core/polyhedral/schedule_tree.h | 9 ++++----- test/test_cuda_mapper.cc | 5 +++-- 4 files changed, 14 insertions(+), 18 deletions(-) diff --git a/tc/core/polyhedral/schedule_isl_conversion.cc b/tc/core/polyhedral/schedule_isl_conversion.cc index 2b01979b9..e6dae4452 100644 --- a/tc/core/polyhedral/schedule_isl_conversion.cc +++ b/tc/core/polyhedral/schedule_isl_conversion.cc @@ -311,7 +311,7 @@ std::unique_ptr fromIslSchedule(isl::schedule schedule) { // Note that the children of set and sequence nodes are always filters, so // they cannot be replaced by empty trees. bool validateSchedule(const ScheduleTree* st) { - return *st == *fromIslSchedule(toIslSchedule(st)); + return st->treeEquals(fromIslSchedule(toIslSchedule(st)).get()); } bool validateSchedule(isl::schedule sc) { diff --git a/tc/core/polyhedral/schedule_tree.cc b/tc/core/polyhedral/schedule_tree.cc index 1fbe57f68..c5ba49a45 100644 --- a/tc/core/polyhedral/schedule_tree.cc +++ b/tc/core/polyhedral/schedule_tree.cc @@ -336,21 +336,17 @@ vector ScheduleTree::collectDFSPreorder( return functional::Filter(filterType, collectDFSPreorder(tree)); } -bool ScheduleTree::operator==(const ScheduleTree& other) const { - // ctx_ cmp ? - if (type_ != other.type_) { +bool ScheduleTree::treeEquals(const ScheduleTree* other) const { + if (!nodeEquals(other)) { return false; } - if (children_.size() != other.children_.size()) { + if (numChildren() != other->numChildren()) { return false; } - if (!this->nodeEquals(&other)) { - return false; - } - TC_CHECK(!other.as()) + TC_CHECK(!other->as()) << "NYI: ScheduleTreeType::Set comparison"; - for (size_t i = 0; i < children_.size(); ++i) { - if (*children_[i] != *other.children_[i]) { + for (size_t i = 0, e = numChildren(); i < e; ++i) { + if (!child({i})->treeEquals(other->child({i}))) { return false; } } diff --git a/tc/core/polyhedral/schedule_tree.h b/tc/core/polyhedral/schedule_tree.h index f0ec82cff..92144087f 100644 --- a/tc/core/polyhedral/schedule_tree.h +++ b/tc/core/polyhedral/schedule_tree.h @@ -156,11 +156,6 @@ struct ScheduleTree { public: virtual ~ScheduleTree(); - bool operator==(const ScheduleTree& other) const; - bool operator!=(const ScheduleTree& other) const { - return !(*this == other); - } - // Swap a tree with with the given tree. void swapChild(size_t pos, ScheduleTreeUPtr& swappee) { TC_CHECK_GE(pos, 0u) << "position out of children bounds"; @@ -474,6 +469,10 @@ struct ScheduleTree { // use treeEquals() instead to compare entire trees. virtual bool nodeEquals(const ScheduleTree* other) const = 0; + // Comapre the subtree rooted at the current node to the subtree + // rooted at "other". + bool treeEquals(const ScheduleTree* other) const; + // // Data members // diff --git a/test/test_cuda_mapper.cc b/test/test_cuda_mapper.cc index 545610a6e..48cf236c9 100644 --- a/test/test_cuda_mapper.cc +++ b/test/test_cuda_mapper.cc @@ -144,8 +144,9 @@ struct PolyhedralMapperTest : public ::testing::Test { islNode = islNode.as().tile(mv); auto scheduleISL = fromIslSchedule(islNode.get_schedule().reset_user()); - ASSERT_TRUE(*scheduleISL == *scheduleISLPP) << *scheduleISL << "\nVS\n" - << *scheduleISLPP; + ASSERT_TRUE(scheduleISL->treeEquals(scheduleISLPP.get())) + << *scheduleISL << "\nVS\n" + << *scheduleISLPP; } }