Files
clang-p2996/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
Tres Popp 5550c82189 [mlir] Move casting calls from methods to function calls
The MLIR classes Type/Attribute/Operation/Op/Value support
cast/dyn_cast/isa/dyn_cast_or_null functionality through llvm's doCast
functionality in addition to defining methods with the same name.
This change begins the migration of uses of the method to the
corresponding function call as has been decided as more consistent.

Note that there still exist classes that only define methods directly,
such as AffineExpr, and this does not include work currently to support
a functional cast/isa call.

Caveats include:
- This clang-tidy script probably has more problems.
- This only touches C++ code, so nothing that is being generated.

Context:
- https://mlir.llvm.org/deprecation/ at "Use the free function variants
  for dyn_cast/cast/isa/…"
- Original discussion at https://discourse.llvm.org/t/preferred-casting-style-going-forward/68443

Implementation:
This first patch was created with the following steps. The intention is
to only do automated changes at first, so I waste less time if it's
reverted, and so the first mass change is more clear as an example to
other teams that will need to follow similar steps.

Steps are described per line, as comments are removed by git:
0. Retrieve the change from the following to build clang-tidy with an
   additional check:
   https://github.com/llvm/llvm-project/compare/main...tpopp:llvm-project:tidy-cast-check
1. Build clang-tidy
2. Run clang-tidy over your entire codebase while disabling all checks
   and enabling the one relevant one. Run on all header files also.
3. Delete .inc files that were also modified, so the next build rebuilds
   them to a pure state.
4. Some changes have been deleted for the following reasons:
   - Some files had a variable also named cast
   - Some files had not included a header file that defines the cast
     functions
   - Some files are definitions of the classes that have the casting
     methods, so the code still refers to the method instead of the
     function without adding a prefix or removing the method declaration
     at the same time.

```
ninja -C $BUILD_DIR clang-tidy

run-clang-tidy -clang-tidy-binary=$BUILD_DIR/bin/clang-tidy -checks='-*,misc-cast-functions'\
               -header-filter=mlir/ mlir/* -fix

rm -rf $BUILD_DIR/tools/mlir/**/*.inc

git restore mlir/lib/IR mlir/lib/Dialect/DLTI/DLTI.cpp\
            mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp\
            mlir/lib/**/IR/\
            mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp\
            mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp\
            mlir/test/lib/Dialect/Test/TestTypes.cpp\
            mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp\
            mlir/test/lib/Dialect/Test/TestAttributes.cpp\
            mlir/unittests/TableGen/EnumsGenTest.cpp\
            mlir/test/python/lib/PythonTestCAPI.cpp\
            mlir/include/mlir/IR/
```

Differential Revision: https://reviews.llvm.org/D150123
2023-05-12 11:21:25 +02:00

719 lines
28 KiB
C++

//===- TestVectorTransforms.cpp - Test Vector transforms and lowerings ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include <optional>
#include <type_traits>
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
using namespace mlir::linalg;
using namespace mlir::vector;
namespace {
struct TestVectorToVectorLowering
: public PassWrapper<TestVectorToVectorLowering,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorToVectorLowering)
TestVectorToVectorLowering() = default;
TestVectorToVectorLowering(const TestVectorToVectorLowering &pass)
: PassWrapper(pass) {}
StringRef getArgument() const final {
return "test-vector-to-vector-lowering";
}
StringRef getDescription() const final {
return "Test lowering patterns between ops in the vector dialect";
}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<affine::AffineDialect>();
}
Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"),
llvm::cl::init(false)};
void runOnOperation() override {
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
if (unroll) {
populateVectorUnrollPatterns(
patterns,
UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint(
filter));
}
populateVectorToVectorCanonicalizationPatterns(patterns);
populateBubbleVectorBitCastOpPatterns(patterns);
populateCastAwayVectorLeadingOneDimPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
private:
// Return the target shape based on op type.
static std::optional<SmallVector<int64_t>> getShape(Operation *op) {
if (isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp>(op))
return SmallVector<int64_t>(2, 2);
if (isa<vector::ContractionOp>(op))
return SmallVector<int64_t>(3, 2);
// For transfer ops, just propagate the shape coming from
// InsertStridedSlices/ExtractStridedSlices.
if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) {
VectorType dstVec;
for (Operation *users : readOp->getUsers()) {
auto extract = dyn_cast<ExtractStridedSliceOp>(users);
if (!extract)
return std::nullopt;
auto vecType = cast<VectorType>(extract.getResult().getType());
if (dstVec && dstVec != vecType)
return std::nullopt;
dstVec = vecType;
}
return SmallVector<int64_t>(dstVec.getShape().begin(),
dstVec.getShape().end());
}
if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
auto insert = writeOp.getVector().getDefiningOp<InsertStridedSliceOp>();
if (!insert)
return std::nullopt;
ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape();
return SmallVector<int64_t>(shape.begin(), shape.end());
}
return std::nullopt;
}
static LogicalResult filter(Operation *op) {
return success(isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp,
ContractionOp, TransferReadOp, TransferWriteOp>(op));
}
};
struct TestVectorContractionPrepareForMMTLowering
: public PassWrapper<TestVectorContractionPrepareForMMTLowering,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
TestVectorContractionPrepareForMMTLowering)
StringRef getArgument() const final {
return "test-vector-contraction-prepare-for-mmt-lowering";
}
StringRef getDescription() const final {
return "Test vector.contraction matmul canonicalization for MMT lowering.";
}
TestVectorContractionPrepareForMMTLowering() = default;
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<affine::AffineDialect, arith::ArithDialect,
vector::VectorDialect>();
}
void runOnOperation() override {
MLIRContext *ctx = &getContext();
RewritePatternSet patterns(ctx);
vector::populateVectorContractCanonicalizeMatmulToMMT(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
struct TestVectorUnrollingPatterns
: public PassWrapper<TestVectorUnrollingPatterns,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorUnrollingPatterns)
StringRef getArgument() const final {
return "test-vector-unrolling-patterns";
}
StringRef getDescription() const final {
return "Test lowering patterns to unroll contract ops in the vector "
"dialect";
}
TestVectorUnrollingPatterns() = default;
TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass)
: PassWrapper(pass) {}
void runOnOperation() override {
MLIRContext *ctx = &getContext();
RewritePatternSet patterns(ctx);
populateVectorUnrollPatterns(
patterns, UnrollVectorOptions()
.setNativeShape(ArrayRef<int64_t>{2, 2})
.setFilterConstraint([](Operation *op) {
return success(isa<arith::AddFOp, vector::FMAOp,
vector::MultiDimReductionOp>(op));
}));
populateVectorUnrollPatterns(
patterns, UnrollVectorOptions()
.setNativeShape(ArrayRef<int64_t>{2})
.setFilterConstraint([](Operation *op) {
return success(isa<vector::ReductionOp>(op));
}));
populateVectorUnrollPatterns(
patterns, UnrollVectorOptions()
.setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2})
.setFilterConstraint([](Operation *op) {
return success(isa<vector::TransposeOp>(op));
}));
if (unrollBasedOnType) {
UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
[](Operation *op) -> std::optional<SmallVector<int64_t>> {
vector::ContractionOp contractOp = cast<vector::ContractionOp>(op);
SmallVector<int64_t> nativeShape(contractOp.getIteratorTypes().size(),
4);
Type lhsType = contractOp.getLhsType().getElementType();
nativeShape[nativeShape.size() - 1] = lhsType.isF16() ? 4 : 2;
return nativeShape;
};
UnrollVectorOptions opts;
opts.setNativeShapeFn(nativeShapeFn)
.setFilterConstraint(
[](Operation *op) { return success(isa<ContractionOp>(op)); });
if (!unrollOrder.empty()) {
opts.setUnrollTraversalOrderFn(
[this](Operation *op) -> std::optional<SmallVector<int64_t>> {
vector::ContractionOp contractOp =
cast<vector::ContractionOp>(op);
if (contractOp.getIteratorTypes().size() == unrollOrder.size())
return SmallVector<int64_t>(unrollOrder.begin(),
unrollOrder.end());
return std::nullopt;
});
}
populateVectorUnrollPatterns(patterns, opts);
} else {
auto nativeShapeFn =
[](Operation *op) -> std::optional<SmallVector<int64_t>> {
auto contractOp = dyn_cast<ContractionOp>(op);
if (!contractOp)
return std::nullopt;
return SmallVector<int64_t>(contractOp.getIteratorTypes().size(), 2);
};
populateVectorUnrollPatterns(patterns,
UnrollVectorOptions()
.setNativeShapeFn(nativeShapeFn)
.setFilterConstraint([](Operation *op) {
return success(isa<ContractionOp>(op));
}));
}
populateVectorToVectorCanonicalizationPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
ListOption<int64_t> unrollOrder{*this, "unroll-order",
llvm::cl::desc("set the unroll order")};
Option<bool> unrollBasedOnType{
*this, "unroll-based-on-type",
llvm::cl::desc("Set the unroll factor based on type of the operation"),
llvm::cl::init(false)};
};
struct TestVectorTransferUnrollingPatterns
: public PassWrapper<TestVectorTransferUnrollingPatterns,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
TestVectorTransferUnrollingPatterns)
TestVectorTransferUnrollingPatterns() = default;
TestVectorTransferUnrollingPatterns(
const TestVectorTransferUnrollingPatterns &pass)
: PassWrapper(pass) {}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<affine::AffineDialect>();
}
StringRef getArgument() const final {
return "test-vector-transfer-unrolling-patterns";
}
StringRef getDescription() const final {
return "Test lowering patterns to unroll transfer ops in the vector "
"dialect";
}
void runOnOperation() override {
MLIRContext *ctx = &getContext();
RewritePatternSet patterns(ctx);
UnrollVectorOptions opts;
opts.setNativeShape(ArrayRef<int64_t>{2, 2})
.setFilterConstraint([](Operation *op) {
return success(isa<vector::TransferReadOp, vector::TransferWriteOp,
vector::GatherOp>(op));
});
if (reverseUnrollOrder.getValue()) {
opts.setUnrollTraversalOrderFn(
[](Operation *op) -> std::optional<SmallVector<int64_t>> {
int64_t numLoops = 0;
if (auto readOp = dyn_cast<vector::TransferReadOp>(op))
numLoops = readOp.getVectorType().getRank();
else if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op))
numLoops = writeOp.getVectorType().getRank();
else if (auto gatherOp = dyn_cast<vector::GatherOp>(op))
numLoops = gatherOp.getVectorType().getRank();
else
return std::nullopt;
auto order = llvm::reverse(llvm::seq<int64_t>(0, numLoops));
return llvm::to_vector(order);
});
}
populateVectorUnrollPatterns(patterns, opts);
populateVectorToVectorCanonicalizationPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
Option<bool> reverseUnrollOrder{
*this, "reverse-unroll-order",
llvm::cl::desc(
"reverse the order of unrolling of vector transfer operations"),
llvm::cl::init(false)};
};
struct TestScalarVectorTransferLoweringPatterns
: public PassWrapper<TestScalarVectorTransferLoweringPatterns,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
TestScalarVectorTransferLoweringPatterns)
StringRef getArgument() const final {
return "test-scalar-vector-transfer-lowering";
}
StringRef getDescription() const final {
return "Test lowering of scalar vector transfers to memref loads/stores.";
}
TestScalarVectorTransferLoweringPatterns() = default;
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<affine::AffineDialect, memref::MemRefDialect,
tensor::TensorDialect, vector::VectorDialect>();
}
void runOnOperation() override {
MLIRContext *ctx = &getContext();
RewritePatternSet patterns(ctx);
vector::populateScalarVectorTransferLoweringPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
struct TestVectorTransferOpt
: public PassWrapper<TestVectorTransferOpt, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransferOpt)
StringRef getArgument() const final { return "test-vector-transferop-opt"; }
StringRef getDescription() const final {
return "Test optimization transformations for transfer ops";
}
void runOnOperation() override {
IRRewriter rewriter(&getContext());
transferOpflowOpt(rewriter, getOperation());
}
};
struct TestVectorTransferCollapseInnerMostContiguousDims
: public PassWrapper<TestVectorTransferCollapseInnerMostContiguousDims,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
TestVectorTransferCollapseInnerMostContiguousDims)
TestVectorTransferCollapseInnerMostContiguousDims() = default;
TestVectorTransferCollapseInnerMostContiguousDims(
const TestVectorTransferCollapseInnerMostContiguousDims &pass) = default;
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<memref::MemRefDialect, affine::AffineDialect>();
}
StringRef getArgument() const final {
return "test-vector-transfer-collapse-inner-most-dims";
}
StringRef getDescription() const final {
return "Test lowering patterns that reducedes the rank of the vector "
"transfer memory and vector operands.";
}
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateVectorTransferCollapseInnerMostContiguousDimsPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
struct TestVectorReduceToContractPatternsPatterns
: public PassWrapper<TestVectorReduceToContractPatternsPatterns,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
TestVectorReduceToContractPatternsPatterns)
StringRef getArgument() const final {
return "test-vector-reduction-to-contract-patterns";
}
StringRef getDescription() const final {
return "Test patterns to convert multireduce op to contract and combine "
"broadcast/transpose to contract";
}
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateVectorReductionToContractPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
struct TestFlattenVectorTransferPatterns
: public PassWrapper<TestFlattenVectorTransferPatterns,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
TestFlattenVectorTransferPatterns)
StringRef getArgument() const final {
return "test-vector-transfer-flatten-patterns";
}
StringRef getDescription() const final {
return "Test patterns to rewrite contiguous row-major N-dimensional "
"vector.transfer_{read,write} ops into 1D transfers";
}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<memref::MemRefDialect>();
}
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateFlattenVectorTransferPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
struct TestVectorScanLowering
: public PassWrapper<TestVectorScanLowering, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorScanLowering)
StringRef getArgument() const final { return "test-vector-scan-lowering"; }
StringRef getDescription() const final {
return "Test lowering patterns that lower the scan op in the vector "
"dialect";
}
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateVectorScanLoweringPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
/// Allocate shared memory for a single warp to test lowering of
/// WarpExecuteOnLane0Op.
static Value allocateGlobalSharedMemory(Location loc, OpBuilder &builder,
WarpExecuteOnLane0Op warpOp,
Type type) {
static constexpr int64_t kSharedMemorySpace = 3;
// Compute type of shared memory buffer.
MemRefType memrefType;
if (auto vectorType = dyn_cast<VectorType>(type)) {
memrefType =
MemRefType::get(vectorType.getShape(), vectorType.getElementType(), {},
kSharedMemorySpace);
} else {
memrefType = MemRefType::get({1}, type, {}, kSharedMemorySpace);
}
// Get symbol table holding all shared memory globals.
ModuleOp moduleOp = warpOp->getParentOfType<ModuleOp>();
SymbolTable symbolTable(moduleOp);
// Create a pretty name.
SmallString<64> buf;
llvm::raw_svector_ostream os(buf);
interleave(memrefType.getShape(), os, "x");
os << "x" << memrefType.getElementType();
std::string symbolName = (Twine("__shared_") + os.str()).str();
auto ip = builder.saveInsertionPoint();
builder.setInsertionPoint(moduleOp);
auto global = builder.create<memref::GlobalOp>(
loc,
/*sym_name=*/symbolName,
/*sym_visibility=*/builder.getStringAttr("private"),
/*type=*/memrefType,
/*initial_value=*/Attribute(),
/*constant=*/false,
/*alignment=*/IntegerAttr());
symbolTable.insert(global);
// The symbol table inserts at the end of the module, but globals are a bit
// nicer if they are at the beginning.
global->moveBefore(&moduleOp.front());
builder.restoreInsertionPoint(ip);
return builder.create<memref::GetGlobalOp>(loc, memrefType, symbolName);
}
static Value warpReduction(Location loc, OpBuilder &builder, Value input,
CombiningKind kind, uint32_t size) {
// First reduce on a single thread to get per lane reduction value.
Value laneVal = builder.create<vector::ReductionOp>(loc, kind, input);
// Parallel reduction using butterfly shuffles.
for (uint64_t i = 1; i < size; i <<= 1) {
Value shuffled = builder
.create<gpu::ShuffleOp>(loc, laneVal, i,
/*width=*/size,
/*mode=*/gpu::ShuffleMode::XOR)
.getShuffleResult();
laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled);
}
return laneVal;
}
struct TestVectorDistribution
: public PassWrapper<TestVectorDistribution, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorDistribution)
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<scf::SCFDialect, memref::MemRefDialect, gpu::GPUDialect,
affine::AffineDialect>();
}
StringRef getArgument() const final { return "test-vector-warp-distribute"; }
StringRef getDescription() const final {
return "Test vector warp distribute transformation and lowering patterns";
}
TestVectorDistribution() = default;
TestVectorDistribution(const TestVectorDistribution &pass)
: PassWrapper(pass) {}
Option<bool> warpOpToSCF{
*this, "rewrite-warp-ops-to-scf-if",
llvm::cl::desc("Lower vector.warp_execute_on_lane0 to scf.if op"),
llvm::cl::init(false)};
Option<bool> distributeTransferWriteOps{
*this, "distribute-transfer-write",
llvm::cl::desc("Test distribution of transfer write"),
llvm::cl::init(false)};
Option<bool> hoistUniform{*this, "hoist-uniform",
llvm::cl::desc("Test hoist uniform"),
llvm::cl::init(false)};
Option<bool> propagateDistribution{
*this, "propagate-distribution",
llvm::cl::desc("Test distribution propgation"), llvm::cl::init(false)};
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
getOperation().walk([&](Operation *op) {
if (auto warpOp = dyn_cast<WarpExecuteOnLane0Op>(op)) {
if (hoistUniform) {
moveScalarUniformCode(warpOp);
}
WalkResult::interrupt();
}
});
MLIRContext *ctx = &getContext();
auto distributionFn = [](Value val) {
// Create a map (d0, d1) -> (d1) to distribute along the inner
// dimension. Once we support n-d distribution we can add more
// complex cases.
VectorType vecType = dyn_cast<VectorType>(val.getType());
int64_t vecRank = vecType ? vecType.getRank() : 0;
OpBuilder builder(val.getContext());
if (vecRank == 0)
return AffineMap::get(val.getContext());
return AffineMap::get(vecRank, 0, builder.getAffineDimExpr(vecRank - 1));
};
auto shuffleFn = [](Location loc, OpBuilder &builder, Value val,
Value srcIdx, int64_t warpSz) {
assert((val.getType().isF32() || val.getType().isInteger(32)) &&
"unsupported shuffle type");
Type i32Type = builder.getIntegerType(32);
Value srcIdxI32 =
builder.create<arith::IndexCastOp>(loc, i32Type, srcIdx);
Value warpSzI32 = builder.create<arith::ConstantOp>(
loc, builder.getIntegerAttr(i32Type, warpSz));
Value result = builder
.create<gpu::ShuffleOp>(loc, val, srcIdxI32, warpSzI32,
gpu::ShuffleMode::IDX)
.getResult(0);
return result;
};
if (distributeTransferWriteOps) {
RewritePatternSet patterns(ctx);
populateDistributeTransferWriteOpPatterns(patterns, distributionFn);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
if (propagateDistribution) {
RewritePatternSet patterns(ctx);
vector::populatePropagateWarpVectorDistributionPatterns(
patterns, distributionFn, shuffleFn);
vector::populateDistributeReduction(patterns, warpReduction);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
WarpExecuteOnLane0LoweringOptions options;
options.warpAllocationFn = allocateGlobalSharedMemory;
options.warpSyncronizationFn = [](Location loc, OpBuilder &builder,
WarpExecuteOnLane0Op warpOp) {
builder.create<gpu::BarrierOp>(loc);
};
// Test on one pattern in isolation.
if (warpOpToSCF) {
populateWarpExecuteOnLane0OpToScfForPattern(patterns, options);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
return;
}
}
};
struct TestVectorExtractStridedSliceLowering
: public PassWrapper<TestVectorExtractStridedSliceLowering,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
TestVectorExtractStridedSliceLowering)
StringRef getArgument() const final {
return "test-vector-extract-strided-slice-lowering";
}
StringRef getDescription() const final {
return "Test lowering patterns that converts vector.extract_strided_slice "
"into a chain of vector.extract and vector.insert ops";
}
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateVectorExtractStridedSliceToExtractInsertChainPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
struct TestVectorBreakDownBitCast
: public PassWrapper<TestVectorBreakDownBitCast,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorBreakDownBitCast)
StringRef getArgument() const final {
return "test-vector-break-down-bitcast";
}
StringRef getDescription() const final {
return "Test pattern that breaks down vector.bitcast ops ";
}
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateBreakDownVectorBitCastOpPatterns(patterns, [](BitCastOp op) {
return op.getSourceVectorType().getShape().back() > 4;
});
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
struct TestCreateVectorBroadcast
: public PassWrapper<TestCreateVectorBroadcast,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestCreateVectorBroadcast)
StringRef getArgument() const final { return "test-create-vector-broadcast"; }
StringRef getDescription() const final {
return "Test optimization transformations for transfer ops";
}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<vector::VectorDialect>();
}
void runOnOperation() override {
getOperation()->walk([](Operation *op) {
if (op->getName().getStringRef() != "test_create_broadcast")
return;
auto targetShape =
cast<VectorType>(op->getResult(0).getType()).getShape();
auto arrayAttr =
cast<DenseI64ArrayAttr>(op->getAttr("broadcast_dims")).asArrayRef();
llvm::SetVector<int64_t> broadcastedDims;
broadcastedDims.insert(arrayAttr.begin(), arrayAttr.end());
OpBuilder b(op);
Value bcast = vector::BroadcastOp::createOrFoldBroadcastOp(
b, op->getOperand(0), targetShape, broadcastedDims);
op->getResult(0).replaceAllUsesWith(bcast);
op->erase();
});
}
};
struct TestVectorGatherLowering
: public PassWrapper<TestVectorGatherLowering,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorGatherLowering)
StringRef getArgument() const final { return "test-vector-gather-lowering"; }
StringRef getDescription() const final {
return "Test patterns that lower the gather op in the vector conditional "
"loads";
}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<arith::ArithDialect, func::FuncDialect,
memref::MemRefDialect, scf::SCFDialect,
tensor::TensorDialect, vector::VectorDialect>();
}
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateVectorGatherLoweringPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
} // namespace
namespace mlir {
namespace test {
void registerTestVectorLowerings() {
PassRegistration<TestVectorToVectorLowering>();
PassRegistration<TestVectorContractionPrepareForMMTLowering>();
PassRegistration<TestVectorUnrollingPatterns>();
PassRegistration<TestVectorTransferUnrollingPatterns>();
PassRegistration<TestScalarVectorTransferLoweringPatterns>();
PassRegistration<TestVectorTransferOpt>();
PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>();
PassRegistration<TestVectorReduceToContractPatternsPatterns>();
PassRegistration<TestFlattenVectorTransferPatterns>();
PassRegistration<TestVectorScanLowering>();
PassRegistration<TestVectorDistribution>();
PassRegistration<TestVectorExtractStridedSliceLowering>();
PassRegistration<TestVectorBreakDownBitCast>();
PassRegistration<TestCreateVectorBroadcast>();
PassRegistration<TestVectorGatherLowering>();
}
} // namespace test
} // namespace mlir