@@ -100,43 +100,59 @@ static Value createLinalgBodyCalculationForElementwiseOp(
100100 }
101101
102102 // tosa::MulOp
103- if (isa<tosa::MulOp>(op) && isa<FloatType>(elementTy))
104- return rewriter.create <arith::MulFOp>(loc, resultTypes, args);
105-
106- if (isa<tosa::MulOp>(op) && isa<IntegerType>(elementTy)) {
107- Value a = args[0 ];
108- Value b = args[1 ];
109- auto shift =
110- cast<IntegerAttr>(op->getAttr (" shift" )).getValue ().getSExtValue ();
111- if (shift > 0 ) {
112- auto shiftConst =
113- rewriter.create <arith::ConstantIntOp>(loc, shift, /* bitwidth=*/ 8 );
114- if (!a.getType ().isInteger (32 ))
115- a = rewriter.create <arith::ExtSIOp>(loc, rewriter.getI32Type (), a);
116-
117- if (!b.getType ().isInteger (32 ))
118- b = rewriter.create <arith::ExtSIOp>(loc, rewriter.getI32Type (), b);
119-
120- auto result = rewriter.create <tosa::ApplyScaleOp>(
121- loc, rewriter.getI32Type (), a, b, shiftConst,
122- rewriter.getBoolAttr (false ));
123-
124- if (elementTy.isInteger (32 ))
125- return result;
126-
127- return rewriter.create <arith::TruncIOp>(loc, elementTy, result);
103+ if (isa<tosa::MulOp>(op)) {
104+ auto shift_val = cast<tosa::MulOp>(op).getShift ();
105+ ElementsAttr shift_elem;
106+ if (!shift_val.getImpl () ||
107+ !matchPattern (shift_val, m_Constant (&shift_elem))) {
108+ (void )rewriter.notifyMatchFailure (op, " shift value of mul not found" );
128109 }
129110
130- int aWidth = a.getType ().getIntOrFloatBitWidth ();
131- int bWidth = b.getType ().getIntOrFloatBitWidth ();
132- int cWidth = resultTypes[0 ].getIntOrFloatBitWidth ();
111+ int32_t shift = shift_elem.getValues <IntegerAttr>()[0 ].getInt ();
133112
134- if (aWidth < cWidth)
135- a = rewriter.create <arith::ExtSIOp>(loc, resultTypes[0 ], a);
136- if (bWidth < cWidth)
137- b = rewriter.create <arith::ExtSIOp>(loc, resultTypes[0 ], b);
113+ if (isa<FloatType>(elementTy)) {
114+ if (shift != 0 ) {
115+ (void )rewriter.notifyMatchFailure (op,
116+ " Cannot have shift value for float" );
117+ return nullptr ;
118+ }
119+ return rewriter.create <arith::MulFOp>(loc, resultTypes, args[0 ], args[1 ]);
120+ }
121+
122+ if (isa<IntegerType>(elementTy)) {
123+ Value a = args[0 ];
124+ Value b = args[1 ];
125+
126+ if (shift > 0 ) {
127+ auto shiftConst =
128+ rewriter.create <arith::ConstantIntOp>(loc, shift, /* bitwidth=*/ 8 );
129+ if (!a.getType ().isInteger (32 ))
130+ a = rewriter.create <arith::ExtSIOp>(loc, rewriter.getI32Type (), a);
131+
132+ if (!b.getType ().isInteger (32 ))
133+ b = rewriter.create <arith::ExtSIOp>(loc, rewriter.getI32Type (), b);
134+
135+ auto result = rewriter.create <tosa::ApplyScaleOp>(
136+ loc, rewriter.getI32Type (), a, b, shiftConst,
137+ rewriter.getBoolAttr (false ));
138+
139+ if (elementTy.isInteger (32 ))
140+ return result;
138141
139- return rewriter.create <arith::MulIOp>(loc, resultTypes, a, b);
142+ return rewriter.create <arith::TruncIOp>(loc, elementTy, result);
143+ }
144+
145+ int aWidth = a.getType ().getIntOrFloatBitWidth ();
146+ int bWidth = b.getType ().getIntOrFloatBitWidth ();
147+ int cWidth = resultTypes[0 ].getIntOrFloatBitWidth ();
148+
149+ if (aWidth < cWidth)
150+ a = rewriter.create <arith::ExtSIOp>(loc, resultTypes[0 ], a);
151+ if (bWidth < cWidth)
152+ b = rewriter.create <arith::ExtSIOp>(loc, resultTypes[0 ], b);
153+
154+ return rewriter.create <arith::MulIOp>(loc, resultTypes, a, b);
155+ }
140156 }
141157
142158 // tosa::NegateOp
@@ -990,7 +1006,13 @@ elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands,
9901006 auto loc = operation->getLoc ();
9911007 auto rank =
9921008 cast<RankedTensorType>(operation->getResultTypes ().front ()).getRank ();
993- auto expandedOperands = expandInputRanks (rewriter, loc, operands, rank);
1009+ // For the mul op we need to avoid expanding the rank of the optional shift
1010+ // input.
1011+ auto operandsToExpand =
1012+ isa<tosa::MulOp>(operation) ? operands.take_front (2 ) : operands;
1013+
1014+ auto expandedOperands =
1015+ expandInputRanks (rewriter, loc, operandsToExpand, rank);
9941016 auto [targetShape, masterOperands] =
9951017 computeTargetShape (rewriter, loc, indexPool, expandedOperands);
9961018 auto broadcastOperands = broadcastDynamicDimensions (
0 commit comments