[mlir] Add "mask" operand to vector.transfer_read/write.

Also factors out out-of-bounds mask generation from vector.transfer_read/write into a new MaterializeTransferMask pattern.

Differential Revision: https://reviews.llvm.org/D100001
This commit is contained in:
Matthias Springer
2021-04-07 21:11:55 +09:00
parent c0ef93bec8
commit 65a3f28939
13 changed files with 389 additions and 186 deletions

View File

@@ -104,66 +104,6 @@ static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
return res;
}
static Value createCastToIndexLike(ConversionPatternRewriter &rewriter,
Location loc, Type targetType, Value value) {
if (targetType == value.getType())
return value;
bool targetIsIndex = targetType.isIndex();
bool valueIsIndex = value.getType().isIndex();
if (targetIsIndex ^ valueIsIndex)
return rewriter.create<IndexCastOp>(loc, targetType, value);
auto targetIntegerType = targetType.dyn_cast<IntegerType>();
auto valueIntegerType = value.getType().dyn_cast<IntegerType>();
assert(targetIntegerType && valueIntegerType &&
"unexpected cast between types other than integers and index");
assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness());
if (targetIntegerType.getWidth() > valueIntegerType.getWidth())
return rewriter.create<SignExtendIOp>(loc, targetIntegerType, value);
return rewriter.create<TruncateIOp>(loc, targetIntegerType, value);
}
// Helper that returns a vector comparison that constructs a mask:
// mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
//
// NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
// much more compact, IR for this operation, but LLVM eventually
// generates more elaborate instructions for this intrinsic since it
// is very conservative on the boundary conditions.
static Value buildVectorComparison(ConversionPatternRewriter &rewriter,
Operation *op, bool enableIndexOptimizations,
int64_t dim, Value b, Value *off = nullptr) {
auto loc = op->getLoc();
// If we can assume all indices fit in 32-bit, we perform the vector
// comparison in 32-bit to get a higher degree of SIMD parallelism.
// Otherwise we perform the vector comparison using 64-bit indices.
Value indices;
Type idxType;
if (enableIndexOptimizations) {
indices = rewriter.create<ConstantOp>(
loc, rewriter.getI32VectorAttr(
llvm::to_vector<4>(llvm::seq<int32_t>(0, dim))));
idxType = rewriter.getI32Type();
} else {
indices = rewriter.create<ConstantOp>(
loc, rewriter.getI64VectorAttr(
llvm::to_vector<4>(llvm::seq<int64_t>(0, dim))));
idxType = rewriter.getI64Type();
}
// Add in an offset if requested.
if (off) {
Value o = createCastToIndexLike(rewriter, loc, idxType, *off);
Value ov = rewriter.create<SplatOp>(loc, indices.getType(), o);
indices = rewriter.create<AddIOp>(loc, ov, indices);
}
// Construct the vector comparison.
Value bound = createCastToIndexLike(rewriter, loc, idxType, b);
Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
return rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, indices, bounds);
}
// Helper that returns data layout alignment of a memref.
LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
MemRefType memrefType, unsigned &align) {
@@ -250,7 +190,7 @@ replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
if (failed(getMemRefAlignment(
typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
return failure();
auto adaptor = TransferWriteOpAdaptor(operands);
auto adaptor = TransferWriteOpAdaptor(operands, xferOp->getAttrDictionary());
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr,
align);
return success();
@@ -266,7 +206,7 @@ replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
return failure();
auto adaptor = TransferWriteOpAdaptor(operands);
auto adaptor = TransferWriteOpAdaptor(operands, xferOp->getAttrDictionary());
rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
xferOp, adaptor.vector(), dataPtr, mask,
rewriter.getI32IntegerAttr(align));
@@ -275,12 +215,12 @@ replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
static TransferReadOpAdaptor getTransferOpAdapter(TransferReadOp xferOp,
ArrayRef<Value> operands) {
return TransferReadOpAdaptor(operands);
return TransferReadOpAdaptor(operands, xferOp->getAttrDictionary());
}
static TransferWriteOpAdaptor getTransferOpAdapter(TransferWriteOp xferOp,
ArrayRef<Value> operands) {
return TransferWriteOpAdaptor(operands);
return TransferWriteOpAdaptor(operands, xferOp->getAttrDictionary());
}
namespace {
@@ -618,33 +558,6 @@ private:
const bool reassociateFPReductions;
};
/// Conversion pattern for a vector.create_mask (1-D only).
class VectorCreateMaskOpConversion
: public ConvertOpToLLVMPattern<vector::CreateMaskOp> {
public:
explicit VectorCreateMaskOpConversion(LLVMTypeConverter &typeConv,
bool enableIndexOpt)
: ConvertOpToLLVMPattern<vector::CreateMaskOp>(typeConv),
enableIndexOptimizations(enableIndexOpt) {}
LogicalResult
matchAndRewrite(vector::CreateMaskOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto dstType = op.getType();
int64_t rank = dstType.getRank();
if (rank == 1) {
rewriter.replaceOp(
op, buildVectorComparison(rewriter, op, enableIndexOptimizations,
dstType.getDimSize(0), operands[0]));
return success();
}
return failure();
}
private:
const bool enableIndexOptimizations;
};
class VectorShuffleOpConversion
: public ConvertOpToLLVMPattern<vector::ShuffleOp> {
public:
@@ -1177,20 +1090,12 @@ public:
}
};
/// Conversion pattern that converts a 1-D vector transfer read/write op in a
/// sequence of:
/// 1. Get the source/dst address as an LLVM vector pointer.
/// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
/// 3. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
/// 4. Create a mask where offsetVector is compared against memref upper bound.
/// 5. Rewrite op as a masked read or write.
/// Conversion pattern that converts a 1-D vector transfer read/write op into a
/// a masked or unmasked read/write.
template <typename ConcreteOp>
class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
public:
explicit VectorTransferConversion(LLVMTypeConverter &typeConv,
bool enableIndexOpt)
: ConvertOpToLLVMPattern<ConcreteOp>(typeConv),
enableIndexOptimizations(enableIndexOpt) {}
using ConvertOpToLLVMPattern<ConcreteOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(ConcreteOp xferOp, ArrayRef<Value> operands,
@@ -1212,6 +1117,9 @@ public:
auto strides = computeContiguousStrides(memRefType);
if (!strides)
return failure();
// Out-of-bounds dims are handled by MaterializeTransferMask.
if (xferOp.hasOutOfBoundsDim())
return failure();
auto toLLVMTy = [&](Type t) {
return this->getTypeConverter()->convertType(t);
@@ -1241,40 +1149,24 @@ public:
#endif // ifndef NDEBUG
}
// 1. Get the source/dst address as an LLVM vector pointer.
// Get the source/dst address as an LLVM vector pointer.
VectorType vtp = xferOp.getVectorType();
Value dataPtr = this->getStridedElementPtr(
loc, memRefType, adaptor.source(), adaptor.indices(), rewriter);
Value vectorDataPtr =
castDataPtr(rewriter, loc, dataPtr, memRefType, toLLVMTy(vtp));
if (xferOp.isDimInBounds(0))
// Rewrite as an unmasked masked read / write.
if (!xferOp.mask())
return replaceTransferOpWithLoadOrStore(rewriter,
*this->getTypeConverter(), loc,
xferOp, operands, vectorDataPtr);
// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
// 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
// 4. Let dim the memref dimension, compute the vector comparison mask
// (in-bounds mask):
// [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
//
// TODO: when the leaf transfer rank is k > 1, we need the last `k`
// dimensions here.
unsigned vecWidth = LLVM::getVectorNumElements(vtp).getFixedValue();
unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
Value off = xferOp.indices()[lastIndex];
Value dim = rewriter.create<memref::DimOp>(loc, xferOp.source(), lastIndex);
Value mask = buildVectorComparison(
rewriter, xferOp, enableIndexOptimizations, vecWidth, dim, &off);
// 5. Rewrite as a masked read / write.
// Rewrite as a masked read / write.
return replaceTransferOpWithMasked(rewriter, *this->getTypeConverter(), loc,
xferOp, operands, vectorDataPtr, mask);
xferOp, operands, vectorDataPtr,
xferOp.mask());
}
private:
const bool enableIndexOptimizations;
};
class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
@@ -1484,17 +1376,13 @@ public:
/// Populate the given list with patterns that convert from Vector to LLVM.
void mlir::populateVectorToLLVMConversionPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns,
bool reassociateFPReductions, bool enableIndexOptimizations) {
bool reassociateFPReductions) {
MLIRContext *ctx = converter.getDialect()->getContext();
patterns.add<VectorFMAOpNDRewritePattern,
VectorInsertStridedSliceOpDifferentRankRewritePattern,
VectorInsertStridedSliceOpSameRankRewritePattern,
VectorExtractStridedSliceOpConversion>(ctx);
patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
patterns.add<VectorCreateMaskOpConversion,
VectorTransferConversion<TransferReadOp>,
VectorTransferConversion<TransferWriteOp>>(
converter, enableIndexOptimizations);
patterns
.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
VectorExtractElementOpConversion, VectorExtractOpConversion,
@@ -1508,8 +1396,9 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorLoadStoreConversion<vector::MaskedStoreOp,
vector::MaskedStoreOpAdaptor>,
VectorGatherOpConversion, VectorScatterOpConversion,
VectorExpandLoadOpConversion, VectorCompressStoreOpConversion>(
converter);
VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
VectorTransferConversion<TransferReadOp>,
VectorTransferConversion<TransferWriteOp>>(converter);
}
void mlir::populateVectorToLLVMMatrixConversionPatterns(