//===- TestXeGPUTransforms.cpp -- Test Vector transforms and lowerings ----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" #include "mlir/Dialect/XeGPU/Transforms/Transforms.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; using namespace mlir::xegpu; namespace { #define DEBUG_TYPE "test-xegpu-unroll" #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") struct TestXeGPUUnrollingPatterns : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestXeGPUUnrollingPatterns) StringRef getArgument() const final { return "test-xegpu-unrolling-patterns"; } StringRef getDescription() const final { return "Test lowering patterns to unroll ops in the xegpu dialect"; } void getDependentDialects(::mlir::DialectRegistry ®istry) const override { registry.insert(); registry.insert(); registry.insert(); } TestXeGPUUnrollingPatterns() = default; TestXeGPUUnrollingPatterns(const TestXeGPUUnrollingPatterns &pass) : PassWrapper(pass) {} void runOnOperation() override { MLIRContext *ctx = &getContext(); xegpu::UnrollOptions options; options.setNativeShapeFn( [&](Operation *op) -> std::optional> { if (isa(op)) { xegpu::TensorDescType tdescTy; if (auto createNdOp = dyn_cast(op)) { tdescTy = createNdOp.getType(); } else if (auto updateNdOp = dyn_cast(op)) { tdescTy = updateNdOp.getTensorDescType(); } else if (auto prefetchNdOp = dyn_cast(op)) { tdescTy = prefetchNdOp.getTensorDescType(); } else if (auto loadNdOp = dyn_cast(op)) { tdescTy = loadNdOp.getTensorDescType(); } else if (auto storeNdOp = dyn_cast(op)) { tdescTy = storeNdOp.getTensorDescType(); } else if (auto createOp = dyn_cast(op)) { tdescTy = createOp.getType(); } else if (auto updateOp = dyn_cast(op)) { tdescTy = updateOp.getTensorDescType(); } else if (auto prefetchOp = dyn_cast(op)) { tdescTy = prefetchOp.getTensorDescType(); } else if (auto loadOp = dyn_cast(op)) { tdescTy = loadOp.getTensorDescType(); } else if (auto storeOp = dyn_cast(op)) { tdescTy = storeOp.getTensorDescType(); } if (auto layout = tdescTy.getLayoutAttr()) { auto inst_data = layout.getInstData(); if (inst_data && layout.isSgLayout()) return SmallVector(inst_data.asArrayRef().begin(), inst_data.asArrayRef().end()); } } if (isa(op)) return SmallVector{8, 16, 16}; return std::nullopt; }); options.setUnrolledTypesFn( [&](ShapedType type, ArrayRef tileShape) -> SmallVector { Type elemTy = type.getElementType(); Type newTy; // TensorDescType needs to drop the inst_data field in the layout // attribute if (auto tdescTy = dyn_cast(type)) { Attribute encoding = tdescTy.getEncoding(); auto layout = llvm::dyn_cast_if_present( tdescTy.getLayout()); // If the encoding is a ScatterTensorDescAttr, we need to // potentially adjust the chunk size based on the inst_data. if (encoding && mlir::isa(encoding)) { auto scatterAttr = mlir::dyn_cast(encoding); int64_t chunkSize = scatterAttr.getChunkSize().getInt(); if (chunkSize > 1) { int64_t blockedChunkSize = chunkSize; auto instData = layout.getInstData(); if (!instData.empty()) blockedChunkSize = instData.asArrayRef().back(); auto chunkSizeAttr = mlir::IntegerAttr::get( mlir::IntegerType::get(ctx, 64), blockedChunkSize); // To create a new attribute with a different chunk_size: auto newEncoding = xegpu::ScatterTensorDescAttr::get( ctx, scatterAttr.getMemorySpace(), chunkSizeAttr); encoding = newEncoding; } } if (layout) { if (layout.getLaneLayout() == nullptr) layout = xegpu::LayoutAttr(); else layout = layout.dropInstData(); } newTy = xegpu::TensorDescType::get(ctx, tileShape, elemTy, encoding, layout); } else { newTy = type.clone(tileShape, elemTy); } std::optional> ratio = computeShapeRatio(type.getShape(), tileShape); assert(ratio && "Expecting the ratio to be valid."); return SmallVector(computeProduct(*ratio), newTy); }); RewritePatternSet patterns(ctx); populateXeGPUUnrollPatterns(patterns, options); (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; } // namespace namespace mlir { namespace test { void registerTestXeGPULowerings() { PassRegistration(); } } // namespace test } // namespace mlir