//===- TestLinalgCodegenStrategy.cpp - Test Linalg codegen strategy -------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements logic for testing the Linalg codegen strategy. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/Transforms/CodegenStrategy.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "llvm/ADT/SetVector.h" using namespace mlir; using namespace mlir::linalg; namespace { struct TestLinalgCodegenStrategy : public PassWrapper { StringRef getArgument() const final { return "test-linalg-codegen-strategy"; } StringRef getDescription() const final { return "Test Linalg Codegen Strategy."; } TestLinalgCodegenStrategy() = default; TestLinalgCodegenStrategy(const TestLinalgCodegenStrategy &pass) {} void getDependentDialects(DialectRegistry ®istry) const override { // clang-format off registry.insert(); // clang-format on } template void applyStrategyToNamedLinalgOp(); void runOnFunction() override; void runStrategy(LinalgTilingOptions tilingOptions, LinalgTilingOptions registerTilingOptions, LinalgPaddingOptions paddingOptions, vector::VectorContractLowering vectorContractLowering, vector::VectorTransferSplit vectorTransferSplit); ListOption tileSizes{*this, "tile-sizes", llvm::cl::MiscFlags::CommaSeparated, llvm::cl::desc("Specifies the tile sizes.")}; ListOption tileInterchange{ *this, "tile-interchange", llvm::cl::MiscFlags::CommaSeparated, llvm::cl::desc("Specifies the tile interchange.")}; Option promote{ *this, "promote", llvm::cl::desc("Promote the tile into a small aligned memory buffer."), llvm::cl::init(false)}; Option promoteFullTile{ *this, "promote-full-tile-pad", llvm::cl::desc("Pad the small aligned memory buffer to the tile sizes."), llvm::cl::init(false)}; ListOption registerTileSizes{ *this, "register-tile-sizes", llvm::cl::MiscFlags::CommaSeparated, llvm::cl::desc( "Specifies the size of the register tile that will be used " " to vectorize")}; Option registerPromote{ *this, "register-promote", llvm::cl::desc( "Promote the register tile into a small aligned memory buffer."), llvm::cl::init(false)}; Option registerPromoteFullTile{ *this, "register-promote-full-tile-pad", llvm::cl::desc("Pad the small aligned memory buffer to the tile sizes."), llvm::cl::init(false)}; Option pad{*this, "pad", llvm::cl::desc("Pad the operands."), llvm::cl::init(false)}; ListOption packPaddings{ *this, "pack-paddings", llvm::cl::desc("Operand packing flags when test-pad-pattern"), llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; ListOption hoistPaddings{ *this, "hoist-paddings", llvm::cl::desc("Operand hoisting depths when test-pad-pattern"), llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; Option generalize{*this, "generalize", llvm::cl::desc("Generalize named operations."), llvm::cl::init(false)}; ListOption iteratorInterchange{ *this, "iterator-interchange", llvm::cl::MiscFlags::CommaSeparated, llvm::cl::desc("Specifies the iterator interchange.")}; Option vectorize{ *this, "vectorize", llvm::cl::desc("Rewrite the linalg op as a vector operation."), llvm::cl::init(false)}; Option splitVectorTransfersTo{ *this, "split-transfers", llvm::cl::desc( "Split vector transfers between slow (masked) and fast " "(unmasked) variants. Possible options are:\n" "\tnone: keep unsplit vector.transfer and pay the full price\n" "\tlinalg-copy: use linalg.fill + linalg.copy for the slow path\n" "\tvector-transfers: use extra small unmasked vector.transfer for" " the slow path\n"), llvm::cl::init("none")}; Option vectorizeContractionTo{ *this, "vectorize-contraction-to", llvm::cl::desc("the type of vector op to use for linalg contractions"), llvm::cl::init("outerproduct")}; Option unrollVectorTransfers{ *this, "unroll-vector-transfers", llvm::cl::desc("Enable full unrolling of vector.transfer operations"), llvm::cl::init(false)}; Option runEnablePass{ *this, "run-enable-pass", llvm::cl::desc("Run the enable pass between transformations"), llvm::cl::init(true)}; Option anchorOpName{ *this, "anchor-op", llvm::cl::desc( "Which single linalg op is the anchor for the codegen strategy to " "latch on:\n" "\tlinalg.matmul: anchor on linalg.matmul\n" "\tlinalg.matmul_column_major: anchor on linalg.matmul_column_major\n" "\tlinalg.copy: anchor on linalg.copy\n" "\tlinalg.fill: anchor on linalg.fill\n"), llvm::cl::init("")}; Option anchorFuncOpName{ *this, "anchor-func", llvm::cl::desc( "Which single func op is the anchor for the codegen strategy to " "latch on."), llvm::cl::init("")}; }; void TestLinalgCodegenStrategy::runStrategy( LinalgTilingOptions tilingOptions, LinalgTilingOptions registerTilingOptions, LinalgPaddingOptions paddingOptions, vector::VectorContractLowering vectorContractLowering, vector::VectorTransferSplit vectorTransferSplit) { assert(!anchorOpName.empty()); CodegenStrategy strategy; StringRef genericOpName = GenericOp::getOperationName(); strategy.tileIf(!tileSizes.empty(), anchorOpName, tilingOptions) .promoteIf(promote, anchorOpName, LinalgPromotionOptions() .setAlignment(16) .setUseFullTileBuffersByDefault(promoteFullTile)) .tileIf(!registerTileSizes.empty(), anchorOpName, registerTilingOptions) .promoteIf(registerPromote, anchorOpName, LinalgPromotionOptions() .setAlignment(16) .setUseFullTileBuffersByDefault(registerPromoteFullTile)) .padIf(pad, anchorOpName, paddingOptions) .generalizeIf(generalize, anchorOpName) .interchangeIf(!iteratorInterchange.empty(), iteratorInterchange) .vectorizeIf(vectorize, generalize ? genericOpName : anchorOpName) .vectorLowering( LinalgVectorLoweringOptions() .setVectorTransformsOptions( vector::VectorTransformsOptions() .setVectorTransformsOptions(vectorContractLowering) .setVectorTransferSplit(vectorTransferSplit)) .setVectorTransferToSCFOptions( VectorTransferToSCFOptions().enableFullUnroll( unrollVectorTransfers)) .enableTransferPartialRewrite() .enableContractionLowering() .enableTransferToSCFConversion()); // Created a nested OpPassManager and run. FuncOp funcOp = getFunction(); OpPassManager dynamicPM("builtin.func"); strategy.configurePassPipeline(dynamicPM, funcOp.getContext(), runEnablePass); if (failed(runPipeline(dynamicPM, funcOp))) return signalPassFailure(); } } // end anonymous namespace // For now, just assume it is the zero of type. // In the future, it should be the zero of type + op. static Value getNeutralOfLinalgOp(OpBuilder &b, OpOperand &op) { auto t = getElementTypeOrSelf(op.get()); return b.create(op.getOwner()->getLoc(), t, b.getZeroAttr(t)); } /// Apply transformations specified as patterns. void TestLinalgCodegenStrategy::runOnFunction() { if (!anchorFuncOpName.empty() && anchorFuncOpName != getFunction().getName()) return; LinalgTilingOptions tilingOptions; if (!tileSizes.empty()) tilingOptions = tilingOptions.setTileSizes(tileSizes); if (!tileInterchange.empty()) tilingOptions = tilingOptions.setInterchange(tileInterchange); LinalgTilingOptions registerTilingOptions; if (!registerTileSizes.empty()) registerTilingOptions = registerTilingOptions.setTileSizes(registerTileSizes); LinalgPaddingOptions paddingOptions; auto packFunc = [&](OpOperand &opOperand) { return opOperand.getOperandNumber() < packPaddings.size() ? packPaddings[opOperand.getOperandNumber()] : false; }; auto hoistingFunc = [&](OpOperand &opOperand) { return opOperand.getOperandNumber() < hoistPaddings.size() ? hoistPaddings[opOperand.getOperandNumber()] : 0; }; paddingOptions.setPaddingValueComputationFunction(getNeutralOfLinalgOp); paddingOptions.setPaddingNoFoldComputationFunction(packFunc); paddingOptions.setPaddingHoistComputationFunction(hoistingFunc); vector::VectorContractLowering vectorContractLowering = llvm::StringSwitch( vectorizeContractionTo.getValue()) .Case("matrixintrinsics", vector::VectorContractLowering::Matmul) .Case("dot", vector::VectorContractLowering::Dot) .Case("outerproduct", vector::VectorContractLowering::OuterProduct) .Default(vector::VectorContractLowering::OuterProduct); vector::VectorTransferSplit vectorTransferSplit = llvm::StringSwitch( splitVectorTransfersTo.getValue()) .Case("none", vector::VectorTransferSplit::None) .Case("linalg-copy", vector::VectorTransferSplit::LinalgCopy) .Case("vector-transfers", vector::VectorTransferSplit::VectorTransfer) .Default(vector::VectorTransferSplit::None); runStrategy(tilingOptions, registerTilingOptions, paddingOptions, vectorContractLowering, vectorTransferSplit); } namespace mlir { namespace test { void registerTestLinalgCodegenStrategy() { PassRegistration(); } } // namespace test } // namespace mlir