[mlir][transform] extract a minimal DomainAndOperandsAffineMapT… (#145034)

…ransferInterface out of LinalgStructuredInterface and use that for
PadTilingInterface

Along the way, a bug was found on the handling of scalar values, fix it
and add a test.
This commit is contained in:
Nicolas Vasilache
2025-06-20 15:45:21 +02:00
committed by GitHub
parent 376b71442d
commit 269cb22ae8
5 changed files with 112 additions and 54 deletions

View File

@@ -222,9 +222,59 @@ def LinalgFillOpInterface : OpInterface<"FillOpInterface"> {
];
}
def IndexingMapOpInterface : OpInterface<"IndexingMapOpInterface"> {
let description = [{
Interface for operations that connect an iteration domain to operands via
affine maps. Provides methods to access indexing maps between iteration
domain and operand index spaces.
}];
let cppNamespace = "::mlir::linalg";
let methods = [
InterfaceMethod<
/*desc=*/[{
Return the indexing maps attribute within the current operation.
}],
/*retTy=*/"ArrayAttr",
/*methodName=*/"getIndexingMaps"
>,
InterfaceMethod<
/*desc=*/[{
Return the indexing maps within the current operation.
}],
/*retTy=*/"SmallVector<AffineMap>",
/*methodName=*/"getIndexingMapsArray",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
auto range = $_op.getIndexingMaps()
.template getAsValueRange<AffineMapAttr>();
return {range.begin(), range.end()};
}]
>,
InterfaceMethod<
/*desc=*/[{
Return the input or output indexing map for `opOperand`.
}],
/*retTy=*/"AffineMap",
/*methodName=*/"getMatchingIndexingMap",
/*args=*/(ins "OpOperand*":$opOperand),
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(opOperand->getOwner() == this->getOperation());
auto indexingMaps =
$_op.getIndexingMaps().template getAsValueRange<AffineMapAttr>();
return *(indexingMaps.begin() + opOperand->getOperandNumber());
}]
>,
];
}
// The 'LinalgStructuredInterface' provides access to the 'LinalgOp' interface.
def LinalgStructuredInterface
: OpInterface<"LinalgOp", [DestinationStyleOpInterface]> {
: OpInterface<"LinalgOp", [
DestinationStyleOpInterface,
IndexingMapOpInterface
]> {
let cppNamespace = "::mlir::linalg";
let methods = [
//===------------------------------------------------------------------===//
@@ -465,21 +515,6 @@ def LinalgStructuredInterface
blockArgument.getArgNumber());
}]
>,
InterfaceMethod<
/*desc=*/[{
Return the input or output indexing map for `opOperand`.
}],
/*retTy=*/"AffineMap",
/*methodName=*/"getMatchingIndexingMap",
/*args=*/(ins "OpOperand*":$opOperand),
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(opOperand->getOwner() == this->getOperation());
auto indexingMaps =
$_op.getIndexingMaps().template getAsValueRange<AffineMapAttr>();
return *(indexingMaps.begin() + opOperand->getOperandNumber());
}]
>,
InterfaceMethod<
/*desc=*/[{
Return the indexing map for a `result`.
@@ -576,27 +611,6 @@ def LinalgStructuredInterface
/*methodBody=*/"",
/*defaultImplementation=*/[{ return success(); }]
>,
InterfaceMethod<
/*desc=*/[{
Return the indexing maps attribute within the current operation.
}],
/*retTy=*/"ArrayAttr",
/*methodName=*/"getIndexingMaps"
>,
InterfaceMethod<
/*desc=*/[{
Return the indexing maps within the current operation.
}],
/*retTy=*/"SmallVector<AffineMap>",
/*methodName=*/"getIndexingMapsArray",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
auto range = $_op.getIndexingMaps()
.template getAsValueRange<AffineMapAttr>();
return {range.begin(), range.end()};
}]
>,
InterfaceMethod<
/*desc=*/[{
Return true if any of the operands has a dynamic shape.

View File

@@ -612,10 +612,9 @@ using PadSizeComputationFunction =
const PadTilingInterfaceOptions &)>;
/// Specific helper for Linalg ops.
FailureOr<SmallVector<OpFoldResult>>
computeLinalgPaddedShape(RewriterBase &rewriter, OpOperand &operandToPad,
ArrayRef<Range> iterationDomain,
const PadTilingInterfaceOptions &options);
FailureOr<SmallVector<OpFoldResult>> computeIndexingMapOpInterfacePaddedShape(
RewriterBase &rewriter, OpOperand &operandToPad,
ArrayRef<Range> iterationDomain, const PadTilingInterfaceOptions &options);
/// Pad the iterator dimensions `options.paddingDimensions` of `opToPad`.
///
@@ -632,7 +631,7 @@ rewriteAsPaddedOp(RewriterBase &rewriter, TilingInterface opToPad,
const PadTilingInterfaceOptions &constOptions,
SmallVector<tensor::PadOp> &padOps,
PadSizeComputationFunction computePaddingSizeFun =
&computeLinalgPaddedShape);
&computeIndexingMapOpInterfacePaddedShape);
namespace detail {

View File

@@ -2229,10 +2229,12 @@ transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter,
return diag;
}
// Only Linalg ops for now, until TilingInterface exposes a loopsToOperand
// map / C++ APIs to compute the effect of padding on operands.
if (!isa<LinalgOp>(targetOp.getOperation())) {
auto diag = emitSilenceableError() << "only LinalgOp supported atm";
// Only IndexingMapOpInterface ops for now, until TilingInterface exposes a
// loopsToOperand map / C++ APIs to compute the effect of padding on
// operands.
if (!isa<IndexingMapOpInterface>(targetOp.getOperation())) {
auto diag = emitSilenceableError() << "only IndexingMapOpInterface ops "
"supported atm";
diag.attachNote(target->getLoc()) << "target op";
return diag;
}

View File

@@ -155,11 +155,13 @@ SmallVector<OpFoldResult> linalg::computePaddedShape(
return paddedShape;
}
FailureOr<SmallVector<OpFoldResult>> linalg::computeLinalgPaddedShape(
FailureOr<SmallVector<OpFoldResult>>
linalg::computeIndexingMapOpInterfacePaddedShape(
RewriterBase &rewriter, OpOperand &operandToPad,
ArrayRef<Range> iterationDomain, const PadTilingInterfaceOptions &options) {
auto linalgOp = llvm::dyn_cast<LinalgOp>(operandToPad.getOwner());
if (!linalgOp)
auto transferOp =
llvm::dyn_cast<IndexingMapOpInterface>(operandToPad.getOwner());
if (!transferOp)
return failure();
// clang-format off
@@ -173,7 +175,7 @@ FailureOr<SmallVector<OpFoldResult>> linalg::computeLinalgPaddedShape(
for (const Range &range : iterationDomain)
loopUpperBounds.push_back(range.size);
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&operandToPad);
AffineMap indexingMap = transferOp.getMatchingIndexingMap(&operandToPad);
return computePaddedShape(
rewriter, cast<TypedValue<RankedTensorType>>(operandToPad.get()),
indexingMap, loopUpperBounds, options);
@@ -255,7 +257,18 @@ linalg::rewriteAsPaddedOp(RewriterBase &rewriter, TilingInterface opToPad,
SmallVector<Value> newOperands;
newOperands.reserve(opToPad->getNumOperands());
for (OpOperand &opOperand : opToPad->getOpOperands()) {
LLVM_DEBUG(DBGS() << "--start padding oprd: " << opOperand.get() << "\n");
Value operand = opOperand.get();
LLVM_DEBUG(DBGS() << "--start padding oprd: " << operand << "\n");
// 2.a. Skip scalar-like operands.
Type operandType = operand.getType();
if (!isa<RankedTensorType>(operandType)) {
assert(!isa<ShapedType>(operandType) ||
isa<VectorType>(operandType) &&
"Unexpected non-vector ShapedType");
newOperands.push_back(operand);
continue;
}
// 2.a. Compute padded shape.
FailureOr<SmallVector<OpFoldResult>> maybePaddedShape =
computePaddingSizeFun(rewriter, opOperand, iterationDomain, options);
@@ -266,14 +279,16 @@ linalg::rewriteAsPaddedOp(RewriterBase &rewriter, TilingInterface opToPad,
// 2.b. Expect proper `paddingValues`.
// TODO: we may want to allow garbage padding in the future, in which case
// we would just not assert.
assert(opOperand.getOperandNumber() < options.paddingValues.size() &&
"--no padding value specified");
if (opOperand.getOperandNumber() >= options.paddingValues.size()) {
return rewriter.notifyMatchFailure(opToPad,
"--no padding value specified");
}
Attribute paddingValueAttr =
options.paddingValues[opOperand.getOperandNumber()];
// 2.c. Perform actual padding.
Value paddedOperand = padOperand(
rewriter, opToPad, cast<TypedValue<RankedTensorType>>(opOperand.get()),
rewriter, opToPad, cast<TypedValue<RankedTensorType>>(operand),
*maybePaddedShape, paddingValueAttr);
LLVM_DEBUG(DBGS() << "--done padding operand: " << paddedOperand << "\n");

View File

@@ -1,5 +1,33 @@
// RUN: mlir-opt --transform-interpreter -canonicalize -split-input-file --verify-diagnostics %s | FileCheck %s
// CHECK-LABEL: pad_fill
// CHECK: linalg.fill ins(%{{.*}} : f32) outs(%{{.*}} : tensor<8x25xf32>) -> tensor<8x25xf32>
func.func @pad_fill(%value: f32, %output: tensor<24x25xf32>) -> tensor<24x25xf32>
{
%0 = linalg.fill ins(%value : f32) outs(%output : tensor<24x25xf32>) -> tensor<24x25xf32>
func.return %0 : tensor<24x25xf32>
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%fill = transform.structured.match ops{["linalg.fill"]} in %arg1
: (!transform.any_op) -> !transform.any_op
// Tile to 5 then pad to 8
%fill_l1, %loops_l1 = transform.structured.tile_using_for %fill tile_sizes [5]
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%fill_padded, %_ = transform.structured.pad_tiling_interface %fill_l1 to padding_sizes [8] {
padding_values=[0.0 : f32, 0.0 : f32],
padding_dimensions=[0]
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}
// -----
// CHECK-LABEL: pad_lhs
func.func @pad_lhs(
%arg0: tensor<24x12xf32>, %arg1: tensor<12x25xf32>, %arg2: tensor<24x25xf32>)