[mlir][Vector] Enable create_mask for scalable vectors
The way vector.create_mask is currently lowered is vector-length-dependent, and therefore incompatible with scalable vector types. This patch adds an alternative lowering path for create_mask operations that return a scalable vector mask. Differential Revision: https://reviews.llvm.org/D118248
This commit is contained in:
@@ -10,6 +10,7 @@
|
||||
|
||||
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
|
||||
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
@@ -900,6 +901,40 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
/// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only).
|
||||
/// Non-scalable versions of this operation are handled in Vector Transforms.
|
||||
class VectorCreateMaskOpRewritePattern
|
||||
: public OpRewritePattern<vector::CreateMaskOp> {
|
||||
public:
|
||||
explicit VectorCreateMaskOpRewritePattern(MLIRContext *context,
|
||||
bool enableIndexOpt)
|
||||
: OpRewritePattern<vector::CreateMaskOp>(context),
|
||||
indexOptimizations(enableIndexOpt) {}
|
||||
|
||||
LogicalResult matchAndRewrite(vector::CreateMaskOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto dstType = op.getType();
|
||||
if (dstType.getRank() != 1 || !dstType.cast<VectorType>().isScalable())
|
||||
return failure();
|
||||
IntegerType idxType =
|
||||
indexOptimizations ? rewriter.getI32Type() : rewriter.getI64Type();
|
||||
auto loc = op->getLoc();
|
||||
Value indices = rewriter.create<LLVM::StepVectorOp>(
|
||||
loc, LLVM::getVectorType(idxType, dstType.getShape()[0],
|
||||
/*isScalable=*/true));
|
||||
auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType,
|
||||
op.getOperand(0));
|
||||
Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
|
||||
Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
|
||||
indices, bounds);
|
||||
rewriter.replaceOp(op, comp);
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
const bool indexOptimizations;
|
||||
};
|
||||
|
||||
class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
|
||||
public:
|
||||
using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern;
|
||||
@@ -1157,13 +1192,15 @@ struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
|
||||
} // namespace
|
||||
|
||||
/// Populate the given list with patterns that convert from Vector to LLVM.
|
||||
void mlir::populateVectorToLLVMConversionPatterns(
|
||||
LLVMTypeConverter &converter, RewritePatternSet &patterns,
|
||||
bool reassociateFPReductions) {
|
||||
void mlir::populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter,
|
||||
RewritePatternSet &patterns,
|
||||
bool reassociateFPReductions,
|
||||
bool indexOptimizations) {
|
||||
MLIRContext *ctx = converter.getDialect()->getContext();
|
||||
patterns.add<VectorFMAOpNDRewritePattern>(ctx);
|
||||
populateVectorInsertExtractStridedSliceTransforms(patterns);
|
||||
patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
|
||||
patterns.add<VectorCreateMaskOpRewritePattern>(ctx, indexOptimizations);
|
||||
patterns
|
||||
.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
|
||||
VectorExtractElementOpConversion, VectorExtractOpConversion,
|
||||
|
||||
Reference in New Issue
Block a user