[mlir][Vector] Enable create_mask for scalable vectors
The way vector.create_mask is currently lowered is vector-length-dependent, and therefore incompatible with scalable vector types. This patch adds an alternative lowering path for create_mask operations that return a scalable vector mask. Differential Revision: https://reviews.llvm.org/D118248
This commit is contained in:
@@ -63,9 +63,10 @@ void populateVectorToLLVMMatrixConversionPatterns(LLVMTypeConverter &converter,
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
/// Collect a set of patterns to convert from the Vector dialect to LLVM.
|
||||
/// If `indexOptimizations` is set, assume indices fit in 32-bit.
|
||||
void populateVectorToLLVMConversionPatterns(
|
||||
LLVMTypeConverter &converter, RewritePatternSet &patterns,
|
||||
bool reassociateFPReductions = false);
|
||||
bool reassociateFPReductions = false, bool indexOptimizations = false);
|
||||
|
||||
/// Create a pass to convert vector operations to the LLVMIR dialect.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertVectorToLLVMPass(
|
||||
|
||||
@@ -80,6 +80,12 @@ public:
|
||||
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
|
||||
OpFoldResult ofr);
|
||||
|
||||
/// Create a cast from an index-like value (index or integer) to another
|
||||
/// index-like value. If the value type and the target type are the same, it
|
||||
/// returns the original value.
|
||||
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc,
|
||||
Type targetType, Value value);
|
||||
|
||||
/// Similar to the other overload, but converts multiple OpFoldResults into
|
||||
/// Values.
|
||||
SmallVector<Value>
|
||||
|
||||
@@ -1752,6 +1752,14 @@ def LLVM_masked_compressstore
|
||||
/// Create a call to vscale intrinsic.
|
||||
def LLVM_vscale : LLVM_IntrOp<"vscale", [0], [], [], 1>;
|
||||
|
||||
/// Create a call to stepvector intrinsic.
|
||||
def LLVM_StepVectorOp
|
||||
: LLVM_IntrOp<"experimental.stepvector", [0], [], [NoSideEffect], 1> {
|
||||
let arguments = (ins);
|
||||
let results = (outs LLVM_Type:$res);
|
||||
let assemblyFormat = "attr-dict `:` type($res)";
|
||||
}
|
||||
|
||||
// Atomic operations.
|
||||
//
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
|
||||
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
|
||||
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
@@ -900,6 +901,40 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
/// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only).
|
||||
/// Non-scalable versions of this operation are handled in Vector Transforms.
|
||||
class VectorCreateMaskOpRewritePattern
|
||||
: public OpRewritePattern<vector::CreateMaskOp> {
|
||||
public:
|
||||
explicit VectorCreateMaskOpRewritePattern(MLIRContext *context,
|
||||
bool enableIndexOpt)
|
||||
: OpRewritePattern<vector::CreateMaskOp>(context),
|
||||
indexOptimizations(enableIndexOpt) {}
|
||||
|
||||
LogicalResult matchAndRewrite(vector::CreateMaskOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto dstType = op.getType();
|
||||
if (dstType.getRank() != 1 || !dstType.cast<VectorType>().isScalable())
|
||||
return failure();
|
||||
IntegerType idxType =
|
||||
indexOptimizations ? rewriter.getI32Type() : rewriter.getI64Type();
|
||||
auto loc = op->getLoc();
|
||||
Value indices = rewriter.create<LLVM::StepVectorOp>(
|
||||
loc, LLVM::getVectorType(idxType, dstType.getShape()[0],
|
||||
/*isScalable=*/true));
|
||||
auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType,
|
||||
op.getOperand(0));
|
||||
Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
|
||||
Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
|
||||
indices, bounds);
|
||||
rewriter.replaceOp(op, comp);
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
const bool indexOptimizations;
|
||||
};
|
||||
|
||||
class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
|
||||
public:
|
||||
using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern;
|
||||
@@ -1157,13 +1192,15 @@ struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
|
||||
} // namespace
|
||||
|
||||
/// Populate the given list with patterns that convert from Vector to LLVM.
|
||||
void mlir::populateVectorToLLVMConversionPatterns(
|
||||
LLVMTypeConverter &converter, RewritePatternSet &patterns,
|
||||
bool reassociateFPReductions) {
|
||||
void mlir::populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter,
|
||||
RewritePatternSet &patterns,
|
||||
bool reassociateFPReductions,
|
||||
bool indexOptimizations) {
|
||||
MLIRContext *ctx = converter.getDialect()->getContext();
|
||||
patterns.add<VectorFMAOpNDRewritePattern>(ctx);
|
||||
populateVectorInsertExtractStridedSliceTransforms(patterns);
|
||||
patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
|
||||
patterns.add<VectorCreateMaskOpRewritePattern>(ctx, indexOptimizations);
|
||||
patterns
|
||||
.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
|
||||
VectorExtractElementOpConversion, VectorExtractOpConversion,
|
||||
|
||||
@@ -80,8 +80,8 @@ void LowerVectorToLLVMPass::runOnOperation() {
|
||||
populateVectorMaskMaterializationPatterns(patterns, indexOptimizations);
|
||||
populateVectorTransferLoweringPatterns(patterns);
|
||||
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
|
||||
populateVectorToLLVMConversionPatterns(converter, patterns,
|
||||
reassociateFPReductions);
|
||||
populateVectorToLLVMConversionPatterns(
|
||||
converter, patterns, reassociateFPReductions, indexOptimizations);
|
||||
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
|
||||
|
||||
// Architecture specific augmentations.
|
||||
|
||||
@@ -59,6 +59,27 @@ Value mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
|
||||
return b.create<arith::ConstantIndexOp>(loc, attr.getValue().getSExtValue());
|
||||
}
|
||||
|
||||
Value mlir::getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc,
|
||||
Type targetType, Value value) {
|
||||
if (targetType == value.getType())
|
||||
return value;
|
||||
|
||||
bool targetIsIndex = targetType.isIndex();
|
||||
bool valueIsIndex = value.getType().isIndex();
|
||||
if (targetIsIndex ^ valueIsIndex)
|
||||
return b.create<arith::IndexCastOp>(loc, targetType, value);
|
||||
|
||||
auto targetIntegerType = targetType.dyn_cast<IntegerType>();
|
||||
auto valueIntegerType = value.getType().dyn_cast<IntegerType>();
|
||||
assert(targetIntegerType && valueIntegerType &&
|
||||
"unexpected cast between types other than integers and index");
|
||||
assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness());
|
||||
|
||||
if (targetIntegerType.getWidth() > valueIntegerType.getWidth())
|
||||
return b.create<arith::ExtSIOp>(loc, targetIntegerType, value);
|
||||
return b.create<arith::TruncIOp>(loc, targetIntegerType, value);
|
||||
}
|
||||
|
||||
SmallVector<Value>
|
||||
mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
|
||||
ArrayRef<OpFoldResult> valueOrAttrVec) {
|
||||
|
||||
@@ -4232,6 +4232,14 @@ LogicalResult ConstantMaskOp::verify() {
|
||||
if (anyZeros && !allZeros)
|
||||
return emitOpError("expected all mask dim sizes to be zeros, "
|
||||
"as a result of conjunction with zero mask dim");
|
||||
// Verify that if the mask type is scalable, dimensions should be zero because
|
||||
// constant scalable masks can only be defined for the "none set" or "all set"
|
||||
// cases, and there is no VLA way to define an "all set" case for
|
||||
// `vector.constant_mask`. In the future, a convention could be established
|
||||
// to decide if a specific dimension value could be considered as "all set".
|
||||
if (resultType.isScalable() &&
|
||||
mask_dim_sizes()[0].cast<IntegerAttr>().getInt() != 0)
|
||||
return emitOpError("expected mask dim sizes for scalable masks to be 0");
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -4269,6 +4277,19 @@ public:
|
||||
};
|
||||
if (llvm::any_of(createMaskOp.operands(), isNotDefByConstant))
|
||||
return failure();
|
||||
|
||||
// CreateMaskOp for scalable vectors can be folded only if all dimensions
|
||||
// are negative or zero.
|
||||
if (auto vType = createMaskOp.getType().dyn_cast<VectorType>()) {
|
||||
if (vType.isScalable())
|
||||
for (auto opDim : createMaskOp.getOperands()) {
|
||||
APInt intVal;
|
||||
if (matchPattern(opDim, m_ConstantInt(&intVal)) &&
|
||||
intVal.isStrictlyPositive())
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
// Gather constant mask dimension sizes.
|
||||
SmallVector<int64_t, 4> maskDimSizes;
|
||||
for (auto it : llvm::zip(createMaskOp.operands(),
|
||||
|
||||
@@ -16,6 +16,8 @@
|
||||
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
@@ -602,6 +604,13 @@ public:
|
||||
return success();
|
||||
}
|
||||
|
||||
// Scalable constant masks can only be lowered for the "none set" case.
|
||||
if (dstType.cast<VectorType>().isScalable()) {
|
||||
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
|
||||
op, DenseElementsAttr::get(dstType, false));
|
||||
return success();
|
||||
}
|
||||
|
||||
int64_t trueDim = std::min(dstType.getDimSize(0),
|
||||
dimSizes[0].cast<IntegerAttr>().getInt());
|
||||
|
||||
@@ -2161,27 +2170,6 @@ struct BubbleUpBitCastForStridedSliceInsert
|
||||
}
|
||||
};
|
||||
|
||||
static Value createCastToIndexLike(PatternRewriter &rewriter, Location loc,
|
||||
Type targetType, Value value) {
|
||||
if (targetType == value.getType())
|
||||
return value;
|
||||
|
||||
bool targetIsIndex = targetType.isIndex();
|
||||
bool valueIsIndex = value.getType().isIndex();
|
||||
if (targetIsIndex ^ valueIsIndex)
|
||||
return rewriter.create<arith::IndexCastOp>(loc, targetType, value);
|
||||
|
||||
auto targetIntegerType = targetType.dyn_cast<IntegerType>();
|
||||
auto valueIntegerType = value.getType().dyn_cast<IntegerType>();
|
||||
assert(targetIntegerType && valueIntegerType &&
|
||||
"unexpected cast between types other than integers and index");
|
||||
assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness());
|
||||
|
||||
if (targetIntegerType.getWidth() > valueIntegerType.getWidth())
|
||||
return rewriter.create<arith::ExtSIOp>(loc, targetIntegerType, value);
|
||||
return rewriter.create<arith::TruncIOp>(loc, targetIntegerType, value);
|
||||
}
|
||||
|
||||
// Helper that returns a vector comparison that constructs a mask:
|
||||
// mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
|
||||
//
|
||||
@@ -2217,12 +2205,12 @@ static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op,
|
||||
Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr);
|
||||
// Add in an offset if requested.
|
||||
if (off) {
|
||||
Value o = createCastToIndexLike(rewriter, loc, idxType, *off);
|
||||
Value o = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, *off);
|
||||
Value ov = rewriter.create<vector::SplatOp>(loc, indices.getType(), o);
|
||||
indices = rewriter.create<arith::AddIOp>(loc, ov, indices);
|
||||
}
|
||||
// Construct the vector comparison.
|
||||
Value bound = createCastToIndexLike(rewriter, loc, idxType, b);
|
||||
Value bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, b);
|
||||
Value bounds =
|
||||
rewriter.create<vector::SplatOp>(loc, indices.getType(), bound);
|
||||
return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, indices,
|
||||
@@ -2292,6 +2280,8 @@ public:
|
||||
LogicalResult matchAndRewrite(vector::CreateMaskOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto dstType = op.getType();
|
||||
if (dstType.cast<VectorType>().isScalable())
|
||||
return failure();
|
||||
int64_t rank = dstType.getRank();
|
||||
if (rank > 1)
|
||||
return failure();
|
||||
|
||||
@@ -24,6 +24,29 @@ func @genbool_var_1d(%arg0: index) -> vector<11xi1> {
|
||||
return %0 : vector<11xi1>
|
||||
}
|
||||
|
||||
// CMP32-LABEL: @genbool_var_1d_scalable(
|
||||
// CMP32-SAME: %[[ARG:.*]]: index)
|
||||
// CMP32: %[[T0:.*]] = llvm.intr.experimental.stepvector : vector<[11]xi32>
|
||||
// CMP32: %[[T1:.*]] = arith.index_cast %[[ARG]] : index to i32
|
||||
// CMP32: %[[T2:.*]] = llvm.insertelement %[[T1]], %{{.*}}[%{{.*}} : i32] : vector<[11]xi32>
|
||||
// CMP32: %[[T3:.*]] = llvm.shufflevector %[[T2]], %{{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32] : vector<[11]xi32>, vector<[11]xi32>
|
||||
// CMP32: %[[T4:.*]] = arith.cmpi slt, %[[T0]], %[[T3]] : vector<[11]xi32>
|
||||
// CMP32: return %[[T4]] : vector<[11]xi1>
|
||||
|
||||
// CMP64-LABEL: @genbool_var_1d_scalable(
|
||||
// CMP64-SAME: %[[ARG:.*]]: index)
|
||||
// CMP64: %[[T0:.*]] = llvm.intr.experimental.stepvector : vector<[11]xi64>
|
||||
// CMP64: %[[T1:.*]] = arith.index_cast %[[ARG]] : index to i64
|
||||
// CMP64: %[[T2:.*]] = llvm.insertelement %[[T1]], %{{.*}}[%{{.*}} : i32] : vector<[11]xi64>
|
||||
// CMP64: %[[T3:.*]] = llvm.shufflevector %[[T2]], %{{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32] : vector<[11]xi64>, vector<[11]xi64>
|
||||
// CMP64: %[[T4:.*]] = arith.cmpi slt, %[[T0]], %[[T3]] : vector<[11]xi64>
|
||||
// CMP64: return %[[T4]] : vector<[11]xi1>
|
||||
|
||||
func @genbool_var_1d_scalable(%arg0: index) -> vector<[11]xi1> {
|
||||
%0 = vector.create_mask %arg0 : vector<[11]xi1>
|
||||
return %0 : vector<[11]xi1>
|
||||
}
|
||||
|
||||
// CMP32-LABEL: @transfer_read_1d
|
||||
// CMP32: %[[MEM:.*]]: memref<?xf32>, %[[OFF:.*]]: index) -> vector<16xf32> {
|
||||
// CMP32: %[[D:.*]] = memref.dim %[[MEM]], %{{.*}} : memref<?xf32>
|
||||
|
||||
@@ -1459,6 +1459,16 @@ func @genbool_1d() -> vector<8xi1> {
|
||||
|
||||
// -----
|
||||
|
||||
func @genbool_1d_scalable() -> vector<[8]xi1> {
|
||||
%0 = vector.constant_mask [0] : vector<[8]xi1>
|
||||
return %0 : vector<[8]xi1>
|
||||
}
|
||||
// CHECK-LABEL: func @genbool_1d_scalable
|
||||
// CHECK: %[[VAL_0:.*]] = arith.constant dense<false> : vector<[8]xi1>
|
||||
// CHECK: return %[[VAL_0]] : vector<[8]xi1>
|
||||
|
||||
// -----
|
||||
|
||||
func @genbool_2d() -> vector<4x4xi1> {
|
||||
%v = vector.constant_mask [2, 2] : vector<4x4xi1>
|
||||
return %v: vector<4x4xi1>
|
||||
@@ -1505,6 +1515,20 @@ func @create_mask_1d(%a : index) -> vector<4xi1> {
|
||||
// CHECK: %[[result:.*]] = arith.cmpi slt, %[[indices]], %[[bounds]] : vector<4xi32>
|
||||
// CHECK: return %[[result]] : vector<4xi1>
|
||||
|
||||
func @create_mask_1d_scalable(%a : index) -> vector<[4]xi1> {
|
||||
%v = vector.create_mask %a : vector<[4]xi1>
|
||||
return %v: vector<[4]xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @create_mask_1d_scalable
|
||||
// CHECK-SAME: %[[arg:.*]]: index
|
||||
// CHECK: %[[indices:.*]] = llvm.intr.experimental.stepvector : vector<[4]xi32>
|
||||
// CHECK: %[[arg_i32:.*]] = arith.index_cast %[[arg]] : index to i32
|
||||
// CHECK: %[[boundsInsert:.*]] = llvm.insertelement %[[arg_i32]], {{.*}} : vector<[4]xi32>
|
||||
// CHECK: %[[bounds:.*]] = llvm.shufflevector %[[boundsInsert]], {{.*}} : vector<[4]xi32>, vector<[4]xi32>
|
||||
// CHECK: %[[result:.*]] = arith.cmpi slt, %[[indices]], %[[bounds]] : vector<[4]xi32>
|
||||
// CHECK: return %[[result]] : vector<[4]xi1>
|
||||
|
||||
// -----
|
||||
|
||||
func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> {
|
||||
|
||||
@@ -13,6 +13,16 @@ func @create_vector_mask_to_constant_mask() -> (vector<4x3xi1>) {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: create_scalable_vector_mask_to_constant_mask
|
||||
func @create_scalable_vector_mask_to_constant_mask() -> (vector<[8]xi1>) {
|
||||
%c-1 = arith.constant -1 : index
|
||||
// CHECK: vector.constant_mask [0] : vector<[8]xi1>
|
||||
%0 = vector.create_mask %c-1 : vector<[8]xi1>
|
||||
return %0 : vector<[8]xi1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: create_vector_mask_to_constant_mask_truncation
|
||||
func @create_vector_mask_to_constant_mask_truncation() -> (vector<4x3xi1>) {
|
||||
%c2 = arith.constant 2 : index
|
||||
|
||||
@@ -944,6 +944,13 @@ func @constant_mask_with_zero_mask_dim_size() {
|
||||
|
||||
// -----
|
||||
|
||||
func @constant_mask_scalable_non_zero_dim_size() {
|
||||
// expected-error@+1 {{expected mask dim sizes for scalable masks to be 0}}
|
||||
%0 = vector.constant_mask [2] : vector<[8]xi1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @print_no_result(%arg0 : f32) -> i32 {
|
||||
// expected-error@+1 {{cannot name an operation with no results}}
|
||||
%0 = vector.print %arg0 : f32
|
||||
|
||||
@@ -389,6 +389,8 @@ func @constant_vector_mask_0d() {
|
||||
func @constant_vector_mask() {
|
||||
// CHECK: vector.constant_mask [3, 2] : vector<4x3xi1>
|
||||
%0 = vector.constant_mask [3, 2] : vector<4x3xi1>
|
||||
// CHECK: vector.constant_mask [0] : vector<[4]xi1>
|
||||
%1 = vector.constant_mask [0] : vector<[4]xi1>
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user