[mlir][vector] Avoid setting padding by default to 0 in vector.transfer_read prefer ub.poison (#146088)

Context:
`vector.transfer_read` always requires a padding value. Most of its
builders take no `padding` value and assume the safe value of `0`.
However, this should be a conscious choice by the API user, as it makes
it easy to introduce bugs.
For example, I found several occasions while making this patch that the
padding value was not getting propagated (`vector.transfer_read` was
transformed into another `vector.transfer_read`). These bugs, were
always caused because of constructors that don't require specifying
padding.

Additionally, using `ub.poison` as a possible default value is better,
as it indicates the user "doesn't care" about the actual padding value,
forcing users to specify the actual padding semantics they want.

With that in mind, this patch changes the builders in
`vector.transfer_read` to always having a `std::optional<Value> padding`
argument. This argument is never optional, but for convenience users can
pass `std::nullopt`, padding the transfer read with `ub.poison`.

---------

Signed-off-by: Fabian Mora <fabian.mora-cordero@amd.com>
This commit is contained in:
Fabian Mora
2025-06-30 15:20:42 -04:00
committed by GitHub
parent 6a57af8d03
commit 878d3594ed
15 changed files with 108 additions and 79 deletions

View File

@@ -154,6 +154,11 @@ Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc,
Value lhs, Value rhs);
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred);
/// Creates an `arith.constant` operation with a zero value of type `type`. This
/// method asserts if `type` is invalid for representing zero with
/// `arith.constant`.
Value getZeroConstant(OpBuilder &builder, Location loc, Type type);
} // namespace arith
} // namespace mlir

View File

@@ -21,7 +21,10 @@ def Vector_Dialect : Dialect {
let useDefaultAttributePrinterParser = 1;
let hasConstantMaterializer = 1;
let dependentDialects = ["arith::ArithDialect"];
let dependentDialects = [
"arith::ArithDialect",
"ub::UBDialect"
];
}
// Base class for Vector dialect ops.

View File

@@ -1543,30 +1543,29 @@ def Vector_TransferReadOp :
}];
let builders = [
/// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
/// 1. Builder that sets padding to `padding` or poison if not provided and
/// an empty mask (variant with attrs).
OpBuilder<(ins "VectorType":$vectorType,
"Value":$source,
"ValueRange":$indices,
"std::optional<Value>":$padding,
"AffineMapAttr":$permutationMapAttr,
"ArrayAttr":$inBoundsAttr)>,
/// 2. Builder that sets padding to zero and an empty mask (variant without attrs).
/// 2. Builder that sets padding to `padding` or poison if not provided and
/// an empty mask (variant without attrs).
OpBuilder<(ins "VectorType":$vectorType,
"Value":$source,
"ValueRange":$indices,
"std::optional<Value>":$padding,
"AffineMap":$permutationMap,
CArg<"std::optional<ArrayRef<bool>>", "::std::nullopt">:$inBounds)>,
/// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
/// 3. Builder that sets padding to `padding` or poison if not provided and
/// permutation map to 'getMinorIdentityMap'.
OpBuilder<(ins "VectorType":$vectorType,
"Value":$source,
"ValueRange":$indices,
"Value":$padding,
CArg<"std::optional<ArrayRef<bool>>", "::std::nullopt">:$inBounds)>,
/// 4. Builder that sets padding to zero and permutation map to
/// 'getMinorIdentityMap'.
OpBuilder<(ins "VectorType":$vectorType,
"Value":$source,
"ValueRange":$indices,
CArg<"std::optional<ArrayRef<bool>>", "::std::nullopt">:$inBounds)>,
"std::optional<Value>":$padding,
CArg<"std::optional<ArrayRef<bool>>", "::std::nullopt">:$inBounds)>
];
let extraClassDeclaration = [{

View File

@@ -1257,7 +1257,8 @@ static Operation *vectorizeAffineLoad(AffineLoadOp loadOp,
LLVM_DEBUG(permutationMap.print(dbgs()));
auto transfer = state.builder.create<vector::TransferReadOp>(
loadOp.getLoc(), vectorType, loadOp.getMemRef(), indices, permutationMap);
loadOp.getLoc(), vectorType, loadOp.getMemRef(), indices,
/*padding=*/std::nullopt, permutationMap);
// Register replacement for future uses in the scope.
state.registerOpVectorReplacement(loadOp, transfer);

View File

@@ -292,6 +292,16 @@ bool arith::ConstantIndexOp::classof(Operation *op) {
return false;
}
Value mlir::arith::getZeroConstant(OpBuilder &builder, Location loc,
Type type) {
// TODO: Incorporate this check to `FloatAttr::get*`.
assert(!isa<Float8E8M0FNUType>(getElementTypeOrSelf(type)) &&
"type doesn't have a zero representation");
TypedAttr zeroAttr = builder.getZeroAttr(type);
assert(zeroAttr && "unsupported type for zero attribute");
return builder.create<arith::ConstantOp>(loc, zeroAttr);
}
//===----------------------------------------------------------------------===//
// AddIOp
//===----------------------------------------------------------------------===//

View File

@@ -426,6 +426,7 @@ struct LegalizeTransferRead : public OpRewritePattern<vector::TransferReadOp> {
// Create the new `transfer_read`.
auto newReadOp = rewriter.create<vector::TransferReadOp>(
readOp.getLoc(), collapsedVT, collapsedMem, indices,
readOp.getPadding(),
ArrayRef<bool>(origInBounds).drop_back(numCollapseDims - 1));
// Cast back to the original vector type.

View File

@@ -1183,6 +1183,10 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
auto srcRank = extractOp.getTensor().getType().getRank();
SmallVector<bool> inBounds(dstRank, true);
// Get the value to pad transfer reads with 0.
Value padding =
arith::getZeroConstant(rewriter, loc, resultType.getElementType());
// 2a. Handle scalar broadcast access.
if (memAccessKind == VectorMemoryAccessKind::ScalarBroadcast) {
MLIRContext *ctx = rewriter.getContext();
@@ -1190,7 +1194,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
auto permutationMap = AffineMap::get(srcRank, 0, exprs, ctx);
auto transferReadOp = rewriter.create<vector::TransferReadOp>(
loc, resultType, extractOp.getTensor(), transferReadIdxs,
loc, resultType, extractOp.getTensor(), transferReadIdxs, padding,
permutationMap, inBounds);
// Mask this broadcasting xfer_read here rather than relying on the generic
@@ -1227,8 +1231,8 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
}
auto transferReadOp = rewriter.create<vector::TransferReadOp>(
loc, resultType, extractOp.getTensor(), transferReadIdxs, permutationMap,
inBounds);
loc, resultType, extractOp.getTensor(), transferReadIdxs, padding,
permutationMap, inBounds);
LDBG("Vectorised as contiguous load: " << extractOp);
return VectorizationHookResult{VectorizationHookStatus::NewOp,
@@ -1384,7 +1388,7 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
/// performed to the maximal common vector size implied by the `linalgOp`
/// iteration space. This eager broadcasting is introduced in the
/// permutation_map of the vector.transfer_read operations. The eager
/// broadcasting makes it trivial to detrmine where broadcast, transposes and
/// broadcasting makes it trivial to determine where broadcast, transposes and
/// reductions should occur, without any bookkeeping. The tradeoff is that, in
/// the absence of good canonicalizations, the amount of work increases.
/// This is not deemed a problem as we expect canonicalizations and foldings to
@@ -1439,7 +1443,8 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
SmallVector<Value> indices(linalgOp.getShape(opOperand).size(), zero);
Operation *read = rewriter.create<vector::TransferReadOp>(
loc, readType, opOperand->get(), indices, readMap);
loc, readType, opOperand->get(), indices,
/*padding=*/arith::getZeroConstant(rewriter, loc, elemType), readMap);
read = state.maskOperation(rewriter, read, linalgOp, indexingMap);
Value readValue = read->getResult(0);
@@ -2641,6 +2646,7 @@ LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
Value readValue = rewriter.create<vector::TransferReadOp>(
loc, readType, copyOp.getSource(), indices,
/*padding=*/arith::getZeroConstant(rewriter, loc, srcElementType),
rewriter.getMultiDimIdentityMap(srcType.getRank()));
if (cast<VectorType>(readValue.getType()).getRank() == 0) {
readValue =
@@ -3487,15 +3493,18 @@ struct Conv1DGenerator
SmallVector<Value> resPadding(resShape.size(), zero);
// Read the whole lhs, rhs and res in one shot (with zero padding).
Value lhs = rewriter.create<vector::TransferReadOp>(loc, lhsType, lhsShaped,
lhsPadding);
Value lhs = rewriter.create<vector::TransferReadOp>(
loc, lhsType, lhsShaped, lhsPadding,
/*padding=*/arith::getZeroConstant(rewriter, loc, lhsEltType));
// This is needed only for Conv.
Value rhs = nullptr;
if (oper == ConvOperationKind::Conv)
rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
rhsPadding);
Value res = rewriter.create<vector::TransferReadOp>(loc, resType, resShaped,
resPadding);
rhs = rewriter.create<vector::TransferReadOp>(
loc, rhsType, rhsShaped, rhsPadding,
/*padding=*/arith::getZeroConstant(rewriter, loc, rhsEltType));
Value res = rewriter.create<vector::TransferReadOp>(
loc, resType, resShaped, resPadding,
/*padding=*/arith::getZeroConstant(rewriter, loc, resEltType));
// The base vectorization case for channeled convolution is input:
// {n,w,c}, weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
@@ -3742,19 +3751,22 @@ struct Conv1DGenerator
// Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
// 0].
Value lhs = rewriter.create<vector::TransferReadOp>(
loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
loc, lhsType, lhsShaped, ValueRange{zero, zero, zero},
/*padding=*/arith::getZeroConstant(rewriter, loc, lhsEltType));
auto maybeMaskedLhs = maybeMaskXferOp(
lhsType.getShape(), lhsType.getScalableDims(), lhs.getDefiningOp());
// Read rhs slice of size {kw, c} @ [0, 0].
Value rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
ValueRange{zero, zero});
Value rhs = rewriter.create<vector::TransferReadOp>(
loc, rhsType, rhsShaped, ValueRange{zero, zero},
/*padding=*/arith::getZeroConstant(rewriter, loc, rhsEltType));
auto maybeMaskedRhs = maybeMaskXferOp(
rhsType.getShape(), rhsType.getScalableDims(), rhs.getDefiningOp());
// Read res slice of size {n, w, c} @ [0, 0, 0].
Value res = rewriter.create<vector::TransferReadOp>(
loc, resType, resShaped, ValueRange{zero, zero, zero});
loc, resType, resShaped, ValueRange{zero, zero, zero},
/*padding=*/arith::getZeroConstant(rewriter, loc, resEltType));
auto maybeMaskedRes = maybeMaskXferOp(
resType.getShape(), resType.getScalableDims(), res.getDefiningOp());

View File

@@ -4261,33 +4261,39 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns(
/// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
void TransferReadOp::build(OpBuilder &builder, OperationState &result,
VectorType vectorType, Value source,
ValueRange indices, AffineMapAttr permutationMapAttr,
ValueRange indices, std::optional<Value> padding,
AffineMapAttr permutationMapAttr,
/*optional*/ ArrayAttr inBoundsAttr) {
Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
Value padding = builder.create<arith::ConstantOp>(
result.location, elemType, builder.getZeroAttr(elemType));
if (!padding)
padding = builder.create<ub::PoisonOp>(result.location, elemType);
build(builder, result, vectorType, source, indices, permutationMapAttr,
padding, /*mask=*/Value(), inBoundsAttr);
*padding, /*mask=*/Value(), inBoundsAttr);
}
/// 2. Builder that sets padding to zero an empty mask (variant without attrs).
void TransferReadOp::build(OpBuilder &builder, OperationState &result,
VectorType vectorType, Value source,
ValueRange indices, AffineMap permutationMap,
ValueRange indices, std::optional<Value> padding,
AffineMap permutationMap,
std::optional<ArrayRef<bool>> inBounds) {
auto permutationMapAttr = AffineMapAttr::get(permutationMap);
auto inBoundsAttr = (inBounds && !inBounds.value().empty())
? builder.getBoolArrayAttr(inBounds.value())
: builder.getBoolArrayAttr(
SmallVector<bool>(vectorType.getRank(), false));
build(builder, result, vectorType, source, indices, permutationMapAttr,
inBoundsAttr);
Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
if (!padding)
padding = builder.create<ub::PoisonOp>(result.location, elemType);
build(builder, result, vectorType, source, indices, *padding,
permutationMapAttr, inBoundsAttr);
}
/// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
void TransferReadOp::build(OpBuilder &builder, OperationState &result,
VectorType vectorType, Value source,
ValueRange indices, Value padding,
ValueRange indices, std::optional<Value> padding,
std::optional<ArrayRef<bool>> inBounds) {
AffineMap permutationMap = getTransferMinorIdentityMap(
llvm::cast<ShapedType>(source.getType()), vectorType);
@@ -4296,21 +4302,12 @@ void TransferReadOp::build(OpBuilder &builder, OperationState &result,
? builder.getBoolArrayAttr(inBounds.value())
: builder.getBoolArrayAttr(
SmallVector<bool>(vectorType.getRank(), false));
build(builder, result, vectorType, source, indices, permutationMapAttr,
padding,
/*mask=*/Value(), inBoundsAttr);
}
/// 4. Builder that sets padding to zero and permutation map to
/// 'getMinorIdentityMap'.
void TransferReadOp::build(OpBuilder &builder, OperationState &result,
VectorType vectorType, Value source,
ValueRange indices,
std::optional<ArrayRef<bool>> inBounds) {
Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
Value padding = builder.create<arith::ConstantOp>(
result.location, elemType, builder.getZeroAttr(elemType));
build(builder, result, vectorType, source, indices, padding, inBounds);
if (!padding)
padding = builder.create<ub::PoisonOp>(result.location, elemType);
build(builder, result, vectorType, source, indices, permutationMapAttr,
*padding,
/*mask=*/Value(), inBoundsAttr);
}
template <typename EmitFun>

View File

@@ -173,7 +173,7 @@ struct DistributedLoadStoreHelper {
}
SmallVector<bool> inBounds(indices.size(), true);
return b.create<vector::TransferReadOp>(
loc, cast<VectorType>(type), buffer, indices,
loc, cast<VectorType>(type), buffer, indices, /*padding=*/std::nullopt,
ArrayRef<bool>(inBounds.begin(), inBounds.end()));
}

View File

@@ -660,7 +660,8 @@ public:
VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
vectorType.getElementType());
vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>(
loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap);
loc, flatVectorType, collapsedSource, collapsedIndices,
transferReadOp.getPadding(), collapsedMap);
flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
// 4. Replace the old transfer_read with the new one reading from the

View File

@@ -21,7 +21,7 @@ func.func @vec1d_1(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
// CHECK: for {{.*}} step 128
// CHECK-NEXT: %{{.*}} = affine.apply #[[$map_id1]](%[[C0]])
// CHECK-NEXT: %{{.*}} = affine.apply #[[$map_id1]](%[[C0]])
// CHECK-NEXT: %{{.*}} = arith.constant 0.0{{.*}}: f32
// CHECK-NEXT: %{{.*}} = ub.poison : f32
// CHECK-NEXT: {{.*}} = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} {permutation_map = #[[$map_proj_d0d1_0]]} : memref<?x?xf32>, vector<128xf32>
affine.for %i0 = 0 to %M { // vectorized due to scalar -> vector
%a0 = affine.load %A[%c0, %c0] : memref<?x?xf32>
@@ -47,7 +47,7 @@ func.func @vec1d_2(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
%P = memref.dim %B, %c2 : memref<?x?x?xf32>
// CHECK:for [[IV3:%[a-zA-Z0-9]+]] = 0 to [[ARG_M]] step 128
// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.0{{.*}}: f32
// CHECK-NEXT: %[[CST:.*]] = ub.poison : f32
// CHECK-NEXT: {{.*}} = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %[[CST]] : memref<?x?xf32>, vector<128xf32>
affine.for %i3 = 0 to %M { // vectorized
%a3 = affine.load %A[%c0, %i3] : memref<?x?xf32>
@@ -76,7 +76,7 @@ func.func @vec1d_3(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
// CHECK-NEXT: for [[IV9:%[0-9a-zA-Z_]*]] = 0 to [[ARG_N]] {
// CHECK-NEXT: %[[APP9_0:[0-9a-zA-Z_]+]] = affine.apply {{.*}}([[IV9]], [[IV8]])
// CHECK-NEXT: %[[APP9_1:[0-9a-zA-Z_]+]] = affine.apply {{.*}}([[IV9]], [[IV8]])
// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.0{{.*}}: f32
// CHECK-NEXT: %[[CST:.*]] = ub.poison : f32
// CHECK-NEXT: {{.*}} = vector.transfer_read %{{.*}}[%[[APP9_0]], %[[APP9_1]]], %[[CST]] : memref<?x?xf32>, vector<128xf32>
affine.for %i8 = 0 to %M { // vectorized
affine.for %i9 = 0 to %N {
@@ -280,7 +280,7 @@ func.func @vec_rejected_3(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
// CHECK:for [[IV4:%[0-9a-zA-Z_]+]] = 0 to [[ARG_M]] step 128 {
// CHECK-NEXT: for [[IV5:%[0-9a-zA-Z_]*]] = 0 to [[ARG_N]] {
// CHECK-NEXT: %{{.*}} = arith.constant 0.0{{.*}}: f32
// CHECK-NEXT: %{{.*}} = ub.poison : f32
// CHECK-NEXT: {{.*}} = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %{{[a-zA-Z0-9_]*}} : memref<?x?xf32>, vector<128xf32>
affine.for %i4 = 0 to %M { // vectorized
affine.for %i5 = 0 to %N { // not vectorized, would vectorize with --test-fastest-varying=1
@@ -424,7 +424,7 @@ func.func @vec_rejected_8(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
// CHECK: for [[IV18:%[a-zA-Z0-9]+]] = 0 to [[ARG_M]] step 128
// CHECK: %{{.*}} = affine.apply #[[$map_id1]](%{{.*}})
// CHECK: %{{.*}} = affine.apply #[[$map_id1]](%{{.*}})
// CHECK: %{{.*}} = arith.constant 0.0{{.*}}: f32
// CHECK: %{{.*}} = ub.poison : f32
// CHECK: {{.*}} = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} {permutation_map = #[[$map_proj_d0d1_0]]} : memref<?x?xf32>, vector<128xf32>
affine.for %i17 = 0 to %M { // not vectorized, the 1-D pattern that matched %{{.*}} in DFS post-order prevents vectorizing %{{.*}}
affine.for %i18 = 0 to %M { // vectorized due to scalar -> vector
@@ -458,7 +458,7 @@ func.func @vec_rejected_9(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
// CHECK: for [[IV18:%[a-zA-Z0-9]+]] = 0 to [[ARG_M]] step 128
// CHECK: %{{.*}} = affine.apply #[[$map_id1]](%{{.*}})
// CHECK-NEXT: %{{.*}} = affine.apply #[[$map_id1]](%{{.*}})
// CHECK-NEXT: %{{.*}} = arith.constant 0.0{{.*}}: f32
// CHECK-NEXT: %{{.*}} = ub.poison : f32
// CHECK-NEXT: {{.*}} = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} {permutation_map = #[[$map_proj_d0d1_0]]} : memref<?x?xf32>, vector<128xf32>
affine.for %i17 = 0 to %M { // not vectorized, the 1-D pattern that matched %i18 in DFS post-order prevents vectorizing %{{.*}}
affine.for %i18 = 0 to %M { // vectorized due to scalar -> vector

View File

@@ -11,7 +11,7 @@ func.func @vec_affine_apply(%arg0: memref<8x12x16xf32>, %arg1: memref<8x24x48xf3
// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 48 step 8 {
// CHECK-NEXT: %[[S0:.*]] = affine.apply #[[$MAP_ID0]](%[[ARG3]])
// CHECK-NEXT: %[[S1:.*]] = affine.apply #[[$MAP_ID1]](%[[ARG4]])
// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-NEXT: %[[CST:.*]] = ub.poison : f32
// CHECK-NEXT: %[[S2:.*]] = vector.transfer_read %[[ARG0]][%[[ARG2]], %[[S0]], %[[S1]]], %[[CST]] : memref<8x12x16xf32>, vector<8xf32>
// CHECK-NEXT: vector.transfer_write %[[S2]], %[[ARG1]][%[[ARG2]], %[[ARG3]], %[[ARG4]]] : vector<8xf32>, memref<8x24x48xf32>
// CHECK-NEXT: }
@@ -42,7 +42,7 @@ func.func @vec_affine_apply_2(%arg0: memref<8x12x16xf32>, %arg1: memref<8x24x48x
// CHECK-NEXT: affine.for %[[ARG3:.*]] = 0 to 12 {
// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 48 step 8 {
// CHECK-NEXT: %[[S0:.*]] = affine.apply #[[$MAP_ID2]](%[[ARG4]])
// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-NEXT: %[[CST:.*]] = ub.poison : f32
// CHECK-NEXT: %[[S1:.*]] = vector.transfer_read %[[ARG0]][%[[ARG2]], %[[ARG3]], %[[S0]]], %[[CST]] : memref<8x12x16xf32>, vector<8xf32>
// CHECK-NEXT: vector.transfer_write %[[S1]], %[[ARG1]][%[[ARG2]], %[[ARG3]], %[[ARG4]]] : vector<8xf32>, memref<8x24x48xf32>
// CHECK-NEXT: }
@@ -140,7 +140,7 @@ func.func @affine_map_with_expr_2(%arg0: memref<8x12x16xf32>, %arg1: memref<8x24
// CHECK-NEXT: %[[S0:.*]] = affine.apply #[[$MAP_ID3]](%[[ARG3]], %[[ARG4]], %[[I0]])
// CHECK-NEXT: %[[S1:.*]] = affine.apply #[[$MAP_ID4]](%[[ARG3]], %[[ARG4]], %[[I0]])
// CHECK-NEXT: %[[S2:.*]] = affine.apply #[[$MAP_ID5]](%[[ARG3]], %[[ARG4]], %[[I0]])
// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-NEXT: %[[CST:.*]] = ub.poison : f32
// CHECK-NEXT: %[[S3:.*]] = vector.transfer_read %[[ARG0]][%[[S0]], %[[S1]], %[[S2]]], %[[CST]] {permutation_map = #[[$MAP_ID6]]} : memref<8x12x16xf32>, vector<8xf32>
// CHECK-NEXT: vector.transfer_write %[[S3]], %[[ARG1]][%[[ARG3]], %[[ARG4]], %[[ARG5]]] : vector<8xf32>, memref<8x24x48xf32>
// CHECK-NEXT: }

View File

@@ -11,8 +11,8 @@
// CHECK-LABEL: @base_case
// CHECK-SAME: %[[I:.+]]: index, %[[J:.+]]: index, %[[M:.+]]:
// CHECK: %[[PAD:.+]] = arith.constant 0 : i8
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[PAD:.+]] = arith.constant 123 : i8
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]]
// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
// CHECK-SAME: : memref<?x?x?x8xi8> into memref<?x?x?xi8>
@@ -36,8 +36,8 @@ func.func @base_case(%i : index, %j : index, %M : memref<?x?x?x8xi8>) -> vector<
// CHECK-LABEL: @with_3d_vector
// CHECK-SAME: %[[I:.+]]: index, %[[J:.+]]: index, %[[M:.+]]:
// CHECK: %[[PAD:.+]] = arith.constant 0 : i8
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]]
// CHECK-DAG: %[[PAD:.+]] = arith.constant 123 : i8
// CHECK-DAG: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]]
// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]]
// CHECK-SAME: : memref<?x?x2x8xi8> into memref<?x?xi8>
// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSED]][%[[I]], %[[J]]], %[[PAD]] {in_bounds = [true]}

View File

@@ -85,8 +85,8 @@ func.func @transfer_read_dims_mismatch_contiguous(
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous(
// CHECK-SAME: %[[MEM:.+]]: memref<5x4x3x2xi8, {{.+}}>) -> vector<2x3x2xi8> {
// CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C0_I8:.+]] = arith.constant 0 : i8
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[COLLAPSED_MEM:.+]] = memref.collapse_shape %[[MEM]]
// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]]
// CHECK-SAME: : memref<5x4x3x2xi8, {{.+}}> into memref<5x24xi8, {{.+}}>
@@ -116,8 +116,8 @@ func.func @transfer_read_dims_mismatch_contiguous_unit_dims(
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous_unit_dims(
// CHECK-SAME: %[[MEM:.+]]: memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>>)
// CHECK-SAME: -> vector<1x1x4x3x2xi8>
// CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C0_I8:.+]] = arith.constant 0 : i8
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]]
// CHECK-SAME{LITERAL}: [[0], [1], [2, 3, 4]]
// CHECK-SAME: : memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>>
@@ -149,8 +149,8 @@ func.func @transfer_read_non_contiguous_unit_dims(
// CHECK-LABEL: func.func @transfer_read_non_contiguous_unit_dims(
// CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>) -> vector<1x1x3x2xi8> {
// CHECK: %[[VAL_1:.*]] = arith.constant 0 : i8
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0 : i8
// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[MEM]]
// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
// CHECK-SAME: : memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>
@@ -182,8 +182,8 @@ func.func @transfer_read_dims_mismatch_non_zero_indices(
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices(
// CHECK-SAME: %[[IDX_1:.+]]: index, %[[IDX_2:.+]]: index,
// CHECK-SAME: %[[MEM:.+]]: memref<1x43x4x6xi32>
// CHECK: %[[C0_I32:.+]] = arith.constant 0 : i32
// CHECK: %[[C_0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C0_I32:.+]] = arith.constant 0 : i32
// CHECK-DAG: %[[C_0:.+]] = arith.constant 0 : index
// CHECK: %[[COLLAPSED_IN:.+]] = memref.collapse_shape %[[MEM]]
// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
// CHECK-SAME: : memref<1x43x4x6xi32> into memref<1x43x24xi32>
@@ -241,8 +241,8 @@ func.func @transfer_read_leading_dynamic_dims(
// CHECK-LABEL: func @transfer_read_leading_dynamic_dims
// CHECK-SAME: %[[MEM:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[IDX_1:.+]]: index, %[[IDX_2:.+]]: index
// CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C0_I8:.+]] = arith.constant 0 : i8
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]]
// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
// CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
@@ -304,8 +304,8 @@ func.func @transfer_read_dynamic_dim_to_flatten(
// CHECK-SAME: %[[IDX_1:arg0]]
// CHECK-SAME: %[[IDX_2:arg1]]
// CHECK-SAME: %[[MEM:arg2]]
// CHECK: %[[C0_I32:.+]] = arith.constant 0 : i32
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C0_I32:.+]] = arith.constant 0 : i32
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]]
// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
// CHECK-SAME: memref<1x?x4x6xi32> into memref<1x?x24xi32>

View File

@@ -1132,8 +1132,8 @@ func.func @warp_execute_nd_distribute(%laneid: index, %v0: vector<1x64x1xf32>, %
// CHECK-SCF-IF: gpu.barrier
// CHECK-SCF-IF: %[[WID:.*]] = affine.apply #[[$TIMES2]]()[%[[LANEID]]]
// CHECK-SCF-IF-DAG: %[[R0:.*]] = vector.transfer_read %{{.*}}[%[[LANEID]], %[[C0]], %[[C0]]], %cst {in_bounds = [true, true, true]} : memref<32x64x1xf32, 3>, vector<1x64x1xf32>
// CHECK-SCF-IF-DAG: %[[R1:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[WID]], %[[C0]]], %cst {in_bounds = [true, true, true]} : memref<1x64x128xf32, 3>, vector<1x2x128xf32>
// CHECK-SCF-IF-DAG: %[[R0:.*]] = vector.transfer_read %{{.*}}[%[[LANEID]], %[[C0]], %[[C0]]], %{{.*}} {in_bounds = [true, true, true]} : memref<32x64x1xf32, 3>, vector<1x64x1xf32>
// CHECK-SCF-IF-DAG: %[[R1:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[WID]], %[[C0]]], %{{.*}} {in_bounds = [true, true, true]} : memref<1x64x128xf32, 3>, vector<1x2x128xf32>
// CHECK-SCF-IF: return %[[R0]], %[[R1]] : vector<1x64x1xf32>, vector<1x2x128xf32>
return %r#0, %r#1 : vector<1x64x1xf32>, vector<1x2x128xf32>
}