Add the decompose patterns that lower higher dimensional convolutions to lower dimensional ones to CodegenStrategy and use CodegenStrategy to test the decompose patterns. Additionally, remove the assertion that checks the anchor op name is set in the CodegenStrategyTest pass. Removing the assertion allows us to simplify the pipelines used in the interchange and decompose tests. Depends On D114797 Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D114798
728 lines
30 KiB
C++
728 lines
30 KiB
C++
//===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements logic for testing Linalg transformations.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
|
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
|
#include "mlir/Dialect/GPU/GPUDialect.h"
|
|
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
|
#include "mlir/Dialect/Linalg/Passes.h"
|
|
#include "mlir/Dialect/Linalg/Transforms/HoistPadding.h"
|
|
#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
|
|
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
|
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
|
#include "mlir/Dialect/Vector/VectorOps.h"
|
|
#include "mlir/Pass/PassManager.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
|
|
#include "llvm/ADT/SetVector.h"
|
|
#include "llvm/ADT/SmallVector.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::linalg;
|
|
|
|
namespace {
|
|
struct TestLinalgTransforms
|
|
: public PassWrapper<TestLinalgTransforms, FunctionPass> {
|
|
TestLinalgTransforms() = default;
|
|
TestLinalgTransforms(const TestLinalgTransforms &pass) {}
|
|
|
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
|
// clang-format off
|
|
registry.insert<AffineDialect,
|
|
memref::MemRefDialect,
|
|
scf::SCFDialect,
|
|
StandardOpsDialect,
|
|
vector::VectorDialect,
|
|
gpu::GPUDialect>();
|
|
// clang-format on
|
|
}
|
|
StringRef getArgument() const final {
|
|
return "test-linalg-transform-patterns";
|
|
}
|
|
StringRef getDescription() const final {
|
|
return "Test Linalg transformation patterns by applying them greedily.";
|
|
}
|
|
|
|
void runOnFunction() override;
|
|
|
|
Option<bool> testPatterns{*this, "test-patterns",
|
|
llvm::cl::desc("Test a mixed set of patterns"),
|
|
llvm::cl::init(false)};
|
|
Option<bool> testMatmulToVectorPatterns1dTiling{
|
|
*this, "test-matmul-to-vector-patterns-tile-1d",
|
|
llvm::cl::desc(
|
|
"Test a fused pass that applies patterns from matmul to vectors via "
|
|
"1-d tiling"),
|
|
llvm::cl::init(false)};
|
|
Option<bool> testMatmulToVectorPatterns2dTiling{
|
|
*this, "test-matmul-to-vector-patterns-tile-2d",
|
|
llvm::cl::desc(
|
|
"Test a fused pass that applies patterns from matmul to vectors via "
|
|
"2-d tiling"),
|
|
llvm::cl::init(false)};
|
|
Option<bool> testPromotionOptions{*this, "test-linalg-promotion-options",
|
|
llvm::cl::desc("Test promotion options"),
|
|
llvm::cl::init(false)};
|
|
Option<bool> testTileAndDistributionOptions{
|
|
*this, "test-tile-and-distribute-options",
|
|
llvm::cl::desc("Test tile and distribute options"),
|
|
llvm::cl::init(false)};
|
|
Option<bool> testVectorTransferForwardingPatterns{
|
|
*this, "test-vector-transfer-forwarding-patterns",
|
|
llvm::cl::desc(
|
|
"Test a fused pass that forwards linalg.copy to vector.transfer"),
|
|
llvm::cl::init(false)};
|
|
Option<bool> testGenericToVectorPattern{
|
|
*this, "test-linalg-to-vector-patterns",
|
|
llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction "
|
|
"in vector.contract form"),
|
|
llvm::cl::init(false)};
|
|
Option<bool> testTilePattern{*this, "test-tile-pattern",
|
|
llvm::cl::desc("Test tile pattern"),
|
|
llvm::cl::init(false)};
|
|
Option<bool> testTileScalarizeDynamicDims{
|
|
*this, "test-tile-scalarize-dynamic-dims",
|
|
llvm::cl::desc("Test tiling of dynamic dims by 1"),
|
|
llvm::cl::init(false)};
|
|
Option<bool> testTransformPadTensor{
|
|
*this, "test-transform-pad-tensor",
|
|
llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
|
|
llvm::cl::init(false)};
|
|
Option<bool> testGeneralizePadTensor{
|
|
*this, "test-generalize-pad-tensor",
|
|
llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
|
|
llvm::cl::init(false)};
|
|
Option<bool> testSwapSubTensorPadTensor{
|
|
*this, "test-swap-subtensor-padtensor",
|
|
llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into "
|
|
"pad_tensor(subtensor)"),
|
|
llvm::cl::init(false)};
|
|
ListOption<int64_t> peeledLoops{
|
|
*this, "peeled-loops",
|
|
llvm::cl::desc("Loops to be peeled when test-tile-pattern"),
|
|
llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
|
|
ListOption<int64_t> tileSizes{
|
|
*this, "tile-sizes",
|
|
llvm::cl::desc("Linalg tile sizes for test-tile-pattern"),
|
|
llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
|
|
ListOption<unsigned> testTiledLoopPeeling{
|
|
*this, "test-tiled-loop-peeling",
|
|
llvm::cl::desc("Test peeling of linalg.tiled_loop ops"),
|
|
llvm::cl::OneOrMore, llvm::cl::MiscFlags::CommaSeparated};
|
|
Option<bool> skipPartial{
|
|
*this, "skip-partial",
|
|
llvm::cl::desc("Skip loops inside partial iterations during peeling"),
|
|
llvm::cl::init(false)};
|
|
Option<std::string> loopType{
|
|
*this, "loop-type",
|
|
llvm::cl::desc("Specify the type of loops to generate: for, parallel or "
|
|
"tiled_loop"),
|
|
llvm::cl::init("for")};
|
|
};
|
|
} // end anonymous namespace
|
|
|
|
static void applyPatterns(FuncOp funcOp) {
|
|
MLIRContext *ctx = funcOp.getContext();
|
|
RewritePatternSet patterns(ctx);
|
|
|
|
//===--------------------------------------------------------------------===//
|
|
// Linalg tiling patterns.
|
|
//===--------------------------------------------------------------------===//
|
|
patterns.add<LinalgTilingPattern<MatmulOp>>(
|
|
ctx, LinalgTilingOptions().setTileSizes({2000, 3000, 4000}),
|
|
LinalgTransformationFilter(StringAttr::get(ctx, "MEM"),
|
|
StringAttr::get(ctx, "L3")));
|
|
patterns.add<LinalgTilingPattern<MatmulOp>>(
|
|
ctx, LinalgTilingOptions().setTileSizes({200, 300, 400}),
|
|
LinalgTransformationFilter(StringAttr::get(ctx, "L3"),
|
|
StringAttr::get(ctx, "L2")));
|
|
patterns.add<LinalgTilingPattern<MatmulOp>>(
|
|
ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
|
|
LinalgTransformationFilter(StringAttr::get(ctx, "L2"),
|
|
StringAttr::get(ctx, "L1")));
|
|
patterns.add<LinalgTilingPattern<MatmulOp>>(
|
|
ctx, LinalgTilingOptions().setTileSizes({2, 3, 4}),
|
|
LinalgTransformationFilter(StringAttr::get(ctx, "L1"),
|
|
StringAttr::get(ctx, "REG")));
|
|
|
|
patterns.add<LinalgTilingPattern<MatvecOp>>(
|
|
ctx,
|
|
LinalgTilingOptions().setTileSizes({5, 6}).setLoopType(
|
|
LinalgTilingLoopType::ParallelLoops),
|
|
LinalgTransformationFilter(ArrayRef<StringAttr>{},
|
|
StringAttr::get(ctx, "L1")));
|
|
|
|
patterns.add<LinalgTilingPattern<DotOp>>(
|
|
ctx, LinalgTilingOptions().setTileSizes(8000),
|
|
LinalgTransformationFilter(
|
|
ArrayRef<StringAttr>{StringAttr::get(ctx, "MEM"),
|
|
StringAttr::get(ctx, "L3"),
|
|
StringAttr::get(ctx, "L2")},
|
|
StringAttr::get(ctx, "REG")));
|
|
|
|
//===--------------------------------------------------------------------===//
|
|
// Linalg tiling and permutation patterns.
|
|
//===--------------------------------------------------------------------===//
|
|
patterns.add<LinalgTilingPattern<MatmulOp>>(
|
|
ctx,
|
|
LinalgTilingOptions()
|
|
.setTileSizes({2000, 3000, 4000})
|
|
.setInterchange({1, 2, 0}),
|
|
LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"),
|
|
StringAttr::get(ctx, "L2__with_perm__")));
|
|
patterns.add<LinalgTilingPattern<MatmulOp>>(
|
|
ctx,
|
|
LinalgTilingOptions()
|
|
.setTileSizes({200, 300, 400})
|
|
.setInterchange({1, 0, 2}),
|
|
LinalgTransformationFilter(StringAttr::get(ctx, "L2__with_perm__"),
|
|
StringAttr::get(ctx, "L1__with_perm__")));
|
|
patterns.add<LinalgTilingPattern<MatmulOp>>(
|
|
ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
|
|
LinalgTransformationFilter(StringAttr::get(ctx, "L1__with_perm__"),
|
|
StringAttr::get(ctx, "REG__with_perm__")));
|
|
|
|
patterns.add<LinalgTilingPattern<MatvecOp>>(
|
|
ctx, LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}),
|
|
LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"),
|
|
StringAttr::get(ctx, "L1__with_perm__")));
|
|
|
|
patterns.add<LinalgTilingPattern<MatmulOp>>(
|
|
ctx,
|
|
LinalgTilingOptions()
|
|
.setTileSizes({16, 8, 4})
|
|
.setInterchange({1, 2, 0})
|
|
.setLoopType(LinalgTilingLoopType::ParallelLoops),
|
|
LinalgTransformationFilter(
|
|
StringAttr::get(ctx, "par__with_perm__"),
|
|
StringAttr::get(ctx, "after_par__with_perm__")));
|
|
|
|
//===--------------------------------------------------------------------===//
|
|
// Linalg to loops patterns.
|
|
//===--------------------------------------------------------------------===//
|
|
patterns.add<LinalgLoweringPattern<DotOp>>(
|
|
ctx,
|
|
/*loweringType=*/LinalgLoweringType::Loops,
|
|
LinalgTransformationFilter(StringAttr::get(ctx, "REG")));
|
|
|
|
//===--------------------------------------------------------------------===//
|
|
// Linalg distribution patterns.
|
|
//===--------------------------------------------------------------------===//
|
|
LinalgLoopDistributionOptions distributionOptions;
|
|
|
|
//===--------------------------------------------------------------------===//
|
|
// Linalg to vector contraction patterns.
|
|
//===--------------------------------------------------------------------===//
|
|
patterns.add<LinalgVectorizationPattern>(
|
|
ctx, LinalgTransformationFilter(StringAttr::get(ctx, "VECTORIZE"))
|
|
.addOpFilter<MatmulOp, FillOp, CopyOp, GenericOp>());
|
|
|
|
//===--------------------------------------------------------------------===//
|
|
// Linalg generic interchange pattern.
|
|
//===--------------------------------------------------------------------===//
|
|
patterns.add<GenericOpInterchangePattern>(
|
|
ctx,
|
|
/*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
|
|
LinalgTransformationFilter(ArrayRef<StringAttr>{},
|
|
StringAttr::get(ctx, "PERMUTED")));
|
|
|
|
//===--------------------------------------------------------------------===//
|
|
// Linalg subview operands promotion.
|
|
//===--------------------------------------------------------------------===//
|
|
patterns.add<LinalgPromotionPattern<MatmulOp>>(
|
|
ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
|
|
LinalgTransformationFilter(StringAttr::get(ctx, "_promote_views_"),
|
|
StringAttr::get(ctx, "_views_promoted_")));
|
|
patterns.add<LinalgPromotionPattern<MatmulOp>>(
|
|
ctx,
|
|
LinalgPromotionOptions()
|
|
.setOperandsToPromote({0})
|
|
.setUseFullTileBuffersByDefault(true),
|
|
LinalgTransformationFilter(
|
|
StringAttr::get(ctx, "_promote_first_view_"),
|
|
StringAttr::get(ctx, "_first_view_promoted_")));
|
|
patterns.add<LinalgPromotionPattern<FillOp>>(
|
|
ctx,
|
|
LinalgPromotionOptions()
|
|
.setOperandsToPromote({1})
|
|
.setUseFullTileBuffers({false, true})
|
|
.setAlignment(32),
|
|
LinalgTransformationFilter(
|
|
StringAttr::get(ctx, "_promote_views_aligned_"),
|
|
StringAttr::get(ctx, "_views_aligned_promoted_")));
|
|
|
|
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
|
|
|
|
// Drop the marker.
|
|
funcOp.walk([](LinalgOp op) {
|
|
op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
|
|
});
|
|
}
|
|
|
|
static void fillL1TilingAndMatmulToVectorPatterns(
|
|
FuncOp funcOp, StringRef startMarker,
|
|
SmallVectorImpl<RewritePatternSet> &patternsVector) {
|
|
MLIRContext *ctx = funcOp.getContext();
|
|
patternsVector.emplace_back(
|
|
ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>(
|
|
ctx,
|
|
LinalgTilingOptions()
|
|
.setTileSizes({8, 12, 16})
|
|
.setInterchange({1, 0, 2}),
|
|
LinalgTransformationFilter(StringAttr::get(ctx, startMarker),
|
|
StringAttr::get(ctx, "L1"))));
|
|
|
|
patternsVector.emplace_back(
|
|
ctx,
|
|
std::make_unique<LinalgPromotionPattern<MatmulOp>>(
|
|
ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
|
|
LinalgTransformationFilter(StringAttr::get(ctx, "L1"),
|
|
StringAttr::get(ctx, "VEC"))));
|
|
|
|
patternsVector.emplace_back(
|
|
ctx, std::make_unique<LinalgVectorizationPattern>(
|
|
MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(),
|
|
LinalgTransformationFilter(StringAttr::get(ctx, "VEC"))));
|
|
patternsVector.back().add<LinalgVectorizationPattern>(
|
|
ctx, LinalgTransformationFilter().addFilter(
|
|
[](Operation *op) { return success(isa<FillOp, CopyOp>(op)); }));
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Test promotion callbacks
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Allocation call back
|
|
static Optional<Value> allocCallBackFn(OpBuilder &b, memref::SubViewOp subView,
|
|
ArrayRef<Value> boundingSubViewSize,
|
|
DataLayout &layout) {
|
|
SmallVector<int64_t, 4> shape(boundingSubViewSize.size(), -1);
|
|
return b
|
|
.create<memref::AllocOp>(
|
|
subView.getLoc(),
|
|
MemRefType::get(shape, subView.getType().getElementType(),
|
|
/*affineMapComposition =*/{}, 3),
|
|
boundingSubViewSize)
|
|
.getResult();
|
|
}
|
|
|
|
// Deallocation callback
|
|
static LogicalResult deallocCallBackFn(OpBuilder &b, Value buffer) {
|
|
b.create<memref::DeallocOp>(buffer.getLoc(), buffer);
|
|
return success();
|
|
}
|
|
|
|
// Copy in call back
|
|
static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst,
|
|
bool isOutput) {
|
|
auto floatType = src.getType().cast<MemRefType>().getElementType();
|
|
if (!floatType.isa<FloatType>())
|
|
return failure();
|
|
if (!isOutput) {
|
|
Value cst = b.create<arith::ConstantOp>(src.getLoc(),
|
|
FloatAttr::get(floatType, 42.0));
|
|
b.create<FillOp>(src.getLoc(), cst, dst);
|
|
}
|
|
b.create<CopyOp>(src.getLoc(), src, dst);
|
|
return success();
|
|
}
|
|
|
|
static void fillPromotionCallBackPatterns(MLIRContext *ctx,
|
|
RewritePatternSet &patterns) {
|
|
patterns.add<LinalgTilingPattern<MatmulOp>>(
|
|
ctx, LinalgTilingOptions().setTileSizes({16, 16, 16}),
|
|
LinalgTransformationFilter(StringAttr::get(ctx, "START"),
|
|
StringAttr::get(ctx, "PROMOTE")));
|
|
patterns.add<LinalgPromotionPattern<MatmulOp>>(
|
|
ctx,
|
|
LinalgPromotionOptions()
|
|
.setOperandsToPromote({0, 2})
|
|
.setUseFullTileBuffers({false, false})
|
|
.setAllocationDeallocationFns(allocCallBackFn, deallocCallBackFn)
|
|
.setCopyInOutFns(
|
|
[](OpBuilder &b, Value src, Value dst) -> LogicalResult {
|
|
return copyCallBackFn(b, src, dst, false);
|
|
},
|
|
[](OpBuilder &b, Value src, Value dst) -> LogicalResult {
|
|
return copyCallBackFn(b, src, dst, true);
|
|
}),
|
|
LinalgTransformationFilter(StringAttr::get(ctx, "PROMOTE")));
|
|
}
|
|
|
|
template <typename IdOp, typename NProcsOp>
|
|
static SmallVector<ProcInfo, 2>
|
|
getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) {
|
|
size_t count = std::min<size_t>(3, parallelLoopRanges.size());
|
|
SmallVector<ProcInfo, 2> procInfo(count);
|
|
const char *xyz[] = {"x", "y", "z"};
|
|
Type indexType = b.getIndexType();
|
|
for (unsigned i = 0; i < count; ++i) {
|
|
procInfo[count - 1 - i] = {
|
|
b.create<IdOp>(loc, indexType, b.getStringAttr(xyz[i])),
|
|
b.create<NProcsOp>(loc, indexType, b.getStringAttr(xyz[i]))};
|
|
}
|
|
return procInfo;
|
|
}
|
|
|
|
static void fillTileAndDistributePatterns(MLIRContext *context,
|
|
RewritePatternSet &patterns) {
|
|
{
|
|
LinalgLoopDistributionOptions cyclicNprocsEqNiters;
|
|
cyclicNprocsEqNiters.distributionMethod.resize(
|
|
2, DistributionMethod::CyclicNumProcsEqNumIters);
|
|
cyclicNprocsEqNiters.procInfo =
|
|
getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
|
|
patterns.add<LinalgTilingPattern<MatmulOp>>(
|
|
context,
|
|
LinalgTilingOptions()
|
|
.setTileSizes({8, 8, 4})
|
|
.setLoopType(LinalgTilingLoopType::ParallelLoops)
|
|
.setDistributionOptions(cyclicNprocsEqNiters),
|
|
LinalgTransformationFilter(
|
|
StringAttr::get(context, "distribute1"),
|
|
StringAttr::get(context, "after_distribute1")));
|
|
}
|
|
|
|
{
|
|
LinalgLoopDistributionOptions cyclicNprocsGeNiters;
|
|
cyclicNprocsGeNiters.distributionMethod.resize(
|
|
2, DistributionMethod::CyclicNumProcsGeNumIters);
|
|
cyclicNprocsGeNiters.procInfo =
|
|
getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
|
|
patterns.add<LinalgTilingPattern<MatmulOp>>(
|
|
context,
|
|
LinalgTilingOptions()
|
|
.setTileSizes({8, 8, 4})
|
|
.setLoopType(LinalgTilingLoopType::ParallelLoops)
|
|
.setDistributionOptions(cyclicNprocsGeNiters),
|
|
LinalgTransformationFilter(
|
|
StringAttr::get(context, "distribute2"),
|
|
StringAttr::get(context, "after_distribute2")));
|
|
}
|
|
|
|
{
|
|
LinalgLoopDistributionOptions cyclicNprocsDefault;
|
|
cyclicNprocsDefault.distributionMethod.resize(2,
|
|
DistributionMethod::Cyclic);
|
|
cyclicNprocsDefault.procInfo =
|
|
getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
|
|
patterns.add<LinalgTilingPattern<MatmulOp>>(
|
|
context,
|
|
LinalgTilingOptions()
|
|
.setTileSizes({8, 8, 4})
|
|
.setLoopType(LinalgTilingLoopType::ParallelLoops)
|
|
.setDistributionOptions(cyclicNprocsDefault),
|
|
LinalgTransformationFilter(
|
|
StringAttr::get(context, "distribute3"),
|
|
StringAttr::get(context, "after_distribute3")));
|
|
}
|
|
|
|
{
|
|
LinalgLoopDistributionOptions cyclicNprocsMixed1;
|
|
cyclicNprocsMixed1.distributionMethod = {
|
|
DistributionMethod::CyclicNumProcsEqNumIters,
|
|
DistributionMethod::CyclicNumProcsGeNumIters};
|
|
cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
|
|
patterns.add<LinalgTilingPattern<MatmulOp>>(
|
|
context,
|
|
LinalgTilingOptions()
|
|
.setTileSizes({8, 8, 4})
|
|
.setLoopType(LinalgTilingLoopType::ParallelLoops)
|
|
.setDistributionOptions(cyclicNprocsMixed1),
|
|
LinalgTransformationFilter(
|
|
StringAttr::get(context, "distribute4"),
|
|
StringAttr::get(context, "after_distribute4")));
|
|
}
|
|
|
|
{
|
|
LinalgLoopDistributionOptions cyclicNprocsMixed2;
|
|
cyclicNprocsMixed2.distributionMethod = {
|
|
DistributionMethod::CyclicNumProcsGeNumIters,
|
|
DistributionMethod::Cyclic};
|
|
cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
|
|
patterns.add<LinalgTilingPattern<MatmulOp>>(
|
|
context,
|
|
LinalgTilingOptions()
|
|
.setTileSizes({8, 8, 4})
|
|
.setLoopType(LinalgTilingLoopType::ParallelLoops)
|
|
.setDistributionOptions(cyclicNprocsMixed2),
|
|
LinalgTransformationFilter(
|
|
StringAttr::get(context, "distribute5"),
|
|
StringAttr::get(context, "after_distribute5")));
|
|
}
|
|
|
|
{
|
|
LinalgLoopDistributionOptions cyclicNprocsMixed3;
|
|
cyclicNprocsMixed3.distributionMethod = {
|
|
DistributionMethod::Cyclic,
|
|
DistributionMethod::CyclicNumProcsEqNumIters};
|
|
cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
|
|
|
|
patterns.add<LinalgTilingPattern<MatmulOp>>(
|
|
context,
|
|
LinalgTilingOptions()
|
|
.setTileSizes({8, 8, 4})
|
|
.setLoopType(LinalgTilingLoopType::ParallelLoops)
|
|
.setDistributionOptions(cyclicNprocsMixed3),
|
|
LinalgTransformationFilter(
|
|
StringAttr::get(context, "distribute6"),
|
|
StringAttr::get(context, "after_distribute6")));
|
|
}
|
|
|
|
{
|
|
LinalgLoopDistributionOptions cyclicNprocsEqNiters;
|
|
cyclicNprocsEqNiters.distributionMethod.resize(2,
|
|
DistributionMethod::Cyclic);
|
|
cyclicNprocsEqNiters.procInfo =
|
|
getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
|
|
patterns.add<LinalgTilingPattern<MatmulOp>>(
|
|
context,
|
|
LinalgTilingOptions()
|
|
.setTileSizes({8, 8, 4})
|
|
.setLoopType(LinalgTilingLoopType::Loops)
|
|
.setDistributionOptions(cyclicNprocsEqNiters),
|
|
LinalgTransformationFilter(
|
|
StringAttr::get(context, "tensors_distribute1"),
|
|
StringAttr::get(context, "tensors_after_distribute1")));
|
|
}
|
|
}
|
|
|
|
static void
|
|
applyMatmulToVectorPatterns(FuncOp funcOp,
|
|
bool testMatmulToVectorPatterns1dTiling,
|
|
bool testMatmulToVectorPatterns2dTiling) {
|
|
MLIRContext *ctx = funcOp.getContext();
|
|
SmallVector<RewritePatternSet, 4> stage1Patterns;
|
|
if (testMatmulToVectorPatterns1dTiling) {
|
|
fillL1TilingAndMatmulToVectorPatterns(funcOp, "START", stage1Patterns);
|
|
} else if (testMatmulToVectorPatterns2dTiling) {
|
|
stage1Patterns.emplace_back(
|
|
ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>(
|
|
ctx,
|
|
LinalgTilingOptions()
|
|
.setTileSizes({768, 264, 768})
|
|
.setInterchange({1, 2, 0}),
|
|
LinalgTransformationFilter(StringAttr::get(ctx, "START"),
|
|
StringAttr::get(ctx, "L2"))));
|
|
fillL1TilingAndMatmulToVectorPatterns(funcOp, "L2", stage1Patterns);
|
|
}
|
|
{
|
|
// Canonicalization patterns
|
|
RewritePatternSet canonicalizationPatterns(funcOp.getContext());
|
|
vector::populateVectorTransferPermutationMapLoweringPatterns(
|
|
canonicalizationPatterns);
|
|
vector::populateVectorReductionToContractPatterns(canonicalizationPatterns);
|
|
stage1Patterns.push_back(std::move(canonicalizationPatterns));
|
|
}
|
|
SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns;
|
|
llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns));
|
|
FrozenRewritePatternSet stage2Patterns =
|
|
getLinalgTilingCanonicalizationPatterns(ctx);
|
|
(void)applyStagedPatterns(funcOp, frozenStage1Patterns,
|
|
std::move(stage2Patterns));
|
|
}
|
|
|
|
static void applyVectorTransferForwardingPatterns(FuncOp funcOp) {
|
|
RewritePatternSet forwardPattern(funcOp.getContext());
|
|
forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext());
|
|
forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext());
|
|
(void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern));
|
|
}
|
|
|
|
static void applyLinalgToVectorPatterns(FuncOp funcOp) {
|
|
RewritePatternSet patterns(funcOp.getContext());
|
|
patterns.add<LinalgVectorizationPattern>(
|
|
funcOp.getContext(),
|
|
LinalgTransformationFilter()
|
|
.addOpFilter<ContractionOpInterface, FillOp, CopyOp, GenericOp>());
|
|
populatePadTensorOpVectorizationPatterns(patterns);
|
|
populateConvolutionVectorizationPatterns(patterns);
|
|
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
|
|
}
|
|
|
|
static void applyPadTensorToGenericPatterns(FuncOp funcOp) {
|
|
RewritePatternSet patterns(funcOp.getContext());
|
|
patterns.add<PadTensorOpTransformationPattern>(funcOp.getContext());
|
|
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
|
|
}
|
|
|
|
static void applyGeneralizePadTensorPatterns(FuncOp funcOp) {
|
|
RewritePatternSet patterns(funcOp.getContext());
|
|
patterns.add<GeneralizePadTensorOpPattern>(funcOp.getContext());
|
|
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
|
|
}
|
|
|
|
static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) {
|
|
RewritePatternSet patterns(funcOp.getContext());
|
|
patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext());
|
|
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
|
|
}
|
|
|
|
static void applyTilePattern(FuncOp funcOp, std::string loopType,
|
|
ArrayRef<int64_t> tileSizes,
|
|
ArrayRef<int64_t> peeledLoops,
|
|
bool scalarizeDynamicDims) {
|
|
MLIRContext *context = funcOp.getContext();
|
|
RewritePatternSet tilingPattern(context);
|
|
LinalgTilingLoopType type =
|
|
llvm::StringSwitch<LinalgTilingLoopType>(loopType)
|
|
.Case("for", LinalgTilingLoopType::Loops)
|
|
.Case("affine", LinalgTilingLoopType::AffineLoops)
|
|
.Case("parallel", LinalgTilingLoopType::ParallelLoops)
|
|
.Case("tiled_loop", LinalgTilingLoopType::TiledLoops);
|
|
auto linalgTilingOptions = linalg::LinalgTilingOptions()
|
|
.setPeeledLoops(peeledLoops)
|
|
.setLoopType(type);
|
|
if (scalarizeDynamicDims) {
|
|
linalgTilingOptions.scalarizeDynamicDims();
|
|
assert(tileSizes.empty() &&
|
|
"tileSizes and scalarizeDynamicDims is mutually exclusive");
|
|
} else {
|
|
linalgTilingOptions.setTileSizes(tileSizes);
|
|
}
|
|
tilingPattern.add<linalg::LinalgTilingPattern<linalg::MatmulOp>,
|
|
linalg::LinalgTilingPattern<linalg::GenericOp>>(
|
|
context, linalgTilingOptions,
|
|
linalg::LinalgTransformationFilter(StringAttr::get(context, "tile")));
|
|
(void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
|
|
}
|
|
|
|
static constexpr char kPeeledLoopsLabel[] = "__peeled_loops__";
|
|
static constexpr char kPartialIterationLabel[] = "__partial_iteration__";
|
|
|
|
namespace {
|
|
/// Peel TiledLoopOps, i.e., split them into two loops: One loop where the
|
|
/// `idx`-th loop contains only "full" iterations and a second loop for the
|
|
/// remaining partial iteration (if any).
|
|
struct TiledLoopPeelingPattern : public OpRewritePattern<TiledLoopOp> {
|
|
TiledLoopPeelingPattern(MLIRContext *ctx, int64_t idx, bool skipPartial)
|
|
: OpRewritePattern<TiledLoopOp>(ctx), idx(idx), skipPartial(skipPartial) {
|
|
}
|
|
|
|
LogicalResult matchAndRewrite(TiledLoopOp loopOp,
|
|
PatternRewriter &rewriter) const override {
|
|
SmallVector<int64_t> peeledLoops;
|
|
if (loopOp->hasAttr(kPeeledLoopsLabel)) {
|
|
auto attr = loopOp->getAttr(kPeeledLoopsLabel).cast<ArrayAttr>();
|
|
peeledLoops =
|
|
llvm::to_vector<4>(llvm::map_range(attr, [](Attribute attr) {
|
|
return attr.cast<IntegerAttr>().getInt();
|
|
}));
|
|
// Check if the loop was already peeled.
|
|
if (llvm::find(peeledLoops, idx) != peeledLoops.end())
|
|
return failure();
|
|
}
|
|
if (skipPartial && loopOp->hasAttr(kPartialIterationLabel))
|
|
// No peeling of loop nests with a partial iteration.
|
|
return failure();
|
|
|
|
if (static_cast<int64_t>(loopOp.iterator_types().size()) <= idx)
|
|
return failure();
|
|
|
|
// Peel loop and canonicalize.
|
|
TiledLoopOp result;
|
|
if (failed(linalg::peelAndCanonicalizeTiledLoop(rewriter, loopOp, idx,
|
|
result)))
|
|
return failure();
|
|
|
|
// Apply label, so that the same loop is not rewritten a second time.
|
|
peeledLoops.push_back(idx);
|
|
rewriter.updateRootInPlace(loopOp, [&]() {
|
|
loopOp->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops));
|
|
});
|
|
result->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops));
|
|
result->setAttr(kPartialIterationLabel, rewriter.getUnitAttr());
|
|
|
|
return success();
|
|
}
|
|
|
|
/// Index of loop to peel.
|
|
int64_t idx;
|
|
|
|
/// If set to true, do not peel TiledLoopOps with a partial iteration.
|
|
bool skipPartial;
|
|
};
|
|
} // namespace
|
|
|
|
static void applyTiledLoopPeelingPattern(FuncOp funcOp,
|
|
ArrayRef<unsigned> loops,
|
|
bool skipPartial) {
|
|
MLIRContext *ctx = funcOp.getContext();
|
|
RewritePatternSet patterns(ctx);
|
|
for (unsigned idx : loops)
|
|
patterns.add<TiledLoopPeelingPattern>(ctx, idx, skipPartial);
|
|
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
|
|
|
|
// Drop the markers.
|
|
funcOp.walk([](TiledLoopOp op) {
|
|
op->removeAttr(kPeeledLoopsLabel);
|
|
op->removeAttr(kPartialIterationLabel);
|
|
});
|
|
}
|
|
|
|
/// Apply transformations specified as patterns.
|
|
void TestLinalgTransforms::runOnFunction() {
|
|
auto lambda = [&](void *) {
|
|
getFunction().walk([](LinalgOp op) {
|
|
op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
|
|
});
|
|
};
|
|
std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda};
|
|
|
|
if (testPromotionOptions) {
|
|
RewritePatternSet patterns(&getContext());
|
|
fillPromotionCallBackPatterns(&getContext(), patterns);
|
|
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
|
|
return;
|
|
}
|
|
if (testTileAndDistributionOptions) {
|
|
RewritePatternSet patterns(&getContext());
|
|
fillTileAndDistributePatterns(&getContext(), patterns);
|
|
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
|
|
return;
|
|
}
|
|
if (testPatterns)
|
|
return applyPatterns(getFunction());
|
|
if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling)
|
|
return applyMatmulToVectorPatterns(getFunction(),
|
|
testMatmulToVectorPatterns1dTiling,
|
|
testMatmulToVectorPatterns2dTiling);
|
|
if (testVectorTransferForwardingPatterns)
|
|
return applyVectorTransferForwardingPatterns(getFunction());
|
|
if (testGenericToVectorPattern)
|
|
return applyLinalgToVectorPatterns(getFunction());
|
|
if (testTransformPadTensor)
|
|
return applyPadTensorToGenericPatterns(getFunction());
|
|
if (testGeneralizePadTensor)
|
|
return applyGeneralizePadTensorPatterns(getFunction());
|
|
if (testSwapSubTensorPadTensor)
|
|
return applyExtractSliceOfPadTensorSwapPattern(getFunction());
|
|
if (testTiledLoopPeeling.hasValue())
|
|
return applyTiledLoopPeelingPattern(getFunction(), testTiledLoopPeeling,
|
|
skipPartial);
|
|
if (testTilePattern)
|
|
return applyTilePattern(getFunction(), loopType, tileSizes, peeledLoops,
|
|
/*scalarizeDynamicDims=*/false);
|
|
if (testTileScalarizeDynamicDims)
|
|
return applyTilePattern(getFunction(), loopType, tileSizes,
|
|
/*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true);
|
|
}
|
|
|
|
namespace mlir {
|
|
namespace test {
|
|
void registerTestLinalgTransforms() {
|
|
PassRegistration<TestLinalgTransforms>();
|
|
}
|
|
} // namespace test
|
|
} // namespace mlir
|