[mlir][vector] Avoid setting padding by default to 0 in vector.transfer_read prefer ub.poison (#146088)

Context:
`vector.transfer_read` always requires a padding value. Most of its
builders take no `padding` value and assume the safe value of `0`.
However, this should be a conscious choice by the API user, as it makes
it easy to introduce bugs.
For example, I found several occasions while making this patch that the
padding value was not getting propagated (`vector.transfer_read` was
transformed into another `vector.transfer_read`). These bugs, were
always caused because of constructors that don't require specifying
padding.

Additionally, using `ub.poison` as a possible default value is better,
as it indicates the user "doesn't care" about the actual padding value,
forcing users to specify the actual padding semantics they want.

With that in mind, this patch changes the builders in
`vector.transfer_read` to always having a `std::optional<Value> padding`
argument. This argument is never optional, but for convenience users can
pass `std::nullopt`, padding the transfer read with `ub.poison`.

---------

Signed-off-by: Fabian Mora <fabian.mora-cordero@amd.com>
This commit is contained in:
Fabian Mora
2025-06-30 15:20:42 -04:00
committed by GitHub
parent 6a57af8d03
commit 878d3594ed
15 changed files with 108 additions and 79 deletions

View File

@@ -154,6 +154,11 @@ Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc,
Value lhs, Value rhs); Value lhs, Value rhs);
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred); arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred);
/// Creates an `arith.constant` operation with a zero value of type `type`. This
/// method asserts if `type` is invalid for representing zero with
/// `arith.constant`.
Value getZeroConstant(OpBuilder &builder, Location loc, Type type);
} // namespace arith } // namespace arith
} // namespace mlir } // namespace mlir

View File

@@ -21,7 +21,10 @@ def Vector_Dialect : Dialect {
let useDefaultAttributePrinterParser = 1; let useDefaultAttributePrinterParser = 1;
let hasConstantMaterializer = 1; let hasConstantMaterializer = 1;
let dependentDialects = ["arith::ArithDialect"]; let dependentDialects = [
"arith::ArithDialect",
"ub::UBDialect"
];
} }
// Base class for Vector dialect ops. // Base class for Vector dialect ops.

View File

@@ -1543,30 +1543,29 @@ def Vector_TransferReadOp :
}]; }];
let builders = [ let builders = [
/// 1. Builder that sets padding to zero and an empty mask (variant with attrs). /// 1. Builder that sets padding to `padding` or poison if not provided and
/// an empty mask (variant with attrs).
OpBuilder<(ins "VectorType":$vectorType, OpBuilder<(ins "VectorType":$vectorType,
"Value":$source, "Value":$source,
"ValueRange":$indices, "ValueRange":$indices,
"std::optional<Value>":$padding,
"AffineMapAttr":$permutationMapAttr, "AffineMapAttr":$permutationMapAttr,
"ArrayAttr":$inBoundsAttr)>, "ArrayAttr":$inBoundsAttr)>,
/// 2. Builder that sets padding to zero and an empty mask (variant without attrs). /// 2. Builder that sets padding to `padding` or poison if not provided and
/// an empty mask (variant without attrs).
OpBuilder<(ins "VectorType":$vectorType, OpBuilder<(ins "VectorType":$vectorType,
"Value":$source, "Value":$source,
"ValueRange":$indices, "ValueRange":$indices,
"std::optional<Value>":$padding,
"AffineMap":$permutationMap, "AffineMap":$permutationMap,
CArg<"std::optional<ArrayRef<bool>>", "::std::nullopt">:$inBounds)>, CArg<"std::optional<ArrayRef<bool>>", "::std::nullopt">:$inBounds)>,
/// 3. Builder that sets permutation map to 'getMinorIdentityMap'. /// 3. Builder that sets padding to `padding` or poison if not provided and
/// permutation map to 'getMinorIdentityMap'.
OpBuilder<(ins "VectorType":$vectorType, OpBuilder<(ins "VectorType":$vectorType,
"Value":$source, "Value":$source,
"ValueRange":$indices, "ValueRange":$indices,
"Value":$padding, "std::optional<Value>":$padding,
CArg<"std::optional<ArrayRef<bool>>", "::std::nullopt">:$inBounds)>, CArg<"std::optional<ArrayRef<bool>>", "::std::nullopt">:$inBounds)>
/// 4. Builder that sets padding to zero and permutation map to
/// 'getMinorIdentityMap'.
OpBuilder<(ins "VectorType":$vectorType,
"Value":$source,
"ValueRange":$indices,
CArg<"std::optional<ArrayRef<bool>>", "::std::nullopt">:$inBounds)>,
]; ];
let extraClassDeclaration = [{ let extraClassDeclaration = [{

View File

@@ -1257,7 +1257,8 @@ static Operation *vectorizeAffineLoad(AffineLoadOp loadOp,
LLVM_DEBUG(permutationMap.print(dbgs())); LLVM_DEBUG(permutationMap.print(dbgs()));
auto transfer = state.builder.create<vector::TransferReadOp>( auto transfer = state.builder.create<vector::TransferReadOp>(
loadOp.getLoc(), vectorType, loadOp.getMemRef(), indices, permutationMap); loadOp.getLoc(), vectorType, loadOp.getMemRef(), indices,
/*padding=*/std::nullopt, permutationMap);
// Register replacement for future uses in the scope. // Register replacement for future uses in the scope.
state.registerOpVectorReplacement(loadOp, transfer); state.registerOpVectorReplacement(loadOp, transfer);

View File

@@ -292,6 +292,16 @@ bool arith::ConstantIndexOp::classof(Operation *op) {
return false; return false;
} }
Value mlir::arith::getZeroConstant(OpBuilder &builder, Location loc,
Type type) {
// TODO: Incorporate this check to `FloatAttr::get*`.
assert(!isa<Float8E8M0FNUType>(getElementTypeOrSelf(type)) &&
"type doesn't have a zero representation");
TypedAttr zeroAttr = builder.getZeroAttr(type);
assert(zeroAttr && "unsupported type for zero attribute");
return builder.create<arith::ConstantOp>(loc, zeroAttr);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// AddIOp // AddIOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@@ -426,6 +426,7 @@ struct LegalizeTransferRead : public OpRewritePattern<vector::TransferReadOp> {
// Create the new `transfer_read`. // Create the new `transfer_read`.
auto newReadOp = rewriter.create<vector::TransferReadOp>( auto newReadOp = rewriter.create<vector::TransferReadOp>(
readOp.getLoc(), collapsedVT, collapsedMem, indices, readOp.getLoc(), collapsedVT, collapsedMem, indices,
readOp.getPadding(),
ArrayRef<bool>(origInBounds).drop_back(numCollapseDims - 1)); ArrayRef<bool>(origInBounds).drop_back(numCollapseDims - 1));
// Cast back to the original vector type. // Cast back to the original vector type.

View File

@@ -1183,6 +1183,10 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
auto srcRank = extractOp.getTensor().getType().getRank(); auto srcRank = extractOp.getTensor().getType().getRank();
SmallVector<bool> inBounds(dstRank, true); SmallVector<bool> inBounds(dstRank, true);
// Get the value to pad transfer reads with 0.
Value padding =
arith::getZeroConstant(rewriter, loc, resultType.getElementType());
// 2a. Handle scalar broadcast access. // 2a. Handle scalar broadcast access.
if (memAccessKind == VectorMemoryAccessKind::ScalarBroadcast) { if (memAccessKind == VectorMemoryAccessKind::ScalarBroadcast) {
MLIRContext *ctx = rewriter.getContext(); MLIRContext *ctx = rewriter.getContext();
@@ -1190,7 +1194,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
auto permutationMap = AffineMap::get(srcRank, 0, exprs, ctx); auto permutationMap = AffineMap::get(srcRank, 0, exprs, ctx);
auto transferReadOp = rewriter.create<vector::TransferReadOp>( auto transferReadOp = rewriter.create<vector::TransferReadOp>(
loc, resultType, extractOp.getTensor(), transferReadIdxs, loc, resultType, extractOp.getTensor(), transferReadIdxs, padding,
permutationMap, inBounds); permutationMap, inBounds);
// Mask this broadcasting xfer_read here rather than relying on the generic // Mask this broadcasting xfer_read here rather than relying on the generic
@@ -1227,8 +1231,8 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
} }
auto transferReadOp = rewriter.create<vector::TransferReadOp>( auto transferReadOp = rewriter.create<vector::TransferReadOp>(
loc, resultType, extractOp.getTensor(), transferReadIdxs, permutationMap, loc, resultType, extractOp.getTensor(), transferReadIdxs, padding,
inBounds); permutationMap, inBounds);
LDBG("Vectorised as contiguous load: " << extractOp); LDBG("Vectorised as contiguous load: " << extractOp);
return VectorizationHookResult{VectorizationHookStatus::NewOp, return VectorizationHookResult{VectorizationHookStatus::NewOp,
@@ -1384,7 +1388,7 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
/// performed to the maximal common vector size implied by the `linalgOp` /// performed to the maximal common vector size implied by the `linalgOp`
/// iteration space. This eager broadcasting is introduced in the /// iteration space. This eager broadcasting is introduced in the
/// permutation_map of the vector.transfer_read operations. The eager /// permutation_map of the vector.transfer_read operations. The eager
/// broadcasting makes it trivial to detrmine where broadcast, transposes and /// broadcasting makes it trivial to determine where broadcast, transposes and
/// reductions should occur, without any bookkeeping. The tradeoff is that, in /// reductions should occur, without any bookkeeping. The tradeoff is that, in
/// the absence of good canonicalizations, the amount of work increases. /// the absence of good canonicalizations, the amount of work increases.
/// This is not deemed a problem as we expect canonicalizations and foldings to /// This is not deemed a problem as we expect canonicalizations and foldings to
@@ -1439,7 +1443,8 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
SmallVector<Value> indices(linalgOp.getShape(opOperand).size(), zero); SmallVector<Value> indices(linalgOp.getShape(opOperand).size(), zero);
Operation *read = rewriter.create<vector::TransferReadOp>( Operation *read = rewriter.create<vector::TransferReadOp>(
loc, readType, opOperand->get(), indices, readMap); loc, readType, opOperand->get(), indices,
/*padding=*/arith::getZeroConstant(rewriter, loc, elemType), readMap);
read = state.maskOperation(rewriter, read, linalgOp, indexingMap); read = state.maskOperation(rewriter, read, linalgOp, indexingMap);
Value readValue = read->getResult(0); Value readValue = read->getResult(0);
@@ -2641,6 +2646,7 @@ LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
Value readValue = rewriter.create<vector::TransferReadOp>( Value readValue = rewriter.create<vector::TransferReadOp>(
loc, readType, copyOp.getSource(), indices, loc, readType, copyOp.getSource(), indices,
/*padding=*/arith::getZeroConstant(rewriter, loc, srcElementType),
rewriter.getMultiDimIdentityMap(srcType.getRank())); rewriter.getMultiDimIdentityMap(srcType.getRank()));
if (cast<VectorType>(readValue.getType()).getRank() == 0) { if (cast<VectorType>(readValue.getType()).getRank() == 0) {
readValue = readValue =
@@ -3487,15 +3493,18 @@ struct Conv1DGenerator
SmallVector<Value> resPadding(resShape.size(), zero); SmallVector<Value> resPadding(resShape.size(), zero);
// Read the whole lhs, rhs and res in one shot (with zero padding). // Read the whole lhs, rhs and res in one shot (with zero padding).
Value lhs = rewriter.create<vector::TransferReadOp>(loc, lhsType, lhsShaped, Value lhs = rewriter.create<vector::TransferReadOp>(
lhsPadding); loc, lhsType, lhsShaped, lhsPadding,
/*padding=*/arith::getZeroConstant(rewriter, loc, lhsEltType));
// This is needed only for Conv. // This is needed only for Conv.
Value rhs = nullptr; Value rhs = nullptr;
if (oper == ConvOperationKind::Conv) if (oper == ConvOperationKind::Conv)
rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped, rhs = rewriter.create<vector::TransferReadOp>(
rhsPadding); loc, rhsType, rhsShaped, rhsPadding,
Value res = rewriter.create<vector::TransferReadOp>(loc, resType, resShaped, /*padding=*/arith::getZeroConstant(rewriter, loc, rhsEltType));
resPadding); Value res = rewriter.create<vector::TransferReadOp>(
loc, resType, resShaped, resPadding,
/*padding=*/arith::getZeroConstant(rewriter, loc, resEltType));
// The base vectorization case for channeled convolution is input: // The base vectorization case for channeled convolution is input:
// {n,w,c}, weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern // {n,w,c}, weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
@@ -3742,19 +3751,22 @@ struct Conv1DGenerator
// Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0, // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
// 0]. // 0].
Value lhs = rewriter.create<vector::TransferReadOp>( Value lhs = rewriter.create<vector::TransferReadOp>(
loc, lhsType, lhsShaped, ValueRange{zero, zero, zero}); loc, lhsType, lhsShaped, ValueRange{zero, zero, zero},
/*padding=*/arith::getZeroConstant(rewriter, loc, lhsEltType));
auto maybeMaskedLhs = maybeMaskXferOp( auto maybeMaskedLhs = maybeMaskXferOp(
lhsType.getShape(), lhsType.getScalableDims(), lhs.getDefiningOp()); lhsType.getShape(), lhsType.getScalableDims(), lhs.getDefiningOp());
// Read rhs slice of size {kw, c} @ [0, 0]. // Read rhs slice of size {kw, c} @ [0, 0].
Value rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped, Value rhs = rewriter.create<vector::TransferReadOp>(
ValueRange{zero, zero}); loc, rhsType, rhsShaped, ValueRange{zero, zero},
/*padding=*/arith::getZeroConstant(rewriter, loc, rhsEltType));
auto maybeMaskedRhs = maybeMaskXferOp( auto maybeMaskedRhs = maybeMaskXferOp(
rhsType.getShape(), rhsType.getScalableDims(), rhs.getDefiningOp()); rhsType.getShape(), rhsType.getScalableDims(), rhs.getDefiningOp());
// Read res slice of size {n, w, c} @ [0, 0, 0]. // Read res slice of size {n, w, c} @ [0, 0, 0].
Value res = rewriter.create<vector::TransferReadOp>( Value res = rewriter.create<vector::TransferReadOp>(
loc, resType, resShaped, ValueRange{zero, zero, zero}); loc, resType, resShaped, ValueRange{zero, zero, zero},
/*padding=*/arith::getZeroConstant(rewriter, loc, resEltType));
auto maybeMaskedRes = maybeMaskXferOp( auto maybeMaskedRes = maybeMaskXferOp(
resType.getShape(), resType.getScalableDims(), res.getDefiningOp()); resType.getShape(), resType.getScalableDims(), res.getDefiningOp());

View File

@@ -4261,33 +4261,39 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns(
/// 1. Builder that sets padding to zero and an empty mask (variant with attrs). /// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
void TransferReadOp::build(OpBuilder &builder, OperationState &result, void TransferReadOp::build(OpBuilder &builder, OperationState &result,
VectorType vectorType, Value source, VectorType vectorType, Value source,
ValueRange indices, AffineMapAttr permutationMapAttr, ValueRange indices, std::optional<Value> padding,
AffineMapAttr permutationMapAttr,
/*optional*/ ArrayAttr inBoundsAttr) { /*optional*/ ArrayAttr inBoundsAttr) {
Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType(); Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
Value padding = builder.create<arith::ConstantOp>( if (!padding)
result.location, elemType, builder.getZeroAttr(elemType)); padding = builder.create<ub::PoisonOp>(result.location, elemType);
build(builder, result, vectorType, source, indices, permutationMapAttr, build(builder, result, vectorType, source, indices, permutationMapAttr,
padding, /*mask=*/Value(), inBoundsAttr); *padding, /*mask=*/Value(), inBoundsAttr);
} }
/// 2. Builder that sets padding to zero an empty mask (variant without attrs). /// 2. Builder that sets padding to zero an empty mask (variant without attrs).
void TransferReadOp::build(OpBuilder &builder, OperationState &result, void TransferReadOp::build(OpBuilder &builder, OperationState &result,
VectorType vectorType, Value source, VectorType vectorType, Value source,
ValueRange indices, AffineMap permutationMap, ValueRange indices, std::optional<Value> padding,
AffineMap permutationMap,
std::optional<ArrayRef<bool>> inBounds) { std::optional<ArrayRef<bool>> inBounds) {
auto permutationMapAttr = AffineMapAttr::get(permutationMap); auto permutationMapAttr = AffineMapAttr::get(permutationMap);
auto inBoundsAttr = (inBounds && !inBounds.value().empty()) auto inBoundsAttr = (inBounds && !inBounds.value().empty())
? builder.getBoolArrayAttr(inBounds.value()) ? builder.getBoolArrayAttr(inBounds.value())
: builder.getBoolArrayAttr( : builder.getBoolArrayAttr(
SmallVector<bool>(vectorType.getRank(), false)); SmallVector<bool>(vectorType.getRank(), false));
build(builder, result, vectorType, source, indices, permutationMapAttr, Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
inBoundsAttr); if (!padding)
padding = builder.create<ub::PoisonOp>(result.location, elemType);
build(builder, result, vectorType, source, indices, *padding,
permutationMapAttr, inBoundsAttr);
} }
/// 3. Builder that sets permutation map to 'getMinorIdentityMap'. /// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
void TransferReadOp::build(OpBuilder &builder, OperationState &result, void TransferReadOp::build(OpBuilder &builder, OperationState &result,
VectorType vectorType, Value source, VectorType vectorType, Value source,
ValueRange indices, Value padding, ValueRange indices, std::optional<Value> padding,
std::optional<ArrayRef<bool>> inBounds) { std::optional<ArrayRef<bool>> inBounds) {
AffineMap permutationMap = getTransferMinorIdentityMap( AffineMap permutationMap = getTransferMinorIdentityMap(
llvm::cast<ShapedType>(source.getType()), vectorType); llvm::cast<ShapedType>(source.getType()), vectorType);
@@ -4296,21 +4302,12 @@ void TransferReadOp::build(OpBuilder &builder, OperationState &result,
? builder.getBoolArrayAttr(inBounds.value()) ? builder.getBoolArrayAttr(inBounds.value())
: builder.getBoolArrayAttr( : builder.getBoolArrayAttr(
SmallVector<bool>(vectorType.getRank(), false)); SmallVector<bool>(vectorType.getRank(), false));
build(builder, result, vectorType, source, indices, permutationMapAttr,
padding,
/*mask=*/Value(), inBoundsAttr);
}
/// 4. Builder that sets padding to zero and permutation map to
/// 'getMinorIdentityMap'.
void TransferReadOp::build(OpBuilder &builder, OperationState &result,
VectorType vectorType, Value source,
ValueRange indices,
std::optional<ArrayRef<bool>> inBounds) {
Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType(); Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
Value padding = builder.create<arith::ConstantOp>( if (!padding)
result.location, elemType, builder.getZeroAttr(elemType)); padding = builder.create<ub::PoisonOp>(result.location, elemType);
build(builder, result, vectorType, source, indices, padding, inBounds); build(builder, result, vectorType, source, indices, permutationMapAttr,
*padding,
/*mask=*/Value(), inBoundsAttr);
} }
template <typename EmitFun> template <typename EmitFun>

View File

@@ -173,7 +173,7 @@ struct DistributedLoadStoreHelper {
} }
SmallVector<bool> inBounds(indices.size(), true); SmallVector<bool> inBounds(indices.size(), true);
return b.create<vector::TransferReadOp>( return b.create<vector::TransferReadOp>(
loc, cast<VectorType>(type), buffer, indices, loc, cast<VectorType>(type), buffer, indices, /*padding=*/std::nullopt,
ArrayRef<bool>(inBounds.begin(), inBounds.end())); ArrayRef<bool>(inBounds.begin(), inBounds.end()));
} }

View File

@@ -660,7 +660,8 @@ public:
VectorType flatVectorType = VectorType::get({vectorType.getNumElements()}, VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
vectorType.getElementType()); vectorType.getElementType());
vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>( vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>(
loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap); loc, flatVectorType, collapsedSource, collapsedIndices,
transferReadOp.getPadding(), collapsedMap);
flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true})); flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
// 4. Replace the old transfer_read with the new one reading from the // 4. Replace the old transfer_read with the new one reading from the

View File

@@ -21,7 +21,7 @@ func.func @vec1d_1(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
// CHECK: for {{.*}} step 128 // CHECK: for {{.*}} step 128
// CHECK-NEXT: %{{.*}} = affine.apply #[[$map_id1]](%[[C0]]) // CHECK-NEXT: %{{.*}} = affine.apply #[[$map_id1]](%[[C0]])
// CHECK-NEXT: %{{.*}} = affine.apply #[[$map_id1]](%[[C0]]) // CHECK-NEXT: %{{.*}} = affine.apply #[[$map_id1]](%[[C0]])
// CHECK-NEXT: %{{.*}} = arith.constant 0.0{{.*}}: f32 // CHECK-NEXT: %{{.*}} = ub.poison : f32
// CHECK-NEXT: {{.*}} = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} {permutation_map = #[[$map_proj_d0d1_0]]} : memref<?x?xf32>, vector<128xf32> // CHECK-NEXT: {{.*}} = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} {permutation_map = #[[$map_proj_d0d1_0]]} : memref<?x?xf32>, vector<128xf32>
affine.for %i0 = 0 to %M { // vectorized due to scalar -> vector affine.for %i0 = 0 to %M { // vectorized due to scalar -> vector
%a0 = affine.load %A[%c0, %c0] : memref<?x?xf32> %a0 = affine.load %A[%c0, %c0] : memref<?x?xf32>
@@ -47,7 +47,7 @@ func.func @vec1d_2(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
%P = memref.dim %B, %c2 : memref<?x?x?xf32> %P = memref.dim %B, %c2 : memref<?x?x?xf32>
// CHECK:for [[IV3:%[a-zA-Z0-9]+]] = 0 to [[ARG_M]] step 128 // CHECK:for [[IV3:%[a-zA-Z0-9]+]] = 0 to [[ARG_M]] step 128
// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.0{{.*}}: f32 // CHECK-NEXT: %[[CST:.*]] = ub.poison : f32
// CHECK-NEXT: {{.*}} = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %[[CST]] : memref<?x?xf32>, vector<128xf32> // CHECK-NEXT: {{.*}} = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %[[CST]] : memref<?x?xf32>, vector<128xf32>
affine.for %i3 = 0 to %M { // vectorized affine.for %i3 = 0 to %M { // vectorized
%a3 = affine.load %A[%c0, %i3] : memref<?x?xf32> %a3 = affine.load %A[%c0, %i3] : memref<?x?xf32>
@@ -76,7 +76,7 @@ func.func @vec1d_3(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
// CHECK-NEXT: for [[IV9:%[0-9a-zA-Z_]*]] = 0 to [[ARG_N]] { // CHECK-NEXT: for [[IV9:%[0-9a-zA-Z_]*]] = 0 to [[ARG_N]] {
// CHECK-NEXT: %[[APP9_0:[0-9a-zA-Z_]+]] = affine.apply {{.*}}([[IV9]], [[IV8]]) // CHECK-NEXT: %[[APP9_0:[0-9a-zA-Z_]+]] = affine.apply {{.*}}([[IV9]], [[IV8]])
// CHECK-NEXT: %[[APP9_1:[0-9a-zA-Z_]+]] = affine.apply {{.*}}([[IV9]], [[IV8]]) // CHECK-NEXT: %[[APP9_1:[0-9a-zA-Z_]+]] = affine.apply {{.*}}([[IV9]], [[IV8]])
// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.0{{.*}}: f32 // CHECK-NEXT: %[[CST:.*]] = ub.poison : f32
// CHECK-NEXT: {{.*}} = vector.transfer_read %{{.*}}[%[[APP9_0]], %[[APP9_1]]], %[[CST]] : memref<?x?xf32>, vector<128xf32> // CHECK-NEXT: {{.*}} = vector.transfer_read %{{.*}}[%[[APP9_0]], %[[APP9_1]]], %[[CST]] : memref<?x?xf32>, vector<128xf32>
affine.for %i8 = 0 to %M { // vectorized affine.for %i8 = 0 to %M { // vectorized
affine.for %i9 = 0 to %N { affine.for %i9 = 0 to %N {
@@ -280,7 +280,7 @@ func.func @vec_rejected_3(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
// CHECK:for [[IV4:%[0-9a-zA-Z_]+]] = 0 to [[ARG_M]] step 128 { // CHECK:for [[IV4:%[0-9a-zA-Z_]+]] = 0 to [[ARG_M]] step 128 {
// CHECK-NEXT: for [[IV5:%[0-9a-zA-Z_]*]] = 0 to [[ARG_N]] { // CHECK-NEXT: for [[IV5:%[0-9a-zA-Z_]*]] = 0 to [[ARG_N]] {
// CHECK-NEXT: %{{.*}} = arith.constant 0.0{{.*}}: f32 // CHECK-NEXT: %{{.*}} = ub.poison : f32
// CHECK-NEXT: {{.*}} = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %{{[a-zA-Z0-9_]*}} : memref<?x?xf32>, vector<128xf32> // CHECK-NEXT: {{.*}} = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %{{[a-zA-Z0-9_]*}} : memref<?x?xf32>, vector<128xf32>
affine.for %i4 = 0 to %M { // vectorized affine.for %i4 = 0 to %M { // vectorized
affine.for %i5 = 0 to %N { // not vectorized, would vectorize with --test-fastest-varying=1 affine.for %i5 = 0 to %N { // not vectorized, would vectorize with --test-fastest-varying=1
@@ -424,7 +424,7 @@ func.func @vec_rejected_8(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
// CHECK: for [[IV18:%[a-zA-Z0-9]+]] = 0 to [[ARG_M]] step 128 // CHECK: for [[IV18:%[a-zA-Z0-9]+]] = 0 to [[ARG_M]] step 128
// CHECK: %{{.*}} = affine.apply #[[$map_id1]](%{{.*}}) // CHECK: %{{.*}} = affine.apply #[[$map_id1]](%{{.*}})
// CHECK: %{{.*}} = affine.apply #[[$map_id1]](%{{.*}}) // CHECK: %{{.*}} = affine.apply #[[$map_id1]](%{{.*}})
// CHECK: %{{.*}} = arith.constant 0.0{{.*}}: f32 // CHECK: %{{.*}} = ub.poison : f32
// CHECK: {{.*}} = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} {permutation_map = #[[$map_proj_d0d1_0]]} : memref<?x?xf32>, vector<128xf32> // CHECK: {{.*}} = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} {permutation_map = #[[$map_proj_d0d1_0]]} : memref<?x?xf32>, vector<128xf32>
affine.for %i17 = 0 to %M { // not vectorized, the 1-D pattern that matched %{{.*}} in DFS post-order prevents vectorizing %{{.*}} affine.for %i17 = 0 to %M { // not vectorized, the 1-D pattern that matched %{{.*}} in DFS post-order prevents vectorizing %{{.*}}
affine.for %i18 = 0 to %M { // vectorized due to scalar -> vector affine.for %i18 = 0 to %M { // vectorized due to scalar -> vector
@@ -458,7 +458,7 @@ func.func @vec_rejected_9(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
// CHECK: for [[IV18:%[a-zA-Z0-9]+]] = 0 to [[ARG_M]] step 128 // CHECK: for [[IV18:%[a-zA-Z0-9]+]] = 0 to [[ARG_M]] step 128
// CHECK: %{{.*}} = affine.apply #[[$map_id1]](%{{.*}}) // CHECK: %{{.*}} = affine.apply #[[$map_id1]](%{{.*}})
// CHECK-NEXT: %{{.*}} = affine.apply #[[$map_id1]](%{{.*}}) // CHECK-NEXT: %{{.*}} = affine.apply #[[$map_id1]](%{{.*}})
// CHECK-NEXT: %{{.*}} = arith.constant 0.0{{.*}}: f32 // CHECK-NEXT: %{{.*}} = ub.poison : f32
// CHECK-NEXT: {{.*}} = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} {permutation_map = #[[$map_proj_d0d1_0]]} : memref<?x?xf32>, vector<128xf32> // CHECK-NEXT: {{.*}} = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} {permutation_map = #[[$map_proj_d0d1_0]]} : memref<?x?xf32>, vector<128xf32>
affine.for %i17 = 0 to %M { // not vectorized, the 1-D pattern that matched %i18 in DFS post-order prevents vectorizing %{{.*}} affine.for %i17 = 0 to %M { // not vectorized, the 1-D pattern that matched %i18 in DFS post-order prevents vectorizing %{{.*}}
affine.for %i18 = 0 to %M { // vectorized due to scalar -> vector affine.for %i18 = 0 to %M { // vectorized due to scalar -> vector

View File

@@ -11,7 +11,7 @@ func.func @vec_affine_apply(%arg0: memref<8x12x16xf32>, %arg1: memref<8x24x48xf3
// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 48 step 8 { // CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 48 step 8 {
// CHECK-NEXT: %[[S0:.*]] = affine.apply #[[$MAP_ID0]](%[[ARG3]]) // CHECK-NEXT: %[[S0:.*]] = affine.apply #[[$MAP_ID0]](%[[ARG3]])
// CHECK-NEXT: %[[S1:.*]] = affine.apply #[[$MAP_ID1]](%[[ARG4]]) // CHECK-NEXT: %[[S1:.*]] = affine.apply #[[$MAP_ID1]](%[[ARG4]])
// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 // CHECK-NEXT: %[[CST:.*]] = ub.poison : f32
// CHECK-NEXT: %[[S2:.*]] = vector.transfer_read %[[ARG0]][%[[ARG2]], %[[S0]], %[[S1]]], %[[CST]] : memref<8x12x16xf32>, vector<8xf32> // CHECK-NEXT: %[[S2:.*]] = vector.transfer_read %[[ARG0]][%[[ARG2]], %[[S0]], %[[S1]]], %[[CST]] : memref<8x12x16xf32>, vector<8xf32>
// CHECK-NEXT: vector.transfer_write %[[S2]], %[[ARG1]][%[[ARG2]], %[[ARG3]], %[[ARG4]]] : vector<8xf32>, memref<8x24x48xf32> // CHECK-NEXT: vector.transfer_write %[[S2]], %[[ARG1]][%[[ARG2]], %[[ARG3]], %[[ARG4]]] : vector<8xf32>, memref<8x24x48xf32>
// CHECK-NEXT: } // CHECK-NEXT: }
@@ -42,7 +42,7 @@ func.func @vec_affine_apply_2(%arg0: memref<8x12x16xf32>, %arg1: memref<8x24x48x
// CHECK-NEXT: affine.for %[[ARG3:.*]] = 0 to 12 { // CHECK-NEXT: affine.for %[[ARG3:.*]] = 0 to 12 {
// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 48 step 8 { // CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 48 step 8 {
// CHECK-NEXT: %[[S0:.*]] = affine.apply #[[$MAP_ID2]](%[[ARG4]]) // CHECK-NEXT: %[[S0:.*]] = affine.apply #[[$MAP_ID2]](%[[ARG4]])
// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 // CHECK-NEXT: %[[CST:.*]] = ub.poison : f32
// CHECK-NEXT: %[[S1:.*]] = vector.transfer_read %[[ARG0]][%[[ARG2]], %[[ARG3]], %[[S0]]], %[[CST]] : memref<8x12x16xf32>, vector<8xf32> // CHECK-NEXT: %[[S1:.*]] = vector.transfer_read %[[ARG0]][%[[ARG2]], %[[ARG3]], %[[S0]]], %[[CST]] : memref<8x12x16xf32>, vector<8xf32>
// CHECK-NEXT: vector.transfer_write %[[S1]], %[[ARG1]][%[[ARG2]], %[[ARG3]], %[[ARG4]]] : vector<8xf32>, memref<8x24x48xf32> // CHECK-NEXT: vector.transfer_write %[[S1]], %[[ARG1]][%[[ARG2]], %[[ARG3]], %[[ARG4]]] : vector<8xf32>, memref<8x24x48xf32>
// CHECK-NEXT: } // CHECK-NEXT: }
@@ -140,7 +140,7 @@ func.func @affine_map_with_expr_2(%arg0: memref<8x12x16xf32>, %arg1: memref<8x24
// CHECK-NEXT: %[[S0:.*]] = affine.apply #[[$MAP_ID3]](%[[ARG3]], %[[ARG4]], %[[I0]]) // CHECK-NEXT: %[[S0:.*]] = affine.apply #[[$MAP_ID3]](%[[ARG3]], %[[ARG4]], %[[I0]])
// CHECK-NEXT: %[[S1:.*]] = affine.apply #[[$MAP_ID4]](%[[ARG3]], %[[ARG4]], %[[I0]]) // CHECK-NEXT: %[[S1:.*]] = affine.apply #[[$MAP_ID4]](%[[ARG3]], %[[ARG4]], %[[I0]])
// CHECK-NEXT: %[[S2:.*]] = affine.apply #[[$MAP_ID5]](%[[ARG3]], %[[ARG4]], %[[I0]]) // CHECK-NEXT: %[[S2:.*]] = affine.apply #[[$MAP_ID5]](%[[ARG3]], %[[ARG4]], %[[I0]])
// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 // CHECK-NEXT: %[[CST:.*]] = ub.poison : f32
// CHECK-NEXT: %[[S3:.*]] = vector.transfer_read %[[ARG0]][%[[S0]], %[[S1]], %[[S2]]], %[[CST]] {permutation_map = #[[$MAP_ID6]]} : memref<8x12x16xf32>, vector<8xf32> // CHECK-NEXT: %[[S3:.*]] = vector.transfer_read %[[ARG0]][%[[S0]], %[[S1]], %[[S2]]], %[[CST]] {permutation_map = #[[$MAP_ID6]]} : memref<8x12x16xf32>, vector<8xf32>
// CHECK-NEXT: vector.transfer_write %[[S3]], %[[ARG1]][%[[ARG3]], %[[ARG4]], %[[ARG5]]] : vector<8xf32>, memref<8x24x48xf32> // CHECK-NEXT: vector.transfer_write %[[S3]], %[[ARG1]][%[[ARG3]], %[[ARG4]], %[[ARG5]]] : vector<8xf32>, memref<8x24x48xf32>
// CHECK-NEXT: } // CHECK-NEXT: }

View File

@@ -11,8 +11,8 @@
// CHECK-LABEL: @base_case // CHECK-LABEL: @base_case
// CHECK-SAME: %[[I:.+]]: index, %[[J:.+]]: index, %[[M:.+]]: // CHECK-SAME: %[[I:.+]]: index, %[[J:.+]]: index, %[[M:.+]]:
// CHECK: %[[PAD:.+]] = arith.constant 0 : i8 // CHECK-DAG: %[[PAD:.+]] = arith.constant 123 : i8
// CHECK: %[[C0:.+]] = arith.constant 0 : index // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]] // CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]]
// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]] // CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
// CHECK-SAME: : memref<?x?x?x8xi8> into memref<?x?x?xi8> // CHECK-SAME: : memref<?x?x?x8xi8> into memref<?x?x?xi8>
@@ -36,8 +36,8 @@ func.func @base_case(%i : index, %j : index, %M : memref<?x?x?x8xi8>) -> vector<
// CHECK-LABEL: @with_3d_vector // CHECK-LABEL: @with_3d_vector
// CHECK-SAME: %[[I:.+]]: index, %[[J:.+]]: index, %[[M:.+]]: // CHECK-SAME: %[[I:.+]]: index, %[[J:.+]]: index, %[[M:.+]]:
// CHECK: %[[PAD:.+]] = arith.constant 0 : i8 // CHECK-DAG: %[[PAD:.+]] = arith.constant 123 : i8
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]] // CHECK-DAG: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]]
// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]] // CHECK-SAME{LITERAL}: [[0], [1, 2, 3]]
// CHECK-SAME: : memref<?x?x2x8xi8> into memref<?x?xi8> // CHECK-SAME: : memref<?x?x2x8xi8> into memref<?x?xi8>
// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSED]][%[[I]], %[[J]]], %[[PAD]] {in_bounds = [true]} // CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSED]][%[[I]], %[[J]]], %[[PAD]] {in_bounds = [true]}

View File

@@ -85,8 +85,8 @@ func.func @transfer_read_dims_mismatch_contiguous(
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous( // CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous(
// CHECK-SAME: %[[MEM:.+]]: memref<5x4x3x2xi8, {{.+}}>) -> vector<2x3x2xi8> { // CHECK-SAME: %[[MEM:.+]]: memref<5x4x3x2xi8, {{.+}}>) -> vector<2x3x2xi8> {
// CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8 // CHECK-DAG: %[[C0_I8:.+]] = arith.constant 0 : i8
// CHECK: %[[C0:.+]] = arith.constant 0 : index // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[COLLAPSED_MEM:.+]] = memref.collapse_shape %[[MEM]] // CHECK: %[[COLLAPSED_MEM:.+]] = memref.collapse_shape %[[MEM]]
// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]] // CHECK-SAME{LITERAL}: [[0], [1, 2, 3]]
// CHECK-SAME: : memref<5x4x3x2xi8, {{.+}}> into memref<5x24xi8, {{.+}}> // CHECK-SAME: : memref<5x4x3x2xi8, {{.+}}> into memref<5x24xi8, {{.+}}>
@@ -116,8 +116,8 @@ func.func @transfer_read_dims_mismatch_contiguous_unit_dims(
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous_unit_dims( // CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous_unit_dims(
// CHECK-SAME: %[[MEM:.+]]: memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>>) // CHECK-SAME: %[[MEM:.+]]: memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>>)
// CHECK-SAME: -> vector<1x1x4x3x2xi8> // CHECK-SAME: -> vector<1x1x4x3x2xi8>
// CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8 // CHECK-DAG: %[[C0_I8:.+]] = arith.constant 0 : i8
// CHECK: %[[C0:.+]] = arith.constant 0 : index // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] // CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]]
// CHECK-SAME{LITERAL}: [[0], [1], [2, 3, 4]] // CHECK-SAME{LITERAL}: [[0], [1], [2, 3, 4]]
// CHECK-SAME: : memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>> // CHECK-SAME: : memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>>
@@ -149,8 +149,8 @@ func.func @transfer_read_non_contiguous_unit_dims(
// CHECK-LABEL: func.func @transfer_read_non_contiguous_unit_dims( // CHECK-LABEL: func.func @transfer_read_non_contiguous_unit_dims(
// CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>) -> vector<1x1x3x2xi8> { // CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>) -> vector<1x1x3x2xi8> {
// CHECK: %[[VAL_1:.*]] = arith.constant 0 : i8 // CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0 : i8
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index // CHECK-DAG: %[[VAL_2:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[MEM]] // CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[MEM]]
// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]] // CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
// CHECK-SAME: : memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>> // CHECK-SAME: : memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>
@@ -182,8 +182,8 @@ func.func @transfer_read_dims_mismatch_non_zero_indices(
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices( // CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices(
// CHECK-SAME: %[[IDX_1:.+]]: index, %[[IDX_2:.+]]: index, // CHECK-SAME: %[[IDX_1:.+]]: index, %[[IDX_2:.+]]: index,
// CHECK-SAME: %[[MEM:.+]]: memref<1x43x4x6xi32> // CHECK-SAME: %[[MEM:.+]]: memref<1x43x4x6xi32>
// CHECK: %[[C0_I32:.+]] = arith.constant 0 : i32 // CHECK-DAG: %[[C0_I32:.+]] = arith.constant 0 : i32
// CHECK: %[[C_0:.+]] = arith.constant 0 : index // CHECK-DAG: %[[C_0:.+]] = arith.constant 0 : index
// CHECK: %[[COLLAPSED_IN:.+]] = memref.collapse_shape %[[MEM]] // CHECK: %[[COLLAPSED_IN:.+]] = memref.collapse_shape %[[MEM]]
// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]] // CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
// CHECK-SAME: : memref<1x43x4x6xi32> into memref<1x43x24xi32> // CHECK-SAME: : memref<1x43x4x6xi32> into memref<1x43x24xi32>
@@ -241,8 +241,8 @@ func.func @transfer_read_leading_dynamic_dims(
// CHECK-LABEL: func @transfer_read_leading_dynamic_dims // CHECK-LABEL: func @transfer_read_leading_dynamic_dims
// CHECK-SAME: %[[MEM:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[IDX_1:.+]]: index, %[[IDX_2:.+]]: index // CHECK-SAME: %[[MEM:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[IDX_1:.+]]: index, %[[IDX_2:.+]]: index
// CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8 // CHECK-DAG: %[[C0_I8:.+]] = arith.constant 0 : i8
// CHECK: %[[C0:.+]] = arith.constant 0 : index // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] // CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]]
// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]] // CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
// CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}> // CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
@@ -304,8 +304,8 @@ func.func @transfer_read_dynamic_dim_to_flatten(
// CHECK-SAME: %[[IDX_1:arg0]] // CHECK-SAME: %[[IDX_1:arg0]]
// CHECK-SAME: %[[IDX_2:arg1]] // CHECK-SAME: %[[IDX_2:arg1]]
// CHECK-SAME: %[[MEM:arg2]] // CHECK-SAME: %[[MEM:arg2]]
// CHECK: %[[C0_I32:.+]] = arith.constant 0 : i32 // CHECK-DAG: %[[C0_I32:.+]] = arith.constant 0 : i32
// CHECK: %[[C0:.+]] = arith.constant 0 : index // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] // CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]]
// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]] // CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
// CHECK-SAME: memref<1x?x4x6xi32> into memref<1x?x24xi32> // CHECK-SAME: memref<1x?x4x6xi32> into memref<1x?x24xi32>

View File

@@ -1132,8 +1132,8 @@ func.func @warp_execute_nd_distribute(%laneid: index, %v0: vector<1x64x1xf32>, %
// CHECK-SCF-IF: gpu.barrier // CHECK-SCF-IF: gpu.barrier
// CHECK-SCF-IF: %[[WID:.*]] = affine.apply #[[$TIMES2]]()[%[[LANEID]]] // CHECK-SCF-IF: %[[WID:.*]] = affine.apply #[[$TIMES2]]()[%[[LANEID]]]
// CHECK-SCF-IF-DAG: %[[R0:.*]] = vector.transfer_read %{{.*}}[%[[LANEID]], %[[C0]], %[[C0]]], %cst {in_bounds = [true, true, true]} : memref<32x64x1xf32, 3>, vector<1x64x1xf32> // CHECK-SCF-IF-DAG: %[[R0:.*]] = vector.transfer_read %{{.*}}[%[[LANEID]], %[[C0]], %[[C0]]], %{{.*}} {in_bounds = [true, true, true]} : memref<32x64x1xf32, 3>, vector<1x64x1xf32>
// CHECK-SCF-IF-DAG: %[[R1:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[WID]], %[[C0]]], %cst {in_bounds = [true, true, true]} : memref<1x64x128xf32, 3>, vector<1x2x128xf32> // CHECK-SCF-IF-DAG: %[[R1:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[WID]], %[[C0]]], %{{.*}} {in_bounds = [true, true, true]} : memref<1x64x128xf32, 3>, vector<1x2x128xf32>
// CHECK-SCF-IF: return %[[R0]], %[[R1]] : vector<1x64x1xf32>, vector<1x2x128xf32> // CHECK-SCF-IF: return %[[R0]], %[[R1]] : vector<1x64x1xf32>, vector<1x2x128xf32>
return %r#0, %r#1 : vector<1x64x1xf32>, vector<1x2x128xf32> return %r#0, %r#1 : vector<1x64x1xf32>, vector<1x2x128xf32>
} }