Skip to content

Commit e422213

Browse files
[fixup] Commenting and rework the tests
1 parent d9ce375 commit e422213

File tree

2 files changed

+112
-103
lines changed

2 files changed

+112
-103
lines changed

mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -299,9 +299,9 @@ struct LegalizeSVEMaskLoadConversion : public OpRewritePattern<memref::LoadOp> {
299299
};
300300

301301
/// Transforms a `transfer_read` operation so it reads vector of a type that
302-
/// can be mapped to an LLVM type. This is done by collapsing trailing
303-
/// dimensions so we obtain a vector type with a single scalable dimension in
304-
/// the rightmost position.
302+
/// can be mapped to an LLVM type ("LLVM-legal" type). This is done by
303+
/// collapsing trailing dimensions so we obtain a vector type with a single
304+
/// scalable dimension in the rightmost position.
305305
///
306306
/// Example:
307307
/// ```
@@ -339,15 +339,30 @@ struct LegalizeTransferRead : public OpRewritePattern<vector::TransferReadOp> {
339339
return rewriter.notifyMatchFailure(readOp,
340340
"masked transfers not-supported");
341341

342+
// General permutation maps are not supported. The issue is with transpose,
343+
// broadcast, and other forms of non-identify mapping in the minor
344+
// dimensions which is impossible to represent after collapsing (at least
345+
// because the resulting "collapsed" maps would have smaller number of
346+
// dimension indices).
347+
// TODO: We have not had yet the need for it, but some forms of permutation
348+
// maps with identity in the minor dimensions voukld be supported, for
349+
// example `(i, j, k, p) -> (j, i, k, p)` where we need to collapse only `k`
350+
// and `p`.
342351
if (!readOp.getPermutationMap().isMinorIdentity())
343352
return rewriter.notifyMatchFailure(readOp, "non-identity permutation");
344353

345354
// We handle transfers of vectors with rank >= 2 and a single scalable
346-
// dimension.
355+
// dimension. This transformation aims to transform an LLVM-illegal type
356+
// into an LLVM-legal type and one dimensional vectors are already
357+
// LLVM-legal, even if scalable. A value of a vector type with more than one
358+
// scalable dimension is impossible to represent using a vector type with no
359+
// scalable dimensions or a single one. For example a `vector<[4]x[4]xi8>`
360+
// would have `4 * 4 * vscale * vscale` elements and this quantity is
361+
// impossible to represent as `N` or `N * vscale` (where `N` is a constant).
347362
VectorType origVT = readOp.getVectorType();
348363
ArrayRef<bool> origScalableDims = origVT.getScalableDims();
349364
const int64_t origVRank = origVT.getRank();
350-
if (origVRank < 2 || llvm::count(origScalableDims, true) != 1)
365+
if (origVRank < 2 || origVT.getNumScalableDims() != 1)
351366
return rewriter.notifyMatchFailure(readOp, "wrong dimensions");
352367

353368
// Number of trailing dimensions to collapse, including the scalable
@@ -366,10 +381,11 @@ struct LegalizeTransferRead : public OpRewritePattern<vector::TransferReadOp> {
366381
return rewriter.notifyMatchFailure(
367382
readOp, "non-contiguous memref dimensions to collapse");
368383

369-
// The collapsed dimensions (excluding the scalable one) of the vector and
370-
// the memref must match and the corresponding indices must be in-bounds (it
371-
// follows these indices would be zero). This guarantees that the operation
372-
// transfers a contiguous block.
384+
// The dimensions to collapse (excluding the scalable one) of the vector and
385+
// the memref must match. A dynamic memref dimension is considered
386+
// non-matching. The transfers from the dimensions to collapse must be
387+
// in-bounds (it follows the corresponding indices would be zero). This
388+
// guarantees that the operation transfers a contiguous block.
373389
if (!llvm::equal(memTy.getShape().take_back(numCollapseDims - 1),
374390
origVT.getShape().take_back(numCollapseDims - 1)))
375391
return rewriter.notifyMatchFailure(
@@ -379,8 +395,8 @@ struct LegalizeTransferRead : public OpRewritePattern<vector::TransferReadOp> {
379395
if (!llvm::all_of(
380396
ArrayRef<bool>(origInBounds).take_back(numCollapseDims - 1),
381397
[](bool v) { return v; }))
382-
return rewriter.notifyMatchFailure(readOp,
383-
"out-if-bounds index to collapse");
398+
return rewriter.notifyMatchFailure(
399+
readOp, "out-of-bounds transfer from a dimension to collapse");
384400

385401
// Collapse the trailing dimensions of the memref.
386402
SmallVector<ReassociationIndices> reassoc;
Lines changed: 85 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
11
// RUN: mlir-opt --arm-sve-legalize-vector-storage --split-input-file %s | FileCheck %s
22

3+
4+
// Test the `LegalizeTransferRead` pattern
5+
// (mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp)
6+
37
// -----
48

9+
// This is the base case, unremarkable in any way, except that it's our main
10+
// motivating example and use case.
11+
512
// CHECK-LABEL: @test_base_case
613
// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]:
714
// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]]
@@ -23,82 +30,35 @@ func.func @test_base_case(%i : index, %j : index, %M : memref<?x?x?x8xi8>) -> ve
2330

2431
// -----
2532

26-
// CHECK-LABEL: @test_using_strided_layout
27-
// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]
28-
// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]]
29-
// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
30-
// CHECK-SAME: : memref<?x?x?x8xi8, strided<[?, ?, 8, 1]>> into
31-
// CHECK-SAME: memref<?x?x?xi8, strided<[?, ?, 1]>>
32-
// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSE]][%[[I]], %[[J]], %c0], %c0_i8 {in_bounds = [true]}
33-
// CHECK-SAME: : memref<?x?x?xi8, strided<[?, ?, 1]>>, vector<[32]xi8>
34-
// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<[32]xi8> to vector<[4]x8xi8>
35-
// CHECK-NEXT: return %[[T1]] : vector<[4]x8xi8>
36-
37-
#s0 = strided<[?, ?, 8, 1]>
38-
39-
func.func @test_using_strided_layout(%i : index, %j : index, %M : memref<?x?x?x8xi8, #s0>) -> vector<[4]x8xi8> {
40-
%c0 = arith.constant 0 : index
41-
%c0_i8 = arith.constant 0 : i8
42-
43-
%A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x?x8xi8, #s0>, vector<[4]x8xi8>
44-
45-
return %A : vector<[4]x8xi8>
46-
}
47-
48-
// -----
33+
// Test the case where the scalable dimension is not the second-to-last.
4934

5035
// CHECK-LABEL: @test_3d_vector
5136
// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]
5237
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]]
5338
// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]]
54-
// CHECK-SAME: : memref<?x?x2x8xi8, strided<[?, 16, 8, 1]>> into
55-
// CHECK-SAME: memref<?x?xi8, strided<[?, 1]>>
39+
// CHECK-SAME: : memref<?x?x2x8xi8> into memref<?x?xi8>
5640
// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSED]][%[[I]], %[[J]]], %c0_i8 {in_bounds = [true]}
57-
// CHECK-SAME: : memref<?x?xi8, strided<[?, 1]>>, vector<[64]xi8>
41+
// CHECK-SAME: : memref<?x?xi8>, vector<[64]xi8>
5842
// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<[64]xi8> to vector<[4]x2x8xi8>
5943
// CHECK-NEXT: return %[[T1]] : vector<[4]x2x8xi8>
6044

61-
#s1 = strided<[?, 16, 8, 1]>
62-
63-
func.func @test_3d_vector(%i : index, %j : index, %M : memref<?x?x2x8xi8, #s1>) -> vector<[4]x2x8xi8> {
45+
func.func @test_3d_vector(%i : index, %j : index, %M : memref<?x?x2x8xi8>) -> vector<[4]x2x8xi8> {
6446
%c0 = arith.constant 0 : index
6547
%c0_i8 = arith.constant 0 : i8
6648

67-
%A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true, true]} : memref<?x?x2x8xi8, #s1>, vector<[4]x2x8xi8>
49+
%A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true, true]} : memref<?x?x2x8xi8>, vector<[4]x2x8xi8>
6850

6951
return %A : vector<[4]x2x8xi8>
7052
}
7153

7254
// -----
7355

74-
// CHECK-LABEL: @test_4d_vector
75-
// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]
76-
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]]
77-
// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]]
78-
// CHECK-SAME: : memref<?x?x2x8xi8, strided<[?, 16, 8, 1]>> into
79-
// CHECK-SAME: memref<?x?xi8, strided<[?, 1]>>
80-
// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSED]][%[[I]], %[[J]]], %c0_i8 {in_bounds = [false, true]}
81-
// CHECK-SAME: : memref<?x?xi8, strided<[?, 1]>>, vector<2x[64]xi8>
82-
// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<2x[64]xi8> to vector<2x[4]x2x8xi8>
83-
// CHECK-NEXT: return %[[T1]] : vector<2x[4]x2x8xi8>
84-
85-
#s2 = strided<[?, 16, 8, 1]>
86-
87-
func.func @test_4d_vector(%i : index, %j : index, %M : memref<?x?x2x8xi8, #s2>) -> vector<2x[4]x2x8xi8> {
88-
%c0 = arith.constant 0 : index
89-
%c0_i8 = arith.constant 0 : i8
90-
91-
%A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [false, true, true, true]} : memref<?x?x2x8xi8, #s2>, vector<2x[4]x2x8xi8>
92-
93-
return %A : vector<2x[4]x2x8xi8>
94-
}
95-
96-
// -----
56+
// Test the case when the vector is already LLVM-legal (fixed).
9757

98-
// CHECK-LABEL: @negative_test_vector_legal_non_scalable
58+
// CHECK-LABEL: @negative_test_vector_legal_fixed
9959
// CHECK-NOT: memref.collapse
10060

101-
func.func @negative_test_vector_legal_non_scalable(%i : index, %j : index, %M : memref<?x?x?x8xi8>) -> vector<8x8xi8> {
61+
func.func @negative_test_vector_legal_fixed(%i : index, %j : index, %M : memref<?x?x?x8xi8>) -> vector<8x8xi8> {
10262
%c0 = arith.constant 0 : index
10363
%c0_i8 = arith.constant 0 : i8
10464

@@ -109,10 +69,12 @@ func.func @negative_test_vector_legal_non_scalable(%i : index, %j : index, %M :
10969

11070
// -----
11171

112-
// CHECK-LABEL: @negative_test_vector_legal_scalable_0
72+
// Test the case when the vector is already LLVM-legal (single-dimension scalable).
73+
74+
// CHECK-LABEL: @negative_test_vector_legal_1d_scalable
11375
// CHECK-NOT: memref.collapse
11476

115-
func.func @negative_test_vector_legal_scalable_0(%i : index, %j : index, %M : memref<?x?x?x8xi8>) -> vector<[8]xi8> {
77+
func.func @negative_test_vector_legal_1d_scalable(%i : index, %j : index, %M : memref<?x?x?x8xi8>) -> vector<[8]xi8> {
11678
%c0 = arith.constant 0 : index
11779
%c0_i8 = arith.constant 0 : i8
11880

@@ -123,10 +85,13 @@ func.func @negative_test_vector_legal_scalable_0(%i : index, %j : index, %M : me
12385

12486
// -----
12587

126-
// CHECK-LABEL: @negative_test_vector_legal_scalable_1
88+
// Test the case when the vector is already LLVM-legal (single trailing
89+
// scalable dimension).
90+
91+
// CHECK-LABEL: @negative_test_vector_legal_trailing_scalable_dim
12792
// CHECK-NOT: memref.collapse
12893

129-
func.func @negative_test_vector_legal_scalable_1(%i : index, %j : index, %M : memref<?x?x?x8xi8>) -> vector<8x[8]xi8> {
94+
func.func @negative_test_vector_legal_trailing_scalable_dim(%i : index, %j : index, %M : memref<?x?x?x8xi8>) -> vector<8x[8]xi8> {
13095
%c0 = arith.constant 0 : index
13196
%c0_i8 = arith.constant 0 : i8
13297

@@ -137,6 +102,8 @@ func.func @negative_test_vector_legal_scalable_1(%i : index, %j : index, %M : me
137102

138103
// -----
139104

105+
// Test the case of unsupported vector type (more than one scalable dimension)
106+
140107
// CHECK-LABEL: @negative_test_vector_type_not_supported
141108
// CHECK-NOT: memref.collapse
142109

@@ -151,10 +118,14 @@ func.func @negative_test_vector_type_not_supported(%i : index, %j : index, %M :
151118

152119
// -----
153120

154-
// CHECK-LABEL: @negative_test_non_mem
121+
// Test the case of reading from a tensor - not supported, since the
122+
// transform reasons about memory layouts.
123+
124+
// CHECK-LABEL: @negative_test_tensor_transfer
125+
155126
// CHECK-NOT: memref.collapse
156127

157-
func.func @negative_test_non_mem(%i : index, %j : index, %M : tensor<?x?x?x8xi8>) -> vector<[4]x8xi8> {
128+
func.func @negative_test_tensor_transfer(%i : index, %j : index, %M : tensor<?x?x?x8xi8>) -> vector<[4]x8xi8> {
158129
%c0 = arith.constant 0 : index
159130
%c0_i8 = arith.constant 0 : i8
160131

@@ -165,98 +136,120 @@ func.func @negative_test_non_mem(%i : index, %j : index, %M : tensor<?x?x?x8xi8>
165136

166137
// -----
167138

168-
// CHECK-LABEL: @negative_test_discontig_mem_0
139+
// Test the case when the transfer is discontiguous because the memref
140+
// is discontiguous.
141+
// There are other ways to make a memref discontiguous. The transformation
142+
// is not concerned with the particular reason a memref is discontiguous, but
143+
// only with the fact. Therefore there are no variations with the memref made
144+
// discontiguous by some other mechanism.
145+
146+
// CHECK-LABEL: @negative_test_discontig_mem
169147
// CHECK-NOT: memref.collapse
170148

171-
#s3 = strided<[?, ?, 16, 1]>
149+
#strides = strided<[?, ?, 16, 1]>
172150

173-
func.func @negative_test_discontig_mem_0(%i : index, %j : index, %M : memref<?x?x?x8xi8, #s3>) -> vector<[4]x8xi8> {
151+
func.func @negative_test_discontig_mem(%i : index, %j : index, %M : memref<?x?x?x8xi8, #strides>) -> vector<[4]x8xi8> {
174152
%c0 = arith.constant 0 : index
175153
%c0_i8 = arith.constant 0 : i8
176154

177-
%A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x?x8xi8, #s3>, vector<[4]x8xi8>
155+
%A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x?x8xi8, #strides>, vector<[4]x8xi8>
178156

179157
return %A : vector<[4]x8xi8>
180158
}
181159

182160
// -----
183161

184-
// CHECK-LABEL: @negative_test_discontig_mem_1
162+
// Test the case when the transformation is not applied because of
163+
// a non-trivial permutation map (broadcast).
164+
165+
// CHECK-LABEL: @negative_test_broadcast
185166
// CHECK-NOT: memref.collapse
186167

187-
#layout = affine_map<(i, j, k, p) -> (j, i, k, p)>
168+
#perm = affine_map<(i, j, k, p) -> (k, 0)>
188169

189-
func.func @negative_test_discontig_mem_1(%i : index, %j : index, %M : memref<?x?x?x8xi8, #layout>) -> vector<[4]x8xi8> {
170+
func.func @negative_test_broadcast(%i : index, %j : index, %M : memref<?x?x?x8xi8>) -> vector<[4]x8xi8> {
190171
%c0 = arith.constant 0 : index
191172
%c0_i8 = arith.constant 0 : i8
192173

193-
%A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x?x8xi8, #layout>, vector<[4]x8xi8>
174+
%A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {permutation_map = #perm, in_bounds = [true, true] } : memref<?x?x?x8xi8>, vector<[4]x8xi8>
194175

195176
return %A : vector<[4]x8xi8>
196177
}
197178

198179
// -----
199180

200-
// CHECK-LABEL: @negative_test_discontig_read_strided_vec
181+
// Test the case of a masked read - not supported right now.
182+
// (see mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp)
183+
184+
// CHECK-LABEL: @negative_test_masked
201185
// CHECK-NOT: memref.collapse
202186

203-
func.func @negative_test_discontig_read_strided_vec(%i : index, %j : index, %M : memref<?x?x?x8xi8>) -> vector<[4]x4xi8> {
187+
func.func @negative_test_masked(
188+
%i : index, %j : index,
189+
%M : memref<?x?x?x8xi8>, %mask : vector<[4]x8xi1>) -> vector<[4]x8xi8> {
190+
204191
%c0 = arith.constant 0 : index
205192
%c0_i8 = arith.constant 0 : i8
206193

207-
%A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x?x8xi8>, vector<[4]x4xi8>
194+
%A = vector.mask %mask {
195+
vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true] } : memref<?x?x?x8xi8>, vector<[4]x8xi8>
196+
} : vector<[4]x8xi1> -> vector<[4]x8xi8>
208197

209-
return %A : vector<[4]x4xi8>
198+
return %A : vector<[4]x8xi8>
210199
}
211200

212201
// -----
213202

214-
// CHECK-LABEL: @negative_test_bcast_transp
215-
// CHECK-NOT: memref.collapse
203+
// Test case with a mask operand - not supported right now.
204+
// (see mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp)
216205

217-
#perm = affine_map<(i, j, k, p) -> (k, 0)>
206+
// CHECK-LABEL: @negative_test_with_mask
207+
// CHECK-NOT: memref.collapse
218208

219-
func.func @negative_test_bcast_transp(%i : index, %j : index, %M : memref<?x?x?x8xi8>) -> vector<[4]x8xi8> {
209+
func.func @negative_test_with_mask(
210+
%i : index, %j : index,
211+
%M : memref<?x?x?x8xi8>, %mask : vector<[4]x8xi1>) -> vector<[4]x8xi8> {
212+
220213
%c0 = arith.constant 0 : index
221214
%c0_i8 = arith.constant 0 : i8
222215

223-
%A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {permutation_map = #perm, in_bounds = [true, true] } : memref<?x?x?x8xi8>, vector<[4]x8xi8>
216+
%A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8, %mask {in_bounds = [true, true] } : memref<?x?x?x8xi8>, vector<[4]x8xi8>
224217

225218
return %A : vector<[4]x8xi8>
226219
}
227220

228221
// -----
229222

230-
// CHECK-LABEL: @negative_test_vector_mask
223+
// Test the case when the dimensions to collapse (excluding the scalable one)
224+
// of the vector and the memref do not match (static non matching dimension).
225+
226+
// CHECK-LABEL: @negative_test_non_matching_dim_static
231227
// CHECK-NOT: memref.collapse
232228

233-
func.func @negative_test_vector_mask(
234-
%i : index, %j : index,
235-
%M : memref<?x?x?x8xi8>, %mask : vector<[4]x8xi1>) -> vector<[4]x8xi8> {
229+
func.func @negative_test_non_matching_dim_static(%i : index, %j : index, %M : memref<?x?x?x8xi8>) -> vector<[4]x4xi8> {
236230

237231
%c0 = arith.constant 0 : index
238232
%c0_i8 = arith.constant 0 : i8
239233

240-
%A = vector.mask %mask {
241-
vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true] } : memref<?x?x?x8xi8>, vector<[4]x8xi8>
242-
} : vector<[4]x8xi1> -> vector<[4]x8xi8>
234+
%A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true] } : memref<?x?x?x8xi8>, vector<[4]x4xi8>
243235

244-
return %A : vector<[4]x8xi8>
236+
return %A : vector<[4]x4xi8>
245237
}
246238

247239
// -----
248240

249-
// CHECK-LABEL: @negative_test_mask_operand
241+
// Test the case when the dimensions to collapse (excluding the scalable one)
242+
// of the vector and the memref do not match (dynamic non matching dimension).
243+
244+
// CHECK-LABEL: @negative_test_non_matching_dim_dynamic
250245
// CHECK-NOT: memref.collapse
251246

252-
func.func @negative_test_mask_operand(
253-
%i : index, %j : index,
254-
%M : memref<?x?x?x8xi8>, %mask : vector<[4]x8xi1>) -> vector<[4]x8xi8> {
247+
func.func @negative_test_non_matching_dim_dynamic(%i : index, %j : index, %M : memref<?x?x?x?xi8>) -> vector<[4]x4xi8> {
255248

256249
%c0 = arith.constant 0 : index
257250
%c0_i8 = arith.constant 0 : i8
258251

259-
%A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8, %mask {in_bounds = [true, true] } : memref<?x?x?x8xi8>, vector<[4]x8xi8>
252+
%A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true] } : memref<?x?x?x?xi8>, vector<[4]x4xi8>
260253

261-
return %A : vector<[4]x8xi8>
254+
return %A : vector<[4]x4xi8>
262255
}

0 commit comments

Comments
 (0)