233 lines
8.8 KiB
C++
233 lines
8.8 KiB
C++
//===- FoldUtils.cpp ---- Fold Utilities ----------------------------------===//
|
|
//
|
|
// Copyright 2019 The MLIR Authors.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
// =============================================================================
|
|
//
|
|
// This file defines various operation fold utilities. These utilities are
|
|
// intended to be used by passes to unify and simply their logic.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Transforms/FoldUtils.h"
|
|
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/Matchers.h"
|
|
#include "mlir/IR/Operation.h"
|
|
#include "mlir/StandardOps/Ops.h"
|
|
|
|
using namespace mlir;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// OperationFolder
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult OperationFolder::tryToFold(
|
|
Operation *op,
|
|
llvm::function_ref<void(Operation *)> processGeneratedConstants,
|
|
llvm::function_ref<void(Operation *)> preReplaceAction) {
|
|
assert(op->getFunction() == function &&
|
|
"cannot constant fold op from another function");
|
|
|
|
// If this is a unique'd constant, return failure as we know that it has
|
|
// already been folded.
|
|
if (referencedDialects.count(op))
|
|
return failure();
|
|
|
|
// Try to fold the operation.
|
|
SmallVector<Value *, 8> results;
|
|
if (failed(tryToFold(op, results, processGeneratedConstants)))
|
|
return failure();
|
|
|
|
// Constant folding succeeded. We will start replacing this op's uses and
|
|
// eventually erase this op. Invoke the callback provided by the caller to
|
|
// perform any pre-replacement action.
|
|
if (preReplaceAction)
|
|
preReplaceAction(op);
|
|
|
|
// Check to see if the operation was just updated in place.
|
|
if (results.empty())
|
|
return success();
|
|
|
|
// Otherwise, replace all of the result values and erase the operation.
|
|
for (unsigned i = 0, e = results.size(); i != e; ++i)
|
|
op->getResult(i)->replaceAllUsesWith(results[i]);
|
|
op->erase();
|
|
return success();
|
|
}
|
|
|
|
/// Notifies that the given constant `op` should be remove from this
|
|
/// OperationFolder's internal bookkeeping.
|
|
void OperationFolder::notifyRemoval(Operation *op) {
|
|
assert(op->getFunction() == function &&
|
|
"cannot remove constant from another function");
|
|
|
|
// Check to see if this operation is uniqued within the folder.
|
|
auto it = referencedDialects.find(op);
|
|
if (it == referencedDialects.end())
|
|
return;
|
|
|
|
// Get the constant value for this operation, this is the value that was used
|
|
// to unique the operation internally.
|
|
Attribute constValue;
|
|
matchPattern(op, m_Constant(&constValue));
|
|
assert(constValue);
|
|
|
|
// Erase all of the references to this operation.
|
|
auto type = op->getResult(0)->getType();
|
|
for (auto *dialect : it->second)
|
|
uniquedConstants.erase(std::make_tuple(dialect, constValue, type));
|
|
referencedDialects.erase(it);
|
|
}
|
|
|
|
/// A utility function used to materialize a constant for a given attribute and
|
|
/// type. On success, a valid constant value is returned. Otherwise, null is
|
|
/// returned
|
|
static Operation *materializeConstant(Dialect *dialect, OpBuilder &builder,
|
|
Attribute value, Type type,
|
|
Location loc) {
|
|
auto insertPt = builder.getInsertionPoint();
|
|
(void)insertPt;
|
|
|
|
// Ask the dialect to materialize a constant operation for this value.
|
|
if (auto *constOp = dialect->materializeConstant(builder, value, type, loc)) {
|
|
assert(insertPt == builder.getInsertionPoint());
|
|
assert(matchPattern(constOp, m_Constant(&value)));
|
|
return constOp;
|
|
}
|
|
|
|
// If the dialect is unable to materialize a constant, check to see if the
|
|
// standard constant can be used.
|
|
if (ConstantOp::isBuildableWith(value, type))
|
|
return builder.create<ConstantOp>(loc, type, value);
|
|
return nullptr;
|
|
}
|
|
|
|
/// Tries to perform folding on the given `op`. If successful, populates
|
|
/// `results` with the results of the folding.
|
|
LogicalResult OperationFolder::tryToFold(
|
|
Operation *op, SmallVectorImpl<Value *> &results,
|
|
llvm::function_ref<void(Operation *)> processGeneratedConstants) {
|
|
assert(op->getFunction() == function &&
|
|
"cannot constant fold op from another function");
|
|
|
|
SmallVector<Attribute, 8> operandConstants;
|
|
SmallVector<OpFoldResult, 8> foldResults;
|
|
|
|
// Check to see if any operands to the operation is constant and whether
|
|
// the operation knows how to constant fold itself.
|
|
operandConstants.assign(op->getNumOperands(), Attribute());
|
|
for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
|
|
matchPattern(op->getOperand(i), m_Constant(&operandConstants[i]));
|
|
|
|
// If this is a commutative binary operation with a constant on the left
|
|
// side move it to the right side.
|
|
if (operandConstants.size() == 2 && operandConstants[0] &&
|
|
!operandConstants[1] && op->isCommutative()) {
|
|
std::swap(op->getOpOperand(0), op->getOpOperand(1));
|
|
std::swap(operandConstants[0], operandConstants[1]);
|
|
}
|
|
|
|
// Attempt to constant fold the operation.
|
|
if (failed(op->fold(operandConstants, foldResults)))
|
|
return failure();
|
|
|
|
// Check to see if the operation was just updated in place.
|
|
if (foldResults.empty())
|
|
return success();
|
|
assert(foldResults.size() == op->getNumResults());
|
|
|
|
// Create a builder to insert new operations into the entry block.
|
|
auto &entry = function->getBody().front();
|
|
OpBuilder builder(&entry, entry.empty() ? entry.end() : entry.begin());
|
|
|
|
// Create the result constants and replace the results.
|
|
auto *dialect = op->getDialect();
|
|
for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
|
|
assert(!foldResults[i].isNull() && "expected valid OpFoldResult");
|
|
|
|
// Check if the result was an SSA value.
|
|
if (auto *repl = foldResults[i].dyn_cast<Value *>()) {
|
|
results.emplace_back(repl);
|
|
continue;
|
|
}
|
|
|
|
// Check to see if there is a canonicalized version of this constant.
|
|
auto *res = op->getResult(i);
|
|
Attribute attrRepl = foldResults[i].get<Attribute>();
|
|
if (auto *constOp = tryGetOrCreateConstant(dialect, builder, attrRepl,
|
|
res->getType(), op->getLoc())) {
|
|
results.push_back(constOp->getResult(0));
|
|
continue;
|
|
}
|
|
// If materialization fails, cleanup any operations generated for the
|
|
// previous results and return failure.
|
|
for (Operation &op : llvm::make_early_inc_range(
|
|
llvm::make_range(entry.begin(), builder.getInsertionPoint()))) {
|
|
notifyRemoval(&op);
|
|
op.erase();
|
|
}
|
|
return failure();
|
|
}
|
|
|
|
// Process any newly generated operations.
|
|
if (processGeneratedConstants) {
|
|
for (auto i = entry.begin(), e = builder.getInsertionPoint(); i != e; ++i)
|
|
processGeneratedConstants(&*i);
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
/// Try to get or create a new constant entry. On success this returns the
|
|
/// constant operation value, nullptr otherwise.
|
|
Operation *OperationFolder::tryGetOrCreateConstant(Dialect *dialect,
|
|
OpBuilder &builder,
|
|
Attribute value, Type type,
|
|
Location loc) {
|
|
// Check if an existing mapping already exists.
|
|
auto constKey = std::make_tuple(dialect, value, type);
|
|
auto *&constInst = uniquedConstants[constKey];
|
|
if (constInst)
|
|
return constInst;
|
|
|
|
// If one doesn't exist, try to materialize one.
|
|
if (!(constInst = materializeConstant(dialect, builder, value, type, loc)))
|
|
return nullptr;
|
|
|
|
// Check to see if the generated constant is in the expected dialect.
|
|
auto *newDialect = constInst->getDialect();
|
|
if (newDialect == dialect) {
|
|
referencedDialects[constInst].push_back(dialect);
|
|
return constInst;
|
|
}
|
|
|
|
// If it isn't, then we also need to make sure that the mapping for the new
|
|
// dialect is valid.
|
|
auto newKey = std::make_tuple(newDialect, value, type);
|
|
|
|
// If an existing operation in the new dialect already exists, delete the
|
|
// materialized operation in favor of the existing one.
|
|
if (auto *existingOp = uniquedConstants.lookup(newKey)) {
|
|
constInst->erase();
|
|
referencedDialects[existingOp].push_back(dialect);
|
|
return constInst = existingOp;
|
|
}
|
|
|
|
// Otherwise, update the new dialect to the materialized operation.
|
|
referencedDialects[constInst].assign({dialect, newDialect});
|
|
auto newIt = uniquedConstants.insert({newKey, constInst});
|
|
return newIt.first->second;
|
|
}
|