@@ -154,6 +154,9 @@ class CreateNdDescToXeVMPattern
154154 matchAndRewrite (xegpu::CreateNdDescOp op,
155155 xegpu::CreateNdDescOp::Adaptor adaptor,
156156 ConversionPatternRewriter &rewriter) const override {
157+ SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets ();
158+ if (mixedOffsets.size () != 0 )
159+ return rewriter.notifyMatchFailure (op, " Offsets not supported." );
157160 auto loc = op.getLoc ();
158161 auto source = op.getSource ();
159162 // Op is lowered to a code sequence that populates payload.
@@ -177,7 +180,6 @@ class CreateNdDescToXeVMPattern
177180
178181 // Source can be a memref or a pointer (ui64, ui32, i64 or i32).
179182 SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes ();
180- SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets ();
181183 // Descriptor shape is expected to be 2D.
182184 int64_t rank = mixedSizes.size ();
183185 if (rank != 2 )
@@ -202,17 +204,9 @@ class CreateNdDescToXeVMPattern
202204 val = getValueOrCreateCastToIndexLike (rewriter, loc, payloadElemTy, val);
203205 return val;
204206 };
205- // Offsets can be either 2D or not provided (0 is used).
206- if (mixedOffsets.size () == 2 ) {
207- offsetW = createOffset (mixedOffsets, 1 );
208- offsetH = createOffset (mixedOffsets, 0 );
209- } else if (mixedOffsets.size () == 0 ) {
210- offsetW = arith::ConstantIntOp::create (rewriter, loc, payloadElemTy, 0 );
211- offsetH = arith::ConstantIntOp::create (rewriter, loc, payloadElemTy, 0 );
212- } else {
213- return rewriter.notifyMatchFailure (op,
214- " Expected 2D offsets or no offsets." );
215- }
207+ // Offsets are not supported (0 is used).
208+ offsetW = arith::ConstantIntOp::create (rewriter, loc, payloadElemTy, 0 );
209+ offsetH = arith::ConstantIntOp::create (rewriter, loc, payloadElemTy, 0 );
216210 // Get shape values from op fold results.
217211 baseShapeW = createOffset (mixedSizes, 1 );
218212 baseShapeH = createOffset (mixedSizes, 0 );
@@ -247,39 +241,6 @@ class CreateNdDescToXeVMPattern
247241 }
248242};
249243
250- class UpdateNdOffsetToXeVMPattern
251- : public OpConversionPattern<xegpu::UpdateNdOffsetOp> {
252- using OpConversionPattern::OpConversionPattern;
253- LogicalResult
254- matchAndRewrite (xegpu::UpdateNdOffsetOp op,
255- xegpu::UpdateNdOffsetOp::Adaptor adaptor,
256- ConversionPatternRewriter &rewriter) const override {
257- auto loc = op.getLoc ();
258- auto mixedOffsets = op.getMixedOffsets ();
259- // Only 2D offsets are supported for now.
260- if (mixedOffsets.size () != 2 )
261- return rewriter.notifyMatchFailure (op, " Expected 2D offsets." );
262- auto payload = adaptor.getTensorDesc ();
263- // Utility for updating payload offset values from op fold result.
264- auto updateOffset = [&](unsigned idx, int payloadPos) -> Value {
265- Value offset =
266- getValueOrCreateConstantIntOp (rewriter, loc, mixedOffsets[idx]);
267- offset = getValueOrCreateCastToIndexLike (rewriter, loc,
268- rewriter.getI32Type (), offset);
269- Value oldOffset =
270- vector::ExtractOp::create (rewriter, loc, payload, payloadPos);
271- Value newOffset = arith::AddIOp::create (rewriter, loc, oldOffset, offset);
272- return vector::InsertOp::create (rewriter, loc, newOffset, payload,
273- payloadPos);
274- };
275- // Update offsets in the payload.
276- payload = updateOffset (0 , static_cast <int >(NdTdescOffset::TensorOffsetH));
277- payload = updateOffset (1 , static_cast <int >(NdTdescOffset::TensorOffsetW));
278- rewriter.replaceOp (op, payload);
279- return success ();
280- }
281- };
282-
283244template <
284245 typename OpType,
285246 typename = std::enable_if_t <llvm::is_one_of<
@@ -289,6 +250,10 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
289250 LogicalResult
290251 matchAndRewrite (OpType op, typename OpType::Adaptor adaptor,
291252 ConversionPatternRewriter &rewriter) const override {
253+ auto mixedOffsets = op.getMixedOffsets ();
254+ int64_t opOffsetsSize = mixedOffsets.size ();
255+ if (opOffsetsSize != 2 )
256+ return rewriter.notifyMatchFailure (op, " Expected 2D offsets." );
292257 auto loc = op.getLoc ();
293258 auto ctxt = rewriter.getContext ();
294259
@@ -311,32 +276,16 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
311276 rewriter, loc, tdesc, static_cast <int >(NdTdescOffset::BaseShapeW));
312277 Value baseShapeH = vector::ExtractOp::create (
313278 rewriter, loc, tdesc, static_cast <int >(NdTdescOffset::BaseShapeH));
314- // Offsets provided in two ways:
315- // 1. Offsets are extracted from the tensor descriptor.
316- // 2. (Mixed) offsets which are provided by the op.
317- Value offsetW;
318- Value offsetH;
319- auto mixedOffsets = op.getMixedOffsets ();
320- int64_t opOffsetsSize = mixedOffsets.size ();
321- if (opOffsetsSize != 0 && opOffsetsSize != 2 )
322- return rewriter.notifyMatchFailure (op,
323- " Expected 2D offsets or no offsets." );
324- if (opOffsetsSize) {
325- // If mixed offsets are provided by the op convert them to i32.
326- offsetW = getValueOrCreateConstantIntOp (rewriter, loc, mixedOffsets[1 ]);
327- offsetW = getValueOrCreateCastToIndexLike (rewriter, loc,
328- rewriter.getI32Type (), offsetW);
329- offsetH = getValueOrCreateConstantIntOp (rewriter, loc, mixedOffsets[0 ]);
330- offsetH = getValueOrCreateCastToIndexLike (rewriter, loc,
331- rewriter.getI32Type (), offsetH);
332- } else {
333- // If offsets are not available, we need to extract them from the tensor
334- // descriptor.
335- offsetW = vector::ExtractOp::create (
336- rewriter, loc, tdesc, static_cast <int >(NdTdescOffset::TensorOffsetW));
337- offsetH = vector::ExtractOp::create (
338- rewriter, loc, tdesc, static_cast <int >(NdTdescOffset::TensorOffsetH));
339- }
279+ // Offsets are provided by the op.
280+ // convert them to i32.
281+ Value offsetW =
282+ getValueOrCreateConstantIntOp (rewriter, loc, mixedOffsets[1 ]);
283+ offsetW = getValueOrCreateCastToIndexLike (rewriter, loc,
284+ rewriter.getI32Type (), offsetW);
285+ Value offsetH =
286+ getValueOrCreateConstantIntOp (rewriter, loc, mixedOffsets[0 ]);
287+ offsetH = getValueOrCreateCastToIndexLike (rewriter, loc,
288+ rewriter.getI32Type (), offsetH);
340289 // Get address space from tensor descriptor memory space.
341290 auto ptrTypeLLVM = LLVM::LLVMPointerType::get (
342291 ctxt, getNumericXeVMAddrSpace (tdescTy.getMemorySpace ()));
@@ -422,54 +371,6 @@ static Value addOffset(ConversionPatternRewriter &rewriter, Location loc,
422371 return newAddr;
423372}
424373
425- class CreateDescToXeVMPattern
426- : public OpConversionPattern<xegpu::CreateDescOp> {
427- using OpConversionPattern::OpConversionPattern;
428- LogicalResult
429- matchAndRewrite (xegpu::CreateDescOp op, xegpu::CreateDescOp::Adaptor adaptor,
430- ConversionPatternRewriter &rewriter) const override {
431- auto eTy = op.getTensorDescType ().getElementType ();
432- auto eBw = eTy.getIntOrFloatBitWidth ();
433- if (eBw % 8 != 0 )
434- return rewriter.notifyMatchFailure (
435- op, " Expected element type bit width to be multiple of 8." );
436- auto loc = op.getLoc ();
437- // Offsets are provided as scalar i64 by type converter.
438- auto offsets = adaptor.getOffsets ();
439- // Source type can be a 1D memref or pointer type (ui64, ui32, i64 or i32).
440- // But type converter will convert them to integer types.
441- Value addr = adaptor.getSource ();
442- // ui32 or i32 are passed as i32 so they need to be casted to i64.
443- if (addr.getType () != rewriter.getI64Type ())
444- addr = arith::ExtUIOp::create (rewriter, loc, rewriter.getI64Type (), addr);
445- auto laneAddr = addOffset (rewriter, loc, addr, offsets, eBw / 8 );
446- rewriter.replaceOp (op, laneAddr);
447- return success ();
448- }
449- };
450-
451- class UpdateOffsetToXeVMPattern
452- : public OpConversionPattern<xegpu::UpdateOffsetOp> {
453- using OpConversionPattern::OpConversionPattern;
454- LogicalResult
455- matchAndRewrite (xegpu::UpdateOffsetOp op,
456- xegpu::UpdateOffsetOp::Adaptor adaptor,
457- ConversionPatternRewriter &rewriter) const override {
458- auto eTy = op.getTensorDescType ().getElementType ();
459- auto eBw = eTy.getIntOrFloatBitWidth ();
460- if (eBw % 8 != 0 )
461- return rewriter.notifyMatchFailure (
462- op, " Expected element type bit width to be multiple of 8." );
463- auto loc = op.getLoc ();
464- // Scatter descriptor is provided as scalar i64 by type converter.
465- // Offsets are provided as scalar i64 by type converter.
466- Value newOffset = addOffset (rewriter, loc, adaptor.getTensorDesc (),
467- adaptor.getOffsets (), eBw / 8 );
468- rewriter.replaceOp (op, newOffset);
469- return success ();
470- }
471- };
472-
473374template <typename OpType,
474375 typename = std::enable_if_t <llvm::is_one_of<
475376 OpType, xegpu::LoadGatherOp, xegpu::StoreScatterOp>::value>>
@@ -478,6 +379,9 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
478379 LogicalResult
479380 matchAndRewrite (OpType op, typename OpType::Adaptor adaptor,
480381 ConversionPatternRewriter &rewriter) const override {
382+ Value offset = adaptor.getOffsets ();
383+ if (!offset)
384+ return rewriter.notifyMatchFailure (op, " Expected offset to be provided." );
481385 auto loc = op.getLoc ();
482386 auto ctxt = rewriter.getContext ();
483387 auto tdescTy = op.getTensorDescType ();
@@ -527,21 +431,16 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
527431 basePtrI64 = arith::ExtUIOp::create (rewriter, loc, rewriter.getI64Type (),
528432 basePtrI64);
529433 }
530- Value offsets = adaptor.getOffsets ();
531434 Value mask = adaptor.getMask ();
532- if (offsets) {
533- if (dyn_cast<VectorType>(offsets.getType ())) {
534- // Offset needs be scalar. Single element vector is converted to scalar
535- // by type converter.
536- return rewriter.notifyMatchFailure (op,
537- " Expected offsets to be a scalar." );
538- } else {
539- // If offsets are provided, we add them to the base pointer.
540- // Offsets are in number of elements, we need to multiply by
541- // element byte size.
542- basePtrI64 =
543- addOffset (rewriter, loc, basePtrI64, offsets, elemByteSize);
544- }
435+ if (dyn_cast<VectorType>(offset.getType ())) {
436+ // Offset needs be scalar. Single element vector is converted to scalar
437+ // by type converter.
438+ return rewriter.notifyMatchFailure (op, " Expected offset to be a scalar." );
439+ } else {
440+ // If offset is provided, we add them to the base pointer.
441+ // Offset is in number of elements, we need to multiply by
442+ // element byte size.
443+ basePtrI64 = addOffset (rewriter, loc, basePtrI64, offset, elemByteSize);
545444 }
546445 // Convert base pointer (i64) to LLVM pointer type.
547446 Value basePtrLLVM =
@@ -1011,13 +910,12 @@ struct ConvertXeGPUToXeVMPass
1011910// ===----------------------------------------------------------------------===//
1012911void mlir::populateXeGPUToXeVMConversionPatterns (
1013912 const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1014- patterns.add <CreateNdDescToXeVMPattern, UpdateNdOffsetToXeVMPattern,
913+ patterns.add <CreateNdDescToXeVMPattern,
1015914 LoadStorePrefetchNdToXeVMPattern<xegpu::LoadNdOp>,
1016915 LoadStorePrefetchNdToXeVMPattern<xegpu::StoreNdOp>,
1017916 LoadStorePrefetchNdToXeVMPattern<xegpu::PrefetchNdOp>>(
1018917 typeConverter, patterns.getContext ());
1019- patterns.add <CreateDescToXeVMPattern, UpdateOffsetToXeVMPattern,
1020- AtomicRMWToXeVMPattern, PrefetchToXeVMPattern,
918+ patterns.add <AtomicRMWToXeVMPattern, PrefetchToXeVMPattern,
1021919 LoadStoreToXeVMPattern<xegpu::LoadGatherOp>,
1022920 LoadStoreToXeVMPattern<xegpu::StoreScatterOp>>(
1023921 typeConverter, patterns.getContext ());
0 commit comments