[mlir][VectorOps] Don't drop scalable dims when lowering transfer_reads/writes (in VectorToSCF)

This allows the lowering of > rank 1 transfer_reads/writes to equivalent
lower-rank ones when the trailing dimension is scalable. The resulting
ops still cannot be completely lowered as they depend on arrays of
scalable vectors being enabled, and a few related fixes (see D158517).

This patch also explicitly disables lowering transfer_reads/writes with
a leading scalable dimension, as more changes would be needed to handle
that correctly and it is unclear if it is required.

Examples of ops that can now be further lowered:

  %vec = vector.transfer_read %arg0[%c0, %c0], %cst, %mask
		 {in_bounds = [true, true]} : memref<3x?xf32>, vector<3x[4]xf32>

  vector.transfer_write %vec, %arg0[%c0, %c0], %mask
		 {in_bounds = [true, true]} : vector<3x[4]xf32>, memref<3x?xf32>

Reviewed By: c-rhodes, awarzynski, dcaballe

Differential Revision: https://reviews.llvm.org/D158753
This commit is contained in:
Benjamin Maxwell
2023-09-08 09:43:15 +00:00
parent eebf8faf3e
commit 2a82dfd704
2 changed files with 129 additions and 8 deletions

View File

@@ -314,15 +314,18 @@ static BufferAllocs allocBuffers(OpBuilder &b, OpTy xferOp) {
/// the VectorType into the MemRefType.
///
/// E.g.: memref<9xvector<5x6xf32>> --> memref<9x5xvector<6xf32>>
static MemRefType unpackOneDim(MemRefType type) {
static FailureOr<MemRefType> unpackOneDim(MemRefType type) {
auto vectorType = dyn_cast<VectorType>(type.getElementType());
// Vectors with leading scalable dims are not supported.
// It may be possible to support these in future by using dynamic memref dims.
if (vectorType.getScalableDims().front())
return failure();
auto memrefShape = type.getShape();
SmallVector<int64_t, 8> newMemrefShape;
newMemrefShape.append(memrefShape.begin(), memrefShape.end());
newMemrefShape.push_back(vectorType.getDimSize(0));
return MemRefType::get(newMemrefShape,
VectorType::get(vectorType.getShape().drop_front(),
vectorType.getElementType()));
VectorType::Builder(vectorType).dropDim(0));
}
/// Given a transfer op, find the memref from which the mask is loaded. This
@@ -542,6 +545,10 @@ LogicalResult checkPrepareXferOp(OpTy xferOp,
return failure();
if (xferOp.getVectorType().getRank() <= options.targetRank)
return failure();
// Currently the unpacking of the leading dimension into the memref is not
// supported for scalable dimensions.
if (xferOp.getVectorType().getScalableDims().front())
return failure();
if (isTensorOp(xferOp) && !options.lowerTensors)
return failure();
// Transfer ops that modify the element type are not supported atm.
@@ -866,8 +873,11 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
auto dataBuffer = Strategy<OpTy>::getBuffer(xferOp);
auto dataBufferType = dyn_cast<MemRefType>(dataBuffer.getType());
auto castedDataType = unpackOneDim(dataBufferType);
if (failed(castedDataType))
return failure();
auto castedDataBuffer =
locB.create<vector::TypeCastOp>(castedDataType, dataBuffer);
locB.create<vector::TypeCastOp>(*castedDataType, dataBuffer);
// If the xferOp has a mask: Find and cast mask buffer.
Value castedMaskBuffer;
@@ -882,7 +892,9 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
// be broadcasted.)
castedMaskBuffer = maskBuffer;
} else {
auto castedMaskType = unpackOneDim(maskBufferType);
// It's safe to assume the mask buffer can be unpacked if the data
// buffer was unpacked.
auto castedMaskType = *unpackOneDim(maskBufferType);
castedMaskBuffer =
locB.create<vector::TypeCastOp>(castedMaskType, maskBuffer);
}
@@ -891,7 +903,7 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
// Loop bounds and step.
auto lb = locB.create<arith::ConstantIndexOp>(0);
auto ub = locB.create<arith::ConstantIndexOp>(
castedDataType.getDimSize(castedDataType.getRank() - 1));
castedDataType->getDimSize(castedDataType->getRank() - 1));
auto step = locB.create<arith::ConstantIndexOp>(1);
// TransferWriteOps that operate on tensors return the modified tensor and
// require a loop state.
@@ -1074,8 +1086,14 @@ struct UnrollTransferReadConversion
auto vec = getResultVector(xferOp, rewriter);
auto vecType = dyn_cast<VectorType>(vec.getType());
auto xferVecType = xferOp.getVectorType();
auto newXferVecType = VectorType::get(xferVecType.getShape().drop_front(),
xferVecType.getElementType());
if (xferVecType.getScalableDims()[0]) {
// Cannot unroll a scalable dimension at compile time.
return failure();
}
VectorType newXferVecType = VectorType::Builder(xferVecType).dropDim(0);
int64_t dimSize = xferVecType.getShape()[0];
// Generate fully unrolled loop of transfer ops.

View File

@@ -635,3 +635,106 @@ func.func @vector_print_scalable_vector(%arg0: vector<[4]xi32>) {
// CHECK: vector.print
// CHECK: return
// CHECK: }
// -----
func.func @transfer_read_array_of_scalable(%arg0: memref<3x?xf32>) -> vector<3x[4]xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%cst = arith.constant 0.000000e+00 : f32
%dim = memref.dim %arg0, %c1 : memref<3x?xf32>
%mask = vector.create_mask %c1, %dim : vector<3x[4]xi1>
%read = vector.transfer_read %arg0[%c0, %c0], %cst, %mask {in_bounds = [true, true]} : memref<3x?xf32>, vector<3x[4]xf32>
return %read : vector<3x[4]xf32>
}
// CHECK-LABEL: func.func @transfer_read_array_of_scalable(
// CHECK-SAME: %[[ARG:.*]]: memref<3x?xf32>) -> vector<3x[4]xf32> {
// CHECK: %[[PADDING:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[C3:.*]] = arith.constant 3 : index
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[ALLOCA_VEC:.*]] = memref.alloca() : memref<vector<3x[4]xf32>>
// CHECK: %[[ALLOCA_MASK:.*]] = memref.alloca() : memref<vector<3x[4]xi1>>
// CHECK: %[[DIM_SIZE:.*]] = memref.dim %[[ARG]], %[[C1]] : memref<3x?xf32>
// CHECK: %[[MASK:.*]] = vector.create_mask %[[C1]], %[[DIM_SIZE]] : vector<3x[4]xi1>
// CHECK: memref.store %[[MASK]], %[[ALLOCA_MASK]][] : memref<vector<3x[4]xi1>>
// CHECK: %[[UNPACK_VECTOR:.*]] = vector.type_cast %[[ALLOCA_VEC]] : memref<vector<3x[4]xf32>> to memref<3xvector<[4]xf32>>
// CHECK: %[[UNPACK_MASK:.*]] = vector.type_cast %[[ALLOCA_MASK]] : memref<vector<3x[4]xi1>> to memref<3xvector<[4]xi1>>
// CHECK: scf.for %[[VAL_11:.*]] = %[[C0]] to %[[C3]] step %[[C1]] {
// CHECK: %[[MASK_SLICE:.*]] = memref.load %[[UNPACK_MASK]]{{\[}}%[[VAL_11]]] : memref<3xvector<[4]xi1>>
// CHECK: %[[READ_SLICE:.*]] = vector.transfer_read %[[ARG]]{{\[}}%[[VAL_11]], %[[C0]]], %[[PADDING]], %[[MASK_SLICE]] {in_bounds = [true]} : memref<3x?xf32>, vector<[4]xf32>
// CHECK: memref.store %[[READ_SLICE]], %[[UNPACK_VECTOR]]{{\[}}%[[VAL_11]]] : memref<3xvector<[4]xf32>>
// CHECK: }
// CHECK: %[[RESULT:.*]] = memref.load %[[ALLOCA_VEC]][] : memref<vector<3x[4]xf32>>
// CHECK: return %[[RESULT]] : vector<3x[4]xf32>
// CHECK: }
// -----
func.func @transfer_write_array_of_scalable(%vec: vector<3x[4]xf32>, %arg0: memref<3x?xf32>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%cst = arith.constant 0.000000e+00 : f32
%dim = memref.dim %arg0, %c1 : memref<3x?xf32>
%mask = vector.create_mask %c1, %dim : vector<3x[4]xi1>
vector.transfer_write %vec, %arg0[%c0, %c0], %mask {in_bounds = [true, true]} : vector<3x[4]xf32>, memref<3x?xf32>
return
}
// CHECK-LABEL: func.func @transfer_write_array_of_scalable(
// CHECK-SAME: %[[VEC:.*]]: vector<3x[4]xf32>,
// CHECK-SAME: %[[MEMREF:.*]]: memref<3x?xf32>) {
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[C3:.*]] = arith.constant 3 : index
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[ALLOCA_VEC:.*]] = memref.alloca() : memref<vector<3x[4]xf32>>
// CHECK: %[[ALLOCA_MASK:.*]] = memref.alloca() : memref<vector<3x[4]xi1>>
// CHECK: %[[DIM_SIZE:.*]] = memref.dim %[[MEMREF]], %[[C1]] : memref<3x?xf32>
// CHECK: %[[MASK:.*]] = vector.create_mask %[[C1]], %[[DIM_SIZE]] : vector<3x[4]xi1>
// CHECK: memref.store %[[MASK]], %[[ALLOCA_MASK]][] : memref<vector<3x[4]xi1>>
// CHECK: memref.store %[[VEC]], %[[ALLOCA_VEC]][] : memref<vector<3x[4]xf32>>
// CHECK: %[[UNPACK_VECTOR:.*]] = vector.type_cast %[[ALLOCA_VEC]] : memref<vector<3x[4]xf32>> to memref<3xvector<[4]xf32>>
// CHECK: %[[UNPACK_MASK:.*]] = vector.type_cast %[[ALLOCA_MASK]] : memref<vector<3x[4]xi1>> to memref<3xvector<[4]xi1>>
// CHECK: scf.for %[[VAL_11:.*]] = %[[C0]] to %[[C3]] step %[[C1]] {
// CHECK: %[[MASK_SLICE:.*]] = memref.load %[[UNPACK_VECTOR]]{{\[}}%[[VAL_11]]] : memref<3xvector<[4]xf32>>
// CHECK: %[[VECTOR_SLICE:.*]] = memref.load %[[UNPACK_MASK]]{{\[}}%[[VAL_11]]] : memref<3xvector<[4]xi1>>
// CHECK: vector.transfer_write %[[MASK_SLICE]], %[[MEMREF]]{{\[}}%[[VAL_11]], %[[C0]]], %[[VECTOR_SLICE]] {in_bounds = [true]} : vector<[4]xf32>, memref<3x?xf32>
// CHECK: }
// CHECK: return
// CHECK: }
// -----
/// The following two tests currently cannot be lowered via unpacking the leading dim since it is scalable.
/// It may be possible to special case this via a dynamic dim in future.
func.func @cannot_lower_transfer_write_with_leading_scalable(%vec: vector<[4]x4xf32>, %arg0: memref<?x4xf32>) {
%c0 = arith.constant 0 : index
%c4 = arith.constant 4 : index
%cst = arith.constant 0.000000e+00 : f32
%dim = memref.dim %arg0, %c0 : memref<?x4xf32>
%mask = vector.create_mask %dim, %c4 : vector<[4]x4xi1>
vector.transfer_write %vec, %arg0[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[4]x4xf32>, memref<?x4xf32>
return
}
// CHECK-LABEL: func.func @cannot_lower_transfer_write_with_leading_scalable(
// CHECK-SAME: %[[VEC:.*]]: vector<[4]x4xf32>,
// CHECK-SAME: %[[MEMREF:.*]]: memref<?x4xf32>)
// CHECK: vector.transfer_write %[[VEC]], %[[MEMREF]][%{{.*}}, %{{.*}}], %{{.*}} {in_bounds = [true, true]} : vector<[4]x4xf32>, memref<?x4xf32>
// -----
func.func @cannot_lower_transfer_read_with_leading_scalable(%arg0: memref<?x4xf32>) -> vector<[4]x4xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
%cst = arith.constant 0.000000e+00 : f32
%dim = memref.dim %arg0, %c0 : memref<?x4xf32>
%mask = vector.create_mask %dim, %c4 : vector<[4]x4xi1>
%read = vector.transfer_read %arg0[%c0, %c0], %cst, %mask {in_bounds = [true, true]} : memref<?x4xf32>, vector<[4]x4xf32>
return %read : vector<[4]x4xf32>
}
// CHECK-LABEL: func.func @cannot_lower_transfer_read_with_leading_scalable(
// CHECK-SAME: %[[MEMREF:.*]]: memref<?x4xf32>)
// CHECK: %{{.*}} = vector.transfer_read %[[MEMREF]][%{{.*}}, %{{.*}}], %{{.*}}, %{{.*}} {in_bounds = [true, true]} : memref<?x4xf32>, vector<[4]x4xf32>