[mlir][transform] extract a minimal DomainAndOperandsAffineMapT… (#145034)
…ransferInterface out of LinalgStructuredInterface and use that for PadTilingInterface Along the way, a bug was found on the handling of scalar values, fix it and add a test.
This commit is contained in:
committed by
GitHub
parent
376b71442d
commit
269cb22ae8
@@ -222,9 +222,59 @@ def LinalgFillOpInterface : OpInterface<"FillOpInterface"> {
|
||||
];
|
||||
}
|
||||
|
||||
def IndexingMapOpInterface : OpInterface<"IndexingMapOpInterface"> {
|
||||
let description = [{
|
||||
Interface for operations that connect an iteration domain to operands via
|
||||
affine maps. Provides methods to access indexing maps between iteration
|
||||
domain and operand index spaces.
|
||||
}];
|
||||
let cppNamespace = "::mlir::linalg";
|
||||
let methods = [
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the indexing maps attribute within the current operation.
|
||||
}],
|
||||
/*retTy=*/"ArrayAttr",
|
||||
/*methodName=*/"getIndexingMaps"
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the indexing maps within the current operation.
|
||||
}],
|
||||
/*retTy=*/"SmallVector<AffineMap>",
|
||||
/*methodName=*/"getIndexingMapsArray",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
auto range = $_op.getIndexingMaps()
|
||||
.template getAsValueRange<AffineMapAttr>();
|
||||
return {range.begin(), range.end()};
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the input or output indexing map for `opOperand`.
|
||||
}],
|
||||
/*retTy=*/"AffineMap",
|
||||
/*methodName=*/"getMatchingIndexingMap",
|
||||
/*args=*/(ins "OpOperand*":$opOperand),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
assert(opOperand->getOwner() == this->getOperation());
|
||||
auto indexingMaps =
|
||||
$_op.getIndexingMaps().template getAsValueRange<AffineMapAttr>();
|
||||
return *(indexingMaps.begin() + opOperand->getOperandNumber());
|
||||
}]
|
||||
>,
|
||||
];
|
||||
}
|
||||
|
||||
// The 'LinalgStructuredInterface' provides access to the 'LinalgOp' interface.
|
||||
def LinalgStructuredInterface
|
||||
: OpInterface<"LinalgOp", [DestinationStyleOpInterface]> {
|
||||
: OpInterface<"LinalgOp", [
|
||||
DestinationStyleOpInterface,
|
||||
IndexingMapOpInterface
|
||||
]> {
|
||||
let cppNamespace = "::mlir::linalg";
|
||||
let methods = [
|
||||
//===------------------------------------------------------------------===//
|
||||
@@ -465,21 +515,6 @@ def LinalgStructuredInterface
|
||||
blockArgument.getArgNumber());
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the input or output indexing map for `opOperand`.
|
||||
}],
|
||||
/*retTy=*/"AffineMap",
|
||||
/*methodName=*/"getMatchingIndexingMap",
|
||||
/*args=*/(ins "OpOperand*":$opOperand),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
assert(opOperand->getOwner() == this->getOperation());
|
||||
auto indexingMaps =
|
||||
$_op.getIndexingMaps().template getAsValueRange<AffineMapAttr>();
|
||||
return *(indexingMaps.begin() + opOperand->getOperandNumber());
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the indexing map for a `result`.
|
||||
@@ -576,27 +611,6 @@ def LinalgStructuredInterface
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{ return success(); }]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the indexing maps attribute within the current operation.
|
||||
}],
|
||||
/*retTy=*/"ArrayAttr",
|
||||
/*methodName=*/"getIndexingMaps"
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the indexing maps within the current operation.
|
||||
}],
|
||||
/*retTy=*/"SmallVector<AffineMap>",
|
||||
/*methodName=*/"getIndexingMapsArray",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
auto range = $_op.getIndexingMaps()
|
||||
.template getAsValueRange<AffineMapAttr>();
|
||||
return {range.begin(), range.end()};
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return true if any of the operands has a dynamic shape.
|
||||
|
||||
@@ -612,10 +612,9 @@ using PadSizeComputationFunction =
|
||||
const PadTilingInterfaceOptions &)>;
|
||||
|
||||
/// Specific helper for Linalg ops.
|
||||
FailureOr<SmallVector<OpFoldResult>>
|
||||
computeLinalgPaddedShape(RewriterBase &rewriter, OpOperand &operandToPad,
|
||||
ArrayRef<Range> iterationDomain,
|
||||
const PadTilingInterfaceOptions &options);
|
||||
FailureOr<SmallVector<OpFoldResult>> computeIndexingMapOpInterfacePaddedShape(
|
||||
RewriterBase &rewriter, OpOperand &operandToPad,
|
||||
ArrayRef<Range> iterationDomain, const PadTilingInterfaceOptions &options);
|
||||
|
||||
/// Pad the iterator dimensions `options.paddingDimensions` of `opToPad`.
|
||||
///
|
||||
@@ -632,7 +631,7 @@ rewriteAsPaddedOp(RewriterBase &rewriter, TilingInterface opToPad,
|
||||
const PadTilingInterfaceOptions &constOptions,
|
||||
SmallVector<tensor::PadOp> &padOps,
|
||||
PadSizeComputationFunction computePaddingSizeFun =
|
||||
&computeLinalgPaddedShape);
|
||||
&computeIndexingMapOpInterfacePaddedShape);
|
||||
|
||||
namespace detail {
|
||||
|
||||
|
||||
@@ -2229,10 +2229,12 @@ transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter,
|
||||
return diag;
|
||||
}
|
||||
|
||||
// Only Linalg ops for now, until TilingInterface exposes a loopsToOperand
|
||||
// map / C++ APIs to compute the effect of padding on operands.
|
||||
if (!isa<LinalgOp>(targetOp.getOperation())) {
|
||||
auto diag = emitSilenceableError() << "only LinalgOp supported atm";
|
||||
// Only IndexingMapOpInterface ops for now, until TilingInterface exposes a
|
||||
// loopsToOperand map / C++ APIs to compute the effect of padding on
|
||||
// operands.
|
||||
if (!isa<IndexingMapOpInterface>(targetOp.getOperation())) {
|
||||
auto diag = emitSilenceableError() << "only IndexingMapOpInterface ops "
|
||||
"supported atm";
|
||||
diag.attachNote(target->getLoc()) << "target op";
|
||||
return diag;
|
||||
}
|
||||
|
||||
@@ -155,11 +155,13 @@ SmallVector<OpFoldResult> linalg::computePaddedShape(
|
||||
return paddedShape;
|
||||
}
|
||||
|
||||
FailureOr<SmallVector<OpFoldResult>> linalg::computeLinalgPaddedShape(
|
||||
FailureOr<SmallVector<OpFoldResult>>
|
||||
linalg::computeIndexingMapOpInterfacePaddedShape(
|
||||
RewriterBase &rewriter, OpOperand &operandToPad,
|
||||
ArrayRef<Range> iterationDomain, const PadTilingInterfaceOptions &options) {
|
||||
auto linalgOp = llvm::dyn_cast<LinalgOp>(operandToPad.getOwner());
|
||||
if (!linalgOp)
|
||||
auto transferOp =
|
||||
llvm::dyn_cast<IndexingMapOpInterface>(operandToPad.getOwner());
|
||||
if (!transferOp)
|
||||
return failure();
|
||||
|
||||
// clang-format off
|
||||
@@ -173,7 +175,7 @@ FailureOr<SmallVector<OpFoldResult>> linalg::computeLinalgPaddedShape(
|
||||
for (const Range &range : iterationDomain)
|
||||
loopUpperBounds.push_back(range.size);
|
||||
|
||||
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&operandToPad);
|
||||
AffineMap indexingMap = transferOp.getMatchingIndexingMap(&operandToPad);
|
||||
return computePaddedShape(
|
||||
rewriter, cast<TypedValue<RankedTensorType>>(operandToPad.get()),
|
||||
indexingMap, loopUpperBounds, options);
|
||||
@@ -255,7 +257,18 @@ linalg::rewriteAsPaddedOp(RewriterBase &rewriter, TilingInterface opToPad,
|
||||
SmallVector<Value> newOperands;
|
||||
newOperands.reserve(opToPad->getNumOperands());
|
||||
for (OpOperand &opOperand : opToPad->getOpOperands()) {
|
||||
LLVM_DEBUG(DBGS() << "--start padding oprd: " << opOperand.get() << "\n");
|
||||
Value operand = opOperand.get();
|
||||
LLVM_DEBUG(DBGS() << "--start padding oprd: " << operand << "\n");
|
||||
|
||||
// 2.a. Skip scalar-like operands.
|
||||
Type operandType = operand.getType();
|
||||
if (!isa<RankedTensorType>(operandType)) {
|
||||
assert(!isa<ShapedType>(operandType) ||
|
||||
isa<VectorType>(operandType) &&
|
||||
"Unexpected non-vector ShapedType");
|
||||
newOperands.push_back(operand);
|
||||
continue;
|
||||
}
|
||||
// 2.a. Compute padded shape.
|
||||
FailureOr<SmallVector<OpFoldResult>> maybePaddedShape =
|
||||
computePaddingSizeFun(rewriter, opOperand, iterationDomain, options);
|
||||
@@ -266,14 +279,16 @@ linalg::rewriteAsPaddedOp(RewriterBase &rewriter, TilingInterface opToPad,
|
||||
// 2.b. Expect proper `paddingValues`.
|
||||
// TODO: we may want to allow garbage padding in the future, in which case
|
||||
// we would just not assert.
|
||||
assert(opOperand.getOperandNumber() < options.paddingValues.size() &&
|
||||
"--no padding value specified");
|
||||
if (opOperand.getOperandNumber() >= options.paddingValues.size()) {
|
||||
return rewriter.notifyMatchFailure(opToPad,
|
||||
"--no padding value specified");
|
||||
}
|
||||
Attribute paddingValueAttr =
|
||||
options.paddingValues[opOperand.getOperandNumber()];
|
||||
|
||||
// 2.c. Perform actual padding.
|
||||
Value paddedOperand = padOperand(
|
||||
rewriter, opToPad, cast<TypedValue<RankedTensorType>>(opOperand.get()),
|
||||
rewriter, opToPad, cast<TypedValue<RankedTensorType>>(operand),
|
||||
*maybePaddedShape, paddingValueAttr);
|
||||
LLVM_DEBUG(DBGS() << "--done padding operand: " << paddedOperand << "\n");
|
||||
|
||||
|
||||
@@ -1,5 +1,33 @@
|
||||
// RUN: mlir-opt --transform-interpreter -canonicalize -split-input-file --verify-diagnostics %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: pad_fill
|
||||
// CHECK: linalg.fill ins(%{{.*}} : f32) outs(%{{.*}} : tensor<8x25xf32>) -> tensor<8x25xf32>
|
||||
func.func @pad_fill(%value: f32, %output: tensor<24x25xf32>) -> tensor<24x25xf32>
|
||||
{
|
||||
%0 = linalg.fill ins(%value : f32) outs(%output : tensor<24x25xf32>) -> tensor<24x25xf32>
|
||||
func.return %0 : tensor<24x25xf32>
|
||||
}
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||
%fill = transform.structured.match ops{["linalg.fill"]} in %arg1
|
||||
: (!transform.any_op) -> !transform.any_op
|
||||
|
||||
// Tile to 5 then pad to 8
|
||||
%fill_l1, %loops_l1 = transform.structured.tile_using_for %fill tile_sizes [5]
|
||||
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
|
||||
|
||||
%fill_padded, %_ = transform.structured.pad_tiling_interface %fill_l1 to padding_sizes [8] {
|
||||
padding_values=[0.0 : f32, 0.0 : f32],
|
||||
padding_dimensions=[0]
|
||||
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
|
||||
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: pad_lhs
|
||||
func.func @pad_lhs(
|
||||
%arg0: tensor<24x12xf32>, %arg1: tensor<12x25xf32>, %arg2: tensor<24x25xf32>)
|
||||
|
||||
Reference in New Issue
Block a user