[mlir[[vector] Extend Transfer read/write ops to support tensor types.
Transfer_ops can now work on both buffers and tensor. Right now, lowering of the tensor case is not supported yet. Differential Revision: https://reviews.llvm.org/D93500
This commit is contained in:
@@ -141,12 +141,10 @@ static Value buildVectorComparison(ConversionPatternRewriter &rewriter,
|
||||
return rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, indices, bounds);
|
||||
}
|
||||
|
||||
// Helper that returns data layout alignment of an operation with memref.
|
||||
template <typename T>
|
||||
LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, T op,
|
||||
unsigned &align) {
|
||||
Type elementTy =
|
||||
typeConverter.convertType(op.getMemRefType().getElementType());
|
||||
// Helper that returns data layout alignment of a memref.
|
||||
LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
|
||||
MemRefType memrefType, unsigned &align) {
|
||||
Type elementTy = typeConverter.convertType(memrefType.getElementType());
|
||||
if (!elementTy)
|
||||
return failure();
|
||||
|
||||
@@ -222,7 +220,8 @@ replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
|
||||
TransferReadOp xferOp,
|
||||
ArrayRef<Value> operands, Value dataPtr) {
|
||||
unsigned align;
|
||||
if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
|
||||
if (failed(getMemRefAlignment(
|
||||
typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
|
||||
return failure();
|
||||
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr, align);
|
||||
return success();
|
||||
@@ -243,7 +242,8 @@ replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
|
||||
return failure();
|
||||
|
||||
unsigned align;
|
||||
if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
|
||||
if (failed(getMemRefAlignment(
|
||||
typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
|
||||
@@ -258,7 +258,8 @@ replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
|
||||
TransferWriteOp xferOp,
|
||||
ArrayRef<Value> operands, Value dataPtr) {
|
||||
unsigned align;
|
||||
if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
|
||||
if (failed(getMemRefAlignment(
|
||||
typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
|
||||
return failure();
|
||||
auto adaptor = TransferWriteOpAdaptor(operands);
|
||||
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr,
|
||||
@@ -272,7 +273,8 @@ replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
|
||||
TransferWriteOp xferOp, ArrayRef<Value> operands,
|
||||
Value dataPtr, Value mask) {
|
||||
unsigned align;
|
||||
if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
|
||||
if (failed(getMemRefAlignment(
|
||||
typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
|
||||
return failure();
|
||||
|
||||
auto adaptor = TransferWriteOpAdaptor(operands);
|
||||
@@ -345,7 +347,8 @@ public:
|
||||
|
||||
// Resolve alignment.
|
||||
unsigned align;
|
||||
if (failed(getMemRefAlignment(*getTypeConverter(), load, align)))
|
||||
if (failed(getMemRefAlignment(*getTypeConverter(), load.getMemRefType(),
|
||||
align)))
|
||||
return failure();
|
||||
|
||||
auto vtype = typeConverter->convertType(load.getResultVectorType());
|
||||
@@ -375,7 +378,8 @@ public:
|
||||
|
||||
// Resolve alignment.
|
||||
unsigned align;
|
||||
if (failed(getMemRefAlignment(*getTypeConverter(), store, align)))
|
||||
if (failed(getMemRefAlignment(*getTypeConverter(), store.getMemRefType(),
|
||||
align)))
|
||||
return failure();
|
||||
|
||||
auto vtype = typeConverter->convertType(store.getValueVectorType());
|
||||
@@ -405,7 +409,8 @@ public:
|
||||
|
||||
// Resolve alignment.
|
||||
unsigned align;
|
||||
if (failed(getMemRefAlignment(*getTypeConverter(), gather, align)))
|
||||
if (failed(getMemRefAlignment(*getTypeConverter(), gather.getMemRefType(),
|
||||
align)))
|
||||
return failure();
|
||||
|
||||
// Get index ptrs.
|
||||
@@ -438,7 +443,8 @@ public:
|
||||
|
||||
// Resolve alignment.
|
||||
unsigned align;
|
||||
if (failed(getMemRefAlignment(*getTypeConverter(), scatter, align)))
|
||||
if (failed(getMemRefAlignment(*getTypeConverter(), scatter.getMemRefType(),
|
||||
align)))
|
||||
return failure();
|
||||
|
||||
// Get index ptrs.
|
||||
@@ -1182,8 +1188,11 @@ public:
|
||||
xferOp.getVectorType().getRank(),
|
||||
xferOp->getContext()))
|
||||
return failure();
|
||||
auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>();
|
||||
if (!memRefType)
|
||||
return failure();
|
||||
// Only contiguous source tensors supported atm.
|
||||
auto strides = computeContiguousStrides(xferOp.getMemRefType());
|
||||
auto strides = computeContiguousStrides(memRefType);
|
||||
if (!strides)
|
||||
return failure();
|
||||
|
||||
@@ -1192,10 +1201,9 @@ public:
|
||||
};
|
||||
|
||||
Location loc = xferOp->getLoc();
|
||||
MemRefType memRefType = xferOp.getMemRefType();
|
||||
|
||||
if (auto memrefVectorElementType =
|
||||
memRefType.getElementType().dyn_cast<VectorType>()) {
|
||||
memRefType.getElementType().template dyn_cast<VectorType>()) {
|
||||
// Memref has vector element type.
|
||||
if (memrefVectorElementType.getElementType() !=
|
||||
xferOp.getVectorType().getElementType())
|
||||
@@ -1222,7 +1230,7 @@ public:
|
||||
// address space 0.
|
||||
// TODO: support alignment when possible.
|
||||
Value dataPtr = this->getStridedElementPtr(
|
||||
loc, memRefType, adaptor.memref(), adaptor.indices(), rewriter);
|
||||
loc, memRefType, adaptor.source(), adaptor.indices(), rewriter);
|
||||
auto vecTy =
|
||||
toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
|
||||
Value vectorDataPtr;
|
||||
@@ -1248,7 +1256,7 @@ public:
|
||||
unsigned vecWidth = vecTy.getVectorNumElements();
|
||||
unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
|
||||
Value off = xferOp.indices()[lastIndex];
|
||||
Value dim = rewriter.create<DimOp>(loc, xferOp.memref(), lastIndex);
|
||||
Value dim = rewriter.create<DimOp>(loc, xferOp.source(), lastIndex);
|
||||
Value mask = buildVectorComparison(
|
||||
rewriter, xferOp, enableIndexOptimizations, vecWidth, dim, &off);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user