Follow up to the discussion from #75258, and serves as an alternate solution for #74670. Set the location to Unknown for deduplicated / moved / materialized constants by OperationFolder. This makes sure that the folded constants don't end up with an arbitrary location of one of the original ops that became it, and that hoisted ops don't confuse the stepping order.
136 lines
4.7 KiB
C++
136 lines
4.7 KiB
C++
//===- SCCP.cpp - Sparse Conditional Constant Propagation -----------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This transformation pass performs a sparse conditional constant propagation
|
|
// in MLIR. It identifies values known to be constant, propagates that
|
|
// information throughout the IR, and replaces them. This is done with an
|
|
// optimistic dataflow analysis that assumes that all values are constant until
|
|
// proven otherwise.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Transforms/Passes.h"
|
|
|
|
#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
|
|
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/Dialect.h"
|
|
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Transforms/FoldUtils.h"
|
|
|
|
namespace mlir {
|
|
#define GEN_PASS_DEF_SCCP
|
|
#include "mlir/Transforms/Passes.h.inc"
|
|
} // namespace mlir
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::dataflow;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// SCCP Rewrites
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Replace the given value with a constant if the corresponding lattice
|
|
/// represents a constant. Returns success if the value was replaced, failure
|
|
/// otherwise.
|
|
static LogicalResult replaceWithConstant(DataFlowSolver &solver,
|
|
OpBuilder &builder,
|
|
OperationFolder &folder, Value value) {
|
|
auto *lattice = solver.lookupState<Lattice<ConstantValue>>(value);
|
|
if (!lattice || lattice->getValue().isUninitialized())
|
|
return failure();
|
|
const ConstantValue &latticeValue = lattice->getValue();
|
|
if (!latticeValue.getConstantValue())
|
|
return failure();
|
|
|
|
// Attempt to materialize a constant for the given value.
|
|
Dialect *dialect = latticeValue.getConstantDialect();
|
|
Value constant = folder.getOrCreateConstant(
|
|
builder.getInsertionBlock(), dialect, latticeValue.getConstantValue(),
|
|
value.getType());
|
|
if (!constant)
|
|
return failure();
|
|
|
|
value.replaceAllUsesWith(constant);
|
|
return success();
|
|
}
|
|
|
|
/// Rewrite the given regions using the computing analysis. This replaces the
|
|
/// uses of all values that have been computed to be constant, and erases as
|
|
/// many newly dead operations.
|
|
static void rewrite(DataFlowSolver &solver, MLIRContext *context,
|
|
MutableArrayRef<Region> initialRegions) {
|
|
SmallVector<Block *> worklist;
|
|
auto addToWorklist = [&](MutableArrayRef<Region> regions) {
|
|
for (Region ®ion : regions)
|
|
for (Block &block : llvm::reverse(region))
|
|
worklist.push_back(&block);
|
|
};
|
|
|
|
// An operation folder used to create and unique constants.
|
|
OperationFolder folder(context);
|
|
OpBuilder builder(context);
|
|
|
|
addToWorklist(initialRegions);
|
|
while (!worklist.empty()) {
|
|
Block *block = worklist.pop_back_val();
|
|
|
|
for (Operation &op : llvm::make_early_inc_range(*block)) {
|
|
builder.setInsertionPoint(&op);
|
|
|
|
// Replace any result with constants.
|
|
bool replacedAll = op.getNumResults() != 0;
|
|
for (Value res : op.getResults())
|
|
replacedAll &=
|
|
succeeded(replaceWithConstant(solver, builder, folder, res));
|
|
|
|
// If all of the results of the operation were replaced, try to erase
|
|
// the operation completely.
|
|
if (replacedAll && wouldOpBeTriviallyDead(&op)) {
|
|
assert(op.use_empty() && "expected all uses to be replaced");
|
|
op.erase();
|
|
continue;
|
|
}
|
|
|
|
// Add any the regions of this operation to the worklist.
|
|
addToWorklist(op.getRegions());
|
|
}
|
|
|
|
// Replace any block arguments with constants.
|
|
builder.setInsertionPointToStart(block);
|
|
for (BlockArgument arg : block->getArguments())
|
|
(void)replaceWithConstant(solver, builder, folder, arg);
|
|
}
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// SCCP Pass
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
struct SCCP : public impl::SCCPBase<SCCP> {
|
|
void runOnOperation() override;
|
|
};
|
|
} // namespace
|
|
|
|
void SCCP::runOnOperation() {
|
|
Operation *op = getOperation();
|
|
|
|
DataFlowSolver solver;
|
|
solver.load<DeadCodeAnalysis>();
|
|
solver.load<SparseConstantPropagation>();
|
|
if (failed(solver.initializeAndRun(op)))
|
|
return signalPassFailure();
|
|
rewrite(solver, op->getContext(), op->getRegions());
|
|
}
|
|
|
|
std::unique_ptr<Pass> mlir::createSCCPPass() {
|
|
return std::make_unique<SCCP>();
|
|
}
|