Files
clang-p2996/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
stanley-nod d2061530dc [mlir][vector] Modify constraint and interface for warp reduce on f16 and i8
Quantization method is crucial and ubiqutous in accelerating machine
learning workloads. Most of these methods uses f16 and i8 types.

This patch relaxes the type contraints on warp reduce distribution to
allow these types. Furthermore, this patch also changed the interface
and moved the initial reduction of data to a single thread into the
distributedReductionFn, this gives flexibility for developers to control
how they are obtaining the initial lane value, which might differ based
on the input types. (i.e to shuffle 32-width type, we need to reduce f16
to 2xf16 types rather than a single element).

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D137691
2022-11-09 11:52:17 -08:00

825 lines
32 KiB
C++

//===- TestVectorTransforms.cpp - Test Vector transforms and lowerings ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include <type_traits>
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
using namespace mlir::linalg;
using namespace mlir::vector;
namespace {
struct TestVectorToVectorLowering
: public PassWrapper<TestVectorToVectorLowering,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorToVectorLowering)
TestVectorToVectorLowering() = default;
TestVectorToVectorLowering(const TestVectorToVectorLowering &pass)
: PassWrapper(pass) {}
StringRef getArgument() const final {
return "test-vector-to-vector-lowering";
}
StringRef getDescription() const final {
return "Test lowering patterns between ops in the vector dialect";
}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AffineDialect>();
}
Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"),
llvm::cl::init(false)};
void runOnOperation() override {
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
if (unroll) {
populateVectorUnrollPatterns(
patterns,
UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint(
filter));
}
populateVectorToVectorCanonicalizationPatterns(patterns);
populateBubbleVectorBitCastOpPatterns(patterns);
populateCastAwayVectorLeadingOneDimPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
private:
// Return the target shape based on op type.
static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) {
if (isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp>(op))
return SmallVector<int64_t, 4>(2, 2);
if (isa<vector::ContractionOp>(op))
return SmallVector<int64_t, 4>(3, 2);
// For transfer ops, just propagate the shape coming from
// InsertStridedSlices/ExtractStridedSlices.
if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) {
VectorType dstVec;
for (Operation *users : readOp->getUsers()) {
auto extract = dyn_cast<ExtractStridedSliceOp>(users);
if (!extract)
return llvm::None;
auto vecType = extract.getResult().getType().cast<VectorType>();
if (dstVec && dstVec != vecType)
return llvm::None;
dstVec = vecType;
}
return SmallVector<int64_t, 4>(dstVec.getShape().begin(),
dstVec.getShape().end());
}
if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
auto insert = writeOp.getVector().getDefiningOp<InsertStridedSliceOp>();
if (!insert)
return llvm::None;
ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape();
return SmallVector<int64_t, 4>(shape.begin(), shape.end());
}
return llvm::None;
}
static LogicalResult filter(Operation *op) {
return success(isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp,
ContractionOp, TransferReadOp, TransferWriteOp>(op));
}
};
struct TestVectorContractionLowering
: public PassWrapper<TestVectorContractionLowering,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorContractionLowering)
StringRef getArgument() const final {
return "test-vector-contraction-lowering";
}
StringRef getDescription() const final {
return "Test lowering patterns that lower contract ops in the vector "
"dialect";
}
TestVectorContractionLowering() = default;
TestVectorContractionLowering(const TestVectorContractionLowering &pass)
: PassWrapper(pass) {}
Option<bool> lowerToFlatMatrix{
*this, "vector-lower-matrix-intrinsics",
llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"),
llvm::cl::init(false)};
Option<bool> lowerToOuterProduct{
*this, "vector-outerproduct",
llvm::cl::desc("Lower vector.contract to vector.outerproduct"),
llvm::cl::init(false)};
Option<bool> lowerToFilterOuterProduct{
*this, "vector-filter-outerproduct",
llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for "
"vectors of size 4."),
llvm::cl::init(false)};
Option<bool> lowerToParallelArith{
*this, "vector-parallel-arith",
llvm::cl::desc("Lower vector.contract to elementwise vector ops."),
llvm::cl::init(false)};
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
// Test on one pattern in isolation.
if (lowerToOuterProduct) {
VectorContractLowering lowering = VectorContractLowering::OuterProduct;
VectorTransformsOptions options{lowering};
patterns.add<ContractionOpToOuterProductOpLowering>(options,
&getContext());
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
return;
}
// Test on one pattern in isolation.
if (lowerToFilterOuterProduct) {
VectorContractLowering lowering = VectorContractLowering::OuterProduct;
VectorTransformsOptions options{lowering};
patterns.add<ContractionOpToOuterProductOpLowering>(
options, &getContext(), /*benefit=*/1, [](vector::ContractionOp op) {
// Only lowers vector.contract where the lhs as a type vector<MxNx?>
// where M is not 4.
if (op.getRhsType().getShape()[0] == 4)
return failure();
return success();
});
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
return;
}
if (lowerToParallelArith) {
vector::populateVectorContractLoweringPatterns(
patterns,
vector::VectorTransformsOptions().setVectorTransformsOptions(
vector::VectorContractLowering::ParallelArith));
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
return;
}
// Test on all contract lowering patterns.
VectorContractLowering contractLowering = VectorContractLowering::Dot;
if (lowerToFlatMatrix)
contractLowering = VectorContractLowering::Matmul;
VectorMultiReductionLowering vectorMultiReductionLowering =
VectorMultiReductionLowering::InnerParallel;
VectorTransformsOptions options{contractLowering,
vectorMultiReductionLowering,
VectorTransposeLowering()};
populateVectorBroadcastLoweringPatterns(patterns);
populateVectorContractLoweringPatterns(patterns, options);
populateVectorMaskOpLoweringPatterns(patterns);
populateVectorShapeCastLoweringPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
struct TestVectorTransposeLowering
: public PassWrapper<TestVectorTransposeLowering,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransposeLowering)
StringRef getArgument() const final {
return "test-vector-transpose-lowering";
}
StringRef getDescription() const final {
return "Test lowering patterns that lower contract ops in the vector "
"dialect";
}
TestVectorTransposeLowering() = default;
TestVectorTransposeLowering(const TestVectorTransposeLowering &pass)
: PassWrapper(pass) {}
Option<bool> lowerToEltwise{
*this, "eltwise",
llvm::cl::desc("Lower 2-D vector.transpose to eltwise insert/extract"),
llvm::cl::init(false)};
Option<bool> lowerToFlatTranspose{
*this, "flat",
llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"),
llvm::cl::init(false)};
Option<bool> lowerToShuffleTranspose{
*this, "shuffle",
llvm::cl::desc("Lower 2-D vector.transpose to shape_cast + shuffle"),
llvm::cl::init(false)};
Option<bool> lowerToAvx2{
*this, "avx2",
llvm::cl::desc("Lower vector.transpose to avx2-specific patterns"),
llvm::cl::init(false)};
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<LLVM::LLVMDialect>();
}
void runOnOperation() override {
func::FuncOp funcOp = getOperation();
MLIRContext *context = funcOp.getContext();
RewritePatternSet patterns(context);
vector::VectorTransformsOptions vectorTransformOptions;
if (lowerToEltwise) {
vectorTransformOptions =
vectorTransformOptions.setVectorTransposeLowering(
VectorTransposeLowering::EltWise);
}
if (lowerToFlatTranspose) {
vectorTransformOptions =
vectorTransformOptions.setVectorTransposeLowering(
VectorTransposeLowering::Flat);
}
if (lowerToShuffleTranspose) {
vectorTransformOptions =
vectorTransformOptions.setVectorTransposeLowering(
VectorTransposeLowering::Shuffle);
}
vector::populateVectorTransposeLoweringPatterns(patterns,
vectorTransformOptions);
if (lowerToAvx2) {
auto avx2LoweringOptions =
x86vector::avx2::LoweringOptions().setTransposeOptions(
x86vector::avx2::TransposeLoweringOptions()
.lower4x8xf32()
.lower8x8xf32());
x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
patterns, avx2LoweringOptions, /*benefit=*/10);
}
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns))))
return signalPassFailure();
}
};
struct TestVectorUnrollingPatterns
: public PassWrapper<TestVectorUnrollingPatterns,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorUnrollingPatterns)
StringRef getArgument() const final {
return "test-vector-unrolling-patterns";
}
StringRef getDescription() const final {
return "Test lowering patterns to unroll contract ops in the vector "
"dialect";
}
TestVectorUnrollingPatterns() = default;
TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass)
: PassWrapper(pass) {}
void runOnOperation() override {
MLIRContext *ctx = &getContext();
RewritePatternSet patterns(ctx);
populateVectorUnrollPatterns(
patterns, UnrollVectorOptions()
.setNativeShape(ArrayRef<int64_t>{2, 2})
.setFilterConstraint([](Operation *op) {
return success(isa<arith::AddFOp, vector::FMAOp,
vector::MultiDimReductionOp>(op));
}));
populateVectorUnrollPatterns(
patterns, UnrollVectorOptions()
.setNativeShape(ArrayRef<int64_t>{2})
.setFilterConstraint([](Operation *op) {
return success(isa<vector::ReductionOp>(op));
}));
populateVectorUnrollPatterns(
patterns, UnrollVectorOptions()
.setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2})
.setFilterConstraint([](Operation *op) {
return success(isa<vector::TransposeOp>(op));
}));
if (unrollBasedOnType) {
UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
[](Operation *op) -> Optional<SmallVector<int64_t, 4>> {
vector::ContractionOp contractOp = cast<vector::ContractionOp>(op);
SmallVector<int64_t, 4> nativeShape(
contractOp.getIteratorTypes().size(), 4);
Type lhsType = contractOp.getLhsType().getElementType();
nativeShape[nativeShape.size() - 1] = lhsType.isF16() ? 4 : 2;
return nativeShape;
};
UnrollVectorOptions opts;
opts.setNativeShapeFn(nativeShapeFn)
.setFilterConstraint(
[](Operation *op) { return success(isa<ContractionOp>(op)); });
if (!unrollOrder.empty()) {
opts.setUnrollTraversalOrderFn([this](Operation *op)
-> Optional<SmallVector<int64_t>> {
vector::ContractionOp contractOp = cast<vector::ContractionOp>(op);
if (contractOp.getIteratorTypes().size() == unrollOrder.size())
return SmallVector<int64_t>(unrollOrder.begin(), unrollOrder.end());
return None;
});
}
populateVectorUnrollPatterns(patterns, opts);
} else {
auto nativeShapeFn =
[](Operation *op) -> Optional<SmallVector<int64_t, 4>> {
auto contractOp = dyn_cast<ContractionOp>(op);
if (!contractOp)
return None;
return SmallVector<int64_t, 4>(contractOp.getIteratorTypes().size(), 2);
};
populateVectorUnrollPatterns(patterns,
UnrollVectorOptions()
.setNativeShapeFn(nativeShapeFn)
.setFilterConstraint([](Operation *op) {
return success(isa<ContractionOp>(op));
}));
}
populateVectorToVectorCanonicalizationPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
ListOption<int64_t> unrollOrder{*this, "unroll-order",
llvm::cl::desc("set the unroll order")};
Option<bool> unrollBasedOnType{
*this, "unroll-based-on-type",
llvm::cl::desc("Set the unroll factor based on type of the operation"),
llvm::cl::init(false)};
};
struct TestVectorTransferUnrollingPatterns
: public PassWrapper<TestVectorTransferUnrollingPatterns,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
TestVectorTransferUnrollingPatterns)
TestVectorTransferUnrollingPatterns() = default;
TestVectorTransferUnrollingPatterns(
const TestVectorTransferUnrollingPatterns &pass)
: PassWrapper(pass) {}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AffineDialect>();
}
StringRef getArgument() const final {
return "test-vector-transfer-unrolling-patterns";
}
StringRef getDescription() const final {
return "Test lowering patterns to unroll transfer ops in the vector "
"dialect";
}
void runOnOperation() override {
MLIRContext *ctx = &getContext();
RewritePatternSet patterns(ctx);
UnrollVectorOptions opts;
opts.setNativeShape(ArrayRef<int64_t>{2, 2})
.setFilterConstraint([](Operation *op) {
return success(
isa<vector::TransferReadOp, vector::TransferWriteOp>(op));
});
if (reverseUnrollOrder.getValue()) {
opts.setUnrollTraversalOrderFn(
[](Operation *op) -> Optional<SmallVector<int64_t>> {
int64_t numLoops = 0;
if (auto readOp = dyn_cast<vector::TransferReadOp>(op))
numLoops = readOp.getVectorType().getRank();
else if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op))
numLoops = writeOp.getVectorType().getRank();
else
return None;
auto order = llvm::reverse(llvm::seq<int64_t>(0, numLoops));
return llvm::to_vector(order);
});
}
populateVectorUnrollPatterns(patterns, opts);
populateVectorToVectorCanonicalizationPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
Option<bool> reverseUnrollOrder{
*this, "reverse-unroll-order",
llvm::cl::desc(
"reverse the order of unrolling of vector transfer operations"),
llvm::cl::init(false)};
};
struct TestVectorTransferFullPartialSplitPatterns
: public PassWrapper<TestVectorTransferFullPartialSplitPatterns,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
TestVectorTransferFullPartialSplitPatterns)
StringRef getArgument() const final {
return "test-vector-transfer-full-partial-split";
}
StringRef getDescription() const final {
return "Test lowering patterns to split "
"transfer ops via scf.if + linalg ops";
}
TestVectorTransferFullPartialSplitPatterns() = default;
TestVectorTransferFullPartialSplitPatterns(
const TestVectorTransferFullPartialSplitPatterns &pass)
: PassWrapper(pass) {}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
scf::SCFDialect>();
}
Option<bool> useLinalgOps{
*this, "use-memref-copy",
llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + "
"memref.copy operations."),
llvm::cl::init(false)};
void runOnOperation() override {
MLIRContext *ctx = &getContext();
RewritePatternSet patterns(ctx);
VectorTransformsOptions options;
if (useLinalgOps)
options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy);
else
options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer);
patterns.add<VectorTransferFullPartialRewriter>(ctx, options);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
struct TestVectorTransferOpt
: public PassWrapper<TestVectorTransferOpt, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransferOpt)
StringRef getArgument() const final { return "test-vector-transferop-opt"; }
StringRef getDescription() const final {
return "Test optimization transformations for transfer ops";
}
void runOnOperation() override { transferOpflowOpt(getOperation()); }
};
struct TestVectorTransferLoweringPatterns
: public PassWrapper<TestVectorTransferLoweringPatterns,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
TestVectorTransferLoweringPatterns)
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<tensor::TensorDialect, memref::MemRefDialect>();
}
StringRef getArgument() const final {
return "test-vector-transfer-lowering-patterns";
}
StringRef getDescription() const final {
return "Test lowering patterns to lower transfer ops to other vector ops";
}
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateVectorTransferLoweringPatterns(patterns);
populateVectorTransferPermutationMapLoweringPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
struct TestVectorMultiReductionLoweringPatterns
: public PassWrapper<TestVectorMultiReductionLoweringPatterns,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
TestVectorMultiReductionLoweringPatterns)
TestVectorMultiReductionLoweringPatterns() = default;
TestVectorMultiReductionLoweringPatterns(
const TestVectorMultiReductionLoweringPatterns &pass)
: PassWrapper(pass) {}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<memref::MemRefDialect>();
}
StringRef getArgument() const final {
return "test-vector-multi-reduction-lowering-patterns";
}
StringRef getDescription() const final {
return "Test lowering patterns to lower vector.multi_reduction to other "
"vector ops";
}
Option<bool> useOuterReductions{
*this, "use-outer-reductions",
llvm::cl::desc("Move reductions to outer most dimensions"),
llvm::cl::init(false)};
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateVectorMultiReductionLoweringPatterns(
patterns, useOuterReductions
? vector::VectorMultiReductionLowering::InnerParallel
: vector::VectorMultiReductionLowering::InnerReduction);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
struct TestVectorTransferCollapseInnerMostContiguousDims
: public PassWrapper<TestVectorTransferCollapseInnerMostContiguousDims,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
TestVectorTransferCollapseInnerMostContiguousDims)
TestVectorTransferCollapseInnerMostContiguousDims() = default;
TestVectorTransferCollapseInnerMostContiguousDims(
const TestVectorTransferCollapseInnerMostContiguousDims &pass) = default;
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<memref::MemRefDialect, AffineDialect>();
}
StringRef getArgument() const final {
return "test-vector-transfer-collapse-inner-most-dims";
}
StringRef getDescription() const final {
return "Test lowering patterns that reducedes the rank of the vector "
"transfer memory and vector operands.";
}
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateVectorTransferCollapseInnerMostContiguousDimsPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
struct TestVectorReduceToContractPatternsPatterns
: public PassWrapper<TestVectorReduceToContractPatternsPatterns,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
TestVectorReduceToContractPatternsPatterns)
StringRef getArgument() const final {
return "test-vector-reduction-to-contract-patterns";
}
StringRef getDescription() const final {
return "Test patterns to convert multireduce op to contract and combine "
"broadcast/transpose to contract";
}
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateVectorReductionToContractPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
struct TestVectorTransferDropUnitDimsPatterns
: public PassWrapper<TestVectorTransferDropUnitDimsPatterns,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
TestVectorTransferDropUnitDimsPatterns)
StringRef getArgument() const final {
return "test-vector-transfer-drop-unit-dims-patterns";
}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<memref::MemRefDialect>();
}
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateVectorTransferDropUnitDimsPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
struct TestFlattenVectorTransferPatterns
: public PassWrapper<TestFlattenVectorTransferPatterns,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
TestFlattenVectorTransferPatterns)
StringRef getArgument() const final {
return "test-vector-transfer-flatten-patterns";
}
StringRef getDescription() const final {
return "Test patterns to rewrite contiguous row-major N-dimensional "
"vector.transfer_{read,write} ops into 1D transfers";
}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<memref::MemRefDialect>();
}
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateFlattenVectorTransferPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
struct TestVectorScanLowering
: public PassWrapper<TestVectorScanLowering, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorScanLowering)
StringRef getArgument() const final { return "test-vector-scan-lowering"; }
StringRef getDescription() const final {
return "Test lowering patterns that lower the scan op in the vector "
"dialect";
}
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateVectorScanLoweringPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
/// Allocate shared memory for a single warp to test lowering of
/// WarpExecuteOnLane0Op.
static Value allocateGlobalSharedMemory(Location loc, OpBuilder &builder,
WarpExecuteOnLane0Op warpOp,
Type type) {
static constexpr int64_t kSharedMemorySpace = 3;
// Compute type of shared memory buffer.
MemRefType memrefType;
if (auto vectorType = type.dyn_cast<VectorType>()) {
memrefType =
MemRefType::get(vectorType.getShape(), vectorType.getElementType(), {},
kSharedMemorySpace);
} else {
memrefType = MemRefType::get({1}, type, {}, kSharedMemorySpace);
}
// Get symbol table holding all shared memory globals.
ModuleOp moduleOp = warpOp->getParentOfType<ModuleOp>();
SymbolTable symbolTable(moduleOp);
// Create a pretty name.
SmallString<64> buf;
llvm::raw_svector_ostream os(buf);
interleave(memrefType.getShape(), os, "x");
os << "x" << memrefType.getElementType();
std::string symbolName = (Twine("__shared_") + os.str()).str();
auto ip = builder.saveInsertionPoint();
builder.setInsertionPoint(moduleOp);
auto global = builder.create<memref::GlobalOp>(
loc,
/*sym_name=*/symbolName,
/*sym_visibility=*/builder.getStringAttr("private"),
/*type=*/memrefType,
/*initial_value=*/Attribute(),
/*constant=*/false,
/*alignment=*/IntegerAttr());
symbolTable.insert(global);
// The symbol table inserts at the end of the module, but globals are a bit
// nicer if they are at the beginning.
global->moveBefore(&moduleOp.front());
builder.restoreInsertionPoint(ip);
return builder.create<memref::GetGlobalOp>(loc, memrefType, symbolName);
}
static Value warpReduction(Location loc, OpBuilder &builder, Value input,
CombiningKind kind, uint32_t size) {
// First reduce on a single thread to get per lane reduction value.
Value laneVal = builder.create<vector::ReductionOp>(loc, kind, input);
// Parallel reduction using butterfly shuffles.
for (uint64_t i = 1; i < size; i <<= 1) {
Value shuffled = builder
.create<gpu::ShuffleOp>(loc, laneVal, i,
/*width=*/size,
/*mode=*/gpu::ShuffleMode::XOR)
.getShuffleResult();
laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled);
}
return laneVal;
}
struct TestVectorDistribution
: public PassWrapper<TestVectorDistribution, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorDistribution)
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<scf::SCFDialect, memref::MemRefDialect, gpu::GPUDialect,
AffineDialect>();
}
StringRef getArgument() const final { return "test-vector-warp-distribute"; }
StringRef getDescription() const final {
return "Test vector warp distribute transformation and lowering patterns";
}
TestVectorDistribution() = default;
TestVectorDistribution(const TestVectorDistribution &pass)
: PassWrapper(pass) {}
Option<bool> warpOpToSCF{
*this, "rewrite-warp-ops-to-scf-if",
llvm::cl::desc("Lower vector.warp_execute_on_lane0 to scf.if op"),
llvm::cl::init(false)};
Option<bool> distributeTransferWriteOps{
*this, "distribute-transfer-write",
llvm::cl::desc("Test distribution of transfer write"),
llvm::cl::init(false)};
Option<bool> hoistUniform{*this, "hoist-uniform",
llvm::cl::desc("Test hoist uniform"),
llvm::cl::init(false)};
Option<bool> propagateDistribution{
*this, "propagate-distribution",
llvm::cl::desc("Test distribution propgation"), llvm::cl::init(false)};
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
getOperation().walk([&](Operation *op) {
if (auto warpOp = dyn_cast<WarpExecuteOnLane0Op>(op)) {
if (hoistUniform) {
moveScalarUniformCode(warpOp);
}
WalkResult::interrupt();
}
});
MLIRContext *ctx = &getContext();
auto distributionFn = [](Value val) {
// Create a map (d0, d1) -> (d1) to distribute along the inner
// dimension. Once we support n-d distribution we can add more
// complex cases.
VectorType vecType = val.getType().dyn_cast<VectorType>();
int64_t vecRank = vecType ? vecType.getRank() : 0;
OpBuilder builder(val.getContext());
if (vecRank == 0)
return AffineMap::get(val.getContext());
return AffineMap::get(vecRank, 0, builder.getAffineDimExpr(vecRank - 1));
};
if (distributeTransferWriteOps) {
RewritePatternSet patterns(ctx);
populateDistributeTransferWriteOpPatterns(patterns, distributionFn);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
if (propagateDistribution) {
RewritePatternSet patterns(ctx);
vector::populatePropagateWarpVectorDistributionPatterns(patterns,
distributionFn);
vector::populateDistributeReduction(patterns, warpReduction);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
WarpExecuteOnLane0LoweringOptions options;
options.warpAllocationFn = allocateGlobalSharedMemory;
options.warpSyncronizationFn = [](Location loc, OpBuilder &builder,
WarpExecuteOnLane0Op warpOp) {
builder.create<gpu::BarrierOp>(loc);
};
// Test on one pattern in isolation.
if (warpOpToSCF) {
populateWarpExecuteOnLane0OpToScfForPattern(patterns, options);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
return;
}
}
};
} // namespace
namespace mlir {
namespace test {
void registerTestVectorLowerings() {
PassRegistration<TestVectorToVectorLowering>();
PassRegistration<TestVectorContractionLowering>();
PassRegistration<TestVectorTransposeLowering>();
PassRegistration<TestVectorUnrollingPatterns>();
PassRegistration<TestVectorTransferUnrollingPatterns>();
PassRegistration<TestVectorTransferFullPartialSplitPatterns>();
PassRegistration<TestVectorTransferOpt>();
PassRegistration<TestVectorTransferLoweringPatterns>();
PassRegistration<TestVectorMultiReductionLoweringPatterns>();
PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>();
PassRegistration<TestVectorReduceToContractPatternsPatterns>();
PassRegistration<TestVectorTransferDropUnitDimsPatterns>();
PassRegistration<TestFlattenVectorTransferPatterns>();
PassRegistration<TestVectorScanLowering>();
PassRegistration<TestVectorDistribution>();
}
} // namespace test
} // namespace mlir