2727#include " mlir/IR/OpImplementation.h"
2828#include " mlir/IR/TypeUtilities.h"
2929#include " mlir/Support/LLVM.h"
30+ #include " llvm/ADT/StringSet.h"
3031
3132using namespace mlir ;
3233using namespace mlir ::vector;
@@ -56,7 +57,10 @@ static ParseResult parseContractionOp(OpAsmParser &parser,
5657 SmallVector<Type, 2 > types;
5758 Type resultVectorType;
5859 auto loc = parser.getCurrentLocation ();
59- if (parser.parseOperand (lhsInfo) || parser.parseComma () ||
60+ DictionaryAttr dictAttr;
61+ // TODO(andydavis, ntv) Unify linalg op attribute parsing.
62+ if (parser.parseAttribute (dictAttr, " _" , result.attributes ) ||
63+ parser.parseOperand (lhsInfo) || parser.parseComma () ||
6064 parser.parseOperand (rhsInfo) || parser.parseComma () ||
6165 parser.parseOperand (accInfo) ||
6266 parser.parseTrailingOperandList (masksInfo) ||
@@ -68,7 +72,8 @@ static ParseResult parseContractionOp(OpAsmParser &parser,
6872 parser.resolveOperand (accInfo, resultVectorType, result.operands ) ||
6973 parser.addTypeToList (resultVectorType, result.types ))
7074 return failure ();
71-
75+ result.attributes .assign (dictAttr.getValue ().begin (),
76+ dictAttr.getValue ().end ());
7277 if (masksInfo.empty ())
7378 return success ();
7479 if (masksInfo.size () != 2 )
@@ -90,13 +95,23 @@ static ParseResult parseContractionOp(OpAsmParser &parser,
9095}
9196
9297static void print (OpAsmPrinter &p, ContractionOp op) {
93- p << op.getOperationName () << " " << *op.lhs () << " , " << *op.rhs ();
94- p << " , " << *op.acc ();
98+ // TODO(andydavis, ntv) Unify printing code with linalg ops.
99+ auto attrNames = op.getTraitAttrNames ();
100+ llvm::StringSet<> traitAttrsSet;
101+ traitAttrsSet.insert (attrNames.begin (), attrNames.end ());
102+ SmallVector<NamedAttribute, 8 > attrs;
103+ for (auto attr : op.getAttrs ()) {
104+ if (traitAttrsSet.count (attr.first .strref ()) > 0 )
105+ attrs.push_back (attr);
106+ }
107+ auto dictAttr = DictionaryAttr::get (attrs, op.getContext ());
108+ p << op.getOperationName () << " " << dictAttr << " " << *op.lhs () << " , " ;
109+ p << *op.rhs () << " , " << *op.acc ();
95110 if (llvm::size (op.masks ()) == 2 ) {
96111 p << " , " << **op.masks ().begin ();
97112 p << " , " << **(op.masks ().begin () + 1 );
98113 }
99- p.printOptionalAttrDict (op.getAttrs ());
114+ p.printOptionalAttrDict (op.getAttrs (), attrNames );
100115 p << " : " << op.lhs ()->getType () << " , " << op.rhs ()->getType () << " into "
101116 << op.getResultType ();
102117}
@@ -159,6 +174,34 @@ static LogicalResult verify(ContractionOp op) {
159174 auto rhsType = op.getRhsType ();
160175 auto accType = op.getAccType ();
161176 auto resType = op.getResultType ();
177+
178+ // Verify that an indexing map was specified for each vector operand.
179+ if (op.indexing_maps ().size () != 3 )
180+ return op.emitOpError (" expected an indexing map for each vector operand" );
181+
182+ // Verify that each index map has 'numIterators' inputs, no symbols, and
183+ // that the number of map outputs equals the rank of its associated
184+ // vector operand.
185+ unsigned numIterators = op.iterator_types ().getValue ().size ();
186+ for (auto it : llvm::enumerate (op.indexing_maps ())) {
187+ auto index = it.index ();
188+ auto map = it.value ().cast <AffineMapAttr>().getValue ();
189+ if (map.getNumSymbols () != 0 )
190+ return op.emitOpError (" expected indexing map " )
191+ << index << " to have no symbols" ;
192+ if (map.getNumDims () != numIterators)
193+ return op.emitOpError (" expected indexing map " )
194+ << index << " to have " << numIterators << " number of inputs" ;
195+ auto operandType = op.getOperand (index)->getType ().cast <VectorType>();
196+ unsigned rank = operandType.getShape ().size ();
197+ if (map.getNumResults () != rank)
198+ return op.emitOpError (" expected indexing map " )
199+ << index << " to have " << rank << " number of outputs" ;
200+ if (!map.isProjectedPermutation ())
201+ return op.emitOpError (" expected indexing map " )
202+ << index << " to be a projected permutation of its inputs" ;
203+ }
204+
162205 auto contractingDimMap = op.getContractingDimMap ();
163206 auto batchDimMap = op.getBatchDimMap ();
164207
@@ -198,27 +241,54 @@ static LogicalResult verify(ContractionOp op) {
198241 return success ();
199242}
200243
201- static std::vector<std::pair<int64_t , int64_t >> getDimMap (Attribute attr) {
244+ SmallVector<StringRef, 2 > ContractionOp::getTraitAttrNames () {
245+ return SmallVector<StringRef, 2 >{" indexing_maps" , " iterator_types" };
246+ }
247+
248+ static int64_t getResultIndex (AffineMap map, AffineExpr targetExpr) {
249+ for (int64_t i = 0 , e = map.getNumResults (); i < e; ++i)
250+ if (targetExpr == map.getResult (i))
251+ return i;
252+ return -1 ;
253+ }
254+
255+ static std::vector<std::pair<int64_t , int64_t >>
256+ getDimMap (ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
257+ StringRef targetIteratorTypeName, MLIRContext *context) {
202258 std::vector<std::pair<int64_t , int64_t >> dimMap;
203- auto dimPairs = attr.dyn_cast_or_null <ArrayAttr>();
204- if (!dimPairs)
205- return dimMap;
206- for (auto dimPairAttr : dimPairs) {
207- auto dimPair = dimPairAttr.cast <ArrayAttr>();
208- assert (dimPair.size () == 2 );
209- auto lhsDim = dimPair.begin ()->cast <IntegerAttr>().getInt ();
210- auto rhsDim = std::prev (dimPair.end ())->cast <IntegerAttr>().getInt ();
211- dimMap.push_back ({lhsDim, rhsDim});
259+ for (auto it : llvm::enumerate (iteratorTypes)) {
260+ auto iteratorTypeName = it.value ().cast <StringAttr>().getValue ();
261+ if (iteratorTypeName != targetIteratorTypeName)
262+ continue ;
263+ // Search lhs/rhs map results for 'targetExpr'.
264+ auto targetExpr = getAffineDimExpr (it.index (), context);
265+ int64_t lhsDim = getResultIndex (indexingMaps[0 ], targetExpr);
266+ int64_t rhsDim = getResultIndex (indexingMaps[1 ], targetExpr);
267+ if (lhsDim >= 0 && rhsDim >= 0 )
268+ dimMap.push_back ({lhsDim, rhsDim});
212269 }
213270 return dimMap;
214271}
215272
216273std::vector<std::pair<int64_t , int64_t >> ContractionOp::getContractingDimMap () {
217- return getDimMap (getAttr (getContractingDimMapAttrName ()));
274+ SmallVector<AffineMap, 4 > indexingMaps (getIndexingMaps ());
275+ return getDimMap (indexingMaps, iterator_types (),
276+ getReductionIteratorTypeName (), getContext ());
218277}
219278
220279std::vector<std::pair<int64_t , int64_t >> ContractionOp::getBatchDimMap () {
221- return getDimMap (getAttr (getBatchDimMapAttrName ()));
280+ SmallVector<AffineMap, 4 > indexingMaps (getIndexingMaps ());
281+ return getDimMap (indexingMaps, iterator_types (),
282+ getParallelIteratorTypeName (), getContext ());
283+ }
284+
285+ SmallVector<AffineMap, 4 > ContractionOp::getIndexingMaps () {
286+ SmallVector<AffineMap, 4 > res;
287+ auto mapAttrs = indexing_maps ().getValue ();
288+ res.reserve (mapAttrs.size ());
289+ for (auto mapAttr : mapAttrs)
290+ res.push_back (mapAttr.cast <AffineMapAttr>().getValue ());
291+ return res;
222292}
223293
224294// ===----------------------------------------------------------------------===//
0 commit comments