[mlir][vector] Address linearization comments (post commit) (#138075)
This PR adds some documentation to address comments in https://github.com/llvm/llvm-project/pull/136581 This PR adds a test for linearization across scf.for. This new test might be considered redundant by more experienced MLIRers, but might help newer users understand how to linearize scf/cf/func operations easily The documentation added in this PR also tightens our definition of linearization, to now exclude unrolling (which creates multiple ops from 1 op). We hadn't really specified what linearization meant before.
This commit is contained in:
@@ -407,13 +407,22 @@ void populateVectorTransposeNarrowTypeRewritePatterns(
|
||||
RewritePatternSet &patterns, PatternBenefit benefit = 1);
|
||||
|
||||
/// Initialize `typeConverter` and `conversionTarget` for vector linearization.
|
||||
/// This registers (1) which operations are legal and hence should not be
|
||||
/// linearized, (2) what converted types are (rank-1 vectors) and how to
|
||||
///
|
||||
/// Definition: here 'linearization' means converting a single operation with
|
||||
/// 1+ vector operand/result of rank>1, into a new single operation whose
|
||||
/// vector operands and results are all of rank<=1.
|
||||
///
|
||||
/// This function registers (1) which operations are legal, and hence should not
|
||||
/// be linearized, (2) what the converted types are (rank-1 vectors) and how to
|
||||
/// materialze the conversion (with shape_cast)
|
||||
///
|
||||
/// Note: the set of legal operations can be extended by a user if for example
|
||||
/// certain rank>1 vectors are considered valid, but adding additional
|
||||
/// certain rank>1 vectors are considered valid, by adding additional
|
||||
/// dynamically legal ops to `conversionTarget`.
|
||||
///
|
||||
/// Further note: the choice to use a dialect conversion design for
|
||||
/// linearization is to make it easy to reuse generic structural type
|
||||
/// conversions for linearizing scf/cf/func operations
|
||||
void populateForVectorLinearize(TypeConverter &typeConverter,
|
||||
ConversionTarget &conversionTarget);
|
||||
|
||||
|
||||
@@ -99,7 +99,7 @@ public:
|
||||
// PR47938 tracks this issue, but it seems hard to fix. Instead, we need
|
||||
// to clone the op.
|
||||
//
|
||||
// 2. We need to resue the original region instead of cloning it, otherwise
|
||||
// 2. We need to reuse the original region instead of cloning it, otherwise
|
||||
// the dialect conversion framework thinks that we just inserted all the
|
||||
// cloned child ops. But what we want is to "take" the child regions and let
|
||||
// the dialect conversion framework continue recursively into ops inside
|
||||
|
||||
@@ -626,45 +626,49 @@ struct LinearizeVectorCreateMask final
|
||||
|
||||
} // namespace
|
||||
|
||||
/// Return true if the operation `op` does not support scalable vectors and
|
||||
/// has at least 1 scalable vector result. These ops should all eventually
|
||||
/// support scalable vectors, and this function should be removed.
|
||||
static bool isNotLinearizableBecauseScalable(Operation *op) {
|
||||
|
||||
bool unsupported =
|
||||
isa<vector::ExtractStridedSliceOp, vector::InsertStridedSliceOp,
|
||||
vector::ExtractOp, vector::InsertOp>(op);
|
||||
if (!unsupported)
|
||||
return false;
|
||||
|
||||
// Check if any of the results is a scalable vector type.
|
||||
auto types = op->getResultTypes();
|
||||
bool containsScalableResult =
|
||||
std::any_of(types.begin(), types.end(), [](Type type) {
|
||||
auto vecType = dyn_cast<VectorType>(type);
|
||||
return vecType && vecType.isScalable();
|
||||
});
|
||||
|
||||
return containsScalableResult;
|
||||
}
|
||||
|
||||
static bool isNotLinearizable(Operation *op) {
|
||||
/// This method defines the set of operations that are linearizable, and hence
|
||||
/// that are considered illegal for the conversion target.
|
||||
static bool isLinearizable(Operation *op) {
|
||||
|
||||
// Only ops that are in the vector dialect, are ConstantLike, or
|
||||
// are Vectorizable might be linearized currently.
|
||||
StringLiteral vectorDialect = vector::VectorDialect::getDialectNamespace();
|
||||
StringRef opDialect = op->getDialect()->getNamespace();
|
||||
bool unsupported = (opDialect != vectorDialect) &&
|
||||
!op->hasTrait<OpTrait::ConstantLike>() &&
|
||||
!op->hasTrait<OpTrait::Vectorizable>();
|
||||
if (unsupported)
|
||||
return true;
|
||||
bool supported = (opDialect == vectorDialect) ||
|
||||
op->hasTrait<OpTrait::ConstantLike>() ||
|
||||
op->hasTrait<OpTrait::Vectorizable>();
|
||||
if (!supported)
|
||||
return false;
|
||||
|
||||
// Some ops currently don't support scalable vectors.
|
||||
if (isNotLinearizableBecauseScalable(op))
|
||||
return true;
|
||||
|
||||
return false;
|
||||
return TypeSwitch<Operation *, bool>(op)
|
||||
// As type legalization is done with vector.shape_cast, shape_cast
|
||||
// itself cannot be linearized (will create new shape_casts to linearize
|
||||
// ad infinitum).
|
||||
.Case<vector::ShapeCastOp>([&](auto) { return false; })
|
||||
// The operations
|
||||
// - vector.extract_strided_slice
|
||||
// - vector.extract
|
||||
// - vector.insert_strided_slice
|
||||
// - vector.insert
|
||||
// are linearized to a rank-1 vector.shuffle by the current patterns.
|
||||
// vector.shuffle only supports fixed size vectors, so it is impossible to
|
||||
// use this approach to linearize these ops if they operate on scalable
|
||||
// vectors.
|
||||
.Case<vector::ExtractStridedSliceOp>(
|
||||
[&](vector::ExtractStridedSliceOp extractOp) {
|
||||
return !extractOp.getType().isScalable();
|
||||
})
|
||||
.Case<vector::InsertStridedSliceOp>(
|
||||
[&](vector::InsertStridedSliceOp insertOp) {
|
||||
return !insertOp.getType().isScalable();
|
||||
})
|
||||
.Case<vector::InsertOp>([&](vector::InsertOp insertOp) {
|
||||
return !insertOp.getType().isScalable();
|
||||
})
|
||||
.Case<vector::ExtractOp>([&](vector::ExtractOp extractOp) {
|
||||
return !extractOp.getSourceVectorType().isScalable();
|
||||
})
|
||||
.Default([&](auto) { return true; });
|
||||
}
|
||||
|
||||
void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter,
|
||||
@@ -698,7 +702,7 @@ void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter,
|
||||
|
||||
target.markUnknownOpDynamicallyLegal(
|
||||
[=](Operation *op) -> std::optional<bool> {
|
||||
if (isNotLinearizable(op))
|
||||
if (!isLinearizable(op))
|
||||
return true;
|
||||
// This will return true if, for all operand and result types `t`,
|
||||
// convertType(t) = t. This is true if there are no rank>=2 vectors.
|
||||
|
||||
@@ -392,6 +392,28 @@ func.func @test_vector_bitcast(%arg0: vector<[4]x2xf32>) -> vector<[4]x4xf16> {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: test_linearize_across_for
|
||||
func.func @test_linearize_across_for(%arg0 : vector<4xi8>) -> vector<4xi8> {
|
||||
%0 = vector.shape_cast %arg0 : vector<4xi8> to vector<2x2xi8>
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
%c4 = arith.constant 4 : index
|
||||
|
||||
// CHECK: scf.for {{.*}} -> (vector<4xi8>)
|
||||
%1 = scf.for %i = %c0 to %c4 step %c1 iter_args(%arg1 = %0) -> (vector<2x2xi8>) {
|
||||
|
||||
// CHECK: arith.addi {{.*}} : vector<4xi8>
|
||||
%2 = arith.addi %arg1, %0 : vector<2x2xi8>
|
||||
|
||||
// CHECK: scf.yield {{.*}} : vector<4xi8>
|
||||
scf.yield %2 : vector<2x2xi8>
|
||||
}
|
||||
%3 = vector.shape_cast %1 : vector<2x2xi8> to vector<4xi8>
|
||||
return %3 : vector<4xi8>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: linearize_vector_splat
|
||||
// CHECK-SAME: (%[[ARG:.*]]: i32) -> vector<4x2xi32>
|
||||
func.func @linearize_vector_splat(%arg0: i32) -> vector<4x2xi32> {
|
||||
@@ -414,6 +436,7 @@ func.func @linearize_scalable_vector_splat(%arg0: i32) -> vector<4x[2]xi32> {
|
||||
// CHECK: return %[[CAST]] : vector<4x[2]xi32>
|
||||
%0 = vector.splat %arg0 : vector<4x[2]xi32>
|
||||
return %0 : vector<4x[2]xi32>
|
||||
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
|
||||
@@ -836,9 +837,6 @@ struct TestVectorEmulateMaskedLoadStore final
|
||||
}
|
||||
};
|
||||
|
||||
// TODO: move this code into the user project.
|
||||
namespace vendor {
|
||||
|
||||
/// Get the set of operand/result types to check for sufficiently
|
||||
/// small inner-most dimension size.
|
||||
static SmallVector<std::pair<Type, unsigned>>
|
||||
@@ -960,8 +958,6 @@ struct TestVectorBitWidthLinearize final
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace vendor
|
||||
|
||||
struct TestVectorLinearize final
|
||||
: public PassWrapper<TestVectorLinearize, OperationPass<>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize)
|
||||
@@ -987,6 +983,8 @@ struct TestVectorLinearize final
|
||||
vector::populateVectorLinearizeBasePatterns(converter, target, patterns);
|
||||
vector::populateVectorLinearizeShuffleLikeOpsPatterns(converter, target,
|
||||
patterns);
|
||||
mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
|
||||
converter, patterns, target);
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))))
|
||||
@@ -1067,7 +1065,7 @@ void registerTestVectorLowerings() {
|
||||
|
||||
PassRegistration<TestVectorLinearize>();
|
||||
|
||||
PassRegistration<vendor::TestVectorBitWidthLinearize>();
|
||||
PassRegistration<TestVectorBitWidthLinearize>();
|
||||
|
||||
PassRegistration<TestEliminateVectorMasks>();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user