Skip to content

Commit 60a1aba

Browse files
mgehre-amdjorickert
authored andcommitted
Solve merge conflicts and fixed tests after bump
Propagate the new TOSA mul shift operand into downstream test expectations and update onnx-mlir and torch-mlir to always supply the shift constant.
1 parent 1a3a85e commit 60a1aba

File tree

3 files changed

+16
-11
lines changed

3 files changed

+16
-11
lines changed

mlir/test/Dialect/Tosa/canonicalize.mlir

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -607,8 +607,9 @@ func.func @mul_zero_broadcast_dynamic_result(%arg0: tensor<?x3xf32>) -> (tensor<
607607
// CHECK: tosa.mul
608608
// CHECK: tosa.mul
609609
%zeros = "tosa.const"() {value = dense<0.0> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
610-
%1 = tosa.mul %arg0, %zeros {shift = 0 : i8} : (tensor<?x3xf32>, tensor<1x1xf32>) -> tensor<?x3xf32>
611-
%2 = tosa.mul %zeros, %arg0 {shift = 0 : i8} : (tensor<1x1xf32>, tensor<?x3xf32>) -> tensor<?x3xf32>
610+
%shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
611+
%1 = tosa.mul %arg0, %zeros, %shift : (tensor<?x3xf32>, tensor<1x1xf32>, tensor<1xi8>) -> tensor<?x3xf32>
612+
%2 = tosa.mul %zeros, %arg0, %shift : (tensor<1x1xf32>, tensor<?x3xf32>, tensor<1xi8>) -> tensor<?x3xf32>
612613
return %1, %2 : tensor<?x3xf32>, tensor<?x3xf32>
613614
}
614615

@@ -1581,7 +1582,8 @@ func.func @canonicalize_select_lrelu_zero_pattern(%arg0: tensor<13x21x3xf32>) ->
15811582
// CHECK: %[[VAL_1:.*]] = tosa.clamp %arg{{.*}} {max_fp = 0x7F800000 : f32, max_int = 9223372036854775807 : i64, min_fp = 0.000000e+00 : f32, min_int = -9223372036854775808 : i64} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
15821583
// CHECK: return %[[VAL_1]] : tensor<13x21x3xf32>
15831584
%0 = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1x1x1xf32>}>: () -> tensor<1x1x1xf32>
1584-
%1 = tosa.mul %arg0, %0 {shift = 0 : i8}: (tensor<13x21x3xf32>, tensor<1x1x1xf32>) -> tensor<13x21x3xf32>
1585+
%shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
1586+
%1 = tosa.mul %arg0, %0, %shift : (tensor<13x21x3xf32>, tensor<1x1x1xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
15851587
%2 = tosa.greater_equal %arg0, %0: (tensor<13x21x3xf32>, tensor<1x1x1xf32>) -> tensor<13x21x3xi1>
15861588
%3 = tosa.select %2, %arg0, %1: ( tensor<13x21x3xi1>, tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
15871589
return %3 : tensor<13x21x3xf32>
@@ -1642,4 +1644,3 @@ func.func @canonicalize_select_to_clamp_i8_and_i64_pat2(%arg0: tensor<13x21x3xi8
16421644
%3 = tosa.select %2, %1, %arg1: ( tensor<13x21x3xi1>, tensor<13x21x3xi64>, tensor<13x21x3xi64>) -> tensor<13x21x3xi64>
16431645
return %3 : tensor<13x21x3xi64>
16441646
}
1645-

mlir/test/Dialect/Tosa/constant-mul-opt.mlir

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ func.func @mul_fold_float() -> tensor<4xf16> {
1717
dense<[-132.7, -3.0, -0.0, 5.0]> :
1818
tensor<4xf16>
1919
} : () -> tensor<4xf16>
20-
%2 = "tosa.mul"(%0, %1) : (tensor<4xf16>, tensor<4xf16>) -> tensor<4xf16>
20+
%shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
21+
%2 = "tosa.mul"(%0, %1, %shift) : (tensor<4xf16>, tensor<4xf16>, tensor<1xi8>) -> tensor<4xf16>
2122
return %2 : tensor<4xf16>
2223
}
2324

@@ -34,7 +35,8 @@ func.func @mul_fold_float_infinity_nan() -> tensor<7xf32> {
3435
dense<[3.0, -3.0, -3.0, 3.0, 1.0, 0xFF800000, 0.0]> :
3536
tensor<7xf32>
3637
} : () -> tensor<7xf32>
37-
%2 = "tosa.mul"(%0, %1) : (tensor<7xf32>, tensor<7xf32>) -> tensor<7xf32>
38+
%shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
39+
%2 = "tosa.mul"(%0, %1, %shift) : (tensor<7xf32>, tensor<7xf32>, tensor<1xi8>) -> tensor<7xf32>
3840
return %2 : tensor<7xf32>
3941
}
4042

@@ -51,7 +53,8 @@ func.func @add_fold_float_overflow() -> tensor<2xf32> {
5153
dense<[2.1e+38, 1.1e+38]> :
5254
tensor<2xf32>
5355
} : () -> tensor<2xf32>
54-
%2 = "tosa.mul"(%0, %1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
56+
%shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
57+
%2 = "tosa.mul"(%0, %1, %shift) : (tensor<2xf32>, tensor<2xf32>, tensor<1xi8>) -> tensor<2xf32>
5558
return %2 : tensor<2xf32>
5659
}
5760

mlir/test/Dialect/Tosa/tosa-reduce-transposes.mlir

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ func.func @test_reshape_for_broadcast(%arg0: tensor<4x3x2xi32>) -> tensor<4x3x2x
187187

188188
// CHECK-LABEL: @test_resnet18_common_case
189189
// COM: note that %74 is now represented by %arg2
190+
// CHECK-DAG: %[[CONST0:.+]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
190191
// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{value = dense_resource<torch_tensor_64_torch.float32_1> : tensor<64xf32>}> : () -> tensor<64xf32>
191192
// CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() <{value = dense_resource<torch_tensor_64_torch.float32> : tensor<64xf32>}> : () -> tensor<64xf32>
192193
// CHECK-DAG: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<9.99999974E-6> : tensor<1xf32>}> : () -> tensor<1xf32>
@@ -197,9 +198,9 @@ func.func @test_reshape_for_broadcast(%arg0: tensor<4x3x2xi32>) -> tensor<4x3x2x
197198
// CHECK-DAG: %[[VAL_9:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 1, 64>} : (tensor<64xf32>) -> tensor<1x1x1x64xf32>
198199
// CHECK-DAG: %[[VAL_10:.*]] = tosa.sub %arg2, %[[VAL_9]] : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>) -> tensor<1x112x112x64xf32>
199200
// CHECK-DAG: %[[VAL_11:.*]] = tosa.reshape %[[VAL_8]] {new_shape = array<i64: 1, 1, 1, 64>} : (tensor<64xf32>) -> tensor<1x1x1x64xf32>
200-
// CHECK-DAG: %[[VAL_12:.*]] = tosa.mul %[[VAL_10]], %[[VAL_11]] : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>) -> tensor<1x112x112x64xf32>
201+
// CHECK-DAG: %[[VAL_12:.*]] = tosa.mul %[[VAL_10]], %[[VAL_11]], %[[CONST0]] : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>, tensor<1xi8>) -> tensor<1x112x112x64xf32>
201202
// CHECK-DAG: %[[VAL_13:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array<i64: 1, 1, 1, 64>} : (tensor<64xf32>) -> tensor<1x1x1x64xf32>
202-
// CHECK-DAG: %[[VAL_14:.*]] = tosa.mul %[[VAL_12]], %[[VAL_13]] : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>) -> tensor<1x112x112x64xf32>
203+
// CHECK-DAG: %[[VAL_14:.*]] = tosa.mul %[[VAL_12]], %[[VAL_13]], %[[CONST0]] : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>, tensor<1xi8>) -> tensor<1x112x112x64xf32>
203204
// CHECK-DAG: %[[VAL_15:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array<i64: 1, 1, 1, 64>} : (tensor<64xf32>) -> tensor<1x1x1x64xf32>
204205
// CHECK-DAG: %[[VAL_16:.*]] = tosa.add %[[VAL_14]], %[[VAL_15]] : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>) -> tensor<1x112x112x64xf32>
205206
// CHECK-DAG: %[[VAL_17:.*]] = tosa.clamp %[[VAL_16]] {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x112x112x64xf32>) -> tensor<1x112x112x64xf32>
@@ -220,9 +221,9 @@ func.func @test_resnet18_common_case(%arg0: tensor<64xf32>, %arg1: tensor<64xf32
220221
%79 = tosa.reshape %arg0 {new_shape = array<i64: 1, 64, 1, 1>} : (tensor<64xf32>) -> tensor<1x64x1x1xf32>
221222
%80 = tosa.sub %75, %79 : (tensor<1x64x112x112xf32>, tensor<1x64x1x1xf32>) -> tensor<1x64x112x112xf32>
222223
%81 = tosa.reshape %78 {new_shape = array<i64: 1, 64, 1, 1>} : (tensor<64xf32>) -> tensor<1x64x1x1xf32>
223-
%82 = tosa.mul %80, %81 : (tensor<1x64x112x112xf32>, tensor<1x64x1x1xf32>) -> tensor<1x64x112x112xf32>
224+
%82 = tosa.mul %80, %81, %shift : (tensor<1x64x112x112xf32>, tensor<1x64x1x1xf32>, tensor<1xi8>) -> tensor<1x64x112x112xf32>
224225
%83 = tosa.reshape %60 {new_shape = array<i64: 1, 64, 1, 1>} : (tensor<64xf32>) -> tensor<1x64x1x1xf32>
225-
%84 = tosa.mul %82, %83 : (tensor<1x64x112x112xf32>, tensor<1x64x1x1xf32>) -> tensor<1x64x112x112xf32>
226+
%84 = tosa.mul %82, %83, %shift : (tensor<1x64x112x112xf32>, tensor<1x64x1x1xf32>, tensor<1xi8>) -> tensor<1x64x112x112xf32>
226227
%85 = tosa.reshape %59 {new_shape = array<i64: 1, 64, 1, 1>} : (tensor<64xf32>) -> tensor<1x64x1x1xf32>
227228
%86 = tosa.add %84, %85 : (tensor<1x64x112x112xf32>, tensor<1x64x1x1xf32>) -> tensor<1x64x112x112xf32>
228229
%87 = tosa.clamp %86 {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x64x112x112xf32>) -> tensor<1x64x112x112xf32>

0 commit comments

Comments
 (0)