[mlir] Add "mask" operand to vector.transfer_read/write.
Also factors out out-of-bounds mask generation from vector.transfer_read/write into a new MaterializeTransferMask pattern. Differential Revision: https://reviews.llvm.org/D100001
This commit is contained in:
@@ -104,66 +104,6 @@ static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
|
||||
return res;
|
||||
}
|
||||
|
||||
static Value createCastToIndexLike(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Type targetType, Value value) {
|
||||
if (targetType == value.getType())
|
||||
return value;
|
||||
|
||||
bool targetIsIndex = targetType.isIndex();
|
||||
bool valueIsIndex = value.getType().isIndex();
|
||||
if (targetIsIndex ^ valueIsIndex)
|
||||
return rewriter.create<IndexCastOp>(loc, targetType, value);
|
||||
|
||||
auto targetIntegerType = targetType.dyn_cast<IntegerType>();
|
||||
auto valueIntegerType = value.getType().dyn_cast<IntegerType>();
|
||||
assert(targetIntegerType && valueIntegerType &&
|
||||
"unexpected cast between types other than integers and index");
|
||||
assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness());
|
||||
|
||||
if (targetIntegerType.getWidth() > valueIntegerType.getWidth())
|
||||
return rewriter.create<SignExtendIOp>(loc, targetIntegerType, value);
|
||||
return rewriter.create<TruncateIOp>(loc, targetIntegerType, value);
|
||||
}
|
||||
|
||||
// Helper that returns a vector comparison that constructs a mask:
|
||||
// mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
|
||||
//
|
||||
// NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
|
||||
// much more compact, IR for this operation, but LLVM eventually
|
||||
// generates more elaborate instructions for this intrinsic since it
|
||||
// is very conservative on the boundary conditions.
|
||||
static Value buildVectorComparison(ConversionPatternRewriter &rewriter,
|
||||
Operation *op, bool enableIndexOptimizations,
|
||||
int64_t dim, Value b, Value *off = nullptr) {
|
||||
auto loc = op->getLoc();
|
||||
// If we can assume all indices fit in 32-bit, we perform the vector
|
||||
// comparison in 32-bit to get a higher degree of SIMD parallelism.
|
||||
// Otherwise we perform the vector comparison using 64-bit indices.
|
||||
Value indices;
|
||||
Type idxType;
|
||||
if (enableIndexOptimizations) {
|
||||
indices = rewriter.create<ConstantOp>(
|
||||
loc, rewriter.getI32VectorAttr(
|
||||
llvm::to_vector<4>(llvm::seq<int32_t>(0, dim))));
|
||||
idxType = rewriter.getI32Type();
|
||||
} else {
|
||||
indices = rewriter.create<ConstantOp>(
|
||||
loc, rewriter.getI64VectorAttr(
|
||||
llvm::to_vector<4>(llvm::seq<int64_t>(0, dim))));
|
||||
idxType = rewriter.getI64Type();
|
||||
}
|
||||
// Add in an offset if requested.
|
||||
if (off) {
|
||||
Value o = createCastToIndexLike(rewriter, loc, idxType, *off);
|
||||
Value ov = rewriter.create<SplatOp>(loc, indices.getType(), o);
|
||||
indices = rewriter.create<AddIOp>(loc, ov, indices);
|
||||
}
|
||||
// Construct the vector comparison.
|
||||
Value bound = createCastToIndexLike(rewriter, loc, idxType, b);
|
||||
Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
|
||||
return rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, indices, bounds);
|
||||
}
|
||||
|
||||
// Helper that returns data layout alignment of a memref.
|
||||
LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
|
||||
MemRefType memrefType, unsigned &align) {
|
||||
@@ -250,7 +190,7 @@ replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
|
||||
if (failed(getMemRefAlignment(
|
||||
typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
|
||||
return failure();
|
||||
auto adaptor = TransferWriteOpAdaptor(operands);
|
||||
auto adaptor = TransferWriteOpAdaptor(operands, xferOp->getAttrDictionary());
|
||||
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr,
|
||||
align);
|
||||
return success();
|
||||
@@ -266,7 +206,7 @@ replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
|
||||
typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
|
||||
return failure();
|
||||
|
||||
auto adaptor = TransferWriteOpAdaptor(operands);
|
||||
auto adaptor = TransferWriteOpAdaptor(operands, xferOp->getAttrDictionary());
|
||||
rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
|
||||
xferOp, adaptor.vector(), dataPtr, mask,
|
||||
rewriter.getI32IntegerAttr(align));
|
||||
@@ -275,12 +215,12 @@ replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
|
||||
|
||||
static TransferReadOpAdaptor getTransferOpAdapter(TransferReadOp xferOp,
|
||||
ArrayRef<Value> operands) {
|
||||
return TransferReadOpAdaptor(operands);
|
||||
return TransferReadOpAdaptor(operands, xferOp->getAttrDictionary());
|
||||
}
|
||||
|
||||
static TransferWriteOpAdaptor getTransferOpAdapter(TransferWriteOp xferOp,
|
||||
ArrayRef<Value> operands) {
|
||||
return TransferWriteOpAdaptor(operands);
|
||||
return TransferWriteOpAdaptor(operands, xferOp->getAttrDictionary());
|
||||
}
|
||||
|
||||
namespace {
|
||||
@@ -618,33 +558,6 @@ private:
|
||||
const bool reassociateFPReductions;
|
||||
};
|
||||
|
||||
/// Conversion pattern for a vector.create_mask (1-D only).
|
||||
class VectorCreateMaskOpConversion
|
||||
: public ConvertOpToLLVMPattern<vector::CreateMaskOp> {
|
||||
public:
|
||||
explicit VectorCreateMaskOpConversion(LLVMTypeConverter &typeConv,
|
||||
bool enableIndexOpt)
|
||||
: ConvertOpToLLVMPattern<vector::CreateMaskOp>(typeConv),
|
||||
enableIndexOptimizations(enableIndexOpt) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(vector::CreateMaskOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto dstType = op.getType();
|
||||
int64_t rank = dstType.getRank();
|
||||
if (rank == 1) {
|
||||
rewriter.replaceOp(
|
||||
op, buildVectorComparison(rewriter, op, enableIndexOptimizations,
|
||||
dstType.getDimSize(0), operands[0]));
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
private:
|
||||
const bool enableIndexOptimizations;
|
||||
};
|
||||
|
||||
class VectorShuffleOpConversion
|
||||
: public ConvertOpToLLVMPattern<vector::ShuffleOp> {
|
||||
public:
|
||||
@@ -1177,20 +1090,12 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
/// Conversion pattern that converts a 1-D vector transfer read/write op in a
|
||||
/// sequence of:
|
||||
/// 1. Get the source/dst address as an LLVM vector pointer.
|
||||
/// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
|
||||
/// 3. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
|
||||
/// 4. Create a mask where offsetVector is compared against memref upper bound.
|
||||
/// 5. Rewrite op as a masked read or write.
|
||||
/// Conversion pattern that converts a 1-D vector transfer read/write op into a
|
||||
/// a masked or unmasked read/write.
|
||||
template <typename ConcreteOp>
|
||||
class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
|
||||
public:
|
||||
explicit VectorTransferConversion(LLVMTypeConverter &typeConv,
|
||||
bool enableIndexOpt)
|
||||
: ConvertOpToLLVMPattern<ConcreteOp>(typeConv),
|
||||
enableIndexOptimizations(enableIndexOpt) {}
|
||||
using ConvertOpToLLVMPattern<ConcreteOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(ConcreteOp xferOp, ArrayRef<Value> operands,
|
||||
@@ -1212,6 +1117,9 @@ public:
|
||||
auto strides = computeContiguousStrides(memRefType);
|
||||
if (!strides)
|
||||
return failure();
|
||||
// Out-of-bounds dims are handled by MaterializeTransferMask.
|
||||
if (xferOp.hasOutOfBoundsDim())
|
||||
return failure();
|
||||
|
||||
auto toLLVMTy = [&](Type t) {
|
||||
return this->getTypeConverter()->convertType(t);
|
||||
@@ -1241,40 +1149,24 @@ public:
|
||||
#endif // ifndef NDEBUG
|
||||
}
|
||||
|
||||
// 1. Get the source/dst address as an LLVM vector pointer.
|
||||
// Get the source/dst address as an LLVM vector pointer.
|
||||
VectorType vtp = xferOp.getVectorType();
|
||||
Value dataPtr = this->getStridedElementPtr(
|
||||
loc, memRefType, adaptor.source(), adaptor.indices(), rewriter);
|
||||
Value vectorDataPtr =
|
||||
castDataPtr(rewriter, loc, dataPtr, memRefType, toLLVMTy(vtp));
|
||||
|
||||
if (xferOp.isDimInBounds(0))
|
||||
// Rewrite as an unmasked masked read / write.
|
||||
if (!xferOp.mask())
|
||||
return replaceTransferOpWithLoadOrStore(rewriter,
|
||||
*this->getTypeConverter(), loc,
|
||||
xferOp, operands, vectorDataPtr);
|
||||
|
||||
// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
|
||||
// 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
|
||||
// 4. Let dim the memref dimension, compute the vector comparison mask
|
||||
// (in-bounds mask):
|
||||
// [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
|
||||
//
|
||||
// TODO: when the leaf transfer rank is k > 1, we need the last `k`
|
||||
// dimensions here.
|
||||
unsigned vecWidth = LLVM::getVectorNumElements(vtp).getFixedValue();
|
||||
unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
|
||||
Value off = xferOp.indices()[lastIndex];
|
||||
Value dim = rewriter.create<memref::DimOp>(loc, xferOp.source(), lastIndex);
|
||||
Value mask = buildVectorComparison(
|
||||
rewriter, xferOp, enableIndexOptimizations, vecWidth, dim, &off);
|
||||
|
||||
// 5. Rewrite as a masked read / write.
|
||||
// Rewrite as a masked read / write.
|
||||
return replaceTransferOpWithMasked(rewriter, *this->getTypeConverter(), loc,
|
||||
xferOp, operands, vectorDataPtr, mask);
|
||||
xferOp, operands, vectorDataPtr,
|
||||
xferOp.mask());
|
||||
}
|
||||
|
||||
private:
|
||||
const bool enableIndexOptimizations;
|
||||
};
|
||||
|
||||
class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
|
||||
@@ -1484,17 +1376,13 @@ public:
|
||||
/// Populate the given list with patterns that convert from Vector to LLVM.
|
||||
void mlir::populateVectorToLLVMConversionPatterns(
|
||||
LLVMTypeConverter &converter, RewritePatternSet &patterns,
|
||||
bool reassociateFPReductions, bool enableIndexOptimizations) {
|
||||
bool reassociateFPReductions) {
|
||||
MLIRContext *ctx = converter.getDialect()->getContext();
|
||||
patterns.add<VectorFMAOpNDRewritePattern,
|
||||
VectorInsertStridedSliceOpDifferentRankRewritePattern,
|
||||
VectorInsertStridedSliceOpSameRankRewritePattern,
|
||||
VectorExtractStridedSliceOpConversion>(ctx);
|
||||
patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
|
||||
patterns.add<VectorCreateMaskOpConversion,
|
||||
VectorTransferConversion<TransferReadOp>,
|
||||
VectorTransferConversion<TransferWriteOp>>(
|
||||
converter, enableIndexOptimizations);
|
||||
patterns
|
||||
.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
|
||||
VectorExtractElementOpConversion, VectorExtractOpConversion,
|
||||
@@ -1508,8 +1396,9 @@ void mlir::populateVectorToLLVMConversionPatterns(
|
||||
VectorLoadStoreConversion<vector::MaskedStoreOp,
|
||||
vector::MaskedStoreOpAdaptor>,
|
||||
VectorGatherOpConversion, VectorScatterOpConversion,
|
||||
VectorExpandLoadOpConversion, VectorCompressStoreOpConversion>(
|
||||
converter);
|
||||
VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
|
||||
VectorTransferConversion<TransferReadOp>,
|
||||
VectorTransferConversion<TransferWriteOp>>(converter);
|
||||
}
|
||||
|
||||
void mlir::populateVectorToLLVMMatrixConversionPatterns(
|
||||
|
||||
Reference in New Issue
Block a user