Skip to content

Commit 4d13aa2

Browse files
[MLIR] Legalize certain vector.transfer_read ops of scalable vectors
THis patch add a transform of `transfer_read` operation to change the vector type to one that can be mapped to an LLVM type. This is done by collapsing trailing dimensions so we obtain a vector type with a single scalable dimension in the rightmost position.
1 parent 1210d59 commit 4d13aa2

File tree

3 files changed

+407
-1
lines changed

3 files changed

+407
-1
lines changed

mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp

Lines changed: 109 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,113 @@ struct LegalizeSVEMaskLoadConversion : public OpRewritePattern<memref::LoadOp> {
298298
}
299299
};
300300

301+
/// Transforms a `transfer_read` operation so it reads vector of a type that
302+
/// can be mapped to an LLVM type. This is done by collapsing trailing
303+
/// dimensions so we obtain a vector type with a single scalable dimension in
304+
/// the rightmost position.
305+
///
306+
/// Example:
307+
/// ```
308+
/// %v = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8
309+
/// {in_bounds = [false, true, true, true]}
310+
/// : memref<?x?x2x8xi8>, vector<2x[4]x2x8xi8>
311+
/// ```
312+
/// is rewriten to
313+
/// ```
314+
/// %collapse_shape = memref.collapse_shape %M [[0], [1, 2, 3]]
315+
/// : memref<?x?x2x8xi8> into memref<?x?xi8>
316+
/// %0 = vector.transfer_read %collapse_shape[%i, %j], %c0_i8
317+
/// {in_bounds = [false, true]}
318+
/// : memref<?x?xi8>, vector<2x[64]xi8>
319+
/// %1 = vector.shape_cast %0 : vector<2x[64]xi8> to vector<2x[4]x2x8xi8>
320+
/// ```
321+
struct LegalizeTransferRead : public OpRewritePattern<vector::TransferReadOp> {
322+
using OpRewritePattern::OpRewritePattern;
323+
324+
LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
325+
PatternRewriter &rewriter) const override {
326+
327+
if (!readOp.getPermutationMap().isMinorIdentity())
328+
return rewriter.notifyMatchFailure(readOp, "non-identity permutation");
329+
330+
// We handle transfers of vectors with rank >= 2 and a single scalable
331+
// dimension.
332+
VectorType origVT = readOp.getVectorType();
333+
ArrayRef<bool> origScalableDims = origVT.getScalableDims();
334+
const int64_t origVRank = origVT.getRank();
335+
if (origVRank < 2 || llvm::count(origScalableDims, true) != 1)
336+
return rewriter.notifyMatchFailure(readOp, "wrong dimensions");
337+
338+
// Number of trailing dimensions to collapse, including the scalable
339+
// dimension. Nothing to do if the single scalable dimension is already the
340+
// last one.
341+
const int64_t numCollapseDims = std::distance(
342+
llvm::find(origScalableDims, true), origScalableDims.end());
343+
if (numCollapseDims < 2)
344+
return rewriter.notifyMatchFailure(readOp,
345+
"scalable dimension is trailing");
346+
347+
// We want a simple memref (not a tensor) with contiguous elements for at
348+
// least all the trailing dimensions up to and including the scalable one.
349+
auto memTy = dyn_cast<MemRefType>(readOp.getBase().getType());
350+
if (!(memTy && memTy.areTrailingDimsContiguous(numCollapseDims)))
351+
return rewriter.notifyMatchFailure(
352+
readOp, "non-contiguous memref dimensions to collapse");
353+
354+
// The collapsed dimensions (excluding the scalable one) of the vector and
355+
// the memref must match and the corresponding indices must be in-bounds (it
356+
// follows these indices would be zero). This guarantees that the operation
357+
// transfers a contiguous block.
358+
if (!llvm::equal(memTy.getShape().take_back(numCollapseDims - 1),
359+
origVT.getShape().take_back(numCollapseDims - 1)))
360+
return rewriter.notifyMatchFailure(
361+
readOp, "memref and vector dimensions do not match");
362+
363+
SmallVector<bool> origInBounds = readOp.getInBoundsValues();
364+
if (!llvm::all_of(
365+
ArrayRef<bool>(origInBounds).take_back(numCollapseDims - 1),
366+
[](bool v) { return v; }))
367+
return rewriter.notifyMatchFailure(readOp,
368+
"out-if-bounds index to collapse");
369+
370+
// Collapse the trailing dimensions of the memref.
371+
SmallVector<ReassociationIndices> reassoc;
372+
for (int64_t i = 0; i < memTy.getRank() - numCollapseDims + 1; ++i)
373+
reassoc.push_back({i});
374+
for (int64_t i = memTy.getRank() - numCollapseDims + 1; i < memTy.getRank();
375+
++i)
376+
reassoc.back().push_back(i);
377+
if (!memref::CollapseShapeOp::isGuaranteedCollapsible(memTy, reassoc))
378+
return failure();
379+
Value collapsedMem = rewriter.create<memref::CollapseShapeOp>(
380+
readOp.getLoc(), readOp.getBase(), reassoc);
381+
382+
// Get a vector type with collapsed trailing dimensions.
383+
SmallVector<int64_t> shape(origVT.getShape());
384+
for (int64_t i = origVRank - numCollapseDims + 1; i < origVRank; ++i)
385+
shape[origVRank - numCollapseDims] *= shape[i];
386+
shape.pop_back_n(numCollapseDims - 1);
387+
auto collapsedVT =
388+
VectorType::get(shape, origVT.getElementType(),
389+
origScalableDims.drop_back(numCollapseDims - 1));
390+
391+
// Drop the extra (zero) indices.
392+
auto indices = readOp.getIndices().drop_back(numCollapseDims - 1);
393+
394+
// Create the new `transfer_read`.
395+
auto newReadOp = rewriter.create<vector::TransferReadOp>(
396+
readOp.getLoc(), collapsedVT, collapsedMem, indices,
397+
ArrayRef<bool>(origInBounds).drop_back(numCollapseDims - 1));
398+
399+
// Cast back to the orignal vector type.
400+
auto toOrigShape = rewriter.create<vector::ShapeCastOp>(readOp.getLoc(),
401+
origVT, newReadOp);
402+
403+
rewriter.replaceOp(readOp, toOrigShape);
404+
return success();
405+
}
406+
};
407+
301408
} // namespace
302409

303410
void mlir::arm_sve::populateLegalizeVectorStoragePatterns(
@@ -306,7 +413,8 @@ void mlir::arm_sve::populateLegalizeVectorStoragePatterns(
306413
LegalizeSVEMaskAllocation<memref::AllocaOp>,
307414
LegalizeSVEMaskAllocation<memref::AllocOp>,
308415
LegalizeSVEMaskTypeCastConversion,
309-
LegalizeSVEMaskStoreConversion, LegalizeSVEMaskLoadConversion>(
416+
LegalizeSVEMaskStoreConversion, LegalizeSVEMaskLoadConversion,
417+
LegalizeTransferRead>(
310418
patterns.getContext());
311419
}
312420

Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
// RUN: mlir-opt --arm-sve-legalize-vector-storage --split-input-file %s | FileCheck %s
2+
3+
// -----
4+
5+
// CHECK-LABEL: @test_base_case
6+
// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]:
7+
// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]]
8+
// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
9+
// CHECK-SAME: : memref<?x?x?x8xi8> into memref<?x?x?xi8>
10+
// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSE]][%[[I]], %[[J]], %c0], %c0_i8 {in_bounds = [true]}
11+
// CHECK-SAME: : memref<?x?x?xi8>, vector<[32]xi8>
12+
// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<[32]xi8> to vector<[4]x8xi8>
13+
// CHECK-NEXT: return %[[T1]] : vector<[4]x8xi8>
14+
15+
func.func @test_base_case(%i : index, %j : index, %M : memref<?x?x?x8xi8>) -> vector<[4]x8xi8> {
16+
%c0 = arith.constant 0 : index
17+
%c0_i8 = arith.constant 0 : i8
18+
19+
%A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x?x8xi8>, vector<[4]x8xi8>
20+
21+
return %A : vector<[4]x8xi8>
22+
}
23+
24+
// -----
25+
26+
// CHECK-LABEL: @test_using_strided_layout
27+
// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]
28+
// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]]
29+
// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
30+
// CHECK-SAME: : memref<?x?x?x8xi8, strided<[?, ?, 8, 1]>> into
31+
// CHECK-SAME: memref<?x?x?xi8, strided<[?, ?, 1]>>
32+
// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSE]][%[[I]], %[[J]], %c0], %c0_i8 {in_bounds = [true]}
33+
// CHECK-SAME: : memref<?x?x?xi8, strided<[?, ?, 1]>>, vector<[32]xi8>
34+
// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<[32]xi8> to vector<[4]x8xi8>
35+
// CHECK-NEXT: return %[[T1]] : vector<[4]x8xi8>
36+
37+
#s0 = strided<[?, ?, 8, 1]>
38+
39+
func.func @test_using_strided_layout(%i : index, %j : index, %M : memref<?x?x?x8xi8, #s0>) -> vector<[4]x8xi8> {
40+
%c0 = arith.constant 0 : index
41+
%c0_i8 = arith.constant 0 : i8
42+
43+
%A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x?x8xi8, #s0>, vector<[4]x8xi8>
44+
45+
return %A : vector<[4]x8xi8>
46+
}
47+
48+
// -----
49+
50+
// CHECK-LABEL: @test_3d_vector
51+
// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]
52+
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]]
53+
// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]]
54+
// CHECK-SAME: : memref<?x?x2x8xi8, strided<[?, 16, 8, 1]>> into
55+
// CHECK-SAME: memref<?x?xi8, strided<[?, 1]>>
56+
// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSED]][%[[I]], %[[J]]], %c0_i8 {in_bounds = [true]}
57+
// CHECK-SAME: : memref<?x?xi8, strided<[?, 1]>>, vector<[64]xi8>
58+
// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<[64]xi8> to vector<[4]x2x8xi8>
59+
// CHECK-NEXT: return %[[T1]] : vector<[4]x2x8xi8>
60+
61+
#s1 = strided<[?, 16, 8, 1]>
62+
63+
func.func @test_3d_vector(%i : index, %j : index, %M : memref<?x?x2x8xi8, #s1>) -> vector<[4]x2x8xi8> {
64+
%c0 = arith.constant 0 : index
65+
%c0_i8 = arith.constant 0 : i8
66+
67+
%A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true, true]} : memref<?x?x2x8xi8, #s1>, vector<[4]x2x8xi8>
68+
69+
return %A : vector<[4]x2x8xi8>
70+
}
71+
72+
// -----
73+
74+
// CHECK-LABEL: @test_4d_vector
75+
// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]
76+
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]]
77+
// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]]
78+
// CHECK-SAME: : memref<?x?x2x8xi8, strided<[?, 16, 8, 1]>> into
79+
// CHECK-SAME: memref<?x?xi8, strided<[?, 1]>>
80+
// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSED]][%[[I]], %[[J]]], %c0_i8 {in_bounds = [false, true]}
81+
// CHECK-SAME: : memref<?x?xi8, strided<[?, 1]>>, vector<2x[64]xi8>
82+
// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<2x[64]xi8> to vector<2x[4]x2x8xi8>
83+
// CHECK-NEXT: return %[[T1]] : vector<2x[4]x2x8xi8>
84+
85+
#s2 = strided<[?, 16, 8, 1]>
86+
87+
func.func @test_4d_vector(%i : index, %j : index, %M : memref<?x?x2x8xi8, #s2>) -> vector<2x[4]x2x8xi8> {
88+
%c0 = arith.constant 0 : index
89+
%c0_i8 = arith.constant 0 : i8
90+
91+
%A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [false, true, true, true]} : memref<?x?x2x8xi8, #s2>, vector<2x[4]x2x8xi8>
92+
93+
return %A : vector<2x[4]x2x8xi8>
94+
}
95+
96+
// -----
97+
98+
// CHECK-LABEL: @negative_test_vector_legal_non_scalable
99+
// CHECK-NOT: memref.collapse
100+
101+
func.func @negative_test_vector_legal_non_scalable(%i : index, %j : index, %M : memref<?x?x?x8xi8>) -> vector<8x8xi8> {
102+
%c0 = arith.constant 0 : index
103+
%c0_i8 = arith.constant 0 : i8
104+
105+
%A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x?x8xi8>, vector<8x8xi8>
106+
107+
return %A : vector<8x8xi8>
108+
}
109+
110+
// -----
111+
112+
// CHECK-LABEL: @negative_test_vector_legal_scalable_0
113+
// CHECK-NOT: memref.collapse
114+
115+
func.func @negative_test_vector_legal_scalable_0(%i : index, %j : index, %M : memref<?x?x?x8xi8>) -> vector<[8]xi8> {
116+
%c0 = arith.constant 0 : index
117+
%c0_i8 = arith.constant 0 : i8
118+
119+
%A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true]} : memref<?x?x?x8xi8>, vector<[8]xi8>
120+
121+
return %A : vector<[8]xi8>
122+
}
123+
124+
// -----
125+
126+
// CHECK-LABEL: @negative_test_vector_legal_scalable_1
127+
// CHECK-NOT: memref.collapse
128+
129+
func.func @negative_test_vector_legal_scalable_1(%i : index, %j : index, %M : memref<?x?x?x8xi8>) -> vector<8x[8]xi8> {
130+
%c0 = arith.constant 0 : index
131+
%c0_i8 = arith.constant 0 : i8
132+
133+
%A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x?x8xi8>, vector<8x[8]xi8>
134+
135+
return %A : vector<8x[8]xi8>
136+
}
137+
138+
// -----
139+
140+
// CHECK-LABEL: @negative_test_vector_type_not_supported
141+
// CHECK-NOT: memref.collapse
142+
143+
func.func @negative_test_vector_type_not_supported(%i : index, %j : index, %M : memref<?x?x?x8xi8>) -> vector<[8]x[8]x8xi8> {
144+
%c0 = arith.constant 0 : index
145+
%c0_i8 = arith.constant 0 : i8
146+
147+
%A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true, true]} : memref<?x?x?x8xi8>, vector<[8]x[8]x8xi8>
148+
149+
return %A : vector<[8]x[8]x8xi8>
150+
}
151+
152+
// -----
153+
154+
// CHECK-LABEL: @negative_test_non_mem
155+
// CHECK-NOT: memref.collapse
156+
157+
func.func @negative_test_non_mem(%i : index, %j : index, %M : tensor<?x?x?x8xi8>) -> vector<[4]x8xi8> {
158+
%c0 = arith.constant 0 : index
159+
%c0_i8 = arith.constant 0 : i8
160+
161+
%A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : tensor<?x?x?x8xi8>, vector<[4]x8xi8>
162+
163+
return %A : vector<[4]x8xi8>
164+
}
165+
166+
// -----
167+
168+
// CHECK-LABEL: @negative_test_discontig_mem_0
169+
// CHECK-NOT: memref.collapse
170+
171+
#s3 = strided<[?, ?, 16, 1]>
172+
173+
func.func @negative_test_discontig_mem_0(%i : index, %j : index, %M : memref<?x?x?x8xi8, #s3>) -> vector<[4]x8xi8> {
174+
%c0 = arith.constant 0 : index
175+
%c0_i8 = arith.constant 0 : i8
176+
177+
%A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x?x8xi8, #s3>, vector<[4]x8xi8>
178+
179+
return %A : vector<[4]x8xi8>
180+
}
181+
182+
// -----
183+
184+
// CHECK-LABEL: @negative_test_discontig_mem_1
185+
// CHECK-NOT: memref.collapse
186+
187+
#layout = affine_map<(i, j, k, p) -> (j, i, k, p)>
188+
189+
func.func @negative_test_discontig_mem_1(%i : index, %j : index, %M : memref<?x?x?x8xi8, #layout>) -> vector<[4]x8xi8> {
190+
%c0 = arith.constant 0 : index
191+
%c0_i8 = arith.constant 0 : i8
192+
193+
%A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x?x8xi8, #layout>, vector<[4]x8xi8>
194+
195+
return %A : vector<[4]x8xi8>
196+
}
197+
198+
// -----
199+
200+
// CHECK-LABEL: @negative_test_discontig_read_strided_vec
201+
// CHECK-NOT: memref.collapse
202+
203+
func.func @negative_test_discontig_read_strided_vec(%i : index, %j : index, %M : memref<?x?x?x8xi8>) -> vector<[4]x4xi8> {
204+
%c0 = arith.constant 0 : index
205+
%c0_i8 = arith.constant 0 : i8
206+
207+
%A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x?x8xi8>, vector<[4]x4xi8>
208+
209+
return %A : vector<[4]x4xi8>
210+
}
211+
212+
// -----
213+
214+
// CHECK-LABEL: @negative_test_bcast_transp
215+
// CHECK-NOT: memref.collapse
216+
217+
#perm = affine_map<(i, j, k, p) -> (k, 0)>
218+
219+
func.func @negative_test_bcast_transp(%i : index, %j : index, %M : memref<?x?x?x8xi8>) -> vector<[4]x8xi8> {
220+
%c0 = arith.constant 0 : index
221+
%c0_i8 = arith.constant 0 : i8
222+
223+
%A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {permutation_map = #perm, in_bounds = [true, true] } : memref<?x?x?x8xi8>, vector<[4]x8xi8>
224+
225+
return %A : vector<[4]x8xi8>
226+
}

0 commit comments

Comments
 (0)