This transform takes a module and a function name, and replaces the signature of the function by reordering the arguments and results according to the interchange arrays. The function is expected to be defined in the module, and the interchange arrays must match the number of arguments and results of the function.
122 lines
5.1 KiB
C++
122 lines
5.1 KiB
C++
//===- Utils.cpp - Utilities to support the Func dialect ------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements utilities for the Func dialect.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Func/Utils/Utils.h"
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/IR/IRMapping.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "llvm/ADT/SmallVector.h"
|
|
|
|
using namespace mlir;
|
|
|
|
FailureOr<func::FuncOp>
|
|
func::replaceFuncWithNewOrder(RewriterBase &rewriter, func::FuncOp funcOp,
|
|
ArrayRef<unsigned> newArgsOrder,
|
|
ArrayRef<unsigned> newResultsOrder) {
|
|
// Generate an empty new function operation with the same name as the
|
|
// original.
|
|
assert(funcOp.getNumArguments() == newArgsOrder.size() &&
|
|
"newArgsOrder must match the number of arguments in the function");
|
|
assert(funcOp.getNumResults() == newResultsOrder.size() &&
|
|
"newResultsOrder must match the number of results in the function");
|
|
|
|
if (!funcOp.getBody().hasOneBlock())
|
|
return rewriter.notifyMatchFailure(
|
|
funcOp, "expected function to have exactly one block");
|
|
|
|
ArrayRef<Type> origInputTypes = funcOp.getFunctionType().getInputs();
|
|
ArrayRef<Type> origOutputTypes = funcOp.getFunctionType().getResults();
|
|
SmallVector<Type> newInputTypes, newOutputTypes;
|
|
SmallVector<Location> locs;
|
|
for (unsigned int idx : newArgsOrder) {
|
|
newInputTypes.push_back(origInputTypes[idx]);
|
|
locs.push_back(funcOp.getArgument(newArgsOrder[idx]).getLoc());
|
|
}
|
|
for (unsigned int idx : newResultsOrder)
|
|
newOutputTypes.push_back(origOutputTypes[idx]);
|
|
rewriter.setInsertionPoint(funcOp);
|
|
auto newFuncOp = rewriter.create<func::FuncOp>(
|
|
funcOp.getLoc(), funcOp.getName(),
|
|
rewriter.getFunctionType(newInputTypes, newOutputTypes));
|
|
|
|
Region &newRegion = newFuncOp.getBody();
|
|
rewriter.createBlock(&newRegion, newRegion.begin(), newInputTypes, locs);
|
|
newFuncOp.setVisibility(funcOp.getVisibility());
|
|
newFuncOp->setDiscardableAttrs(funcOp->getDiscardableAttrDictionary());
|
|
|
|
// Map the arguments of the original function to the new function in
|
|
// the new order and adjust the attributes accordingly.
|
|
IRMapping operandMapper;
|
|
SmallVector<DictionaryAttr> argAttrs, resultAttrs;
|
|
funcOp.getAllArgAttrs(argAttrs);
|
|
for (unsigned int i = 0; i < newArgsOrder.size(); ++i) {
|
|
operandMapper.map(funcOp.getArgument(newArgsOrder[i]),
|
|
newFuncOp.getArgument(i));
|
|
newFuncOp.setArgAttrs(i, argAttrs[newArgsOrder[i]]);
|
|
}
|
|
funcOp.getAllResultAttrs(resultAttrs);
|
|
for (unsigned int i = 0; i < newResultsOrder.size(); ++i)
|
|
newFuncOp.setResultAttrs(i, resultAttrs[newResultsOrder[i]]);
|
|
|
|
// Clone the operations from the original function to the new function.
|
|
rewriter.setInsertionPointToStart(&newFuncOp.getBody().front());
|
|
for (Operation &op : funcOp.getOps())
|
|
rewriter.clone(op, operandMapper);
|
|
|
|
// Handle the return operation.
|
|
auto returnOp = cast<func::ReturnOp>(
|
|
newFuncOp.getFunctionBody().begin()->getTerminator());
|
|
SmallVector<Value> newReturnValues;
|
|
for (unsigned int idx : newResultsOrder)
|
|
newReturnValues.push_back(returnOp.getOperand(idx));
|
|
rewriter.setInsertionPoint(returnOp);
|
|
auto newReturnOp =
|
|
rewriter.create<func::ReturnOp>(newFuncOp.getLoc(), newReturnValues);
|
|
newReturnOp->setDiscardableAttrs(returnOp->getDiscardableAttrDictionary());
|
|
rewriter.eraseOp(returnOp);
|
|
|
|
rewriter.eraseOp(funcOp);
|
|
|
|
return newFuncOp;
|
|
}
|
|
|
|
func::CallOp
|
|
func::replaceCallOpWithNewOrder(RewriterBase &rewriter, func::CallOp callOp,
|
|
ArrayRef<unsigned> newArgsOrder,
|
|
ArrayRef<unsigned> newResultsOrder) {
|
|
assert(
|
|
callOp.getNumOperands() == newArgsOrder.size() &&
|
|
"newArgsOrder must match the number of operands in the call operation");
|
|
assert(
|
|
callOp.getNumResults() == newResultsOrder.size() &&
|
|
"newResultsOrder must match the number of results in the call operation");
|
|
SmallVector<Value> newArgsOrderValues;
|
|
for (unsigned int argIdx : newArgsOrder)
|
|
newArgsOrderValues.push_back(callOp.getOperand(argIdx));
|
|
SmallVector<Type> newResultTypes;
|
|
for (unsigned int resIdx : newResultsOrder)
|
|
newResultTypes.push_back(callOp.getResult(resIdx).getType());
|
|
|
|
// Replace the kernel call operation with a new one that has the
|
|
// reordered arguments.
|
|
rewriter.setInsertionPoint(callOp);
|
|
auto newCallOp = rewriter.create<func::CallOp>(
|
|
callOp.getLoc(), callOp.getCallee(), newResultTypes, newArgsOrderValues);
|
|
newCallOp.setNoInlineAttr(callOp.getNoInlineAttr());
|
|
for (auto &&[newIndex, origIndex] : llvm::enumerate(newResultsOrder))
|
|
rewriter.replaceAllUsesWith(callOp.getResult(origIndex),
|
|
newCallOp.getResult(newIndex));
|
|
rewriter.eraseOp(callOp);
|
|
|
|
return newCallOp;
|
|
}
|