When a value used in the forOp is defined outside the region but within the parent warpOp we need to return and distribute the value to pass it to new operations created within the loop. Also simplify the lambda interface. Differential Revision: https://reviews.llvm.org/D137146
824 lines
31 KiB
C++
824 lines
31 KiB
C++
//===- TestVectorTransforms.cpp - Test Vector transforms and lowerings ----===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include <type_traits>
|
|
|
|
#include "mlir/Analysis/SliceAnalysis.h"
|
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
|
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
|
#include "mlir/Dialect/Linalg/Passes.h"
|
|
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
|
#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
|
|
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Pass/PassManager.h"
|
|
#include "mlir/Support/LLVM.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::linalg;
|
|
using namespace mlir::vector;
|
|
|
|
namespace {
|
|
|
|
struct TestVectorToVectorLowering
|
|
: public PassWrapper<TestVectorToVectorLowering,
|
|
OperationPass<func::FuncOp>> {
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorToVectorLowering)
|
|
|
|
TestVectorToVectorLowering() = default;
|
|
TestVectorToVectorLowering(const TestVectorToVectorLowering &pass)
|
|
: PassWrapper(pass) {}
|
|
StringRef getArgument() const final {
|
|
return "test-vector-to-vector-lowering";
|
|
}
|
|
StringRef getDescription() const final {
|
|
return "Test lowering patterns between ops in the vector dialect";
|
|
}
|
|
|
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
|
registry.insert<AffineDialect>();
|
|
}
|
|
|
|
Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"),
|
|
llvm::cl::init(false)};
|
|
|
|
void runOnOperation() override {
|
|
auto *ctx = &getContext();
|
|
RewritePatternSet patterns(ctx);
|
|
if (unroll) {
|
|
populateVectorUnrollPatterns(
|
|
patterns,
|
|
UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint(
|
|
filter));
|
|
}
|
|
populateVectorToVectorCanonicalizationPatterns(patterns);
|
|
populateBubbleVectorBitCastOpPatterns(patterns);
|
|
populateCastAwayVectorLeadingOneDimPatterns(patterns);
|
|
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
|
}
|
|
|
|
private:
|
|
// Return the target shape based on op type.
|
|
static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) {
|
|
if (isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp>(op))
|
|
return SmallVector<int64_t, 4>(2, 2);
|
|
if (isa<vector::ContractionOp>(op))
|
|
return SmallVector<int64_t, 4>(3, 2);
|
|
// For transfer ops, just propagate the shape coming from
|
|
// InsertStridedSlices/ExtractStridedSlices.
|
|
if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) {
|
|
VectorType dstVec;
|
|
for (Operation *users : readOp->getUsers()) {
|
|
auto extract = dyn_cast<ExtractStridedSliceOp>(users);
|
|
if (!extract)
|
|
return llvm::None;
|
|
auto vecType = extract.getResult().getType().cast<VectorType>();
|
|
if (dstVec && dstVec != vecType)
|
|
return llvm::None;
|
|
dstVec = vecType;
|
|
}
|
|
return SmallVector<int64_t, 4>(dstVec.getShape().begin(),
|
|
dstVec.getShape().end());
|
|
}
|
|
if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
|
|
auto insert = writeOp.getVector().getDefiningOp<InsertStridedSliceOp>();
|
|
if (!insert)
|
|
return llvm::None;
|
|
ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape();
|
|
return SmallVector<int64_t, 4>(shape.begin(), shape.end());
|
|
}
|
|
return llvm::None;
|
|
}
|
|
|
|
static LogicalResult filter(Operation *op) {
|
|
return success(isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp,
|
|
ContractionOp, TransferReadOp, TransferWriteOp>(op));
|
|
}
|
|
};
|
|
|
|
struct TestVectorContractionLowering
|
|
: public PassWrapper<TestVectorContractionLowering,
|
|
OperationPass<func::FuncOp>> {
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorContractionLowering)
|
|
|
|
StringRef getArgument() const final {
|
|
return "test-vector-contraction-lowering";
|
|
}
|
|
StringRef getDescription() const final {
|
|
return "Test lowering patterns that lower contract ops in the vector "
|
|
"dialect";
|
|
}
|
|
TestVectorContractionLowering() = default;
|
|
TestVectorContractionLowering(const TestVectorContractionLowering &pass)
|
|
: PassWrapper(pass) {}
|
|
|
|
Option<bool> lowerToFlatMatrix{
|
|
*this, "vector-lower-matrix-intrinsics",
|
|
llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"),
|
|
llvm::cl::init(false)};
|
|
Option<bool> lowerToOuterProduct{
|
|
*this, "vector-outerproduct",
|
|
llvm::cl::desc("Lower vector.contract to vector.outerproduct"),
|
|
llvm::cl::init(false)};
|
|
Option<bool> lowerToFilterOuterProduct{
|
|
*this, "vector-filter-outerproduct",
|
|
llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for "
|
|
"vectors of size 4."),
|
|
llvm::cl::init(false)};
|
|
Option<bool> lowerToParallelArith{
|
|
*this, "vector-parallel-arith",
|
|
llvm::cl::desc("Lower vector.contract to elementwise vector ops."),
|
|
llvm::cl::init(false)};
|
|
|
|
void runOnOperation() override {
|
|
RewritePatternSet patterns(&getContext());
|
|
|
|
// Test on one pattern in isolation.
|
|
if (lowerToOuterProduct) {
|
|
VectorContractLowering lowering = VectorContractLowering::OuterProduct;
|
|
VectorTransformsOptions options{lowering};
|
|
patterns.add<ContractionOpToOuterProductOpLowering>(options,
|
|
&getContext());
|
|
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
|
return;
|
|
}
|
|
|
|
// Test on one pattern in isolation.
|
|
if (lowerToFilterOuterProduct) {
|
|
VectorContractLowering lowering = VectorContractLowering::OuterProduct;
|
|
VectorTransformsOptions options{lowering};
|
|
patterns.add<ContractionOpToOuterProductOpLowering>(
|
|
options, &getContext(), /*benefit=*/1, [](vector::ContractionOp op) {
|
|
// Only lowers vector.contract where the lhs as a type vector<MxNx?>
|
|
// where M is not 4.
|
|
if (op.getRhsType().getShape()[0] == 4)
|
|
return failure();
|
|
return success();
|
|
});
|
|
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
|
return;
|
|
}
|
|
|
|
if (lowerToParallelArith) {
|
|
vector::populateVectorContractLoweringPatterns(
|
|
patterns,
|
|
vector::VectorTransformsOptions().setVectorTransformsOptions(
|
|
vector::VectorContractLowering::ParallelArith));
|
|
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
|
return;
|
|
}
|
|
|
|
// Test on all contract lowering patterns.
|
|
VectorContractLowering contractLowering = VectorContractLowering::Dot;
|
|
if (lowerToFlatMatrix)
|
|
contractLowering = VectorContractLowering::Matmul;
|
|
VectorMultiReductionLowering vectorMultiReductionLowering =
|
|
VectorMultiReductionLowering::InnerParallel;
|
|
VectorTransformsOptions options{contractLowering,
|
|
vectorMultiReductionLowering,
|
|
VectorTransposeLowering()};
|
|
populateVectorBroadcastLoweringPatterns(patterns);
|
|
populateVectorContractLoweringPatterns(patterns, options);
|
|
populateVectorMaskOpLoweringPatterns(patterns);
|
|
populateVectorShapeCastLoweringPatterns(patterns);
|
|
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
|
}
|
|
};
|
|
|
|
struct TestVectorTransposeLowering
|
|
: public PassWrapper<TestVectorTransposeLowering,
|
|
OperationPass<func::FuncOp>> {
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransposeLowering)
|
|
|
|
StringRef getArgument() const final {
|
|
return "test-vector-transpose-lowering";
|
|
}
|
|
StringRef getDescription() const final {
|
|
return "Test lowering patterns that lower contract ops in the vector "
|
|
"dialect";
|
|
}
|
|
TestVectorTransposeLowering() = default;
|
|
TestVectorTransposeLowering(const TestVectorTransposeLowering &pass)
|
|
: PassWrapper(pass) {}
|
|
|
|
Option<bool> lowerToEltwise{
|
|
*this, "eltwise",
|
|
llvm::cl::desc("Lower 2-D vector.transpose to eltwise insert/extract"),
|
|
llvm::cl::init(false)};
|
|
Option<bool> lowerToFlatTranspose{
|
|
*this, "flat",
|
|
llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"),
|
|
llvm::cl::init(false)};
|
|
Option<bool> lowerToShuffleTranspose{
|
|
*this, "shuffle",
|
|
llvm::cl::desc("Lower 2-D vector.transpose to shape_cast + shuffle"),
|
|
llvm::cl::init(false)};
|
|
Option<bool> lowerToAvx2{
|
|
*this, "avx2",
|
|
llvm::cl::desc("Lower vector.transpose to avx2-specific patterns"),
|
|
llvm::cl::init(false)};
|
|
|
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
|
registry.insert<LLVM::LLVMDialect>();
|
|
}
|
|
|
|
void runOnOperation() override {
|
|
func::FuncOp funcOp = getOperation();
|
|
MLIRContext *context = funcOp.getContext();
|
|
RewritePatternSet patterns(context);
|
|
|
|
vector::VectorTransformsOptions vectorTransformOptions;
|
|
if (lowerToEltwise) {
|
|
vectorTransformOptions =
|
|
vectorTransformOptions.setVectorTransposeLowering(
|
|
VectorTransposeLowering::EltWise);
|
|
}
|
|
if (lowerToFlatTranspose) {
|
|
vectorTransformOptions =
|
|
vectorTransformOptions.setVectorTransposeLowering(
|
|
VectorTransposeLowering::Flat);
|
|
}
|
|
if (lowerToShuffleTranspose) {
|
|
vectorTransformOptions =
|
|
vectorTransformOptions.setVectorTransposeLowering(
|
|
VectorTransposeLowering::Shuffle);
|
|
}
|
|
vector::populateVectorTransposeLoweringPatterns(patterns,
|
|
vectorTransformOptions);
|
|
|
|
if (lowerToAvx2) {
|
|
auto avx2LoweringOptions =
|
|
x86vector::avx2::LoweringOptions().setTransposeOptions(
|
|
x86vector::avx2::TransposeLoweringOptions()
|
|
.lower4x8xf32()
|
|
.lower8x8xf32());
|
|
x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
|
|
patterns, avx2LoweringOptions, /*benefit=*/10);
|
|
}
|
|
|
|
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns))))
|
|
return signalPassFailure();
|
|
}
|
|
};
|
|
|
|
struct TestVectorUnrollingPatterns
|
|
: public PassWrapper<TestVectorUnrollingPatterns,
|
|
OperationPass<func::FuncOp>> {
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorUnrollingPatterns)
|
|
|
|
StringRef getArgument() const final {
|
|
return "test-vector-unrolling-patterns";
|
|
}
|
|
StringRef getDescription() const final {
|
|
return "Test lowering patterns to unroll contract ops in the vector "
|
|
"dialect";
|
|
}
|
|
TestVectorUnrollingPatterns() = default;
|
|
TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass)
|
|
: PassWrapper(pass) {}
|
|
void runOnOperation() override {
|
|
MLIRContext *ctx = &getContext();
|
|
RewritePatternSet patterns(ctx);
|
|
populateVectorUnrollPatterns(
|
|
patterns, UnrollVectorOptions()
|
|
.setNativeShape(ArrayRef<int64_t>{2, 2})
|
|
.setFilterConstraint([](Operation *op) {
|
|
return success(isa<arith::AddFOp, vector::FMAOp,
|
|
vector::MultiDimReductionOp>(op));
|
|
}));
|
|
populateVectorUnrollPatterns(
|
|
patterns, UnrollVectorOptions()
|
|
.setNativeShape(ArrayRef<int64_t>{2})
|
|
.setFilterConstraint([](Operation *op) {
|
|
return success(isa<vector::ReductionOp>(op));
|
|
}));
|
|
populateVectorUnrollPatterns(
|
|
patterns, UnrollVectorOptions()
|
|
.setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2})
|
|
.setFilterConstraint([](Operation *op) {
|
|
return success(isa<vector::TransposeOp>(op));
|
|
}));
|
|
|
|
if (unrollBasedOnType) {
|
|
UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
|
|
[](Operation *op) -> Optional<SmallVector<int64_t, 4>> {
|
|
vector::ContractionOp contractOp = cast<vector::ContractionOp>(op);
|
|
SmallVector<int64_t, 4> nativeShape(
|
|
contractOp.getIteratorTypes().size(), 4);
|
|
Type lhsType = contractOp.getLhsType().getElementType();
|
|
nativeShape[nativeShape.size() - 1] = lhsType.isF16() ? 4 : 2;
|
|
return nativeShape;
|
|
};
|
|
|
|
UnrollVectorOptions opts;
|
|
opts.setNativeShapeFn(nativeShapeFn)
|
|
.setFilterConstraint(
|
|
[](Operation *op) { return success(isa<ContractionOp>(op)); });
|
|
|
|
if (!unrollOrder.empty()) {
|
|
opts.setUnrollTraversalOrderFn([this](Operation *op)
|
|
-> Optional<SmallVector<int64_t>> {
|
|
vector::ContractionOp contractOp = cast<vector::ContractionOp>(op);
|
|
if (contractOp.getIteratorTypes().size() == unrollOrder.size())
|
|
return SmallVector<int64_t>(unrollOrder.begin(), unrollOrder.end());
|
|
return None;
|
|
});
|
|
}
|
|
populateVectorUnrollPatterns(patterns, opts);
|
|
} else {
|
|
auto nativeShapeFn =
|
|
[](Operation *op) -> Optional<SmallVector<int64_t, 4>> {
|
|
auto contractOp = dyn_cast<ContractionOp>(op);
|
|
if (!contractOp)
|
|
return None;
|
|
return SmallVector<int64_t, 4>(contractOp.getIteratorTypes().size(), 2);
|
|
};
|
|
populateVectorUnrollPatterns(patterns,
|
|
UnrollVectorOptions()
|
|
.setNativeShapeFn(nativeShapeFn)
|
|
.setFilterConstraint([](Operation *op) {
|
|
return success(isa<ContractionOp>(op));
|
|
}));
|
|
}
|
|
populateVectorToVectorCanonicalizationPatterns(patterns);
|
|
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
|
}
|
|
|
|
ListOption<int64_t> unrollOrder{*this, "unroll-order",
|
|
llvm::cl::desc("set the unroll order")};
|
|
|
|
Option<bool> unrollBasedOnType{
|
|
*this, "unroll-based-on-type",
|
|
llvm::cl::desc("Set the unroll factor based on type of the operation"),
|
|
llvm::cl::init(false)};
|
|
};
|
|
|
|
struct TestVectorTransferUnrollingPatterns
|
|
: public PassWrapper<TestVectorTransferUnrollingPatterns,
|
|
OperationPass<func::FuncOp>> {
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
|
|
TestVectorTransferUnrollingPatterns)
|
|
|
|
TestVectorTransferUnrollingPatterns() = default;
|
|
TestVectorTransferUnrollingPatterns(
|
|
const TestVectorTransferUnrollingPatterns &pass)
|
|
: PassWrapper(pass) {}
|
|
|
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
|
registry.insert<AffineDialect>();
|
|
}
|
|
StringRef getArgument() const final {
|
|
return "test-vector-transfer-unrolling-patterns";
|
|
}
|
|
StringRef getDescription() const final {
|
|
return "Test lowering patterns to unroll transfer ops in the vector "
|
|
"dialect";
|
|
}
|
|
void runOnOperation() override {
|
|
MLIRContext *ctx = &getContext();
|
|
RewritePatternSet patterns(ctx);
|
|
UnrollVectorOptions opts;
|
|
opts.setNativeShape(ArrayRef<int64_t>{2, 2})
|
|
.setFilterConstraint([](Operation *op) {
|
|
return success(
|
|
isa<vector::TransferReadOp, vector::TransferWriteOp>(op));
|
|
});
|
|
if (reverseUnrollOrder.getValue()) {
|
|
opts.setUnrollTraversalOrderFn(
|
|
[](Operation *op) -> Optional<SmallVector<int64_t>> {
|
|
int64_t numLoops = 0;
|
|
if (auto readOp = dyn_cast<vector::TransferReadOp>(op))
|
|
numLoops = readOp.getVectorType().getRank();
|
|
else if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op))
|
|
numLoops = writeOp.getVectorType().getRank();
|
|
else
|
|
return None;
|
|
auto order = llvm::reverse(llvm::seq<int64_t>(0, numLoops));
|
|
return llvm::to_vector(order);
|
|
});
|
|
}
|
|
populateVectorUnrollPatterns(patterns, opts);
|
|
populateVectorToVectorCanonicalizationPatterns(patterns);
|
|
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
|
}
|
|
|
|
Option<bool> reverseUnrollOrder{
|
|
*this, "reverse-unroll-order",
|
|
llvm::cl::desc(
|
|
"reverse the order of unrolling of vector transfer operations"),
|
|
llvm::cl::init(false)};
|
|
};
|
|
|
|
struct TestVectorTransferFullPartialSplitPatterns
|
|
: public PassWrapper<TestVectorTransferFullPartialSplitPatterns,
|
|
OperationPass<func::FuncOp>> {
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
|
|
TestVectorTransferFullPartialSplitPatterns)
|
|
|
|
StringRef getArgument() const final {
|
|
return "test-vector-transfer-full-partial-split";
|
|
}
|
|
StringRef getDescription() const final {
|
|
return "Test lowering patterns to split "
|
|
"transfer ops via scf.if + linalg ops";
|
|
}
|
|
TestVectorTransferFullPartialSplitPatterns() = default;
|
|
TestVectorTransferFullPartialSplitPatterns(
|
|
const TestVectorTransferFullPartialSplitPatterns &pass)
|
|
: PassWrapper(pass) {}
|
|
|
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
|
registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
|
|
scf::SCFDialect>();
|
|
}
|
|
|
|
Option<bool> useLinalgOps{
|
|
*this, "use-memref-copy",
|
|
llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + "
|
|
"memref.copy operations."),
|
|
llvm::cl::init(false)};
|
|
void runOnOperation() override {
|
|
MLIRContext *ctx = &getContext();
|
|
RewritePatternSet patterns(ctx);
|
|
VectorTransformsOptions options;
|
|
if (useLinalgOps)
|
|
options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy);
|
|
else
|
|
options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer);
|
|
patterns.add<VectorTransferFullPartialRewriter>(ctx, options);
|
|
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
|
}
|
|
};
|
|
|
|
struct TestVectorTransferOpt
|
|
: public PassWrapper<TestVectorTransferOpt, OperationPass<func::FuncOp>> {
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransferOpt)
|
|
|
|
StringRef getArgument() const final { return "test-vector-transferop-opt"; }
|
|
StringRef getDescription() const final {
|
|
return "Test optimization transformations for transfer ops";
|
|
}
|
|
void runOnOperation() override { transferOpflowOpt(getOperation()); }
|
|
};
|
|
|
|
struct TestVectorTransferLoweringPatterns
|
|
: public PassWrapper<TestVectorTransferLoweringPatterns,
|
|
OperationPass<func::FuncOp>> {
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
|
|
TestVectorTransferLoweringPatterns)
|
|
|
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
|
registry.insert<tensor::TensorDialect, memref::MemRefDialect>();
|
|
}
|
|
StringRef getArgument() const final {
|
|
return "test-vector-transfer-lowering-patterns";
|
|
}
|
|
StringRef getDescription() const final {
|
|
return "Test lowering patterns to lower transfer ops to other vector ops";
|
|
}
|
|
void runOnOperation() override {
|
|
RewritePatternSet patterns(&getContext());
|
|
populateVectorTransferLoweringPatterns(patterns);
|
|
populateVectorTransferPermutationMapLoweringPatterns(patterns);
|
|
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
|
}
|
|
};
|
|
|
|
struct TestVectorMultiReductionLoweringPatterns
|
|
: public PassWrapper<TestVectorMultiReductionLoweringPatterns,
|
|
OperationPass<func::FuncOp>> {
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
|
|
TestVectorMultiReductionLoweringPatterns)
|
|
|
|
TestVectorMultiReductionLoweringPatterns() = default;
|
|
TestVectorMultiReductionLoweringPatterns(
|
|
const TestVectorMultiReductionLoweringPatterns &pass)
|
|
: PassWrapper(pass) {}
|
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
|
registry.insert<memref::MemRefDialect>();
|
|
}
|
|
StringRef getArgument() const final {
|
|
return "test-vector-multi-reduction-lowering-patterns";
|
|
}
|
|
StringRef getDescription() const final {
|
|
return "Test lowering patterns to lower vector.multi_reduction to other "
|
|
"vector ops";
|
|
}
|
|
Option<bool> useOuterReductions{
|
|
*this, "use-outer-reductions",
|
|
llvm::cl::desc("Move reductions to outer most dimensions"),
|
|
llvm::cl::init(false)};
|
|
void runOnOperation() override {
|
|
RewritePatternSet patterns(&getContext());
|
|
populateVectorMultiReductionLoweringPatterns(
|
|
patterns, useOuterReductions
|
|
? vector::VectorMultiReductionLowering::InnerParallel
|
|
: vector::VectorMultiReductionLowering::InnerReduction);
|
|
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
|
}
|
|
};
|
|
|
|
struct TestVectorTransferCollapseInnerMostContiguousDims
|
|
: public PassWrapper<TestVectorTransferCollapseInnerMostContiguousDims,
|
|
OperationPass<func::FuncOp>> {
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
|
|
TestVectorTransferCollapseInnerMostContiguousDims)
|
|
|
|
TestVectorTransferCollapseInnerMostContiguousDims() = default;
|
|
TestVectorTransferCollapseInnerMostContiguousDims(
|
|
const TestVectorTransferCollapseInnerMostContiguousDims &pass) = default;
|
|
|
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
|
registry.insert<memref::MemRefDialect, AffineDialect>();
|
|
}
|
|
|
|
StringRef getArgument() const final {
|
|
return "test-vector-transfer-collapse-inner-most-dims";
|
|
}
|
|
|
|
StringRef getDescription() const final {
|
|
return "Test lowering patterns that reducedes the rank of the vector "
|
|
"transfer memory and vector operands.";
|
|
}
|
|
|
|
void runOnOperation() override {
|
|
RewritePatternSet patterns(&getContext());
|
|
populateVectorTransferCollapseInnerMostContiguousDimsPatterns(patterns);
|
|
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
|
}
|
|
};
|
|
|
|
struct TestVectorReduceToContractPatternsPatterns
|
|
: public PassWrapper<TestVectorReduceToContractPatternsPatterns,
|
|
OperationPass<func::FuncOp>> {
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
|
|
TestVectorReduceToContractPatternsPatterns)
|
|
|
|
StringRef getArgument() const final {
|
|
return "test-vector-reduction-to-contract-patterns";
|
|
}
|
|
StringRef getDescription() const final {
|
|
return "Test patterns to convert multireduce op to contract and combine "
|
|
"broadcast/transpose to contract";
|
|
}
|
|
void runOnOperation() override {
|
|
RewritePatternSet patterns(&getContext());
|
|
populateVectorReductionToContractPatterns(patterns);
|
|
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
|
}
|
|
};
|
|
|
|
struct TestVectorTransferDropUnitDimsPatterns
|
|
: public PassWrapper<TestVectorTransferDropUnitDimsPatterns,
|
|
OperationPass<func::FuncOp>> {
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
|
|
TestVectorTransferDropUnitDimsPatterns)
|
|
|
|
StringRef getArgument() const final {
|
|
return "test-vector-transfer-drop-unit-dims-patterns";
|
|
}
|
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
|
registry.insert<memref::MemRefDialect>();
|
|
}
|
|
void runOnOperation() override {
|
|
RewritePatternSet patterns(&getContext());
|
|
populateVectorTransferDropUnitDimsPatterns(patterns);
|
|
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
|
}
|
|
};
|
|
|
|
struct TestFlattenVectorTransferPatterns
|
|
: public PassWrapper<TestFlattenVectorTransferPatterns,
|
|
OperationPass<func::FuncOp>> {
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
|
|
TestFlattenVectorTransferPatterns)
|
|
|
|
StringRef getArgument() const final {
|
|
return "test-vector-transfer-flatten-patterns";
|
|
}
|
|
StringRef getDescription() const final {
|
|
return "Test patterns to rewrite contiguous row-major N-dimensional "
|
|
"vector.transfer_{read,write} ops into 1D transfers";
|
|
}
|
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
|
registry.insert<memref::MemRefDialect>();
|
|
}
|
|
void runOnOperation() override {
|
|
RewritePatternSet patterns(&getContext());
|
|
populateFlattenVectorTransferPatterns(patterns);
|
|
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
|
}
|
|
};
|
|
|
|
struct TestVectorScanLowering
|
|
: public PassWrapper<TestVectorScanLowering, OperationPass<func::FuncOp>> {
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorScanLowering)
|
|
|
|
StringRef getArgument() const final { return "test-vector-scan-lowering"; }
|
|
StringRef getDescription() const final {
|
|
return "Test lowering patterns that lower the scan op in the vector "
|
|
"dialect";
|
|
}
|
|
void runOnOperation() override {
|
|
RewritePatternSet patterns(&getContext());
|
|
populateVectorScanLoweringPatterns(patterns);
|
|
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
|
}
|
|
};
|
|
|
|
/// Allocate shared memory for a single warp to test lowering of
|
|
/// WarpExecuteOnLane0Op.
|
|
static Value allocateGlobalSharedMemory(Location loc, OpBuilder &builder,
|
|
WarpExecuteOnLane0Op warpOp,
|
|
Type type) {
|
|
static constexpr int64_t kSharedMemorySpace = 3;
|
|
// Compute type of shared memory buffer.
|
|
MemRefType memrefType;
|
|
if (auto vectorType = type.dyn_cast<VectorType>()) {
|
|
memrefType =
|
|
MemRefType::get(vectorType.getShape(), vectorType.getElementType(), {},
|
|
kSharedMemorySpace);
|
|
} else {
|
|
memrefType = MemRefType::get({1}, type, {}, kSharedMemorySpace);
|
|
}
|
|
|
|
// Get symbol table holding all shared memory globals.
|
|
ModuleOp moduleOp = warpOp->getParentOfType<ModuleOp>();
|
|
SymbolTable symbolTable(moduleOp);
|
|
|
|
// Create a pretty name.
|
|
SmallString<64> buf;
|
|
llvm::raw_svector_ostream os(buf);
|
|
interleave(memrefType.getShape(), os, "x");
|
|
os << "x" << memrefType.getElementType();
|
|
std::string symbolName = (Twine("__shared_") + os.str()).str();
|
|
|
|
auto ip = builder.saveInsertionPoint();
|
|
builder.setInsertionPoint(moduleOp);
|
|
auto global = builder.create<memref::GlobalOp>(
|
|
loc,
|
|
/*sym_name=*/symbolName,
|
|
/*sym_visibility=*/builder.getStringAttr("private"),
|
|
/*type=*/memrefType,
|
|
/*initial_value=*/Attribute(),
|
|
/*constant=*/false,
|
|
/*alignment=*/IntegerAttr());
|
|
symbolTable.insert(global);
|
|
// The symbol table inserts at the end of the module, but globals are a bit
|
|
// nicer if they are at the beginning.
|
|
global->moveBefore(&moduleOp.front());
|
|
|
|
builder.restoreInsertionPoint(ip);
|
|
return builder.create<memref::GetGlobalOp>(loc, memrefType, symbolName);
|
|
}
|
|
|
|
static Value warpReduction(Location loc, OpBuilder &builder, Value input,
|
|
CombiningKind kind, uint32_t size) {
|
|
Value laneVal = input;
|
|
// Parallel reduction using butterfly shuffles.
|
|
for (uint64_t i = 1; i < size; i <<= 1) {
|
|
Value shuffled = builder
|
|
.create<gpu::ShuffleOp>(loc, laneVal, i,
|
|
/*width=*/size,
|
|
/*mode=*/gpu::ShuffleMode::XOR)
|
|
.getShuffleResult();
|
|
laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled);
|
|
}
|
|
return laneVal;
|
|
}
|
|
|
|
struct TestVectorDistribution
|
|
: public PassWrapper<TestVectorDistribution, OperationPass<func::FuncOp>> {
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorDistribution)
|
|
|
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
|
registry.insert<scf::SCFDialect, memref::MemRefDialect, gpu::GPUDialect,
|
|
AffineDialect>();
|
|
}
|
|
|
|
StringRef getArgument() const final { return "test-vector-warp-distribute"; }
|
|
StringRef getDescription() const final {
|
|
return "Test vector warp distribute transformation and lowering patterns";
|
|
}
|
|
TestVectorDistribution() = default;
|
|
TestVectorDistribution(const TestVectorDistribution &pass)
|
|
: PassWrapper(pass) {}
|
|
|
|
Option<bool> warpOpToSCF{
|
|
*this, "rewrite-warp-ops-to-scf-if",
|
|
llvm::cl::desc("Lower vector.warp_execute_on_lane0 to scf.if op"),
|
|
llvm::cl::init(false)};
|
|
|
|
Option<bool> distributeTransferWriteOps{
|
|
*this, "distribute-transfer-write",
|
|
llvm::cl::desc("Test distribution of transfer write"),
|
|
llvm::cl::init(false)};
|
|
|
|
Option<bool> hoistUniform{*this, "hoist-uniform",
|
|
llvm::cl::desc("Test hoist uniform"),
|
|
llvm::cl::init(false)};
|
|
|
|
Option<bool> propagateDistribution{
|
|
*this, "propagate-distribution",
|
|
llvm::cl::desc("Test distribution propgation"), llvm::cl::init(false)};
|
|
|
|
void runOnOperation() override {
|
|
RewritePatternSet patterns(&getContext());
|
|
|
|
getOperation().walk([&](Operation *op) {
|
|
if (auto warpOp = dyn_cast<WarpExecuteOnLane0Op>(op)) {
|
|
if (hoistUniform) {
|
|
moveScalarUniformCode(warpOp);
|
|
}
|
|
WalkResult::interrupt();
|
|
}
|
|
});
|
|
MLIRContext *ctx = &getContext();
|
|
auto distributionFn = [](Value val) {
|
|
// Create a map (d0, d1) -> (d1) to distribute along the inner
|
|
// dimension. Once we support n-d distribution we can add more
|
|
// complex cases.
|
|
VectorType vecType = val.getType().dyn_cast<VectorType>();
|
|
int64_t vecRank = vecType ? vecType.getRank() : 0;
|
|
OpBuilder builder(val.getContext());
|
|
if (vecRank == 0)
|
|
return AffineMap::get(val.getContext());
|
|
return AffineMap::get(vecRank, 0, builder.getAffineDimExpr(vecRank - 1));
|
|
};
|
|
if (distributeTransferWriteOps) {
|
|
RewritePatternSet patterns(ctx);
|
|
populateDistributeTransferWriteOpPatterns(patterns, distributionFn);
|
|
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
|
}
|
|
if (propagateDistribution) {
|
|
RewritePatternSet patterns(ctx);
|
|
vector::populatePropagateWarpVectorDistributionPatterns(patterns,
|
|
distributionFn);
|
|
vector::populateDistributeReduction(patterns, warpReduction);
|
|
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
|
}
|
|
WarpExecuteOnLane0LoweringOptions options;
|
|
options.warpAllocationFn = allocateGlobalSharedMemory;
|
|
options.warpSyncronizationFn = [](Location loc, OpBuilder &builder,
|
|
WarpExecuteOnLane0Op warpOp) {
|
|
builder.create<gpu::BarrierOp>(loc);
|
|
};
|
|
// Test on one pattern in isolation.
|
|
if (warpOpToSCF) {
|
|
populateWarpExecuteOnLane0OpToScfForPattern(patterns, options);
|
|
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
|
return;
|
|
}
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
namespace mlir {
|
|
namespace test {
|
|
void registerTestVectorLowerings() {
|
|
PassRegistration<TestVectorToVectorLowering>();
|
|
|
|
PassRegistration<TestVectorContractionLowering>();
|
|
|
|
PassRegistration<TestVectorTransposeLowering>();
|
|
|
|
PassRegistration<TestVectorUnrollingPatterns>();
|
|
|
|
PassRegistration<TestVectorTransferUnrollingPatterns>();
|
|
|
|
PassRegistration<TestVectorTransferFullPartialSplitPatterns>();
|
|
|
|
PassRegistration<TestVectorTransferOpt>();
|
|
|
|
PassRegistration<TestVectorTransferLoweringPatterns>();
|
|
|
|
PassRegistration<TestVectorMultiReductionLoweringPatterns>();
|
|
|
|
PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>();
|
|
|
|
PassRegistration<TestVectorReduceToContractPatternsPatterns>();
|
|
|
|
PassRegistration<TestVectorTransferDropUnitDimsPatterns>();
|
|
|
|
PassRegistration<TestFlattenVectorTransferPatterns>();
|
|
|
|
PassRegistration<TestVectorScanLowering>();
|
|
|
|
PassRegistration<TestVectorDistribution>();
|
|
}
|
|
} // namespace test
|
|
} // namespace mlir
|