Summary: This revision restructures the calling of vector transforms to make it more flexible to ask for lowering through LLVM matrix intrinsics. This also makes sure we bail out in degenerate cases (i.e. 1) in which LLVM complains about not being able to scalarize. Differential Revision: https://reviews.llvm.org/D76266
81 lines
2.8 KiB
C++
81 lines
2.8 KiB
C++
//===- TestVectorToVectorConversion.cpp - Test VectorTransfers lowering ---===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include <type_traits>
|
|
|
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
|
#include "mlir/Dialect/Vector/VectorOps.h"
|
|
#include "mlir/Dialect/Vector/VectorTransforms.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::vector;
|
|
namespace {
|
|
|
|
#include "TestVectorTransformPatterns.h.inc"
|
|
|
|
struct TestVectorToVectorConversion
|
|
: public FunctionPass<TestVectorToVectorConversion> {
|
|
void runOnFunction() override {
|
|
OwningRewritePatternList patterns;
|
|
auto *context = &getContext();
|
|
populateWithGenerated(context, &patterns);
|
|
populateVectorToVectorCanonicalizationPatterns(patterns, context);
|
|
populateVectorToVectorTransformationPatterns(patterns, context);
|
|
applyPatternsGreedily(getFunction(), patterns);
|
|
}
|
|
};
|
|
|
|
struct TestVectorSlicesConversion
|
|
: public FunctionPass<TestVectorSlicesConversion> {
|
|
void runOnFunction() override {
|
|
OwningRewritePatternList patterns;
|
|
populateVectorSlicesLoweringPatterns(patterns, &getContext());
|
|
applyPatternsGreedily(getFunction(), patterns);
|
|
}
|
|
};
|
|
|
|
struct TestVectorContractionConversion
|
|
: public FunctionPass<TestVectorContractionConversion> {
|
|
TestVectorContractionConversion() = default;
|
|
TestVectorContractionConversion(const TestVectorContractionConversion &pass) {
|
|
}
|
|
|
|
Option<bool> lowerToLLVMMatrixIntrinsics{
|
|
*this, "vector-lower-matrix-intrinsics",
|
|
llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"),
|
|
llvm::cl::init(false)};
|
|
|
|
void runOnFunction() override {
|
|
OwningRewritePatternList patterns;
|
|
VectorTransformsOptions options{
|
|
/*lowerToLLVMMatrixIntrinsics=*/lowerToLLVMMatrixIntrinsics};
|
|
populateVectorContractLoweringPatterns(patterns, &getContext(), options);
|
|
applyPatternsGreedily(getFunction(), patterns);
|
|
}
|
|
};
|
|
|
|
} // end anonymous namespace
|
|
|
|
namespace mlir {
|
|
void registerTestVectorConversions() {
|
|
PassRegistration<TestVectorToVectorConversion> vectorToVectorPass(
|
|
"test-vector-to-vector-conversion",
|
|
"Test conversion patterns between ops in the vector dialect");
|
|
|
|
PassRegistration<TestVectorSlicesConversion> slicesPass(
|
|
"test-vector-slices-conversion",
|
|
"Test conversion patterns that lower slices ops in the vector dialect");
|
|
|
|
PassRegistration<TestVectorContractionConversion> contractionPass(
|
|
"test-vector-contraction-conversion",
|
|
"Test conversion patterns that lower contract ops in the vector dialect");
|
|
}
|
|
} // namespace mlir
|