This makes ignoring a result explicit by the user, and helps to prevent accidental errors with dropped results. Marking LogicalResult as no discard was always the intention from the beginning, but got lost along the way. Differential Revision: https://reviews.llvm.org/D95841
184 lines
7.0 KiB
C++
184 lines
7.0 KiB
C++
//===- Generalization.cpp - linalg named ops to generic ops --------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements the Linalg generalization pass. It converts named
|
|
// Linalg ops to linalg.generic ops.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "PassDetail.h"
|
|
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
|
#include "mlir/Dialect/Linalg/Passes.h"
|
|
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
|
#include "mlir/EDSC/Builders.h"
|
|
#include "mlir/IR/AffineMap.h"
|
|
#include "mlir/IR/Attributes.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
#include "llvm/ADT/SmallVector.h"
|
|
#include "llvm/Support/Debug.h"
|
|
|
|
#define DEBUG_TYPE "linalg-generalization"
|
|
|
|
using namespace mlir;
|
|
|
|
// Creates a linalg.generic op from the given `namedOp`. Returns a null op if
|
|
// the given `namedOp` does not have a region builder.
|
|
static linalg::GenericOp createGenericOpFromNamedOp(linalg::LinalgOp namedOp,
|
|
OpBuilder &builder) {
|
|
auto regionBuilder = namedOp.getRegionBuilder();
|
|
if (!regionBuilder) {
|
|
LLVM_DEBUG(llvm::dbgs() << "no region builder for op: " << namedOp << "\n");
|
|
return nullptr;
|
|
}
|
|
|
|
SmallVector<AffineMap, 4> indexingMaps = namedOp.getIndexingMaps();
|
|
auto iterators = llvm::to_vector<4>(
|
|
namedOp.iterator_types().getAsValueRange<StringAttr>());
|
|
auto resultTypes = namedOp.getOutputTensorTypes();
|
|
SmallVector<Type, 4> types(resultTypes.begin(), resultTypes.end());
|
|
|
|
return builder.create<linalg::GenericOp>(
|
|
namedOp.getLoc(), types, namedOp.getInputs(), namedOp.getOutputs(),
|
|
indexingMaps, iterators,
|
|
[®ionBuilder](OpBuilder &bodyBuilder, Location loc, ValueRange) {
|
|
edsc::ScopedContext scope(bodyBuilder, loc);
|
|
regionBuilder(*bodyBuilder.getBlock());
|
|
});
|
|
}
|
|
|
|
namespace {
|
|
|
|
/// Base class for all linalg generalization patterns. A subclass must provide
|
|
/// the following method:
|
|
/// linalg::GenericOp createGenericOp(RootOp, PatternRewriter &)
|
|
/// for creating the generic op.
|
|
// TODO: remove this pattern after migrating all manually-written named ops
|
|
// into auto-generated ones.
|
|
template <typename ConcretePattern, typename RootOp>
|
|
struct LinalgGeneralizationPattern : OpRewritePattern<RootOp> {
|
|
LinalgGeneralizationPattern(MLIRContext *context,
|
|
linalg::LinalgTransformationFilter marker,
|
|
PatternBenefit benefit = 1)
|
|
: OpRewritePattern<RootOp>(context, benefit), marker(std::move(marker)) {}
|
|
|
|
LogicalResult matchAndRewrite(RootOp rootOp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto linalgOp = dyn_cast<linalg::LinalgOp>(rootOp.getOperation());
|
|
if (!linalgOp)
|
|
return failure();
|
|
if (failed(marker.checkAndNotify(rewriter, linalgOp)))
|
|
return failure();
|
|
|
|
auto *pattern = static_cast<const ConcretePattern *>(this);
|
|
linalg::GenericOp genericOp = pattern->createGenericOp(rootOp, rewriter);
|
|
if (!genericOp)
|
|
return failure();
|
|
|
|
rewriter.replaceOp(rootOp, genericOp.getResults());
|
|
marker.replaceLinalgTransformationFilter(rewriter,
|
|
genericOp.getOperation());
|
|
return success();
|
|
}
|
|
|
|
private:
|
|
linalg::LinalgTransformationFilter marker;
|
|
};
|
|
|
|
struct GeneralizeConvOp
|
|
: public LinalgGeneralizationPattern<GeneralizeConvOp, linalg::ConvOp> {
|
|
using LinalgGeneralizationPattern::LinalgGeneralizationPattern;
|
|
|
|
linalg::GenericOp createGenericOp(linalg::ConvOp, OpBuilder &rewriter) const;
|
|
};
|
|
|
|
/// Catch-all pattern for converting all named ops with a region builder into
|
|
/// linalg.generic.
|
|
struct LinalgNamedOpGeneralizationPattern : RewritePattern {
|
|
LinalgNamedOpGeneralizationPattern(MLIRContext *context,
|
|
linalg::LinalgTransformationFilter marker,
|
|
PatternBenefit benefit = 1)
|
|
: RewritePattern(benefit, MatchAnyOpTypeTag()),
|
|
marker(std::move(marker)) {}
|
|
|
|
LogicalResult matchAndRewrite(Operation *rootOp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto linalgOp = dyn_cast<linalg::LinalgOp>(rootOp);
|
|
if (!linalgOp)
|
|
return failure();
|
|
if (failed(marker.checkAndNotify(rewriter, linalgOp)))
|
|
return failure();
|
|
|
|
// No nothing to do for linalg.generic and linalg.indexed_generic.
|
|
if (isa<linalg::GenericOp, linalg::IndexedGenericOp>(rootOp))
|
|
return failure();
|
|
|
|
linalg::GenericOp genericOp =
|
|
createGenericOpFromNamedOp(linalgOp, rewriter);
|
|
if (!genericOp)
|
|
return failure();
|
|
|
|
rewriter.replaceOp(rootOp, genericOp.getResults());
|
|
marker.replaceLinalgTransformationFilter(rewriter,
|
|
genericOp.getOperation());
|
|
return success();
|
|
}
|
|
|
|
private:
|
|
linalg::LinalgTransformationFilter marker;
|
|
};
|
|
|
|
struct LinalgGeneralizationPass
|
|
: public LinalgGeneralizationBase<LinalgGeneralizationPass> {
|
|
void runOnFunction() override;
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void LinalgGeneralizationPass::runOnFunction() {
|
|
FuncOp func = getFunction();
|
|
OwningRewritePatternList patterns;
|
|
linalg::populateLinalgConvGeneralizationPatterns(&getContext(), patterns);
|
|
linalg::populateLinalgNamedOpsGeneralizationPatterns(&getContext(), patterns);
|
|
(void)applyPatternsAndFoldGreedily(func.getBody(), std::move(patterns));
|
|
}
|
|
|
|
linalg::GenericOp GeneralizeConvOp::createGenericOp(linalg::ConvOp convOp,
|
|
OpBuilder &builder) const {
|
|
SmallVector<AffineMap, 4> indexingMaps = convOp.getIndexingMaps();
|
|
auto iterators =
|
|
llvm::to_vector<4>(convOp.iterator_types().getAsValueRange<StringAttr>());
|
|
return builder.create<linalg::GenericOp>(
|
|
convOp.getLoc(), /*resultTensorTypes=*/ArrayRef<Type>(),
|
|
convOp.getInputBuffers(), convOp.getOutputBuffers(), indexingMaps,
|
|
iterators,
|
|
[](OpBuilder &bodyBuilder, Location bodyLoc, ValueRange bodyArgs) {
|
|
Value mul =
|
|
bodyBuilder.create<MulFOp>(bodyLoc, bodyArgs[0], bodyArgs[1]);
|
|
Value add = bodyBuilder.create<AddFOp>(bodyLoc, mul, bodyArgs[2]);
|
|
bodyBuilder.create<linalg::YieldOp>(bodyLoc, add);
|
|
});
|
|
}
|
|
|
|
void mlir::linalg::populateLinalgConvGeneralizationPatterns(
|
|
MLIRContext *context, OwningRewritePatternList &patterns,
|
|
linalg::LinalgTransformationFilter marker) {
|
|
patterns.insert<GeneralizeConvOp>(context, marker);
|
|
}
|
|
|
|
void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns(
|
|
MLIRContext *context, OwningRewritePatternList &patterns,
|
|
linalg::LinalgTransformationFilter marker) {
|
|
patterns.insert<LinalgNamedOpGeneralizationPattern>(context, marker);
|
|
}
|
|
|
|
std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgGeneralizationPass() {
|
|
return std::make_unique<LinalgGeneralizationPass>();
|
|
}
|