diff --git a/clang/lib/AST/ByteCode/Interp.h b/clang/lib/AST/ByteCode/Interp.h index dcc4587751974..cc146157ac3be 100644 --- a/clang/lib/AST/ByteCode/Interp.h +++ b/clang/lib/AST/ByteCode/Interp.h @@ -1143,30 +1143,12 @@ inline bool CmpHelperEQ(InterpState &S, CodePtr OpPC, CompareFn Fn) { } if (Pointer::hasSameBase(LHS, RHS)) { - if (LHS.inUnion() && RHS.inUnion()) { - // If the pointers point into a union, things are a little more - // complicated since the offset we save in interp::Pointer can't be used - // to compare the pointers directly. - size_t A = LHS.computeOffsetForComparison(); - size_t B = RHS.computeOffsetForComparison(); - S.Stk.push(BoolT::from(Fn(Compare(A, B)))); - return true; - } - - unsigned VL = LHS.getByteOffset(); - unsigned VR = RHS.getByteOffset(); - // In our Pointer class, a pointer to an array and a pointer to the first - // element in the same array are NOT equal. They have the same Base value, - // but a different Offset. This is a pretty rare case, so we fix this here - // by comparing pointers to the first elements. - if (!LHS.isZero() && LHS.isArrayRoot()) - VL = LHS.atIndex(0).getByteOffset(); - if (!RHS.isZero() && RHS.isArrayRoot()) - VR = RHS.atIndex(0).getByteOffset(); - - S.Stk.push(BoolT::from(Fn(Compare(VL, VR)))); + size_t A = LHS.computeOffsetForComparison(); + size_t B = RHS.computeOffsetForComparison(); + S.Stk.push(BoolT::from(Fn(Compare(A, B)))); return true; } + // Otherwise we need to do a bunch of extra checks before returning Unordered. if (LHS.isOnePastEnd() && !RHS.isOnePastEnd() && !RHS.isZero() && RHS.getOffset() == 0) { diff --git a/clang/lib/AST/ByteCode/Pointer.cpp b/clang/lib/AST/ByteCode/Pointer.cpp index f0b0384f32ac8..7bfdb641abf62 100644 --- a/clang/lib/AST/ByteCode/Pointer.cpp +++ b/clang/lib/AST/ByteCode/Pointer.cpp @@ -349,16 +349,28 @@ void Pointer::print(llvm::raw_ostream &OS) const { } } -/// Compute an integer that can be used to compare this pointer to -/// another one. size_t Pointer::computeOffsetForComparison() const { + if (isIntegralPointer()) + return asIntPointer().Value + Offset; + if (isTypeidPointer()) + return reinterpret_cast(asTypeidPointer().TypePtr) + Offset; + if (!isBlockPointer()) return Offset; size_t Result = 0; Pointer P = *this; - while (!P.isRoot()) { - if (P.isArrayRoot()) { + while (true) { + + if (P.isVirtualBaseClass()) { + Result += getInlineDesc()->Offset; + P = P.getBase(); + continue; + } + + if (P.isBaseClass()) { + if (P.getRecord()->getNumVirtualBases() > 0) + Result += P.getInlineDesc()->Offset; P = P.getBase(); continue; } @@ -369,14 +381,23 @@ size_t Pointer::computeOffsetForComparison() const { continue; } + if (P.isRoot()) { + if (P.isOnePastEnd()) + Result += P.Offset; + break; + } + if (const Record *R = P.getBase().getRecord(); R && R->isUnion()) { // Direct child of a union - all have offset 0. P = P.getBase(); continue; } + // Fields, etc. Result += P.getInlineDesc()->Offset; P = P.getBase(); + if (P.isRoot()) + break; } return Result; diff --git a/clang/lib/AST/ByteCode/Pointer.h b/clang/lib/AST/ByteCode/Pointer.h index 0234ab02ab8f6..d525f84fd6605 100644 --- a/clang/lib/AST/ByteCode/Pointer.h +++ b/clang/lib/AST/ByteCode/Pointer.h @@ -761,6 +761,9 @@ class Pointer { /// Prints the pointer. void print(llvm::raw_ostream &OS) const; + /// Compute an integer that can be used to compare this pointer to + /// another one. This is usually NOT the same as the pointer offset + /// regarding the AST record layout. size_t computeOffsetForComparison() const; private: diff --git a/clang/test/AST/ByteCode/new-delete.cpp b/clang/test/AST/ByteCode/new-delete.cpp index 9c293e5d15fc8..840736f332250 100644 --- a/clang/test/AST/ByteCode/new-delete.cpp +++ b/clang/test/AST/ByteCode/new-delete.cpp @@ -1022,6 +1022,42 @@ namespace OpNewNothrow { // both-note {{in call to}} } +namespace BaseCompare { + struct Cmp { + void *p; + + template + constexpr Cmp(T *t) : p(t) {} + + constexpr friend bool operator==(Cmp a, Cmp b) { + return a.p == b.p; + } + }; + + class Base {}; + class Derived : public Base {}; + constexpr bool foo() { + Derived *D = std::allocator{}.allocate(1);; + std::construct_at(D); + + Derived *d = D; + Base *b = D; + + Cmp ca(d); + Cmp cb(b); + + if (ca == cb) { + std::allocator{}.deallocate(D); + return true; + } + std::allocator{}.deallocate(D); + + return false; + + } + static_assert(foo()); +} + #else /// Make sure we reject this prior to C++20 constexpr int a() { // both-error {{never produces a constant expression}}