[mlir][linalg] Extract GeneralizePadOpPattern into a standalone transformation (#117329)
Currently, `GeneralizePadOpPattern` is grouped under `populatePadOpVectorizationPatterns`. However, as noted in #111349, this transformation "decomposes" rather than "vectorizes" `tensor.pad`. As such, it functions as: * a vectorization _pre-processing_ transformation, not * a vectorization transformation itself. To clarify its purpose, this PR turns `GeneralizePadOpPattern` into a standalone transformation by: * introducing a dedicated `populateDecomposePadPatterns` method, * adding a `apply_patterns.linalg.decompose_pad` Transform Dialect Op, * removing it from `populatePadOpVectorizationPatterns`. In addition, to better reflect its role, it is renamed as "decomposition" rather then "generalization". This is in line with the recent renaming of similar ops, i.e. tensor.pack/tensor.unpack Ops in #116439.
This commit is contained in:
committed by
GitHub
parent
56eb559b1d
commit
1b2c8f104f
@@ -52,6 +52,17 @@ def ApplyDecomposeTensorPackUnpackPatternsOp
|
||||
let assemblyFormat = "attr-dict";
|
||||
}
|
||||
|
||||
def ApplyDecomposeTensorPadPatternsOp
|
||||
: Op<Transform_Dialect, "apply_patterns.linalg.decompose_pad",
|
||||
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
|
||||
let description = [{
|
||||
Collect patterns to decompose tensor.pad into e.g. tensor::EmptyOp,
|
||||
linalg::FillOp and tensor::InsertSliceOp.
|
||||
}];
|
||||
|
||||
let assemblyFormat = "attr-dict";
|
||||
}
|
||||
|
||||
def ApplyFoldUnitExtentDimsViaReshapesPatternsOp : Op<Transform_Dialect,
|
||||
"apply_patterns.linalg.fold_unit_extent_dims_via_reshapes",
|
||||
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
|
||||
|
||||
@@ -1503,8 +1503,8 @@ using OptimizeCopyFn =
|
||||
|
||||
/// Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and
|
||||
/// InsertSliceOp. For now, only constant padding values are supported.
|
||||
struct GeneralizePadOpPattern : public OpRewritePattern<tensor::PadOp> {
|
||||
GeneralizePadOpPattern(MLIRContext *context, PatternBenefit benefit = 1)
|
||||
struct DecomposePadOpPattern : public OpRewritePattern<tensor::PadOp> {
|
||||
DecomposePadOpPattern(MLIRContext *context, PatternBenefit benefit = 1)
|
||||
: OpRewritePattern<tensor::PadOp>(context, benefit) {}
|
||||
LogicalResult matchAndRewrite(tensor::PadOp padOp,
|
||||
PatternRewriter &rewriter) const override;
|
||||
@@ -1688,6 +1688,10 @@ void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
|
||||
/// outer dims to be unit.
|
||||
void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns);
|
||||
|
||||
/// Populates patterns to decompose tensor.pad into e.g.
|
||||
/// tensor.empty, linalg.fill, tensor.insert_slice.
|
||||
void populateDecomposePadPatterns(RewritePatternSet &patterns);
|
||||
|
||||
/// Populates patterns to transform linalg.conv_2d_xxx operations into
|
||||
/// linalg.generic (for img2col packing) and linalg.matmul.
|
||||
/// \see rewriteInIm2Col for more details.
|
||||
|
||||
@@ -25,5 +25,7 @@ using namespace mlir;
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void mlir::populateTensorToLinalgPatterns(RewritePatternSet &patterns) {
|
||||
patterns.add<mlir::linalg::GeneralizePadOpPattern>(patterns.getContext());
|
||||
// TODO: Add the remaining patterns, e.g. to decompose Pack/Unpack Ops.
|
||||
// Alternatively, delete this file.
|
||||
patterns.add<mlir::linalg::DecomposePadOpPattern>(patterns.getContext());
|
||||
}
|
||||
|
||||
@@ -234,6 +234,11 @@ void transform::ApplyDecomposeTensorPackUnpackPatternsOp::populatePatterns(
|
||||
linalg::populateDecomposePackUnpackPatterns(patterns);
|
||||
}
|
||||
|
||||
void transform::ApplyDecomposeTensorPadPatternsOp::populatePatterns(
|
||||
RewritePatternSet &patterns) {
|
||||
linalg::populateDecomposePadPatterns(patterns);
|
||||
}
|
||||
|
||||
void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns(
|
||||
RewritePatternSet &patterns) {
|
||||
linalg::ControlDropUnitDims options;
|
||||
@@ -3491,8 +3496,12 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
|
||||
// Add misc. vectorization patterns (e.g. for tensor.insert_slice)
|
||||
linalg::populateInsertSliceVectorizationPatterns(patterns);
|
||||
|
||||
if (getVectorizePadding())
|
||||
if (getVectorizePadding()) {
|
||||
linalg::populatePadOpVectorizationPatterns(patterns);
|
||||
// This creates an alternative path for lowering tensor.pad - by
|
||||
// decomposing it into e.g. linalg.fill.
|
||||
linalg::populateDecomposePadPatterns(patterns);
|
||||
}
|
||||
vector::populateVectorStepLoweringPatterns(patterns);
|
||||
|
||||
TrackingListener listener(state, *this);
|
||||
|
||||
@@ -921,7 +921,7 @@ LogicalResult mlir::linalg::CopyVectorizationPattern::matchAndRewrite(
|
||||
|
||||
/// Filling `dest` using FillOp constant padding value if possible.
|
||||
/// Otherwise, generate a tensor::GenerateOp.
|
||||
Value GeneralizePadOpPattern::createFillOrGenerateOp(
|
||||
Value DecomposePadOpPattern::createFillOrGenerateOp(
|
||||
RewriterBase &rewriter, tensor::PadOp padOp, Value dest,
|
||||
const SmallVector<Value> &dynSizes) const {
|
||||
auto padValue = padOp.getConstantPaddingValue();
|
||||
@@ -938,8 +938,8 @@ Value GeneralizePadOpPattern::createFillOrGenerateOp(
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
GeneralizePadOpPattern::matchAndRewrite(tensor::PadOp padOp,
|
||||
PatternRewriter &rewriter) const {
|
||||
DecomposePadOpPattern::matchAndRewrite(tensor::PadOp padOp,
|
||||
PatternRewriter &rewriter) const {
|
||||
// Given an OpFoldResult, return an index-typed value.
|
||||
auto getIdxValue = [&](OpFoldResult ofr) {
|
||||
if (auto val = llvm::dyn_cast_if_present<Value>(ofr))
|
||||
@@ -1623,3 +1623,7 @@ void linalg::populateDecomposePackUnpackPatterns(RewritePatternSet &patterns) {
|
||||
// TODO: Add and test patterns for tensor.unpack
|
||||
patterns.add<DecomposeOuterUnitDimsPackOpPattern>(patterns.getContext());
|
||||
}
|
||||
|
||||
void linalg::populateDecomposePadPatterns(RewritePatternSet &patterns) {
|
||||
patterns.add<DecomposePadOpPattern>(patterns.getContext());
|
||||
}
|
||||
|
||||
@@ -2770,12 +2770,6 @@ void mlir::linalg::populateInsertSliceVectorizationPatterns(
|
||||
|
||||
void mlir::linalg::populatePadOpVectorizationPatterns(
|
||||
RewritePatternSet &patterns, PatternBenefit baseBenefit) {
|
||||
// TODO: The following pattern implements "decomposition" and
|
||||
// optional "vectorization". Seperate "decomposition" into a sepereate
|
||||
// pre-processing pattern group.
|
||||
patterns.add<GeneralizePadOpPattern>(patterns.getContext(), baseBenefit);
|
||||
|
||||
// Try these specialized patterns first before resorting to the generic one.
|
||||
patterns.add<PadOpVectorizationWithTransferReadPattern,
|
||||
PadOpVectorizationWithTransferWritePattern,
|
||||
PadOpVectorizationWithInsertSlicePattern>(
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: mlir-opt -split-input-file --test-linalg-transform-patterns="test-generalize-pad-tensor" %s | FileCheck %s
|
||||
// RUN: mlir-opt -split-input-file --test-linalg-transform-patterns="test-decompose-pad-tensor" %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @generalize_pad_tensor_static_shape(
|
||||
// CHECK-SAME: %[[IN:.*]]: tensor<1x28x28x1xf32>) -> tensor<1x32x32x1xf32> {
|
||||
@@ -202,6 +202,8 @@ module attributes {transform.with_named_sequence} {
|
||||
%func_op = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.op<"func.func">
|
||||
|
||||
transform.apply_patterns to %func_op {
|
||||
// TODO: Split into two tests, one for each pattern
|
||||
transform.apply_patterns.linalg.decompose_pad
|
||||
transform.apply_patterns.linalg.pad_vectorization
|
||||
} : !transform.op<"func.func">
|
||||
transform.yield
|
||||
@@ -236,6 +238,8 @@ module attributes {transform.with_named_sequence} {
|
||||
%func_op = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.op<"func.func">
|
||||
|
||||
transform.apply_patterns to %func_op {
|
||||
// TODO: Split into two tests, one for each pattern
|
||||
transform.apply_patterns.linalg.decompose_pad
|
||||
transform.apply_patterns.linalg.pad_vectorization
|
||||
} : !transform.op<"func.func">
|
||||
transform.yield
|
||||
@@ -270,6 +274,8 @@ module attributes {transform.with_named_sequence} {
|
||||
%func_op = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.op<"func.func">
|
||||
|
||||
transform.apply_patterns to %func_op {
|
||||
// TODO: Split into two tests, one for each pattern
|
||||
transform.apply_patterns.linalg.decompose_pad
|
||||
transform.apply_patterns.linalg.pad_vectorization
|
||||
} : !transform.op<"func.func">
|
||||
transform.yield
|
||||
|
||||
@@ -70,8 +70,8 @@ struct TestLinalgTransforms
|
||||
llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction "
|
||||
"in vector.contract form"),
|
||||
llvm::cl::init(false)};
|
||||
Option<bool> testGeneralizePadTensor{
|
||||
*this, "test-generalize-pad-tensor",
|
||||
Option<bool> testDecomposePadTensor{
|
||||
*this, "test-decompose-pad-tensor",
|
||||
llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
|
||||
llvm::cl::init(false)};
|
||||
Option<bool> testDecomposeTensorPackOp{
|
||||
@@ -166,9 +166,9 @@ static void applyLinalgToVectorPatterns(func::FuncOp funcOp) {
|
||||
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
|
||||
}
|
||||
|
||||
static void applyGeneralizePadTensorPatterns(func::FuncOp funcOp) {
|
||||
static void applyDecomposePadPatterns(func::FuncOp funcOp) {
|
||||
RewritePatternSet patterns(funcOp.getContext());
|
||||
patterns.add<GeneralizePadOpPattern>(funcOp.getContext());
|
||||
patterns.add<DecomposePadOpPattern>(funcOp.getContext());
|
||||
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
|
||||
}
|
||||
|
||||
@@ -235,8 +235,8 @@ void TestLinalgTransforms::runOnOperation() {
|
||||
return applyVectorTransferForwardingPatterns(getOperation());
|
||||
if (testGenericToVectorPattern)
|
||||
return applyLinalgToVectorPatterns(getOperation());
|
||||
if (testGeneralizePadTensor)
|
||||
return applyGeneralizePadTensorPatterns(getOperation());
|
||||
if (testDecomposePadTensor)
|
||||
return applyDecomposePadPatterns(getOperation());
|
||||
if (testDecomposeTensorPackOp)
|
||||
return applyDecomposeTensorPackPatterns(getOperation());
|
||||
if (testDecomposeTensorUnPackOp)
|
||||
|
||||
Reference in New Issue
Block a user