1
1
// RUN: mlir-opt --arm-sve-legalize-vector-storage --split-input-file %s | FileCheck %s
2
2
3
+
4
+ // Test the `LegalizeTransferRead` pattern
5
+ // (mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp)
6
+
3
7
// -----
4
8
9
+ // This is the base case, unremarkable in any way, except that it's our main
10
+ // motivating example and use case.
11
+
5
12
// CHECK-LABEL: @test_base_case
6
13
// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]:
7
14
// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]]
@@ -23,82 +30,35 @@ func.func @test_base_case(%i : index, %j : index, %M : memref<?x?x?x8xi8>) -> ve
23
30
24
31
// -----
25
32
26
- // CHECK-LABEL: @test_using_strided_layout
27
- // CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]
28
- // CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]]
29
- // CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
30
- // CHECK-SAME: : memref<?x?x?x8xi8, strided<[?, ?, 8, 1]>> into
31
- // CHECK-SAME: memref<?x?x?xi8, strided<[?, ?, 1]>>
32
- // CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSE]][%[[I]], %[[J]], %c0], %c0_i8 {in_bounds = [true]}
33
- // CHECK-SAME: : memref<?x?x?xi8, strided<[?, ?, 1]>>, vector<[32]xi8>
34
- // CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<[32]xi8> to vector<[4]x8xi8>
35
- // CHECK-NEXT: return %[[T1]] : vector<[4]x8xi8>
36
-
37
- #s0 = strided <[?, ?, 8 , 1 ]>
38
-
39
- func.func @test_using_strided_layout (%i : index , %j : index , %M : memref <?x?x?x8 xi8 , #s0 >) -> vector <[4 ]x8 xi8 > {
40
- %c0 = arith.constant 0 : index
41
- %c0_i8 = arith.constant 0 : i8
42
-
43
- %A = vector.transfer_read %M [%i , %j , %c0 , %c0 ], %c0_i8 {in_bounds = [true , true ]} : memref <?x?x?x8 xi8 , #s0 >, vector <[4 ]x8 xi8 >
44
-
45
- return %A : vector <[4 ]x8 xi8 >
46
- }
47
-
48
- // -----
33
+ // Test the case where the scalable dimension is not the second-to-last.
49
34
50
35
// CHECK-LABEL: @test_3d_vector
51
36
// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]
52
37
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]]
53
38
// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]]
54
- // CHECK-SAME: : memref<?x?x2x8xi8, strided<[?, 16, 8, 1]>> into
55
- // CHECK-SAME: memref<?x?xi8, strided<[?, 1]>>
39
+ // CHECK-SAME: : memref<?x?x2x8xi8> into memref<?x?xi8>
56
40
// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSED]][%[[I]], %[[J]]], %c0_i8 {in_bounds = [true]}
57
- // CHECK-SAME: : memref<?x?xi8, strided<[?, 1]> >, vector<[64]xi8>
41
+ // CHECK-SAME: : memref<?x?xi8>, vector<[64]xi8>
58
42
// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<[64]xi8> to vector<[4]x2x8xi8>
59
43
// CHECK-NEXT: return %[[T1]] : vector<[4]x2x8xi8>
60
44
61
- #s1 = strided <[?, 16 , 8 , 1 ]>
62
-
63
- func.func @test_3d_vector (%i : index , %j : index , %M : memref <?x?x2 x8 xi8 , #s1 >) -> vector <[4 ]x2 x8 xi8 > {
45
+ func.func @test_3d_vector (%i : index , %j : index , %M : memref <?x?x2 x8 xi8 >) -> vector <[4 ]x2 x8 xi8 > {
64
46
%c0 = arith.constant 0 : index
65
47
%c0_i8 = arith.constant 0 : i8
66
48
67
- %A = vector.transfer_read %M [%i , %j , %c0 , %c0 ], %c0_i8 {in_bounds = [true , true , true ]} : memref <?x?x2 x8 xi8 , #s1 >, vector <[4 ]x2 x8 xi8 >
49
+ %A = vector.transfer_read %M [%i , %j , %c0 , %c0 ], %c0_i8 {in_bounds = [true , true , true ]} : memref <?x?x2 x8 xi8 >, vector <[4 ]x2 x8 xi8 >
68
50
69
51
return %A : vector <[4 ]x2 x8 xi8 >
70
52
}
71
53
72
54
// -----
73
55
74
- // CHECK-LABEL: @test_4d_vector
75
- // CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]
76
- // CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]]
77
- // CHECK-SAME{LITERAL}: [[0], [1, 2, 3]]
78
- // CHECK-SAME: : memref<?x?x2x8xi8, strided<[?, 16, 8, 1]>> into
79
- // CHECK-SAME: memref<?x?xi8, strided<[?, 1]>>
80
- // CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSED]][%[[I]], %[[J]]], %c0_i8 {in_bounds = [false, true]}
81
- // CHECK-SAME: : memref<?x?xi8, strided<[?, 1]>>, vector<2x[64]xi8>
82
- // CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<2x[64]xi8> to vector<2x[4]x2x8xi8>
83
- // CHECK-NEXT: return %[[T1]] : vector<2x[4]x2x8xi8>
84
-
85
- #s2 = strided <[?, 16 , 8 , 1 ]>
86
-
87
- func.func @test_4d_vector (%i : index , %j : index , %M : memref <?x?x2 x8 xi8 , #s2 >) -> vector <2 x[4 ]x2 x8 xi8 > {
88
- %c0 = arith.constant 0 : index
89
- %c0_i8 = arith.constant 0 : i8
90
-
91
- %A = vector.transfer_read %M [%i , %j , %c0 , %c0 ], %c0_i8 {in_bounds = [false , true , true , true ]} : memref <?x?x2 x8 xi8 , #s2 >, vector <2 x[4 ]x2 x8 xi8 >
92
-
93
- return %A : vector <2 x[4 ]x2 x8 xi8 >
94
- }
95
-
96
- // -----
56
+ // Test the case when the vector is already LLVM-legal (fixed).
97
57
98
- // CHECK-LABEL: @negative_test_vector_legal_non_scalable
58
+ // CHECK-LABEL: @negative_test_vector_legal_fixed
99
59
// CHECK-NOT: memref.collapse
100
60
101
- func.func @negative_test_vector_legal_non_scalable (%i : index , %j : index , %M : memref <?x?x?x8 xi8 >) -> vector <8 x8 xi8 > {
61
+ func.func @negative_test_vector_legal_fixed (%i : index , %j : index , %M : memref <?x?x?x8 xi8 >) -> vector <8 x8 xi8 > {
102
62
%c0 = arith.constant 0 : index
103
63
%c0_i8 = arith.constant 0 : i8
104
64
@@ -109,10 +69,12 @@ func.func @negative_test_vector_legal_non_scalable(%i : index, %j : index, %M :
109
69
110
70
// -----
111
71
112
- // CHECK-LABEL: @negative_test_vector_legal_scalable_0
72
+ // Test the case when the vector is already LLVM-legal (single-dimension scalable).
73
+
74
+ // CHECK-LABEL: @negative_test_vector_legal_1d_scalable
113
75
// CHECK-NOT: memref.collapse
114
76
115
- func.func @negative_test_vector_legal_scalable_0 (%i : index , %j : index , %M : memref <?x?x?x8 xi8 >) -> vector <[8 ]xi8 > {
77
+ func.func @negative_test_vector_legal_1d_scalable (%i : index , %j : index , %M : memref <?x?x?x8 xi8 >) -> vector <[8 ]xi8 > {
116
78
%c0 = arith.constant 0 : index
117
79
%c0_i8 = arith.constant 0 : i8
118
80
@@ -123,10 +85,13 @@ func.func @negative_test_vector_legal_scalable_0(%i : index, %j : index, %M : me
123
85
124
86
// -----
125
87
126
- // CHECK-LABEL: @negative_test_vector_legal_scalable_1
88
+ // Test the case when the vector is already LLVM-legal (single trailing
89
+ // scalable dimension).
90
+
91
+ // CHECK-LABEL: @negative_test_vector_legal_trailing_scalable_dim
127
92
// CHECK-NOT: memref.collapse
128
93
129
- func.func @negative_test_vector_legal_scalable_1 (%i : index , %j : index , %M : memref <?x?x?x8 xi8 >) -> vector <8 x[8 ]xi8 > {
94
+ func.func @negative_test_vector_legal_trailing_scalable_dim (%i : index , %j : index , %M : memref <?x?x?x8 xi8 >) -> vector <8 x[8 ]xi8 > {
130
95
%c0 = arith.constant 0 : index
131
96
%c0_i8 = arith.constant 0 : i8
132
97
@@ -137,6 +102,8 @@ func.func @negative_test_vector_legal_scalable_1(%i : index, %j : index, %M : me
137
102
138
103
// -----
139
104
105
+ // Test the case of unsupported vector type (more than one scalable dimension)
106
+
140
107
// CHECK-LABEL: @negative_test_vector_type_not_supported
141
108
// CHECK-NOT: memref.collapse
142
109
@@ -151,10 +118,14 @@ func.func @negative_test_vector_type_not_supported(%i : index, %j : index, %M :
151
118
152
119
// -----
153
120
154
- // CHECK-LABEL: @negative_test_non_mem
121
+ // Test the case of reading from a tensor - not supported, since the
122
+ // transform reasons about memory layouts.
123
+
124
+ // CHECK-LABEL: @negative_test_tensor_transfer
125
+
155
126
// CHECK-NOT: memref.collapse
156
127
157
- func.func @negative_test_non_mem (%i : index , %j : index , %M : tensor <?x?x?x8 xi8 >) -> vector <[4 ]x8 xi8 > {
128
+ func.func @negative_test_tensor_transfer (%i : index , %j : index , %M : tensor <?x?x?x8 xi8 >) -> vector <[4 ]x8 xi8 > {
158
129
%c0 = arith.constant 0 : index
159
130
%c0_i8 = arith.constant 0 : i8
160
131
@@ -165,98 +136,120 @@ func.func @negative_test_non_mem(%i : index, %j : index, %M : tensor<?x?x?x8xi8>
165
136
166
137
// -----
167
138
168
- // CHECK-LABEL: @negative_test_discontig_mem_0
139
+ // Test the case when the transfer is discontiguous because the memref
140
+ // is discontiguous.
141
+ // There are other ways to make a memref discontiguous. The transformation
142
+ // is not concerned with the particular reason a memref is discontiguous, but
143
+ // only with the fact. Therefore there are no variations with the memref made
144
+ // discontiguous by some other mechanism.
145
+
146
+ // CHECK-LABEL: @negative_test_discontig_mem
169
147
// CHECK-NOT: memref.collapse
170
148
171
- #s3 = strided <[?, ?, 16 , 1 ]>
149
+ #strides = strided <[?, ?, 16 , 1 ]>
172
150
173
- func.func @negative_test_discontig_mem_0 (%i : index , %j : index , %M : memref <?x?x?x8 xi8 , #s3 >) -> vector <[4 ]x8 xi8 > {
151
+ func.func @negative_test_discontig_mem (%i : index , %j : index , %M : memref <?x?x?x8 xi8 , #strides >) -> vector <[4 ]x8 xi8 > {
174
152
%c0 = arith.constant 0 : index
175
153
%c0_i8 = arith.constant 0 : i8
176
154
177
- %A = vector.transfer_read %M [%i , %j , %c0 , %c0 ], %c0_i8 {in_bounds = [true , true ]} : memref <?x?x?x8 xi8 , #s3 >, vector <[4 ]x8 xi8 >
155
+ %A = vector.transfer_read %M [%i , %j , %c0 , %c0 ], %c0_i8 {in_bounds = [true , true ]} : memref <?x?x?x8 xi8 , #strides >, vector <[4 ]x8 xi8 >
178
156
179
157
return %A : vector <[4 ]x8 xi8 >
180
158
}
181
159
182
160
// -----
183
161
184
- // CHECK-LABEL: @negative_test_discontig_mem_1
162
+ // Test the case when the transformation is not applied because of
163
+ // a non-trivial permutation map (broadcast).
164
+
165
+ // CHECK-LABEL: @negative_test_broadcast
185
166
// CHECK-NOT: memref.collapse
186
167
187
- #layout = affine_map <(i , j , k , p ) -> (j , i , k , p )>
168
+ #perm = affine_map <(i , j , k , p ) -> (k , 0 )>
188
169
189
- func.func @negative_test_discontig_mem_1 (%i : index , %j : index , %M : memref <?x?x?x8 xi8 , #layout >) -> vector <[4 ]x8 xi8 > {
170
+ func.func @negative_test_broadcast (%i : index , %j : index , %M : memref <?x?x?x8 xi8 >) -> vector <[4 ]x8 xi8 > {
190
171
%c0 = arith.constant 0 : index
191
172
%c0_i8 = arith.constant 0 : i8
192
173
193
- %A = vector.transfer_read %M [%i , %j , %c0 , %c0 ], %c0_i8 {in_bounds = [true , true ]} : memref <?x?x?x8 xi8 , #layout >, vector <[4 ]x8 xi8 >
174
+ %A = vector.transfer_read %M [%i , %j , %c0 , %c0 ], %c0_i8 {permutation_map = #perm , in_bounds = [true , true ] } : memref <?x?x?x8 xi8 >, vector <[4 ]x8 xi8 >
194
175
195
176
return %A : vector <[4 ]x8 xi8 >
196
177
}
197
178
198
179
// -----
199
180
200
- // CHECK-LABEL: @negative_test_discontig_read_strided_vec
181
+ // Test the case of a masked read - not supported right now.
182
+ // (see mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp)
183
+
184
+ // CHECK-LABEL: @negative_test_masked
201
185
// CHECK-NOT: memref.collapse
202
186
203
- func.func @negative_test_discontig_read_strided_vec (%i : index , %j : index , %M : memref <?x?x?x8 xi8 >) -> vector <[4 ]x4 xi8 > {
187
+ func.func @negative_test_masked (
188
+ %i : index , %j : index ,
189
+ %M : memref <?x?x?x8 xi8 >, %mask : vector <[4 ]x8 xi1 >) -> vector <[4 ]x8 xi8 > {
190
+
204
191
%c0 = arith.constant 0 : index
205
192
%c0_i8 = arith.constant 0 : i8
206
193
207
- %A = vector.transfer_read %M [%i , %j , %c0 , %c0 ], %c0_i8 {in_bounds = [true , true ]} : memref <?x?x?x8 xi8 >, vector <[4 ]x4 xi8 >
194
+ %A = vector.mask %mask {
195
+ vector.transfer_read %M [%i , %j , %c0 , %c0 ], %c0_i8 {in_bounds = [true , true ] } : memref <?x?x?x8 xi8 >, vector <[4 ]x8 xi8 >
196
+ } : vector <[4 ]x8 xi1 > -> vector <[4 ]x8 xi8 >
208
197
209
- return %A : vector <[4 ]x 4 x i8 >
198
+ return %A : vector <[4 ]x 8 x i8 >
210
199
}
211
200
212
201
// -----
213
202
214
- // CHECK-LABEL: @negative_test_bcast_transp
215
- // CHECK-NOT: memref.collapse
203
+ // Test case with a mask operand - not supported right now.
204
+ // (see mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp)
216
205
217
- #perm = affine_map <(i , j , k , p ) -> (k , 0 )>
206
+ // CHECK-LABEL: @negative_test_with_mask
207
+ // CHECK-NOT: memref.collapse
218
208
219
- func.func @negative_test_bcast_transp (%i : index , %j : index , %M : memref <?x?x?x8 xi8 >) -> vector <[4 ]x8 xi8 > {
209
+ func.func @negative_test_with_mask (
210
+ %i : index , %j : index ,
211
+ %M : memref <?x?x?x8 xi8 >, %mask : vector <[4 ]x8 xi1 >) -> vector <[4 ]x8 xi8 > {
212
+
220
213
%c0 = arith.constant 0 : index
221
214
%c0_i8 = arith.constant 0 : i8
222
215
223
- %A = vector.transfer_read %M [%i , %j , %c0 , %c0 ], %c0_i8 { permutation_map = #perm , in_bounds = [true , true ] } : memref <?x?x?x8 xi8 >, vector <[4 ]x8 xi8 >
216
+ %A = vector.transfer_read %M [%i , %j , %c0 , %c0 ], %c0_i8 , %mask { in_bounds = [true , true ] } : memref <?x?x?x8 xi8 >, vector <[4 ]x8 xi8 >
224
217
225
218
return %A : vector <[4 ]x8 xi8 >
226
219
}
227
220
228
221
// -----
229
222
230
- // CHECK-LABEL: @negative_test_vector_mask
223
+ // Test the case when the dimensions to collapse (excluding the scalable one)
224
+ // of the vector and the memref do not match (static non matching dimension).
225
+
226
+ // CHECK-LABEL: @negative_test_non_matching_dim_static
231
227
// CHECK-NOT: memref.collapse
232
228
233
- func.func @negative_test_vector_mask (
234
- %i : index , %j : index ,
235
- %M : memref <?x?x?x8 xi8 >, %mask : vector <[4 ]x8 xi1 >) -> vector <[4 ]x8 xi8 > {
229
+ func.func @negative_test_non_matching_dim_static (%i : index , %j : index , %M : memref <?x?x?x8 xi8 >) -> vector <[4 ]x4 xi8 > {
236
230
237
231
%c0 = arith.constant 0 : index
238
232
%c0_i8 = arith.constant 0 : i8
239
233
240
- %A = vector.mask %mask {
241
- vector.transfer_read %M [%i , %j , %c0 , %c0 ], %c0_i8 {in_bounds = [true , true ] } : memref <?x?x?x8 xi8 >, vector <[4 ]x8 xi8 >
242
- } : vector <[4 ]x8 xi1 > -> vector <[4 ]x8 xi8 >
234
+ %A = vector.transfer_read %M [%i , %j , %c0 , %c0 ], %c0_i8 {in_bounds = [true , true ] } : memref <?x?x?x8 xi8 >, vector <[4 ]x4 xi8 >
243
235
244
- return %A : vector <[4 ]x 8 x i8 >
236
+ return %A : vector <[4 ]x 4 x i8 >
245
237
}
246
238
247
239
// -----
248
240
249
- // CHECK-LABEL: @negative_test_mask_operand
241
+ // Test the case when the dimensions to collapse (excluding the scalable one)
242
+ // of the vector and the memref do not match (dynamic non matching dimension).
243
+
244
+ // CHECK-LABEL: @negative_test_non_matching_dim_dynamic
250
245
// CHECK-NOT: memref.collapse
251
246
252
- func.func @negative_test_mask_operand (
253
- %i : index , %j : index ,
254
- %M : memref <?x?x?x8 xi8 >, %mask : vector <[4 ]x8 xi1 >) -> vector <[4 ]x8 xi8 > {
247
+ func.func @negative_test_non_matching_dim_dynamic (%i : index , %j : index , %M : memref <?x?x?x?xi8 >) -> vector <[4 ]x4 xi8 > {
255
248
256
249
%c0 = arith.constant 0 : index
257
250
%c0_i8 = arith.constant 0 : i8
258
251
259
- %A = vector.transfer_read %M [%i , %j , %c0 , %c0 ], %c0_i8 , %mask {in_bounds = [true , true ] } : memref <?x?x?x 8 x i8 >, vector <[4 ]x 8 x i8 >
252
+ %A = vector.transfer_read %M [%i , %j , %c0 , %c0 ], %c0_i8 {in_bounds = [true , true ] } : memref <?x?x?x?x i8 >, vector <[4 ]x 4 x i8 >
260
253
261
- return %A : vector <[4 ]x 8 x i8 >
254
+ return %A : vector <[4 ]x 4 x i8 >
262
255
}
0 commit comments