The only benefit of FunctionPass is that it filters out function declarations. This isn't enough to justify carrying it around, as we can simplify filter out declarations when necessary within the pass. We can also explore with better scheduling primitives to filter out declarations at the pipeline level in the future. The definition of FunctionPass is left intact for now to allow time for downstream users to migrate. Differential Revision: https://reviews.llvm.org/D117182
84 lines
2.8 KiB
C++
84 lines
2.8 KiB
C++
//===- ShapeToShapeLowering.cpp - Prepare for lowering to Standard --------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "PassDetail.h"
|
|
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
|
#include "mlir/Dialect/Shape/IR/Shape.h"
|
|
#include "mlir/Dialect/Shape/Transforms/Passes.h"
|
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::shape;
|
|
|
|
namespace {
|
|
/// Converts `shape.num_elements` to `shape.reduce`.
|
|
struct NumElementsOpConverter : public OpRewritePattern<NumElementsOp> {
|
|
public:
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(NumElementsOp op,
|
|
PatternRewriter &rewriter) const final;
|
|
};
|
|
} // namespace
|
|
|
|
LogicalResult
|
|
NumElementsOpConverter::matchAndRewrite(NumElementsOp op,
|
|
PatternRewriter &rewriter) const {
|
|
auto loc = op.getLoc();
|
|
Type valueType = op.getResult().getType();
|
|
Value init = op->getDialect()
|
|
->materializeConstant(rewriter, rewriter.getIndexAttr(1),
|
|
valueType, loc)
|
|
->getResult(0);
|
|
ReduceOp reduce = rewriter.create<ReduceOp>(loc, op.getShape(), init);
|
|
|
|
// Generate reduce operator.
|
|
Block *body = reduce.getBody();
|
|
OpBuilder b = OpBuilder::atBlockEnd(body);
|
|
Value product = b.create<MulOp>(loc, valueType, body->getArgument(1),
|
|
body->getArgument(2));
|
|
b.create<shape::YieldOp>(loc, product);
|
|
|
|
rewriter.replaceOp(op, reduce.getResult());
|
|
return success();
|
|
}
|
|
|
|
namespace {
|
|
struct ShapeToShapeLowering
|
|
: public ShapeToShapeLoweringBase<ShapeToShapeLowering> {
|
|
void runOnOperation() override;
|
|
};
|
|
} // namespace
|
|
|
|
void ShapeToShapeLowering::runOnOperation() {
|
|
MLIRContext &ctx = getContext();
|
|
|
|
RewritePatternSet patterns(&ctx);
|
|
populateShapeRewritePatterns(patterns);
|
|
|
|
ConversionTarget target(getContext());
|
|
target.addLegalDialect<arith::ArithmeticDialect, ShapeDialect,
|
|
StandardOpsDialect>();
|
|
target.addIllegalOp<NumElementsOp>();
|
|
if (failed(mlir::applyPartialConversion(getOperation(), target,
|
|
std::move(patterns))))
|
|
signalPassFailure();
|
|
}
|
|
|
|
void mlir::populateShapeRewritePatterns(RewritePatternSet &patterns) {
|
|
patterns.add<NumElementsOpConverter>(patterns.getContext());
|
|
}
|
|
|
|
std::unique_ptr<Pass> mlir::createShapeToShapeLowering() {
|
|
return std::make_unique<ShapeToShapeLowering>();
|
|
}
|