@@ -124,14 +124,14 @@ Type LLVMTypeConverter::convertFunctionType(FunctionType type) {
124124
125125// Convert a MemRef to an LLVM type. If the memref is statically-shaped, then
126126// we return a pointer to the converted element type. Otherwise we return an
127- // LLVM stucture type, where the first element of the structure type is a
127+ // LLVM structure type, where the first element of the structure type is a
128128// pointer to the elemental type of the MemRef and the following N elements are
129129// values of the Index type, one for each of N dynamic dimensions of the MemRef.
130130Type LLVMTypeConverter::convertMemRefType (MemRefType type) {
131131 LLVM::LLVMType elementType = unwrap (convertType (type.getElementType ()));
132132 if (!elementType)
133133 return {};
134- auto ptrType = elementType.getPointerTo ();
134+ auto ptrType = elementType.getPointerTo (type. getMemorySpace () );
135135
136136 // Extra value for the memory space.
137137 unsigned numDynamicSizes = type.getNumDynamicDims ();
@@ -189,7 +189,7 @@ static Type getMemRefElementPtrType(MemRefType t, LLVMTypeConverter &lowering) {
189189 auto converted = lowering.convertType (elementType);
190190 if (!converted)
191191 return {};
192- return converted.cast <LLVM::LLVMType>().getPointerTo ();
192+ return converted.cast <LLVM::LLVMType>().getPointerTo (t. getMemorySpace () );
193193}
194194
195195LLVMOpLowering::LLVMOpLowering (StringRef rootOpName, MLIRContext *context,
@@ -514,14 +514,11 @@ struct ConstLLVMOpLowering
514514 using Super::Super;
515515};
516516
517- // Check if the MemRefType `type` is supported by the lowering. We currently do
518- // not support memrefs with affine maps and non-default memory spaces .
517+ // Check if the MemRefType `type` is supported by the lowering. We currently
518+ // only support memrefs with identity maps.
519519static bool isSupportedMemRefType (MemRefType type) {
520- if (!type.getAffineMaps ().empty ())
521- return false ;
522- if (type.getMemorySpace () != 0 )
523- return false ;
524- return true ;
520+ return llvm::all_of (type.getAffineMaps (),
521+ [](AffineMap map) { return map.isIdentity (); });
525522}
526523
527524// An `alloc` is converted into a definition of a memref descriptor value and
@@ -598,8 +595,8 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
598595 cumulativeSize)
599596 .getResult (0 );
600597 auto structElementType = lowering.convertType (elementType);
601- auto elementPtrType =
602- structElementType. cast <LLVM::LLVMType>(). getPointerTo ( );
598+ auto elementPtrType = structElementType. cast <LLVM::LLVMType>(). getPointerTo (
599+ type. getMemorySpace () );
603600 allocated = rewriter.create <LLVM::BitcastOp>(op->getLoc (), elementPtrType,
604601 ArrayRef<Value *>(allocated));
605602
0 commit comments