The MLIR classes Type/Attribute/Operation/Op/Value support cast/dyn_cast/isa/dyn_cast_or_null functionality through llvm's doCast functionality in addition to defining methods with the same name. This change begins the migration of uses of the method to the corresponding function call as has been decided as more consistent. Note that there still exist classes that only define methods directly, such as AffineExpr, and this does not include work currently to support a functional cast/isa call. Context: - https://mlir.llvm.org/deprecation/ at "Use the free function variants for dyn_cast/cast/isa/…" - Original discussion at https://discourse.llvm.org/t/preferred-casting-style-going-forward/68443 Implementation: This patch updates all remaining uses of the deprecated functionality in mlir/. This was done with clang-tidy as described below and further modifications to GPUBase.td and OpenMPOpsInterfaces.td. Steps are described per line, as comments are removed by git: 0. Retrieve the change from the following to build clang-tidy with an additional check: main...tpopp:llvm-project:tidy-cast-check 1. Build clang-tidy 2. Run clang-tidy over your entire codebase while disabling all checks and enabling the one relevant one. Run on all header files also. 3. Delete .inc files that were also modified, so the next build rebuilds them to a pure state. ``` ninja -C $BUILD_DIR clang-tidy run-clang-tidy -clang-tidy-binary=$BUILD_DIR/bin/clang-tidy -checks='-*,misc-cast-functions'\ -header-filter=mlir/ mlir/* -fix rm -rf $BUILD_DIR/tools/mlir/**/*.inc ``` Differential Revision: https://reviews.llvm.org/D151542
161 lines
5.0 KiB
C++
161 lines
5.0 KiB
C++
//===- DropEquivalentBufferResults.cpp - Calling convention conversion ----===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This pass drops return values from functions if they are equivalent to one of
|
|
// their arguments. E.g.:
|
|
//
|
|
// ```
|
|
// func.func @foo(%m : memref<?xf32>) -> (memref<?xf32>) {
|
|
// return %m : memref<?xf32>
|
|
// }
|
|
// ```
|
|
//
|
|
// This functions is rewritten to:
|
|
//
|
|
// ```
|
|
// func.func @foo(%m : memref<?xf32>) {
|
|
// return
|
|
// }
|
|
// ```
|
|
//
|
|
// All call sites are updated accordingly. If a function returns a cast of a
|
|
// function argument, it is also considered equivalent. A cast is inserted at
|
|
// the call site in that case.
|
|
|
|
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
|
|
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/IR/Operation.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
|
|
namespace mlir {
|
|
namespace bufferization {
|
|
#define GEN_PASS_DEF_DROPEQUIVALENTBUFFERRESULTS
|
|
#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
|
|
} // namespace bufferization
|
|
} // namespace mlir
|
|
|
|
using namespace mlir;
|
|
|
|
/// Return the unique ReturnOp that terminates `funcOp`.
|
|
/// Return nullptr if there is no such unique ReturnOp.
|
|
static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) {
|
|
func::ReturnOp returnOp;
|
|
for (Block &b : funcOp.getBody()) {
|
|
if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
|
|
if (returnOp)
|
|
return nullptr;
|
|
returnOp = candidateOp;
|
|
}
|
|
}
|
|
return returnOp;
|
|
}
|
|
|
|
/// Return the func::FuncOp called by `callOp`.
|
|
static func::FuncOp getCalledFunction(CallOpInterface callOp) {
|
|
SymbolRefAttr sym =
|
|
llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
|
|
if (!sym)
|
|
return nullptr;
|
|
return dyn_cast_or_null<func::FuncOp>(
|
|
SymbolTable::lookupNearestSymbolFrom(callOp, sym));
|
|
}
|
|
|
|
LogicalResult
|
|
mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) {
|
|
IRRewriter rewriter(module.getContext());
|
|
|
|
for (auto funcOp : module.getOps<func::FuncOp>()) {
|
|
if (funcOp.isExternal())
|
|
continue;
|
|
func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
|
|
// TODO: Support functions with multiple blocks.
|
|
if (!returnOp)
|
|
continue;
|
|
|
|
// Compute erased results.
|
|
SmallVector<Value> newReturnValues;
|
|
BitVector erasedResultIndices(funcOp.getFunctionType().getNumResults());
|
|
DenseMap<int64_t, int64_t> resultToArgs;
|
|
for (const auto &it : llvm::enumerate(returnOp.getOperands())) {
|
|
bool erased = false;
|
|
for (BlockArgument bbArg : funcOp.getArguments()) {
|
|
Value val = it.value();
|
|
while (auto castOp = val.getDefiningOp<memref::CastOp>())
|
|
val = castOp.getSource();
|
|
|
|
if (val == bbArg) {
|
|
resultToArgs[it.index()] = bbArg.getArgNumber();
|
|
erased = true;
|
|
break;
|
|
}
|
|
}
|
|
|
|
if (erased) {
|
|
erasedResultIndices.set(it.index());
|
|
} else {
|
|
newReturnValues.push_back(it.value());
|
|
}
|
|
}
|
|
|
|
// Update function.
|
|
funcOp.eraseResults(erasedResultIndices);
|
|
returnOp.getOperandsMutable().assign(newReturnValues);
|
|
|
|
// Update function calls.
|
|
module.walk([&](func::CallOp callOp) {
|
|
if (getCalledFunction(callOp) != funcOp)
|
|
return WalkResult::skip();
|
|
|
|
rewriter.setInsertionPoint(callOp);
|
|
auto newCallOp = rewriter.create<func::CallOp>(callOp.getLoc(), funcOp,
|
|
callOp.getOperands());
|
|
SmallVector<Value> newResults;
|
|
int64_t nextResult = 0;
|
|
for (int64_t i = 0; i < callOp.getNumResults(); ++i) {
|
|
if (!resultToArgs.count(i)) {
|
|
// This result was not erased.
|
|
newResults.push_back(newCallOp.getResult(nextResult++));
|
|
continue;
|
|
}
|
|
|
|
// This result was erased.
|
|
Value replacement = callOp.getOperand(resultToArgs[i]);
|
|
Type expectedType = callOp.getResult(i).getType();
|
|
if (replacement.getType() != expectedType) {
|
|
// A cast must be inserted at the call site.
|
|
replacement = rewriter.create<memref::CastOp>(
|
|
callOp.getLoc(), expectedType, replacement);
|
|
}
|
|
newResults.push_back(replacement);
|
|
}
|
|
rewriter.replaceOp(callOp, newResults);
|
|
return WalkResult::advance();
|
|
});
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
namespace {
|
|
struct DropEquivalentBufferResultsPass
|
|
: bufferization::impl::DropEquivalentBufferResultsBase<
|
|
DropEquivalentBufferResultsPass> {
|
|
void runOnOperation() override {
|
|
if (failed(bufferization::dropEquivalentBufferResults(getOperation())))
|
|
return signalPassFailure();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
std::unique_ptr<Pass>
|
|
mlir::bufferization::createDropEquivalentBufferResultsPass() {
|
|
return std::make_unique<DropEquivalentBufferResultsPass>();
|
|
}
|