Skip to content

Commit 898b7cc

Browse files
author
Mahmood Yassin
committed
Add createCastsForTypeOfSameSize
1 parent 411a74a commit 898b7cc

File tree

5 files changed

+115
-52
lines changed

5 files changed

+115
-52
lines changed

clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -574,7 +574,9 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
574574

575575
mlir::Value createPtrBitcast(mlir::Value src, mlir::Type newPointeeTy) {
576576
assert(mlir::isa<cir::PointerType>(src.getType()) && "expected ptr src");
577-
return createBitcast(src, getPointerTo(newPointeeTy));
577+
auto srcPtrTy = mlir::cast<cir::PointerType>(src.getType());
578+
mlir::Type newPtrTy = getPointerTo(newPointeeTy, srcPtrTy.getAddrSpace());
579+
return createBitcast(src, newPtrTy);
578580
}
579581

580582
mlir::Value createAddrSpaceCast(mlir::Location loc, mlir::Value src,
@@ -586,6 +588,29 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
586588
return createAddrSpaceCast(src.getLoc(), src, newTy);
587589
}
588590

591+
mlir::Value createPointerBitCastOrAddrSpaceCast(mlir::Location loc,
592+
mlir::Value src,
593+
mlir::Type newPointerTy) {
594+
assert(mlir::isa<cir::PointerType>(src.getType()) &&
595+
"expected source pointer");
596+
assert(mlir::isa<cir::PointerType>(newPointerTy) &&
597+
"expected destination pointer type");
598+
599+
auto srcPtrTy = mlir::cast<cir::PointerType>(src.getType());
600+
auto dstPtrTy = mlir::cast<cir::PointerType>(newPointerTy);
601+
602+
mlir::Value addrSpaceCasted = src;
603+
if (srcPtrTy.getAddrSpace() != dstPtrTy.getAddrSpace())
604+
addrSpaceCasted = createAddrSpaceCast(loc, src, dstPtrTy);
605+
606+
return createPtrBitcast(addrSpaceCasted, dstPtrTy.getPointee());
607+
}
608+
609+
mlir::Value createPointerBitCastOrAddrSpaceCast(mlir::Value src,
610+
mlir::Type newPointerTy) {
611+
return createPointerBitCastOrAddrSpaceCast(src.getLoc(), src, newPointerTy);
612+
}
613+
589614
mlir::Value createPtrIsNull(mlir::Value ptr) {
590615
return createNot(createPtrToBoolCast(ptr));
591616
}

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

Lines changed: 48 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -821,20 +821,56 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
821821
llvm_unreachable("NYI");
822822
}
823823

824+
// Create cast instructions for converting LLVM value Src to MLIR type DstTy.
825+
// Src has the same size as DstTy. Both are single value types
826+
// but could be scalar or vectors of different lengths, and either can be
827+
// pointer.
828+
mlir::Value createCastsForTypeOfSameSize(mlir::Value Src, mlir::Type DstTy) {
829+
auto SrcTy = Src.getType();
830+
831+
// Case 1.
832+
if (!isa<cir::PointerType>(SrcTy) && !isa<cir::PointerType>(DstTy))
833+
return Builder.createBitcast(Src, DstTy);
834+
835+
// Case 2.
836+
if (isa<cir::PointerType>(SrcTy) && isa<cir::PointerType>(DstTy))
837+
return Builder.createPointerBitCastOrAddrSpaceCast(Src, DstTy);
838+
839+
// Case 3.
840+
if (isa<cir::PointerType>(SrcTy) && !isa<cir::PointerType>(DstTy)) {
841+
// Case 3b.
842+
if (!Builder.isInt(DstTy))
843+
llvm_unreachable("NYI");
844+
// Cases 3a and 3b.
845+
llvm_unreachable("NYI");
846+
}
847+
848+
// Case 4b.
849+
if (!Builder.isInt(SrcTy))
850+
llvm_unreachable("NYI");
851+
852+
// Cases 4a and 4b.
853+
llvm_unreachable("NYI");
854+
}
855+
824856
mlir::Value VisitAsTypeExpr(AsTypeExpr *E) {
825-
mlir::Value src = CGF.emitScalarExpr(E->getSrcExpr());
857+
unsigned numSrcElems = 0;
826858
QualType qualSrcTy = E->getSrcExpr()->getType();
827-
QualType qualDstTy = E->getType();
828-
829859
mlir::Type srcTy = CGF.convertType(qualSrcTy);
830-
mlir::Type dstTy = CGF.convertType(qualDstTy);
831-
auto loc = CGF.getLoc(E->getExprLoc());
832-
833-
unsigned numSrcElems = 0, numDstElems = 0;
834-
if (auto v = dyn_cast<cir::VectorType>(srcTy))
860+
if (auto v = dyn_cast<cir::VectorType>(srcTy)) {
861+
assert(!cir::MissingFeatures::scalableVectors() &&
862+
"NYI: non-fixed (scalable) vector src");
835863
numSrcElems = v.getSize();
836-
if (auto v = dyn_cast<cir::VectorType>(dstTy))
864+
}
865+
866+
unsigned numDstElems = 0;
867+
QualType qualDstTy = E->getType();
868+
mlir::Type dstTy = CGF.convertType(qualDstTy);
869+
if (auto v = dyn_cast<cir::VectorType>(dstTy)) {
870+
assert(!cir::MissingFeatures::scalableVectors() &&
871+
"NYI: non-fixed (scalable) vector dst");
837872
numDstElems = v.getSize();
873+
}
838874

839875
// Use bit vector expansion for ext_vector_type boolean vectors.
840876
if (qualDstTy->isExtVectorBoolType()) {
@@ -854,12 +890,9 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
854890
llvm_unreachable("NYI");
855891
}
856892

857-
// If types are identical, return the source
858-
if (srcTy == dstTy)
859-
return src;
860-
861-
// Otherwise, fallback to CIR bitcast
862-
return cir::CastOp::create(Builder, loc, dstTy, cir::CastKind::bitcast, src);
893+
// Otherwise, fallback to bitcast of same size
894+
mlir::Value src = CGF.emitScalarExpr(E->getSrcExpr());
895+
return createCastsForTypeOfSameSize(src, dstTy);
863896
}
864897

865898
mlir::Value VisitAtomicExpr(AtomicExpr *E) {

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -662,8 +662,10 @@ LogicalResult cir::CastOp::verify() {
662662
auto resPtrTy = mlir::dyn_cast<cir::PointerType>(resType);
663663
if (!srcPtrTy || !resPtrTy)
664664
return emitOpError() << "requires !cir.ptr type for source and result";
665-
if (srcPtrTy.getPointee() != resPtrTy.getPointee())
666-
return emitOpError() << "requires two types differ in addrspace only";
665+
// Address space verification is sufficient here. The pointee types need not
666+
// be verified as they are handled by bitcast verification logic, which
667+
// ensures address space compatibility. Verifying pointee types would create
668+
// a circular dependency between address space and pointee type casting.
667669
return success();
668670
}
669671
case cir::CastKind::float_to_complex: {

clang/test/CIR/CodeGen/OpenCL/as_type.cl

Lines changed: 37 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,37 +7,49 @@
77
// RUN: %clang_cc1 %s -cl-std=CL2.0 -emit-llvm -triple spirv64-unknown-unknown -o %t.ll
88
// RUN: FileCheck %s --input-file=%t.ll --check-prefix=OG-LLVM
99

10-
typedef __attribute__(( ext_vector_type(3) )) char char3;
1110
typedef __attribute__(( ext_vector_type(4) )) char char4;
12-
typedef __attribute__(( ext_vector_type(16) )) char char16;
13-
typedef __attribute__(( ext_vector_type(3) )) int int3;
1411

15-
//CIR: cir.func @f4(%{{.*}}: !s32i loc({{.*}})) -> !cir.vector<!s8i x 4>
16-
//CIR: %[[x:.*]] = cir.load align(4) %{{.*}} : !cir.ptr<!s32i, addrspace(offload_private)>
17-
//CIR: cir.cast bitcast %[[x]] : !s32i -> !cir.vector<!s8i x 4>
18-
//LLVM: define spir_func <4 x i8> @f4(i32 %[[x:.*]])
19-
//LLVM: %[[astype:.*]] = bitcast i32 %[[x]] to <4 x i8>
20-
//LLVM-NOT: shufflevector
21-
//LLVM: ret <4 x i8> %[[astype]]
22-
//OG-LLVM: define spir_func noundef <4 x i8> @f4(i32 noundef %[[x:.*]])
23-
//OG-LLVM: %[[astype:.*]] = bitcast i32 %[[x]] to <4 x i8>
24-
//OG-LLVM-NOT: shufflevector
25-
//OG-LLVM: ret <4 x i8> %[[astype]]
12+
// CIR: cir.func @f4(%{{.*}}: !s32i loc({{.*}})) -> !cir.vector<!s8i x 4>
13+
// CIR: %[[x:.*]] = cir.load align(4) %{{.*}} : !cir.ptr<!s32i, addrspace(offload_private)>
14+
// CIR: cir.cast bitcast %[[x]] : !s32i -> !cir.vector<!s8i x 4>
15+
// LLVM: define spir_func <4 x i8> @f4(i32 %[[x:.*]])
16+
// LLVM: %[[astype:.*]] = bitcast i32 %[[x]] to <4 x i8>
17+
// LLVM-NOT: shufflevector
18+
// LLVM: ret <4 x i8> %[[astype]]
19+
// OG-LLVM: define spir_func noundef <4 x i8> @f4(i32 noundef %[[x:.*]])
20+
// OG-LLVM: %[[astype:.*]] = bitcast i32 %[[x]] to <4 x i8>
21+
// OG-LLVM-NOT: shufflevector
22+
// OG-LLVM: ret <4 x i8> %[[astype]]
2623
char4 f4(int x) {
2724
return __builtin_astype(x, char4);
2825
}
2926

30-
//CIR: cir.func @f6(%{{.*}}: !cir.vector<!s8i x 4> loc({{.*}})) -> !s32i
31-
//CIR: %[[x:.*]] = cir.load align(4) %{{.*}} : !cir.ptr<!cir.vector<!s8i x 4>, addrspace(offload_private)>, !cir.vector<!s8i x 4>
32-
//CIR: cir.cast bitcast %[[x]] : !cir.vector<!s8i x 4> -> !s32i
33-
//LLVM: define{{.*}} spir_func i32 @f6(<4 x i8> %[[x:.*]])
34-
//LLVM: %[[astype:.*]] = bitcast <4 x i8> %[[x]] to i32
35-
//LLVM-NOT: shufflevector
36-
//LLVM: ret i32 %[[astype]]
37-
//OG-LLVM: define{{.*}} spir_func noundef i32 @f6(<4 x i8> noundef %[[x:.*]])
38-
//OG-LLVM: %[[astype:.*]] = bitcast <4 x i8> %[[x]] to i32
39-
//OG-LLVM-NOT: shufflevector
40-
//OG-LLVM: ret i32 %[[astype]]
27+
// CIR: cir.func @f6(%{{.*}}: !cir.vector<!s8i x 4> loc({{.*}})) -> !s32i
28+
// CIR: %[[x:.*]] = cir.load align(4) %{{.*}} : !cir.ptr<!cir.vector<!s8i x 4>, addrspace(offload_private)>, !cir.vector<!s8i x 4>
29+
// CIR: cir.cast bitcast %[[x]] : !cir.vector<!s8i x 4> -> !s32i
30+
// LLVM: define{{.*}} spir_func i32 @f6(<4 x i8> %[[x:.*]])
31+
// LLVM: %[[astype:.*]] = bitcast <4 x i8> %[[x]] to i32
32+
// LLVM-NOT: shufflevector
33+
// LLVM: ret i32 %[[astype]]
34+
// OG-LLVM: define{{.*}} spir_func noundef i32 @f6(<4 x i8> noundef %[[x:.*]])
35+
// OG-LLVM: %[[astype:.*]] = bitcast <4 x i8> %[[x]] to i32
36+
// OG-LLVM-NOT: shufflevector
37+
// OG-LLVM: ret i32 %[[astype]]
4138
int f6(char4 x) {
4239
return __builtin_astype(x, int);
40+
}
41+
42+
// CIR: cir.func @f4_ptr(%{{.*}}: !cir.ptr<!s32i, addrspace(offload_global)> loc({{.*}})) -> !cir.ptr<!cir.vector<!s8i x 4>, addrspace(offload_local)>
43+
// CIR: %[[x:.*]] = cir.load align(8) %{{.*}} : !cir.ptr<!cir.ptr<!s32i, addrspace(offload_global)>, addrspace(offload_private)>, !cir.ptr<!s32i, addrspace(offload_global)>
44+
// CIR: cir.cast address_space %[[x]] : !cir.ptr<!s32i, addrspace(offload_global)> -> !cir.ptr<!cir.vector<!s8i x 4>, addrspace(offload_local)>
45+
// LLVM: define spir_func ptr addrspace(3) @f4_ptr(ptr addrspace(1) readnone captures(ret: address, provenance) %[[x:.*]])
46+
// LLVM: %[[astype:.*]] = addrspacecast ptr addrspace(1) %[[x]] to ptr addrspace(3)
47+
// LLVM-NOT: shufflevector
48+
// LLVM: ret ptr addrspace(3) %[[astype]]
49+
// OG-LLVM: define spir_func ptr addrspace(3) @f4_ptr(ptr addrspace(1) noundef readnone captures(ret: address, provenance) %[[x:.*]])
50+
// OG-LLVM: %[[astype:.*]] = addrspacecast ptr addrspace(1) %[[x]] to ptr addrspace(3)
51+
// OG-LLVM-NOT: shufflevector
52+
// OG-LLVM: ret ptr addrspace(3) %[[astype]]
53+
__local char4* f4_ptr(__global int* x) {
54+
return __builtin_astype(x, __local char4*);
4355
}

clang/test/CIR/IR/invalid.cir

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -300,15 +300,6 @@ cir.func @cast24(%p : !u32i) {
300300

301301
// -----
302302

303-
!u32i = !cir.int<u, 32>
304-
!u64i = !cir.int<u, 64>
305-
cir.func @cast25(%p : !cir.ptr<!u32i, addrspace(target<1>)>) {
306-
%0 = cir.cast address_space %p : !cir.ptr<!u32i, addrspace(target<1>)> -> !cir.ptr<!u64i, addrspace(target<2>)> // expected-error {{requires two types differ in addrspace only}}
307-
cir.return
308-
}
309-
310-
// -----
311-
312303
!u64i = !cir.int<u, 64>
313304
cir.func @cast26(%p : !cir.ptr<!u64i, addrspace(target<1>)>) {
314305
%0 = cir.cast address_space %p : !cir.ptr<!u64i, addrspace(target<1>)> -> !u64i // expected-error {{requires !cir.ptr type for source and result}}

0 commit comments

Comments
 (0)