Skip to content
This repository was archived by the owner on Apr 23, 2021. It is now read-only.

Commit a3ba537

Browse files
MLIR Teamtensorflower-gardener
authored andcommitted
Retain address space during MLIR > LLVM conversion.
PiperOrigin-RevId: 267206460
1 parent 68dc98b commit a3ba537

File tree

2 files changed

+20
-12
lines changed

2 files changed

+20
-12
lines changed

lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -124,14 +124,14 @@ Type LLVMTypeConverter::convertFunctionType(FunctionType type) {
124124

125125
// Convert a MemRef to an LLVM type. If the memref is statically-shaped, then
126126
// we return a pointer to the converted element type. Otherwise we return an
127-
// LLVM stucture type, where the first element of the structure type is a
127+
// LLVM structure type, where the first element of the structure type is a
128128
// pointer to the elemental type of the MemRef and the following N elements are
129129
// values of the Index type, one for each of N dynamic dimensions of the MemRef.
130130
Type LLVMTypeConverter::convertMemRefType(MemRefType type) {
131131
LLVM::LLVMType elementType = unwrap(convertType(type.getElementType()));
132132
if (!elementType)
133133
return {};
134-
auto ptrType = elementType.getPointerTo();
134+
auto ptrType = elementType.getPointerTo(type.getMemorySpace());
135135

136136
// Extra value for the memory space.
137137
unsigned numDynamicSizes = type.getNumDynamicDims();
@@ -189,7 +189,7 @@ static Type getMemRefElementPtrType(MemRefType t, LLVMTypeConverter &lowering) {
189189
auto converted = lowering.convertType(elementType);
190190
if (!converted)
191191
return {};
192-
return converted.cast<LLVM::LLVMType>().getPointerTo();
192+
return converted.cast<LLVM::LLVMType>().getPointerTo(t.getMemorySpace());
193193
}
194194

195195
LLVMOpLowering::LLVMOpLowering(StringRef rootOpName, MLIRContext *context,
@@ -514,14 +514,11 @@ struct ConstLLVMOpLowering
514514
using Super::Super;
515515
};
516516

517-
// Check if the MemRefType `type` is supported by the lowering. We currently do
518-
// not support memrefs with affine maps and non-default memory spaces.
517+
// Check if the MemRefType `type` is supported by the lowering. We currently
518+
// only support memrefs with identity maps.
519519
static bool isSupportedMemRefType(MemRefType type) {
520-
if (!type.getAffineMaps().empty())
521-
return false;
522-
if (type.getMemorySpace() != 0)
523-
return false;
524-
return true;
520+
return llvm::all_of(type.getAffineMaps(),
521+
[](AffineMap map) { return map.isIdentity(); });
525522
}
526523

527524
// An `alloc` is converted into a definition of a memref descriptor value and
@@ -598,8 +595,8 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
598595
cumulativeSize)
599596
.getResult(0);
600597
auto structElementType = lowering.convertType(elementType);
601-
auto elementPtrType =
602-
structElementType.cast<LLVM::LLVMType>().getPointerTo();
598+
auto elementPtrType = structElementType.cast<LLVM::LLVMType>().getPointerTo(
599+
type.getMemorySpace());
603600
allocated = rewriter.create<LLVM::BitcastOp>(op->getLoc(), elementPtrType,
604601
ArrayRef<Value *>(allocated));
605602

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
// RUN: mlir-opt %s -lower-to-llvm | FileCheck %s
2+
3+
// CHECK-LABEL: func @address_space(%{{.*}}: !llvm<"float addrspace(7)*">)
4+
func @address_space(%arg0 : memref<32xf32, (d0) -> (d0), 7>) {
5+
%0 = alloc() : memref<32xf32, (d0) -> (d0), 5>
6+
%1 = constant 7 : index
7+
// CHECK: llvm.load %{{.*}} : !llvm<"float addrspace(5)*">
8+
%2 = load %0[%1] : memref<32xf32, (d0) -> (d0), 5>
9+
std.return
10+
}
11+

0 commit comments

Comments
 (0)