Files
clang-p2996/mlir/lib/Transforms/Utils/FoldUtils.cpp

171 lines
6.1 KiB
C++

//===- FoldUtils.cpp ---- Fold Utilities ----------------------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file defines various operation fold utilities. These utilities are
// intended to be used by passes to unify and simply their logic.
//
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Operation.h"
#include "mlir/StandardOps/Ops.h"
using namespace mlir;
FoldHelper::FoldHelper(Function *f) : function(f) {}
LogicalResult
FoldHelper::tryToFold(Operation *op,
std::function<void(Operation *)> preReplaceAction) {
assert(op->getFunction() == function &&
"cannot constant fold op from another function");
// The constant op also implements the constant fold hook; it can be folded
// into the value it contains. We need to consider constants before the
// constant folding logic to avoid re-creating the same constant later.
// TODO: Extend to support dialect-specific constant ops.
if (auto constant = dyn_cast<ConstantOp>(op)) {
// If this constant is dead, update bookkeeping and signal the caller.
if (constant.use_empty()) {
notifyRemoval(op);
op->erase();
return success();
}
// Otherwise, try to see if we can de-duplicate it.
return tryToUnify(op);
}
SmallVector<Attribute, 8> operandConstants;
SmallVector<OpFoldResult, 8> results;
// Check to see if any operands to the operation is constant and whether
// the operation knows how to constant fold itself.
operandConstants.assign(op->getNumOperands(), Attribute());
for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
matchPattern(op->getOperand(i), m_Constant(&operandConstants[i]));
// If this is a commutative binary operation with a constant on the left
// side move it to the right side.
if (operandConstants.size() == 2 && operandConstants[0] &&
!operandConstants[1] && op->isCommutative()) {
std::swap(op->getOpOperand(0), op->getOpOperand(1));
std::swap(operandConstants[0], operandConstants[1]);
}
// Attempt to constant fold the operation.
if (failed(op->fold(operandConstants, results)))
return failure();
// Constant folding succeeded. We will start replacing this op's uses and
// eventually erase this op. Invoke the callback provided by the caller to
// perform any pre-replacement action.
if (preReplaceAction)
preReplaceAction(op);
// Check to see if the operation was just updated in place.
if (results.empty())
return success();
assert(results.size() == op->getNumResults());
// Create the result constants and replace the results.
FuncBuilder builder(op);
for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
auto *res = op->getResult(i);
if (res->use_empty()) // Ignore dead uses.
continue;
assert(!results[i].isNull() && "expected valid OpFoldResult");
// Check if the result was an SSA value.
if (auto *repl = results[i].dyn_cast<Value *>()) {
if (repl != res)
res->replaceAllUsesWith(repl);
continue;
}
// If we already have a canonicalized version of this constant, just reuse
// it. Otherwise create a new one.
Attribute attrRepl = results[i].get<Attribute>();
auto &constInst =
uniquedConstants[std::make_pair(attrRepl, res->getType())];
if (!constInst) {
// TODO: Extend to support dialect-specific constant ops.
auto newOp =
builder.create<ConstantOp>(op->getLoc(), res->getType(), attrRepl);
// Register to the constant map and also move up to entry block to
// guarantee dominance.
constInst = newOp.getOperation();
moveConstantToEntryBlock(constInst);
}
res->replaceAllUsesWith(constInst->getResult(0));
}
op->erase();
return success();
}
void FoldHelper::notifyRemoval(Operation *op) {
assert(op->getFunction() == function &&
"cannot remove constant from another function");
Attribute constValue;
if (!matchPattern(op, m_Constant(&constValue)))
return;
// This constant is dead. keep uniquedConstants up to date.
auto it = uniquedConstants.find({constValue, op->getResult(0)->getType()});
if (it != uniquedConstants.end() && it->second == op)
uniquedConstants.erase(it);
}
LogicalResult FoldHelper::tryToUnify(Operation *op) {
Attribute constValue;
matchPattern(op, m_Constant(&constValue));
assert(constValue);
// Check to see if we already have a constant with this type and value:
auto &constInst =
uniquedConstants[std::make_pair(constValue, op->getResult(0)->getType())];
if (constInst) {
// If this constant is already our uniqued one, then leave it alone.
if (constInst == op)
return failure();
// Otherwise replace this redundant constant with the uniqued one. We know
// this is safe because we move constants to the top of the function when
// they are uniqued, so we know they dominate all uses.
op->getResult(0)->replaceAllUsesWith(constInst->getResult(0));
op->erase();
return success();
}
// If we have no entry, then we should unique this constant as the
// canonical version. To ensure safe dominance, move the operation to the
// entry block of the function.
constInst = op;
moveConstantToEntryBlock(op);
return failure();
}
void FoldHelper::moveConstantToEntryBlock(Operation *op) {
// Insert at the very top of the entry block.
auto &entryBB = function->front();
op->moveBefore(&entryBB, entryBB.begin());
}