Files
clang-p2996/mlir/lib/Transforms/SCCP.cpp
Billy Zhu 34a65980d7 [MLIR] Erase location of folded constants (#75415)
Follow up to the discussion from #75258, and serves as an alternate
solution for #74670.

Set the location to Unknown for deduplicated / moved / materialized
constants by OperationFolder. This makes sure that the folded constants
don't end up with an arbitrary location of one of the original ops that
became it, and that hoisted ops don't confuse the stepping order.
2023-12-21 09:54:48 -08:00

136 lines
4.7 KiB
C++

//===- SCCP.cpp - Sparse Conditional Constant Propagation -----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This transformation pass performs a sparse conditional constant propagation
// in MLIR. It identifies values known to be constant, propagates that
// information throughout the IR, and replaces them. This is done with an
// optimistic dataflow analysis that assumes that all values are constant until
// proven otherwise.
//
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/Passes.h"
#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/FoldUtils.h"
namespace mlir {
#define GEN_PASS_DEF_SCCP
#include "mlir/Transforms/Passes.h.inc"
} // namespace mlir
using namespace mlir;
using namespace mlir::dataflow;
//===----------------------------------------------------------------------===//
// SCCP Rewrites
//===----------------------------------------------------------------------===//
/// Replace the given value with a constant if the corresponding lattice
/// represents a constant. Returns success if the value was replaced, failure
/// otherwise.
static LogicalResult replaceWithConstant(DataFlowSolver &solver,
OpBuilder &builder,
OperationFolder &folder, Value value) {
auto *lattice = solver.lookupState<Lattice<ConstantValue>>(value);
if (!lattice || lattice->getValue().isUninitialized())
return failure();
const ConstantValue &latticeValue = lattice->getValue();
if (!latticeValue.getConstantValue())
return failure();
// Attempt to materialize a constant for the given value.
Dialect *dialect = latticeValue.getConstantDialect();
Value constant = folder.getOrCreateConstant(
builder.getInsertionBlock(), dialect, latticeValue.getConstantValue(),
value.getType());
if (!constant)
return failure();
value.replaceAllUsesWith(constant);
return success();
}
/// Rewrite the given regions using the computing analysis. This replaces the
/// uses of all values that have been computed to be constant, and erases as
/// many newly dead operations.
static void rewrite(DataFlowSolver &solver, MLIRContext *context,
MutableArrayRef<Region> initialRegions) {
SmallVector<Block *> worklist;
auto addToWorklist = [&](MutableArrayRef<Region> regions) {
for (Region &region : regions)
for (Block &block : llvm::reverse(region))
worklist.push_back(&block);
};
// An operation folder used to create and unique constants.
OperationFolder folder(context);
OpBuilder builder(context);
addToWorklist(initialRegions);
while (!worklist.empty()) {
Block *block = worklist.pop_back_val();
for (Operation &op : llvm::make_early_inc_range(*block)) {
builder.setInsertionPoint(&op);
// Replace any result with constants.
bool replacedAll = op.getNumResults() != 0;
for (Value res : op.getResults())
replacedAll &=
succeeded(replaceWithConstant(solver, builder, folder, res));
// If all of the results of the operation were replaced, try to erase
// the operation completely.
if (replacedAll && wouldOpBeTriviallyDead(&op)) {
assert(op.use_empty() && "expected all uses to be replaced");
op.erase();
continue;
}
// Add any the regions of this operation to the worklist.
addToWorklist(op.getRegions());
}
// Replace any block arguments with constants.
builder.setInsertionPointToStart(block);
for (BlockArgument arg : block->getArguments())
(void)replaceWithConstant(solver, builder, folder, arg);
}
}
//===----------------------------------------------------------------------===//
// SCCP Pass
//===----------------------------------------------------------------------===//
namespace {
struct SCCP : public impl::SCCPBase<SCCP> {
void runOnOperation() override;
};
} // namespace
void SCCP::runOnOperation() {
Operation *op = getOperation();
DataFlowSolver solver;
solver.load<DeadCodeAnalysis>();
solver.load<SparseConstantPropagation>();
if (failed(solver.initializeAndRun(op)))
return signalPassFailure();
rewrite(solver, op->getContext(), op->getRegions());
}
std::unique_ptr<Pass> mlir::createSCCPPass() {
return std::make_unique<SCCP>();
}