[mlir[[vector] Extend Transfer read/write ops to support tensor types.

Transfer_ops can now work on both buffers and tensor. Right now, lowering of
the tensor case is not supported yet.

Differential Revision: https://reviews.llvm.org/D93500
This commit is contained in:
Thomas Raoux
2020-12-17 16:26:07 -08:00
parent 9a93f95fce
commit 26c8f9081b
16 changed files with 304 additions and 189 deletions

View File

@@ -141,12 +141,10 @@ static Value buildVectorComparison(ConversionPatternRewriter &rewriter,
return rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, indices, bounds);
}
// Helper that returns data layout alignment of an operation with memref.
template <typename T>
LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, T op,
unsigned &align) {
Type elementTy =
typeConverter.convertType(op.getMemRefType().getElementType());
// Helper that returns data layout alignment of a memref.
LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
MemRefType memrefType, unsigned &align) {
Type elementTy = typeConverter.convertType(memrefType.getElementType());
if (!elementTy)
return failure();
@@ -222,7 +220,8 @@ replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
TransferReadOp xferOp,
ArrayRef<Value> operands, Value dataPtr) {
unsigned align;
if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
if (failed(getMemRefAlignment(
typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
return failure();
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr, align);
return success();
@@ -243,7 +242,8 @@ replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
return failure();
unsigned align;
if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
if (failed(getMemRefAlignment(
typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
return failure();
rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
@@ -258,7 +258,8 @@ replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
TransferWriteOp xferOp,
ArrayRef<Value> operands, Value dataPtr) {
unsigned align;
if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
if (failed(getMemRefAlignment(
typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
return failure();
auto adaptor = TransferWriteOpAdaptor(operands);
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr,
@@ -272,7 +273,8 @@ replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
TransferWriteOp xferOp, ArrayRef<Value> operands,
Value dataPtr, Value mask) {
unsigned align;
if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
if (failed(getMemRefAlignment(
typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
return failure();
auto adaptor = TransferWriteOpAdaptor(operands);
@@ -345,7 +347,8 @@ public:
// Resolve alignment.
unsigned align;
if (failed(getMemRefAlignment(*getTypeConverter(), load, align)))
if (failed(getMemRefAlignment(*getTypeConverter(), load.getMemRefType(),
align)))
return failure();
auto vtype = typeConverter->convertType(load.getResultVectorType());
@@ -375,7 +378,8 @@ public:
// Resolve alignment.
unsigned align;
if (failed(getMemRefAlignment(*getTypeConverter(), store, align)))
if (failed(getMemRefAlignment(*getTypeConverter(), store.getMemRefType(),
align)))
return failure();
auto vtype = typeConverter->convertType(store.getValueVectorType());
@@ -405,7 +409,8 @@ public:
// Resolve alignment.
unsigned align;
if (failed(getMemRefAlignment(*getTypeConverter(), gather, align)))
if (failed(getMemRefAlignment(*getTypeConverter(), gather.getMemRefType(),
align)))
return failure();
// Get index ptrs.
@@ -438,7 +443,8 @@ public:
// Resolve alignment.
unsigned align;
if (failed(getMemRefAlignment(*getTypeConverter(), scatter, align)))
if (failed(getMemRefAlignment(*getTypeConverter(), scatter.getMemRefType(),
align)))
return failure();
// Get index ptrs.
@@ -1182,8 +1188,11 @@ public:
xferOp.getVectorType().getRank(),
xferOp->getContext()))
return failure();
auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>();
if (!memRefType)
return failure();
// Only contiguous source tensors supported atm.
auto strides = computeContiguousStrides(xferOp.getMemRefType());
auto strides = computeContiguousStrides(memRefType);
if (!strides)
return failure();
@@ -1192,10 +1201,9 @@ public:
};
Location loc = xferOp->getLoc();
MemRefType memRefType = xferOp.getMemRefType();
if (auto memrefVectorElementType =
memRefType.getElementType().dyn_cast<VectorType>()) {
memRefType.getElementType().template dyn_cast<VectorType>()) {
// Memref has vector element type.
if (memrefVectorElementType.getElementType() !=
xferOp.getVectorType().getElementType())
@@ -1222,7 +1230,7 @@ public:
// address space 0.
// TODO: support alignment when possible.
Value dataPtr = this->getStridedElementPtr(
loc, memRefType, adaptor.memref(), adaptor.indices(), rewriter);
loc, memRefType, adaptor.source(), adaptor.indices(), rewriter);
auto vecTy =
toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
Value vectorDataPtr;
@@ -1248,7 +1256,7 @@ public:
unsigned vecWidth = vecTy.getVectorNumElements();
unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
Value off = xferOp.indices()[lastIndex];
Value dim = rewriter.create<DimOp>(loc, xferOp.memref(), lastIndex);
Value dim = rewriter.create<DimOp>(loc, xferOp.source(), lastIndex);
Value mask = buildVectorComparison(
rewriter, xferOp, enableIndexOptimizations, vecWidth, dim, &off);