Files
clang-p2996/mlir/test/lib/Transforms/TestVectorTransforms.cpp
Thomas Raoux 6ad31c0f4a [mlir][vector] Support N-D vector in InsertMap/ExtractMap op
Support multi-dimension vector for InsertMap/ExtractMap op and update the
transformations. Currently the relation between IDs and dimension is implicitly
deduced from the types. We can then calculate an AffineMap based on it. In the
future the AffineMap could be part of the operation itself.

Differential Revision: https://reviews.llvm.org/D90995
2020-11-13 12:40:17 -08:00

354 lines
15 KiB
C++

//===- TestVectorToVectorConversion.cpp - Test VectorTransfers lowering ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include <type_traits>
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Dialect/Vector/VectorTransforms.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
using namespace mlir::vector;
namespace {
struct TestVectorToVectorConversion
: public PassWrapper<TestVectorToVectorConversion, FunctionPass> {
void runOnFunction() override {
OwningRewritePatternList patterns;
auto *ctx = &getContext();
patterns.insert<UnrollVectorPattern<AddFOp>>(
ctx, UnrollVectorOptions().setNativeShape(ArrayRef<int64_t>{2, 2}));
patterns.insert<UnrollVectorPattern<vector::ContractionOp>>(
ctx, UnrollVectorOptions().setNativeShape(ArrayRef<int64_t>{2, 2, 2}));
populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
populateVectorToVectorTransformationPatterns(patterns, ctx);
applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
};
struct TestVectorSlicesConversion
: public PassWrapper<TestVectorSlicesConversion, FunctionPass> {
void runOnFunction() override {
OwningRewritePatternList patterns;
populateVectorSlicesLoweringPatterns(patterns, &getContext());
applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
};
struct TestVectorContractionConversion
: public PassWrapper<TestVectorContractionConversion, FunctionPass> {
TestVectorContractionConversion() = default;
TestVectorContractionConversion(const TestVectorContractionConversion &pass) {
}
Option<bool> lowerToFlatMatrix{
*this, "vector-lower-matrix-intrinsics",
llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"),
llvm::cl::init(false)};
Option<bool> lowerToFlatTranspose{
*this, "vector-flat-transpose",
llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"),
llvm::cl::init(false)};
Option<bool> lowerToOuterProduct{
*this, "vector-outerproduct",
llvm::cl::desc("Lower vector.contract to vector.outerproduct"),
llvm::cl::init(false)};
Option<bool> lowerToFilterOuterProduct{
*this, "vector-filter-outerproduct",
llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for "
"vectors of size 4."),
llvm::cl::init(false)};
void runOnFunction() override {
OwningRewritePatternList patterns;
// Test on one pattern in isolation.
if (lowerToOuterProduct) {
VectorContractLowering lowering = VectorContractLowering::OuterProduct;
VectorTransformsOptions options{lowering};
patterns.insert<ContractionOpToOuterProductOpLowering>(options,
&getContext());
applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
return;
}
// Test on one pattern in isolation.
if (lowerToFilterOuterProduct) {
VectorContractLowering lowering = VectorContractLowering::OuterProduct;
VectorTransformsOptions options{lowering};
patterns.insert<ContractionOpToOuterProductOpLowering>(
options, &getContext(), [](vector::ContractionOp op) {
// Only lowers vector.contract where the lhs as a type vector<MxNx?>
// where M is not 4.
if (op.getRhsType().getShape()[0] == 4)
return failure();
return success();
});
applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
return;
}
// Test on all contract lowering patterns.
VectorContractLowering contractLowering = VectorContractLowering::Dot;
if (lowerToFlatMatrix)
contractLowering = VectorContractLowering::Matmul;
VectorTransposeLowering transposeLowering =
VectorTransposeLowering::EltWise;
if (lowerToFlatTranspose)
transposeLowering = VectorTransposeLowering::Flat;
VectorTransformsOptions options{contractLowering, transposeLowering};
populateVectorContractLoweringPatterns(patterns, &getContext(), options);
applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
};
struct TestVectorUnrollingPatterns
: public PassWrapper<TestVectorUnrollingPatterns, FunctionPass> {
TestVectorUnrollingPatterns() = default;
TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass) {}
void runOnFunction() override {
MLIRContext *ctx = &getContext();
OwningRewritePatternList patterns;
patterns.insert<UnrollVectorPattern<AddFOp>>(
ctx, UnrollVectorOptions().setNativeShape(ArrayRef<int64_t>{2, 2}));
if (unrollBasedOnType) {
UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
[](Operation *op) -> Optional<SmallVector<int64_t, 4>> {
vector::ContractionOp contractOp = cast<vector::ContractionOp>(op);
SmallVector<int64_t, 4> nativeShape = {4, 4, 2};
if (auto floatType = contractOp.getLhsType()
.getElementType()
.dyn_cast<FloatType>()) {
if (floatType.getWidth() == 16) {
nativeShape[2] = 4;
}
}
return nativeShape;
};
patterns.insert<UnrollVectorPattern<vector::ContractionOp>>(
ctx, UnrollVectorOptions().setNativeShapeFn(nativeShapeFn));
} else {
patterns.insert<UnrollVectorPattern<vector::ContractionOp>>(
ctx,
UnrollVectorOptions().setNativeShape(ArrayRef<int64_t>{2, 2, 2}));
}
populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
populateVectorToVectorTransformationPatterns(patterns, ctx);
applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
Option<bool> unrollBasedOnType{
*this, "unroll-based-on-type",
llvm::cl::desc("Set the unroll factor based on type of the operation"),
llvm::cl::init(false)};
};
struct TestVectorDistributePatterns
: public PassWrapper<TestVectorDistributePatterns, FunctionPass> {
TestVectorDistributePatterns() = default;
TestVectorDistributePatterns(const TestVectorDistributePatterns &pass) {}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<VectorDialect>();
registry.insert<AffineDialect>();
}
ListOption<int32_t> multiplicity{
*this, "distribution-multiplicity", llvm::cl::MiscFlags::CommaSeparated,
llvm::cl::desc("Set the multiplicity used for distributing vector")};
void runOnFunction() override {
MLIRContext *ctx = &getContext();
OwningRewritePatternList patterns;
FuncOp func = getFunction();
func.walk([&](AddFOp op) {
OpBuilder builder(op);
if (auto vecType = op.getType().dyn_cast<VectorType>()) {
SmallVector<int64_t, 2> mul;
SmallVector<AffineExpr, 2> perm;
SmallVector<Value, 2> ids;
unsigned count = 0;
// Remove the multiplicity of 1 and calculate the affine map based on
// the multiplicity.
SmallVector<int32_t, 4> m(multiplicity.begin(), multiplicity.end());
for (unsigned i = 0, e = vecType.getRank(); i < e; i++) {
if (i < m.size() && m[i] != 1 && vecType.getDimSize(i) % m[i] == 0) {
mul.push_back(m[i]);
ids.push_back(func.getArgument(count++));
perm.push_back(getAffineDimExpr(i, ctx));
}
}
auto map = AffineMap::get(op.getType().cast<VectorType>().getRank(), 0,
perm, ctx);
Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
builder, op.getOperation(), ids, mul, map);
if (ops.hasValue()) {
SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
op.getResult().replaceAllUsesExcept(ops->insert.getResult(),
extractOp);
}
}
});
patterns.insert<PointwiseExtractPattern>(ctx);
populateVectorToVectorTransformationPatterns(patterns, ctx);
applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
};
struct TestVectorToLoopPatterns
: public PassWrapper<TestVectorToLoopPatterns, FunctionPass> {
TestVectorToLoopPatterns() = default;
TestVectorToLoopPatterns(const TestVectorToLoopPatterns &pass) {}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<VectorDialect>();
registry.insert<AffineDialect>();
}
Option<int32_t> multiplicity{
*this, "distribution-multiplicity",
llvm::cl::desc("Set the multiplicity used for distributing vector"),
llvm::cl::init(32)};
void runOnFunction() override {
MLIRContext *ctx = &getContext();
OwningRewritePatternList patterns;
FuncOp func = getFunction();
func.walk([&](AddFOp op) {
// Check that the operation type can be broken down into a loop.
VectorType type = op.getType().dyn_cast<VectorType>();
if (!type || type.getRank() != 1 ||
type.getNumElements() % multiplicity != 0)
return mlir::WalkResult::advance();
auto filterAlloc = [](Operation *op) {
if (isa<ConstantOp, AllocOp, CallOp>(op))
return false;
return true;
};
auto dependentOps = getSlice(op, filterAlloc);
// Create a loop and move instructions from the Op slice into the loop.
OpBuilder builder(op);
auto zero = builder.create<ConstantOp>(
op.getLoc(), builder.getIndexType(),
builder.getIntegerAttr(builder.getIndexType(), 0));
auto one = builder.create<ConstantOp>(
op.getLoc(), builder.getIndexType(),
builder.getIntegerAttr(builder.getIndexType(), 1));
auto numIter = builder.create<ConstantOp>(
op.getLoc(), builder.getIndexType(),
builder.getIntegerAttr(builder.getIndexType(), multiplicity));
auto forOp = builder.create<scf::ForOp>(op.getLoc(), zero, numIter, one);
for (Operation *it : dependentOps) {
it->moveBefore(forOp.getBody()->getTerminator());
}
auto map = AffineMap::getMultiDimIdentityMap(1, ctx);
// break up the original op and let the patterns propagate.
Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
builder, op.getOperation(), {forOp.getInductionVar()}, {multiplicity},
map);
if (ops.hasValue()) {
SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp);
}
return mlir::WalkResult::interrupt();
});
patterns.insert<PointwiseExtractPattern>(ctx);
populateVectorToVectorTransformationPatterns(patterns, ctx);
applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
};
struct TestVectorTransferUnrollingPatterns
: public PassWrapper<TestVectorTransferUnrollingPatterns, FunctionPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AffineDialect>();
}
void runOnFunction() override {
MLIRContext *ctx = &getContext();
OwningRewritePatternList patterns;
patterns.insert<UnrollVectorPattern<vector::TransferReadOp>>(
ctx, UnrollVectorOptions().setNativeShape(ArrayRef<int64_t>{2, 2}));
patterns.insert<UnrollVectorPattern<vector::TransferWriteOp>>(
ctx, UnrollVectorOptions().setNativeShape(ArrayRef<int64_t>{2, 2}));
populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
populateVectorToVectorTransformationPatterns(patterns, ctx);
applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
};
struct TestVectorTransferFullPartialSplitPatterns
: public PassWrapper<TestVectorTransferFullPartialSplitPatterns,
FunctionPass> {
TestVectorTransferFullPartialSplitPatterns() = default;
TestVectorTransferFullPartialSplitPatterns(
const TestVectorTransferFullPartialSplitPatterns &pass) {}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AffineDialect, linalg::LinalgDialect, scf::SCFDialect>();
}
Option<bool> useLinalgOps{
*this, "use-linalg-copy",
llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + "
"linalg.copy operations."),
llvm::cl::init(false)};
void runOnFunction() override {
MLIRContext *ctx = &getContext();
OwningRewritePatternList patterns;
VectorTransformsOptions options;
if (useLinalgOps)
options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy);
else
options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer);
patterns.insert<VectorTransferFullPartialRewriter>(ctx, options);
applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
};
} // end anonymous namespace
namespace mlir {
namespace test {
void registerTestVectorConversions() {
PassRegistration<TestVectorToVectorConversion> vectorToVectorPass(
"test-vector-to-vector-conversion",
"Test conversion patterns between ops in the vector dialect");
PassRegistration<TestVectorSlicesConversion> slicesPass(
"test-vector-slices-conversion",
"Test conversion patterns that lower slices ops in the vector dialect");
PassRegistration<TestVectorContractionConversion> contractionPass(
"test-vector-contraction-conversion",
"Test conversion patterns that lower contract ops in the vector dialect");
PassRegistration<TestVectorUnrollingPatterns> contractionUnrollingPass(
"test-vector-unrolling-patterns",
"Test conversion patterns to unroll contract ops in the vector dialect");
PassRegistration<TestVectorTransferUnrollingPatterns> transferOpUnrollingPass(
"test-vector-transfer-unrolling-patterns",
"Test conversion patterns to unroll transfer ops in the vector dialect");
PassRegistration<TestVectorTransferFullPartialSplitPatterns>
vectorTransformFullPartialPass("test-vector-transfer-full-partial-split",
"Test conversion patterns to split "
"transfer ops via scf.if + linalg ops");
PassRegistration<TestVectorDistributePatterns> distributePass(
"test-vector-distribute-patterns",
"Test conversion patterns to distribute vector ops in the vector "
"dialect");
PassRegistration<TestVectorToLoopPatterns> vectorToForLoop(
"test-vector-to-forloop",
"Test conversion patterns to break up a vector op into a for loop");
}
} // namespace test
} // namespace mlir