The greedy rewriter is used in many different flows and it has a lot of
convenience (work list management, debugging actions, tracing, etc). But
it combines two kinds of greedy behavior 1) how ops are matched, 2)
folding wherever it can.
These are independent forms of greedy and leads to inefficiency. E.g.,
cases where one need to create different phases in lowering and is
required to applying patterns in specific order split across different
passes. Using the driver one ends up needlessly retrying folding/having
multiple rounds of folding attempts, where one final run would have
sufficed.
Of course folks can locally avoid this behavior by just building their
own, but this is also a common requested feature that folks keep on
working around locally in suboptimal ways.
For downstream users, there should be no behavioral change. Updating
from the deprecated should just be a find and replace (e.g., `find ./
-type f -exec sed -i
's|applyPatternsAndFoldGreedily|applyPatternsGreedily|g' {} \;` variety)
as the API arguments hasn't changed between the two.
360 lines
14 KiB
C++
360 lines
14 KiB
C++
//===- Specialize.cpp - linalg generic ops to named ops ------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements a method to specialize generic operations to named
|
|
// operations. Conceptually it is the opposite of generalize.cpp.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Complex/IR/Complex.h"
|
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
|
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
|
|
#include "mlir/Dialect/Linalg/Passes.h"
|
|
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
|
#include "mlir/Dialect/Math/IR/Math.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/Support/TypeID.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
#include "llvm/Support/Debug.h"
|
|
|
|
namespace mlir {
|
|
#define GEN_PASS_DEF_LINALGSPECIALIZEGENERICOPSPASS
|
|
#include "mlir/Dialect/Linalg/Passes.h.inc"
|
|
} // namespace mlir
|
|
|
|
#define DEBUG_TYPE "linalg-specialization"
|
|
|
|
#define REPLACE_BINARY_OP(NEWOP, OPERANDS_SWAP) \
|
|
(rewriter.replaceOpWithNewOp<NEWOP>( \
|
|
genericOp, \
|
|
ValueRange{genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 1 : 0], \
|
|
genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 0 : 1]}, \
|
|
ValueRange{genericOp.getDpsInits()[0]}))
|
|
|
|
#define REPLACE_UNARY_OP(NEWOP) \
|
|
(rewriter.replaceOpWithNewOp<NEWOP>(genericOp, \
|
|
ValueRange{genericOp.getDpsInputs()[0]}, \
|
|
ValueRange{genericOp.getDpsInits()[0]}))
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::linalg;
|
|
|
|
// Given a elementwise single binary linalg generic op, checks whether the
|
|
// binary op accesses operands as swapped. e.g.
|
|
// this differentiates between a linalg-generic body that contains:
|
|
// ^bb0(%a: f32, %b: f32, %c : f32):
|
|
// %0 = arith.subf %a, %b : f32
|
|
// linalg.yield %0: f32
|
|
// against:
|
|
// ^bb0(%a: f32, %b: f32, %c : f32):
|
|
// %0 = arith.subf %b, %a : f32
|
|
// linalg.yield %0: f32
|
|
// Former is linalg.sub(a,b), latter is linalg.sub(b,a).
|
|
static bool areBinOpsSwapped(GenericOp genericOp) {
|
|
Block *body = genericOp.getBody();
|
|
Operation *op = &body->front();
|
|
bool swapped = false;
|
|
if (op->getOpOperand(0).get() != body->getArgument(0)) {
|
|
swapped = true;
|
|
assert(op->getOpOperand(0).get() == body->getArgument(1) &&
|
|
op->getOpOperand(1).get() == body->getArgument(0) &&
|
|
"binary op uses just one block arg");
|
|
}
|
|
return swapped;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Specialize linalg generic to matmul variants.
|
|
//===----------------------------------------------------------------------===//
|
|
/// Identifies linalg.generic that is essentially named op of the form:
|
|
// ` linalg.{batch_}?matmul{_transpose_a | _transpose_b}? `
|
|
//
|
|
// It is possible that a linalg.generic may be implementing a matmul but not
|
|
// in a straight-forward way e.g. below is matrix multiply over some slice
|
|
// ```
|
|
// %0 = linalg.generic {
|
|
// indexing_maps = [affine_map<(d0, d1, d2) -> (3, d1, d0)>,
|
|
// affine_map<(d0, d1, d2) -> (d0, 5, d2)>,
|
|
// affine_map<(d0, d1, d2) -> (d2, d1, 13)>],
|
|
// iterator_types = ["parallel", "parallel", "parallel"]}
|
|
// ins(%A, %B : tensor<20x20x20xf32>, tensor<20x20x20xf32>)
|
|
// outs(%C : tensor<20x20x20xf32>) {
|
|
// ^bb0(%a: f32, %b: f32, %c : f32):
|
|
// %mul = arith.mulf %a, %b : f32
|
|
// %add = arith.addf %mul, %c : f32
|
|
// linalg.yield %add : f32
|
|
// } -> tensor<20x20x20xf32>
|
|
// ```
|
|
// It is not possible to represent above as named op.
|
|
// e.g. linalg.batch_matmul(%A, %B : tensor<20x20x20xf32>, ...) is
|
|
// not the same as linalg.generic above.
|
|
namespace {
|
|
enum class IndexMatchResult {
|
|
Match = 0, // identity map.
|
|
Transposed, // transposed map.
|
|
Mismatch // none of the above.
|
|
};
|
|
|
|
// Checks whether the input Affine `map` contains two consecutive dims that
|
|
// can be interpreted as accessing a 2D matrix. It is assumed that the row
|
|
// column dimension are adjacent axis (in this order) and start at
|
|
// `rowDimIdx` in the input map.
|
|
//
|
|
// e.g. consider A matrix in `C[M,N] = A[M,K] * B[K,N]`. We will check
|
|
// whether the map of A is identity (match), transposed, or something
|
|
// completely different (mis-match). Similar for B and C.
|
|
static IndexMatchResult matchOperandMap(AffineMap map, unsigned rowDimIdx,
|
|
unsigned expectedPosOfRowDim,
|
|
unsigned expectedPosOfColDim) {
|
|
// Get the matrix multiply indices. They are past the batch indices.
|
|
auto exprOfRowDim = map.getResults()[rowDimIdx];
|
|
auto exprOfColDim = map.getResults()[rowDimIdx + 1];
|
|
|
|
// They should be pure dimension ids.
|
|
if (exprOfRowDim.getKind() != AffineExprKind::DimId ||
|
|
exprOfColDim.getKind() != AffineExprKind::DimId)
|
|
return IndexMatchResult::Mismatch;
|
|
|
|
auto posRowDim = cast<AffineDimExpr>(exprOfRowDim).getPosition();
|
|
auto posColDim = cast<AffineDimExpr>(exprOfColDim).getPosition();
|
|
|
|
if (expectedPosOfRowDim == posRowDim && expectedPosOfColDim == posColDim)
|
|
return IndexMatchResult::Match;
|
|
|
|
if (expectedPosOfRowDim == posColDim && expectedPosOfColDim == posRowDim)
|
|
return IndexMatchResult::Transposed;
|
|
|
|
return IndexMatchResult::Mismatch;
|
|
}
|
|
|
|
// Replaces genericOp with `NamedOpTy` op, supplied as a template arg.
|
|
// All the variants expressed as pseudo regular expression:
|
|
// `linalg.{batch_}?matmul{_transpose_a | _transpose_b}?`
|
|
// have same number of ins/out, so its easy to stamp different versions.
|
|
template <typename NamedOpTy>
|
|
static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op) {
|
|
LinalgOp namedOp = rewriter.replaceOpWithNewOp<NamedOpTy>(
|
|
op, ValueRange{op.getDpsInputs()[0], op.getDpsInputs()[1]},
|
|
ValueRange{op.getDpsInits()[0]});
|
|
return namedOp;
|
|
}
|
|
|
|
// Converts linalg.generic to named linalg.*matmul* where possible.
|
|
static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
|
|
GenericOp genericOp) {
|
|
if (genericOp.getNumDpsInputs() != 2 || genericOp.getNumDpsInits() != 1)
|
|
return failure();
|
|
|
|
// Early exit if not projected permutations.
|
|
auto mapRange = genericOp.getIndexingMapsArray();
|
|
if (llvm::any_of(mapRange,
|
|
[](AffineMap m) { return !m.isProjectedPermutation(); }))
|
|
return failure();
|
|
|
|
// Linalg generic contraction can be across multiple axis e.g.
|
|
// ```
|
|
// linalg.generic
|
|
// {indexing_maps = [affine_map<(m, n, k1, k2) -> (m, k1, k2)>,
|
|
// affine_map<(m, n, k1, k2) -> (k2, k1, n)>,
|
|
// affine_map<(m, n, k1, k2) -> (m, n)>],
|
|
// iterator_types = ["parallel", "parallel",
|
|
// "reduction", "reduction"]}
|
|
// ins(%A, %B : tensor<10x20x30xf32>, tensor<30x20x40xf32>)
|
|
// outs(%C : tensor<10x40xf32>) {
|
|
// ^bb0(%a: f32, %b: f32, %c: f32):
|
|
// %1 = arith.mulf %a, %b : f32
|
|
// %2 = arith.addf %c, %1 : f32
|
|
// linalg.yield %2 : f32
|
|
// } -> tensor<10x40xf32>
|
|
// ```
|
|
// In above contraction, there are two reduction dimensions {k1, k2}
|
|
// and although a valid linalg contraction, it is not a named-op
|
|
// matrix multiply kind. Therefore, reject multi-dim reduction.
|
|
auto res = inferContractionDims(genericOp);
|
|
if (!succeeded(res))
|
|
return failure();
|
|
auto dims = *res;
|
|
if (dims.m.size() != 1 || dims.n.size() != 1 || dims.k.size() != 1)
|
|
return failure();
|
|
|
|
if (!mlir::linalg::detail::isContractionBody(
|
|
*genericOp.getBlock(), [](Operation *first, Operation *second) {
|
|
if ((isa<arith::MulFOp>(first) && isa<arith::AddFOp>(second)) ||
|
|
(isa<arith::MulIOp>(first) && isa<arith::AddIOp>(second)) ||
|
|
(isa<complex::MulOp>(first) && isa<complex::AddOp>(second)))
|
|
return true;
|
|
return false;
|
|
}))
|
|
return failure();
|
|
|
|
// Check rank of operands
|
|
auto indexingMaps = genericOp.getIndexingMapsArray();
|
|
if (llvm::any_of(indexingMaps, [&dims](AffineMap m) {
|
|
return m.getResults().size() !=
|
|
dims.batch.size() + 2 /* any two of {m,n,k} */;
|
|
}))
|
|
return failure();
|
|
|
|
auto numOfBatchDims = dims.batch.size();
|
|
if (indexingMaps[0].getNumDims() != numOfBatchDims + 3)
|
|
return failure();
|
|
|
|
if (numOfBatchDims) {
|
|
// Each operand in a linalg generic contraction could express different
|
|
// permutations for its batch dimension. But for named op it must be
|
|
// identity since separate maps are not specified.
|
|
if (llvm::any_of(indexingMaps, [numOfBatchDims](AffineMap m) {
|
|
for (unsigned i = 0; i < numOfBatchDims; ++i) {
|
|
auto expr = m.getResults()[i];
|
|
if (expr.getKind() != AffineExprKind::DimId ||
|
|
cast<AffineDimExpr>(expr).getPosition() != i)
|
|
return true;
|
|
}
|
|
return false;
|
|
}))
|
|
return failure();
|
|
}
|
|
|
|
auto a =
|
|
matchOperandMap(indexingMaps[0], numOfBatchDims, dims.m[0], dims.k[0]);
|
|
auto b =
|
|
matchOperandMap(indexingMaps[1], numOfBatchDims, dims.k[0], dims.n[0]);
|
|
auto c =
|
|
matchOperandMap(indexingMaps[2], numOfBatchDims, dims.m[0], dims.n[0]);
|
|
|
|
if (llvm::is_contained({a, b, c}, IndexMatchResult::Mismatch))
|
|
return failure();
|
|
|
|
if (c != IndexMatchResult::Match ||
|
|
(a == IndexMatchResult::Transposed && b == IndexMatchResult::Transposed))
|
|
return failure();
|
|
|
|
/// Codegen the different matmul variants.
|
|
if (numOfBatchDims) {
|
|
if (a == IndexMatchResult::Transposed)
|
|
return replaceWithMatmulVariant<BatchMatmulTransposeAOp>(rewriter,
|
|
genericOp);
|
|
if (b == IndexMatchResult::Transposed)
|
|
return replaceWithMatmulVariant<BatchMatmulTransposeBOp>(rewriter,
|
|
genericOp);
|
|
return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp);
|
|
}
|
|
|
|
if (a == IndexMatchResult::Transposed)
|
|
return replaceWithMatmulVariant<MatmulTransposeAOp>(rewriter, genericOp);
|
|
if (b == IndexMatchResult::Transposed)
|
|
return replaceWithMatmulVariant<MatmulTransposeBOp>(rewriter, genericOp);
|
|
return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
|
|
}
|
|
|
|
} // namespace
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Categorize linalg generic to named op where possible.
|
|
//===----------------------------------------------------------------------===//
|
|
FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
|
|
GenericOp genericOp) {
|
|
// Copy
|
|
if (isaCopyOpInterface(genericOp)) {
|
|
LinalgOp namedOp = rewriter.replaceOpWithNewOp<CopyOp>(
|
|
genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
|
|
return namedOp;
|
|
}
|
|
|
|
// Fill
|
|
if (isaFillOpInterface(genericOp)) {
|
|
LinalgOp namedOp = rewriter.replaceOpWithNewOp<FillOp>(
|
|
genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
|
|
return namedOp;
|
|
}
|
|
|
|
// Broadcast
|
|
std::optional<SmallVector<int64_t>> equivalentToBroadcast =
|
|
isaBroadcastOpInterface(genericOp);
|
|
if (equivalentToBroadcast) {
|
|
auto dims = *equivalentToBroadcast;
|
|
LinalgOp namedOp = rewriter.replaceOpWithNewOp<BroadcastOp>(
|
|
genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0],
|
|
dims);
|
|
return namedOp;
|
|
}
|
|
|
|
// Transpose
|
|
std::optional<SmallVector<int64_t>> equivalentToTranspose =
|
|
isaTransposeOpInterface(genericOp);
|
|
if (equivalentToTranspose) {
|
|
auto permutation = *equivalentToTranspose;
|
|
LinalgOp namedOp = rewriter.replaceOpWithNewOp<TransposeOp>(
|
|
genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0],
|
|
permutation);
|
|
return namedOp;
|
|
}
|
|
|
|
// Elementwise Unary
|
|
if (isaElemwiseSingleUnaryOpInterface(genericOp)) {
|
|
Operation *op = &genericOp.getBody()->front();
|
|
if (isa<math::ExpOp>(op)) {
|
|
LinalgOp namedOp = REPLACE_UNARY_OP(ExpOp);
|
|
return namedOp;
|
|
}
|
|
}
|
|
|
|
// Elementwise Binary
|
|
if (isaElemwiseSingleBinaryOpInterface(genericOp)) {
|
|
bool swap = areBinOpsSwapped(genericOp);
|
|
Operation *op = &genericOp.getBody()->front();
|
|
if (isa<arith::AddFOp>(op)) {
|
|
LinalgOp namedOp = REPLACE_BINARY_OP(AddOp, swap);
|
|
return namedOp;
|
|
}
|
|
if (isa<arith::SubFOp>(op)) {
|
|
LinalgOp namedOp = REPLACE_BINARY_OP(SubOp, swap);
|
|
return namedOp;
|
|
}
|
|
if (isa<arith::MulFOp>(op)) {
|
|
LinalgOp namedOp = REPLACE_BINARY_OP(MulOp, swap);
|
|
return namedOp;
|
|
}
|
|
if (isa<arith::DivFOp>(op)) {
|
|
LinalgOp namedOp = REPLACE_BINARY_OP(DivOp, swap);
|
|
return namedOp;
|
|
}
|
|
}
|
|
|
|
// Contraction - e.g. matmul
|
|
if (isaContractionOpInterface(genericOp)) {
|
|
return specializeLinalgContractions(rewriter, genericOp);
|
|
}
|
|
return failure();
|
|
}
|
|
|
|
namespace {
|
|
struct LinalgSpecializeGenericOpsPass
|
|
: public impl::LinalgSpecializeGenericOpsPassBase<
|
|
LinalgSpecializeGenericOpsPass> {
|
|
|
|
using impl::LinalgSpecializeGenericOpsPassBase<
|
|
LinalgSpecializeGenericOpsPass>::LinalgSpecializeGenericOpsPassBase;
|
|
void runOnOperation() override;
|
|
};
|
|
} // namespace
|
|
|
|
void LinalgSpecializeGenericOpsPass::runOnOperation() {
|
|
RewritePatternSet patterns(&getContext());
|
|
populateLinalgGenericOpsSpecializationPatterns(patterns);
|
|
populateDecomposeProjectedPermutationPatterns(patterns);
|
|
|
|
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
|
|
signalPassFailure();
|
|
}
|
|
|
|
void mlir::linalg::populateLinalgGenericOpsSpecializationPatterns(
|
|
RewritePatternSet &patterns) {
|
|
patterns.add<LinalgSpecializationPattern>(patterns.getContext());
|
|
}
|