Add support for specializing linalg.broadcast and linalg.transform from generic. Also, does some refactoring to reuse specialization checks, migrating some common uses to op interface methods.
359 lines
14 KiB
C++
359 lines
14 KiB
C++
//===- Specialize.cpp - linalg generic ops to named ops ------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements a method to specialize generic operations to named
|
|
// operations. Conceptually it is the opposite of generalize.cpp.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Complex/IR/Complex.h"
|
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
|
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
|
|
#include "mlir/Dialect/Linalg/Passes.h"
|
|
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
|
#include "mlir/Dialect/Math/IR/Math.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/Support/TypeID.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
#include "llvm/Support/Debug.h"
|
|
|
|
namespace mlir {
|
|
#define GEN_PASS_DEF_LINALGSPECIALIZEGENERICOPSPASS
|
|
#include "mlir/Dialect/Linalg/Passes.h.inc"
|
|
} // namespace mlir
|
|
|
|
#define DEBUG_TYPE "linalg-specialization"
|
|
|
|
#define REPLACE_BINARY_OP(NEWOP, OPERANDS_SWAP) \
|
|
(rewriter.replaceOpWithNewOp<NEWOP>( \
|
|
genericOp, \
|
|
ValueRange{genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 1 : 0], \
|
|
genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 0 : 1]}, \
|
|
ValueRange{genericOp.getDpsInits()[0]}))
|
|
|
|
#define REPLACE_UNARY_OP(NEWOP) \
|
|
(rewriter.replaceOpWithNewOp<NEWOP>(genericOp, \
|
|
ValueRange{genericOp.getDpsInputs()[0]}, \
|
|
ValueRange{genericOp.getDpsInits()[0]}))
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::linalg;
|
|
|
|
// Given a elementwise single binary linalg generic op, checks whether the
|
|
// binary op accesses operands as swapped. e.g.
|
|
// this differentiates between a linalg-generic body that contains:
|
|
// ^bb0(%a: f32, %b: f32, %c : f32):
|
|
// %0 = arith.subf %a, %b : f32
|
|
// linalg.yield %0: f32
|
|
// against:
|
|
// ^bb0(%a: f32, %b: f32, %c : f32):
|
|
// %0 = arith.subf %b, %a : f32
|
|
// linalg.yield %0: f32
|
|
// Former is linalg.sub(a,b), latter is linalg.sub(b,a).
|
|
static bool areBinOpsSwapped(GenericOp genericOp) {
|
|
Block *body = genericOp.getBody();
|
|
Operation *op = &body->front();
|
|
bool swapped = false;
|
|
if (op->getOpOperand(0).get() != body->getArgument(0)) {
|
|
swapped = true;
|
|
assert(op->getOpOperand(0).get() == body->getArgument(1) &&
|
|
op->getOpOperand(1).get() == body->getArgument(0) &&
|
|
"binary op uses just one block arg");
|
|
}
|
|
return swapped;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Specialize linalg generic to matmul variants.
|
|
//===----------------------------------------------------------------------===//
|
|
/// Identifies linalg.generic that is essentially named op of the form:
|
|
// ` linalg.{batch_}?matmul{_transpose_a | _transpose_b}? `
|
|
//
|
|
// It is possible that a linalg.generic may be implementing a matmul but not
|
|
// in a straight-forward way e.g. below is matrix multiply over some slice
|
|
// ```
|
|
// %0 = linalg.generic {
|
|
// indexing_maps = [affine_map<(d0, d1, d2) -> (3, d1, d0)>,
|
|
// affine_map<(d0, d1, d2) -> (d0, 5, d2)>,
|
|
// affine_map<(d0, d1, d2) -> (d2, d1, 13)>],
|
|
// iterator_types = ["parallel", "parallel", "parallel"]}
|
|
// ins(%A, %B : tensor<20x20x20xf32>, tensor<20x20x20xf32>)
|
|
// outs(%C : tensor<20x20x20xf32>) {
|
|
// ^bb0(%a: f32, %b: f32, %c : f32):
|
|
// %mul = arith.mulf %a, %b : f32
|
|
// %add = arith.addf %mul, %c : f32
|
|
// linalg.yield %add : f32
|
|
// } -> tensor<20x20x20xf32>
|
|
// ```
|
|
// It is not possible to represent above as named op.
|
|
// e.g. linalg.batch_matmul(%A, %B : tensor<20x20x20xf32>, ...) is
|
|
// not the same as linalg.generic above.
|
|
namespace {
|
|
enum class IndexMatchResult {
|
|
Match = 0, // identity map.
|
|
Transposed, // transposed map.
|
|
Mismatch // none of the above.
|
|
};
|
|
|
|
// Checks whether the input Affine `map` contains two consecutive dims that
|
|
// can be interpreted as accessing a 2D matrix. It is assumed that the row
|
|
// column dimension are adjacent axis (in this order) and start at
|
|
// `rowDimIdx` in the input map.
|
|
//
|
|
// e.g. consider A matrix in `C[M,N] = A[M,K] * B[K,N]`. We will check
|
|
// whether the map of A is identity (match), transposed, or something
|
|
// completely different (mis-match). Similar for B and C.
|
|
static IndexMatchResult matchOperandMap(AffineMap map, unsigned rowDimIdx,
|
|
unsigned expectedPosOfRowDim,
|
|
unsigned expectedPosOfColDim) {
|
|
// Get the matrix multiply indices. They are past the batch indices.
|
|
auto exprOfRowDim = map.getResults()[rowDimIdx];
|
|
auto exprOfColDim = map.getResults()[rowDimIdx + 1];
|
|
|
|
// They should be pure dimension ids.
|
|
if (exprOfRowDim.getKind() != AffineExprKind::DimId ||
|
|
exprOfColDim.getKind() != AffineExprKind::DimId)
|
|
return IndexMatchResult::Mismatch;
|
|
|
|
auto posRowDim = cast<AffineDimExpr>(exprOfRowDim).getPosition();
|
|
auto posColDim = cast<AffineDimExpr>(exprOfColDim).getPosition();
|
|
|
|
if (expectedPosOfRowDim == posRowDim && expectedPosOfColDim == posColDim)
|
|
return IndexMatchResult::Match;
|
|
|
|
if (expectedPosOfRowDim == posColDim && expectedPosOfColDim == posRowDim)
|
|
return IndexMatchResult::Transposed;
|
|
|
|
return IndexMatchResult::Mismatch;
|
|
}
|
|
|
|
// Replaces genericOp with `NamedOpTy` op, supplied as a template arg.
|
|
// All the variants expressed as pseudo regular expression:
|
|
// `linalg.{batch_}?matmul{_transpose_a | _transpose_b}?`
|
|
// have same number of ins/out, so its easy to stamp different versions.
|
|
template <typename NamedOpTy>
|
|
static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op) {
|
|
LinalgOp namedOp = rewriter.replaceOpWithNewOp<NamedOpTy>(
|
|
op, ValueRange{op.getDpsInputs()[0], op.getDpsInputs()[1]},
|
|
ValueRange{op.getDpsInits()[0]});
|
|
return namedOp;
|
|
}
|
|
|
|
// Converts linalg.generic to named linalg.*matmul* where possible.
|
|
static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
|
|
GenericOp genericOp) {
|
|
if (genericOp.getNumDpsInputs() != 2 || genericOp.getNumDpsInits() != 1)
|
|
return failure();
|
|
|
|
// Early exit if not projected permutations.
|
|
auto mapRange = genericOp.getIndexingMapsArray();
|
|
if (llvm::any_of(mapRange,
|
|
[](AffineMap m) { return !m.isProjectedPermutation(); }))
|
|
return failure();
|
|
|
|
// Linalg generic contraction can be across multiple axis e.g.
|
|
// ```
|
|
// linalg.generic
|
|
// {indexing_maps = [affine_map<(m, n, k1, k2) -> (m, k1, k2)>,
|
|
// affine_map<(m, n, k1, k2) -> (k2, k1, n)>,
|
|
// affine_map<(m, n, k1, k2) -> (m, n)>],
|
|
// iterator_types = ["parallel", "parallel",
|
|
// "reduction", "reduction"]}
|
|
// ins(%A, %B : tensor<10x20x30xf32>, tensor<30x20x40xf32>)
|
|
// outs(%C : tensor<10x40xf32>) {
|
|
// ^bb0(%a: f32, %b: f32, %c: f32):
|
|
// %1 = arith.mulf %a, %b : f32
|
|
// %2 = arith.addf %c, %1 : f32
|
|
// linalg.yield %2 : f32
|
|
// } -> tensor<10x40xf32>
|
|
// ```
|
|
// In above contraction, there are two reduction dimensions {k1, k2}
|
|
// and although a valid linalg contraction, it is not a named-op
|
|
// matrix multiply kind. Therefore, reject multi-dim reduction.
|
|
auto res = inferContractionDims(genericOp);
|
|
if (!succeeded(res))
|
|
return failure();
|
|
auto dims = *res;
|
|
if (dims.m.size() != 1 || dims.n.size() != 1 || dims.k.size() != 1)
|
|
return failure();
|
|
|
|
if (!mlir::linalg::detail::isContractionBody(
|
|
*genericOp.getBlock(), [](Operation *first, Operation *second) {
|
|
if ((isa<arith::MulFOp>(first) && isa<arith::AddFOp>(second)) ||
|
|
(isa<arith::MulIOp>(first) && isa<arith::AddIOp>(second)) ||
|
|
(isa<complex::MulOp>(first) && isa<complex::AddOp>(second)))
|
|
return true;
|
|
return false;
|
|
}))
|
|
return failure();
|
|
|
|
// Check rank of operands
|
|
auto indexingMaps = genericOp.getIndexingMapsArray();
|
|
if (llvm::any_of(indexingMaps, [&dims](AffineMap m) {
|
|
return m.getResults().size() !=
|
|
dims.batch.size() + 2 /* any two of {m,n,k} */;
|
|
}))
|
|
return failure();
|
|
|
|
auto numOfBatchDims = dims.batch.size();
|
|
if (indexingMaps[0].getNumDims() != numOfBatchDims + 3)
|
|
return failure();
|
|
|
|
if (numOfBatchDims) {
|
|
// Each operand in a linalg generic contraction could express different
|
|
// permutations for its batch dimension. But for named op it must be
|
|
// identity since separate maps are not specified.
|
|
if (llvm::any_of(indexingMaps, [numOfBatchDims](AffineMap m) {
|
|
for (unsigned i = 0; i < numOfBatchDims; ++i) {
|
|
auto expr = m.getResults()[i];
|
|
if (expr.getKind() != AffineExprKind::DimId ||
|
|
cast<AffineDimExpr>(expr).getPosition() != i)
|
|
return true;
|
|
}
|
|
return false;
|
|
}))
|
|
return failure();
|
|
}
|
|
|
|
auto a =
|
|
matchOperandMap(indexingMaps[0], numOfBatchDims, dims.m[0], dims.k[0]);
|
|
auto b =
|
|
matchOperandMap(indexingMaps[1], numOfBatchDims, dims.k[0], dims.n[0]);
|
|
auto c =
|
|
matchOperandMap(indexingMaps[2], numOfBatchDims, dims.m[0], dims.n[0]);
|
|
|
|
if (llvm::is_contained({a, b, c}, IndexMatchResult::Mismatch))
|
|
return failure();
|
|
|
|
if (c != IndexMatchResult::Match ||
|
|
(a == IndexMatchResult::Transposed && b == IndexMatchResult::Transposed))
|
|
return failure();
|
|
|
|
/// Codegen the different matmul variants.
|
|
if (numOfBatchDims) {
|
|
if (a == IndexMatchResult::Transposed)
|
|
return replaceWithMatmulVariant<BatchMatmulTransposeAOp>(rewriter,
|
|
genericOp);
|
|
if (b == IndexMatchResult::Transposed)
|
|
return replaceWithMatmulVariant<BatchMatmulTransposeBOp>(rewriter,
|
|
genericOp);
|
|
return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp);
|
|
}
|
|
|
|
if (a == IndexMatchResult::Transposed)
|
|
return replaceWithMatmulVariant<MatmulTransposeAOp>(rewriter, genericOp);
|
|
if (b == IndexMatchResult::Transposed)
|
|
return replaceWithMatmulVariant<MatmulTransposeBOp>(rewriter, genericOp);
|
|
return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
|
|
}
|
|
|
|
} // namespace
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Categorize linalg generic to named op where possible.
|
|
//===----------------------------------------------------------------------===//
|
|
FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
|
|
GenericOp genericOp) {
|
|
// Copy
|
|
if (isaCopyOpInterface(genericOp)) {
|
|
LinalgOp namedOp = rewriter.replaceOpWithNewOp<CopyOp>(
|
|
genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
|
|
return namedOp;
|
|
}
|
|
|
|
// Fill
|
|
if (isaFillOpInterface(genericOp)) {
|
|
LinalgOp namedOp = rewriter.replaceOpWithNewOp<FillOp>(
|
|
genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
|
|
return namedOp;
|
|
}
|
|
|
|
// Broadcast
|
|
std::optional<SmallVector<int64_t>> equivalentToBroadcast =
|
|
isaBroadcastOpInterface(genericOp);
|
|
if (equivalentToBroadcast) {
|
|
auto dims = *equivalentToBroadcast;
|
|
LinalgOp namedOp = rewriter.replaceOpWithNewOp<BroadcastOp>(
|
|
genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0],
|
|
dims);
|
|
return namedOp;
|
|
}
|
|
|
|
// Transpose
|
|
std::optional<SmallVector<int64_t>> equivalentToTranspose =
|
|
isaTransposeOpInterface(genericOp);
|
|
if (equivalentToTranspose) {
|
|
auto permutation = *equivalentToTranspose;
|
|
LinalgOp namedOp = rewriter.replaceOpWithNewOp<TransposeOp>(
|
|
genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0],
|
|
permutation);
|
|
return namedOp;
|
|
}
|
|
|
|
// Elementwise Unary
|
|
if (isaElemwiseSingleUnaryOpInterface(genericOp)) {
|
|
Operation *op = &genericOp.getBody()->front();
|
|
if (isa<math::ExpOp>(op)) {
|
|
LinalgOp namedOp = REPLACE_UNARY_OP(ExpOp);
|
|
return namedOp;
|
|
}
|
|
}
|
|
|
|
// Elementwise Binary
|
|
if (isaElemwiseSingleBinaryOpInterface(genericOp)) {
|
|
bool swap = areBinOpsSwapped(genericOp);
|
|
Operation *op = &genericOp.getBody()->front();
|
|
if (isa<arith::AddFOp>(op)) {
|
|
LinalgOp namedOp = REPLACE_BINARY_OP(AddOp, swap);
|
|
return namedOp;
|
|
}
|
|
if (isa<arith::SubFOp>(op)) {
|
|
LinalgOp namedOp = REPLACE_BINARY_OP(SubOp, swap);
|
|
return namedOp;
|
|
}
|
|
if (isa<arith::MulFOp>(op)) {
|
|
LinalgOp namedOp = REPLACE_BINARY_OP(MulOp, swap);
|
|
return namedOp;
|
|
}
|
|
if (isa<arith::DivFOp>(op)) {
|
|
LinalgOp namedOp = REPLACE_BINARY_OP(DivOp, swap);
|
|
return namedOp;
|
|
}
|
|
}
|
|
|
|
// Contraction - e.g. matmul
|
|
if (isaContractionOpInterface(genericOp)) {
|
|
return specializeLinalgContractions(rewriter, genericOp);
|
|
}
|
|
return failure();
|
|
}
|
|
|
|
namespace {
|
|
struct LinalgSpecializeGenericOpsPass
|
|
: public impl::LinalgSpecializeGenericOpsPassBase<
|
|
LinalgSpecializeGenericOpsPass> {
|
|
|
|
using impl::LinalgSpecializeGenericOpsPassBase<
|
|
LinalgSpecializeGenericOpsPass>::LinalgSpecializeGenericOpsPassBase;
|
|
void runOnOperation() override;
|
|
};
|
|
} // namespace
|
|
|
|
void LinalgSpecializeGenericOpsPass::runOnOperation() {
|
|
RewritePatternSet patterns(&getContext());
|
|
populateLinalgGenericOpsSpecializationPatterns(patterns);
|
|
|
|
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
|
|
signalPassFailure();
|
|
}
|
|
|
|
void mlir::linalg::populateLinalgGenericOpsSpecializationPatterns(
|
|
RewritePatternSet &patterns) {
|
|
patterns.add<LinalgSpecializationPattern>(patterns.getContext());
|
|
}
|