Skip to content

Commit 5496f97

Browse files
[fixup] Prevent the pattern from (incorrectly) applying to masked transfers
1 parent 050a4ad commit 5496f97

File tree

2 files changed

+52
-1
lines changed

2 files changed

+52
-1
lines changed

mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ struct LegalizeSVEMaskLoadConversion : public OpRewritePattern<memref::LoadOp> {
309309
/// {in_bounds = [false, true, true, true]}
310310
/// : memref<?x?x2x8xi8>, vector<2x[4]x2x8xi8>
311311
/// ```
312-
/// is rewriten to
312+
/// is rewritten to
313313
/// ```
314314
/// %collapse_shape = memref.collapse_shape %M [[0], [1, 2, 3]]
315315
/// : memref<?x?x2x8xi8> into memref<?x?xi8>
@@ -324,6 +324,21 @@ struct LegalizeTransferRead : public OpRewritePattern<vector::TransferReadOp> {
324324
LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
325325
PatternRewriter &rewriter) const override {
326326

327+
// Do not try to transform masked reads. For example, if we have a transfer
328+
// to a `vector<[4]x4xi8>` we could have a mask like
329+
// 1 1 1 0
330+
// 1 1 1 0
331+
// 1 1 1 0
332+
// 0 0 0 0
333+
// Flattening this mask would look like
334+
// 1 1 1 0 1 1 1 0 1 1 1 0 0 0 0 0
335+
// and we have not yet figured out an efficient way to build such a mask,
336+
// neither from the mask operand, nor from the original `vector.create_mask`
337+
// operation (if visible at all).
338+
if (readOp.isMasked() || readOp.getMask())
339+
return rewriter.notifyMatchFailure(readOp,
340+
"masked transfers not-supported");
341+
327342
if (!readOp.getPermutationMap().isMinorIdentity())
328343
return rewriter.notifyMatchFailure(readOp, "non-identity permutation");
329344

mlir/test/Dialect/ArmSVE/legalize-transfer-read.mlir

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,3 +224,39 @@ func.func @negative_test_bcast_transp(%i : index, %j : index, %M : memref<?x?x?x
224224

225225
return %A : vector<[4]x8xi8>
226226
}
227+
228+
// -----
229+
230+
// CHECK-LABEL: @negative_test_vector_mask
231+
// CHECK-NOT: memref.collapse
232+
233+
func.func @negative_test_vector_mask(
234+
%i : index, %j : index,
235+
%M : memref<?x?x?x8xi8>, %mask : vector<[4]x8xi1>) -> vector<[4]x8xi8> {
236+
237+
%c0 = arith.constant 0 : index
238+
%c0_i8 = arith.constant 0 : i8
239+
240+
%A = vector.mask %mask {
241+
vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true] } : memref<?x?x?x8xi8>, vector<[4]x8xi8>
242+
} : vector<[4]x8xi1> -> vector<[4]x8xi8>
243+
244+
return %A : vector<[4]x8xi8>
245+
}
246+
247+
// -----
248+
249+
// CHECK-LABEL: @negative_test_mask_operand
250+
// CHECK-NOT: memref.collapse
251+
252+
func.func @negative_test_mask_operand(
253+
%i : index, %j : index,
254+
%M : memref<?x?x?x8xi8>, %mask : vector<[4]x8xi1>) -> vector<[4]x8xi8> {
255+
256+
%c0 = arith.constant 0 : index
257+
%c0_i8 = arith.constant 0 : i8
258+
259+
%A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8, %mask {in_bounds = [true, true] } : memref<?x?x?x8xi8>, vector<[4]x8xi8>
260+
261+
return %A : vector<[4]x8xi8>
262+
}

0 commit comments

Comments
 (0)