[mlir][Vector] Enable create_mask for scalable vectors

The way vector.create_mask is currently lowered is
vector-length-dependent, and therefore incompatible with scalable vector
types. This patch adds an alternative lowering path for create_mask
operations that return a scalable vector mask.

Differential Revision: https://reviews.llvm.org/D118248
This commit is contained in:
Javier Setoain
2022-01-26 15:01:39 +00:00
parent 718aec209c
commit a75a46db89
13 changed files with 179 additions and 29 deletions

View File

@@ -63,9 +63,10 @@ void populateVectorToLLVMMatrixConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns);
/// Collect a set of patterns to convert from the Vector dialect to LLVM.
/// If `indexOptimizations` is set, assume indices fit in 32-bit.
void populateVectorToLLVMConversionPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns,
bool reassociateFPReductions = false);
bool reassociateFPReductions = false, bool indexOptimizations = false);
/// Create a pass to convert vector operations to the LLVMIR dialect.
std::unique_ptr<OperationPass<ModuleOp>> createConvertVectorToLLVMPass(

View File

@@ -80,6 +80,12 @@ public:
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
OpFoldResult ofr);
/// Create a cast from an index-like value (index or integer) to another
/// index-like value. If the value type and the target type are the same, it
/// returns the original value.
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc,
Type targetType, Value value);
/// Similar to the other overload, but converts multiple OpFoldResults into
/// Values.
SmallVector<Value>

View File

@@ -1752,6 +1752,14 @@ def LLVM_masked_compressstore
/// Create a call to vscale intrinsic.
def LLVM_vscale : LLVM_IntrOp<"vscale", [0], [], [], 1>;
/// Create a call to stepvector intrinsic.
def LLVM_StepVectorOp
: LLVM_IntrOp<"experimental.stepvector", [0], [], [NoSideEffect], 1> {
let arguments = (ins);
let results = (outs LLVM_Type:$res);
let assemblyFormat = "attr-dict `:` type($res)";
}
// Atomic operations.
//

View File

@@ -10,6 +10,7 @@
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -900,6 +901,40 @@ public:
}
};
/// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only).
/// Non-scalable versions of this operation are handled in Vector Transforms.
class VectorCreateMaskOpRewritePattern
: public OpRewritePattern<vector::CreateMaskOp> {
public:
explicit VectorCreateMaskOpRewritePattern(MLIRContext *context,
bool enableIndexOpt)
: OpRewritePattern<vector::CreateMaskOp>(context),
indexOptimizations(enableIndexOpt) {}
LogicalResult matchAndRewrite(vector::CreateMaskOp op,
PatternRewriter &rewriter) const override {
auto dstType = op.getType();
if (dstType.getRank() != 1 || !dstType.cast<VectorType>().isScalable())
return failure();
IntegerType idxType =
indexOptimizations ? rewriter.getI32Type() : rewriter.getI64Type();
auto loc = op->getLoc();
Value indices = rewriter.create<LLVM::StepVectorOp>(
loc, LLVM::getVectorType(idxType, dstType.getShape()[0],
/*isScalable=*/true));
auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType,
op.getOperand(0));
Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
indices, bounds);
rewriter.replaceOp(op, comp);
return success();
}
private:
const bool indexOptimizations;
};
class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
public:
using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern;
@@ -1157,13 +1192,15 @@ struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
} // namespace
/// Populate the given list with patterns that convert from Vector to LLVM.
void mlir::populateVectorToLLVMConversionPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns,
bool reassociateFPReductions) {
void mlir::populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns,
bool reassociateFPReductions,
bool indexOptimizations) {
MLIRContext *ctx = converter.getDialect()->getContext();
patterns.add<VectorFMAOpNDRewritePattern>(ctx);
populateVectorInsertExtractStridedSliceTransforms(patterns);
patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
patterns.add<VectorCreateMaskOpRewritePattern>(ctx, indexOptimizations);
patterns
.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
VectorExtractElementOpConversion, VectorExtractOpConversion,

View File

@@ -80,8 +80,8 @@ void LowerVectorToLLVMPass::runOnOperation() {
populateVectorMaskMaterializationPatterns(patterns, indexOptimizations);
populateVectorTransferLoweringPatterns(patterns);
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
populateVectorToLLVMConversionPatterns(converter, patterns,
reassociateFPReductions);
populateVectorToLLVMConversionPatterns(
converter, patterns, reassociateFPReductions, indexOptimizations);
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
// Architecture specific augmentations.

View File

@@ -59,6 +59,27 @@ Value mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
return b.create<arith::ConstantIndexOp>(loc, attr.getValue().getSExtValue());
}
Value mlir::getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc,
Type targetType, Value value) {
if (targetType == value.getType())
return value;
bool targetIsIndex = targetType.isIndex();
bool valueIsIndex = value.getType().isIndex();
if (targetIsIndex ^ valueIsIndex)
return b.create<arith::IndexCastOp>(loc, targetType, value);
auto targetIntegerType = targetType.dyn_cast<IntegerType>();
auto valueIntegerType = value.getType().dyn_cast<IntegerType>();
assert(targetIntegerType && valueIntegerType &&
"unexpected cast between types other than integers and index");
assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness());
if (targetIntegerType.getWidth() > valueIntegerType.getWidth())
return b.create<arith::ExtSIOp>(loc, targetIntegerType, value);
return b.create<arith::TruncIOp>(loc, targetIntegerType, value);
}
SmallVector<Value>
mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
ArrayRef<OpFoldResult> valueOrAttrVec) {

View File

@@ -4232,6 +4232,14 @@ LogicalResult ConstantMaskOp::verify() {
if (anyZeros && !allZeros)
return emitOpError("expected all mask dim sizes to be zeros, "
"as a result of conjunction with zero mask dim");
// Verify that if the mask type is scalable, dimensions should be zero because
// constant scalable masks can only be defined for the "none set" or "all set"
// cases, and there is no VLA way to define an "all set" case for
// `vector.constant_mask`. In the future, a convention could be established
// to decide if a specific dimension value could be considered as "all set".
if (resultType.isScalable() &&
mask_dim_sizes()[0].cast<IntegerAttr>().getInt() != 0)
return emitOpError("expected mask dim sizes for scalable masks to be 0");
return success();
}
@@ -4269,6 +4277,19 @@ public:
};
if (llvm::any_of(createMaskOp.operands(), isNotDefByConstant))
return failure();
// CreateMaskOp for scalable vectors can be folded only if all dimensions
// are negative or zero.
if (auto vType = createMaskOp.getType().dyn_cast<VectorType>()) {
if (vType.isScalable())
for (auto opDim : createMaskOp.getOperands()) {
APInt intVal;
if (matchPattern(opDim, m_ConstantInt(&intVal)) &&
intVal.isStrictlyPositive())
return failure();
}
}
// Gather constant mask dimension sizes.
SmallVector<int64_t, 4> maskDimSizes;
for (auto it : llvm::zip(createMaskOp.operands(),

View File

@@ -16,6 +16,8 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
@@ -602,6 +604,13 @@ public:
return success();
}
// Scalable constant masks can only be lowered for the "none set" case.
if (dstType.cast<VectorType>().isScalable()) {
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, DenseElementsAttr::get(dstType, false));
return success();
}
int64_t trueDim = std::min(dstType.getDimSize(0),
dimSizes[0].cast<IntegerAttr>().getInt());
@@ -2161,27 +2170,6 @@ struct BubbleUpBitCastForStridedSliceInsert
}
};
static Value createCastToIndexLike(PatternRewriter &rewriter, Location loc,
Type targetType, Value value) {
if (targetType == value.getType())
return value;
bool targetIsIndex = targetType.isIndex();
bool valueIsIndex = value.getType().isIndex();
if (targetIsIndex ^ valueIsIndex)
return rewriter.create<arith::IndexCastOp>(loc, targetType, value);
auto targetIntegerType = targetType.dyn_cast<IntegerType>();
auto valueIntegerType = value.getType().dyn_cast<IntegerType>();
assert(targetIntegerType && valueIntegerType &&
"unexpected cast between types other than integers and index");
assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness());
if (targetIntegerType.getWidth() > valueIntegerType.getWidth())
return rewriter.create<arith::ExtSIOp>(loc, targetIntegerType, value);
return rewriter.create<arith::TruncIOp>(loc, targetIntegerType, value);
}
// Helper that returns a vector comparison that constructs a mask:
// mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
//
@@ -2217,12 +2205,12 @@ static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op,
Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr);
// Add in an offset if requested.
if (off) {
Value o = createCastToIndexLike(rewriter, loc, idxType, *off);
Value o = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, *off);
Value ov = rewriter.create<vector::SplatOp>(loc, indices.getType(), o);
indices = rewriter.create<arith::AddIOp>(loc, ov, indices);
}
// Construct the vector comparison.
Value bound = createCastToIndexLike(rewriter, loc, idxType, b);
Value bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, b);
Value bounds =
rewriter.create<vector::SplatOp>(loc, indices.getType(), bound);
return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, indices,
@@ -2292,6 +2280,8 @@ public:
LogicalResult matchAndRewrite(vector::CreateMaskOp op,
PatternRewriter &rewriter) const override {
auto dstType = op.getType();
if (dstType.cast<VectorType>().isScalable())
return failure();
int64_t rank = dstType.getRank();
if (rank > 1)
return failure();

View File

@@ -24,6 +24,29 @@ func @genbool_var_1d(%arg0: index) -> vector<11xi1> {
return %0 : vector<11xi1>
}
// CMP32-LABEL: @genbool_var_1d_scalable(
// CMP32-SAME: %[[ARG:.*]]: index)
// CMP32: %[[T0:.*]] = llvm.intr.experimental.stepvector : vector<[11]xi32>
// CMP32: %[[T1:.*]] = arith.index_cast %[[ARG]] : index to i32
// CMP32: %[[T2:.*]] = llvm.insertelement %[[T1]], %{{.*}}[%{{.*}} : i32] : vector<[11]xi32>
// CMP32: %[[T3:.*]] = llvm.shufflevector %[[T2]], %{{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32] : vector<[11]xi32>, vector<[11]xi32>
// CMP32: %[[T4:.*]] = arith.cmpi slt, %[[T0]], %[[T3]] : vector<[11]xi32>
// CMP32: return %[[T4]] : vector<[11]xi1>
// CMP64-LABEL: @genbool_var_1d_scalable(
// CMP64-SAME: %[[ARG:.*]]: index)
// CMP64: %[[T0:.*]] = llvm.intr.experimental.stepvector : vector<[11]xi64>
// CMP64: %[[T1:.*]] = arith.index_cast %[[ARG]] : index to i64
// CMP64: %[[T2:.*]] = llvm.insertelement %[[T1]], %{{.*}}[%{{.*}} : i32] : vector<[11]xi64>
// CMP64: %[[T3:.*]] = llvm.shufflevector %[[T2]], %{{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32] : vector<[11]xi64>, vector<[11]xi64>
// CMP64: %[[T4:.*]] = arith.cmpi slt, %[[T0]], %[[T3]] : vector<[11]xi64>
// CMP64: return %[[T4]] : vector<[11]xi1>
func @genbool_var_1d_scalable(%arg0: index) -> vector<[11]xi1> {
%0 = vector.create_mask %arg0 : vector<[11]xi1>
return %0 : vector<[11]xi1>
}
// CMP32-LABEL: @transfer_read_1d
// CMP32: %[[MEM:.*]]: memref<?xf32>, %[[OFF:.*]]: index) -> vector<16xf32> {
// CMP32: %[[D:.*]] = memref.dim %[[MEM]], %{{.*}} : memref<?xf32>

View File

@@ -1459,6 +1459,16 @@ func @genbool_1d() -> vector<8xi1> {
// -----
func @genbool_1d_scalable() -> vector<[8]xi1> {
%0 = vector.constant_mask [0] : vector<[8]xi1>
return %0 : vector<[8]xi1>
}
// CHECK-LABEL: func @genbool_1d_scalable
// CHECK: %[[VAL_0:.*]] = arith.constant dense<false> : vector<[8]xi1>
// CHECK: return %[[VAL_0]] : vector<[8]xi1>
// -----
func @genbool_2d() -> vector<4x4xi1> {
%v = vector.constant_mask [2, 2] : vector<4x4xi1>
return %v: vector<4x4xi1>
@@ -1505,6 +1515,20 @@ func @create_mask_1d(%a : index) -> vector<4xi1> {
// CHECK: %[[result:.*]] = arith.cmpi slt, %[[indices]], %[[bounds]] : vector<4xi32>
// CHECK: return %[[result]] : vector<4xi1>
func @create_mask_1d_scalable(%a : index) -> vector<[4]xi1> {
%v = vector.create_mask %a : vector<[4]xi1>
return %v: vector<[4]xi1>
}
// CHECK-LABEL: func @create_mask_1d_scalable
// CHECK-SAME: %[[arg:.*]]: index
// CHECK: %[[indices:.*]] = llvm.intr.experimental.stepvector : vector<[4]xi32>
// CHECK: %[[arg_i32:.*]] = arith.index_cast %[[arg]] : index to i32
// CHECK: %[[boundsInsert:.*]] = llvm.insertelement %[[arg_i32]], {{.*}} : vector<[4]xi32>
// CHECK: %[[bounds:.*]] = llvm.shufflevector %[[boundsInsert]], {{.*}} : vector<[4]xi32>, vector<[4]xi32>
// CHECK: %[[result:.*]] = arith.cmpi slt, %[[indices]], %[[bounds]] : vector<[4]xi32>
// CHECK: return %[[result]] : vector<[4]xi1>
// -----
func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> {

View File

@@ -13,6 +13,16 @@ func @create_vector_mask_to_constant_mask() -> (vector<4x3xi1>) {
// -----
// CHECK-LABEL: create_scalable_vector_mask_to_constant_mask
func @create_scalable_vector_mask_to_constant_mask() -> (vector<[8]xi1>) {
%c-1 = arith.constant -1 : index
// CHECK: vector.constant_mask [0] : vector<[8]xi1>
%0 = vector.create_mask %c-1 : vector<[8]xi1>
return %0 : vector<[8]xi1>
}
// -----
// CHECK-LABEL: create_vector_mask_to_constant_mask_truncation
func @create_vector_mask_to_constant_mask_truncation() -> (vector<4x3xi1>) {
%c2 = arith.constant 2 : index

View File

@@ -944,6 +944,13 @@ func @constant_mask_with_zero_mask_dim_size() {
// -----
func @constant_mask_scalable_non_zero_dim_size() {
// expected-error@+1 {{expected mask dim sizes for scalable masks to be 0}}
%0 = vector.constant_mask [2] : vector<[8]xi1>
}
// -----
func @print_no_result(%arg0 : f32) -> i32 {
// expected-error@+1 {{cannot name an operation with no results}}
%0 = vector.print %arg0 : f32

View File

@@ -389,6 +389,8 @@ func @constant_vector_mask_0d() {
func @constant_vector_mask() {
// CHECK: vector.constant_mask [3, 2] : vector<4x3xi1>
%0 = vector.constant_mask [3, 2] : vector<4x3xi1>
// CHECK: vector.constant_mask [0] : vector<[4]xi1>
%1 = vector.constant_mask [0] : vector<[4]xi1>
return
}