Files
clang-p2996/mlir/lib/Dialect/Func/IR/FuncOps.cpp
Jeff Niu 53406427cd [mlir] FunctionOpInterface: turn required attributes into interface methods (Reland)
Reland D139447, D139471 With flang actually working

- FunctionOpInterface: make get/setFunctionType interface methods

This patch removes the concept of a `function_type`-named type attribute
as a requirement for implementors of FunctionOpInterface. Instead, this
type should be provided through two interface methods, `getFunctionType`
and `setFunctionTypeAttr` (*Attr because functions may use different
concrete function types), which should be automatically implemented by
ODS for ops that define a `$function_type` attribute.

This also allows FunctionOpInterface to materialize function types if
they don't carry them in an attribute, for example.

Importantly, all the function "helper" still accept an attribute name to
use in parsing and printing functions, for example.

- FunctionOpInterface: arg and result attrs dispatch to interface

This patch removes the `arg_attrs` and `res_attrs` named attributes as a
requirement for FunctionOpInterface and replaces them with interface
methods for the getters, setters, and removers of the relevent
attributes. This allows operations to use their own storage for the
argument and result attributes.

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D139736
2022-12-10 15:17:09 -08:00

376 lines
14 KiB
C++

//===- FuncOps.cpp - Func Dialect Operations ------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/FunctionImplementation.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/MathExtras.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
#include <numeric>
#include "mlir/Dialect/Func/IR/FuncOpsDialect.cpp.inc"
using namespace mlir;
using namespace mlir::func;
//===----------------------------------------------------------------------===//
// FuncDialect Interfaces
//===----------------------------------------------------------------------===//
namespace {
/// This class defines the interface for handling inlining with func operations.
struct FuncInlinerInterface : public DialectInlinerInterface {
using DialectInlinerInterface::DialectInlinerInterface;
//===--------------------------------------------------------------------===//
// Analysis Hooks
//===--------------------------------------------------------------------===//
/// All call operations can be inlined.
bool isLegalToInline(Operation *call, Operation *callable,
bool wouldBeCloned) const final {
return true;
}
/// All operations can be inlined.
bool isLegalToInline(Operation *, Region *, bool,
BlockAndValueMapping &) const final {
return true;
}
/// All functions can be inlined.
bool isLegalToInline(Region *, Region *, bool,
BlockAndValueMapping &) const final {
return true;
}
//===--------------------------------------------------------------------===//
// Transformation Hooks
//===--------------------------------------------------------------------===//
/// Handle the given inlined terminator by replacing it with a new operation
/// as necessary.
void handleTerminator(Operation *op, Block *newDest) const final {
// Only return needs to be handled here.
auto returnOp = dyn_cast<ReturnOp>(op);
if (!returnOp)
return;
// Replace the return with a branch to the dest.
OpBuilder builder(op);
builder.create<cf::BranchOp>(op->getLoc(), newDest, returnOp.getOperands());
op->erase();
}
/// Handle the given inlined terminator by replacing it with a new operation
/// as necessary.
void handleTerminator(Operation *op,
ArrayRef<Value> valuesToRepl) const final {
// Only return needs to be handled here.
auto returnOp = cast<ReturnOp>(op);
// Replace the values directly with the return operands.
assert(returnOp.getNumOperands() == valuesToRepl.size());
for (const auto &it : llvm::enumerate(returnOp.getOperands()))
valuesToRepl[it.index()].replaceAllUsesWith(it.value());
}
};
} // namespace
//===----------------------------------------------------------------------===//
// FuncDialect
//===----------------------------------------------------------------------===//
void FuncDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/Func/IR/FuncOps.cpp.inc"
>();
addInterfaces<FuncInlinerInterface>();
}
/// Materialize a single constant operation from a given attribute value with
/// the desired resultant type.
Operation *FuncDialect::materializeConstant(OpBuilder &builder, Attribute value,
Type type, Location loc) {
if (ConstantOp::isBuildableWith(value, type))
return builder.create<ConstantOp>(loc, type,
value.cast<FlatSymbolRefAttr>());
return nullptr;
}
//===----------------------------------------------------------------------===//
// CallOp
//===----------------------------------------------------------------------===//
LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
// Check that the callee attribute was specified.
auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
if (!fnAttr)
return emitOpError("requires a 'callee' symbol reference attribute");
FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
if (!fn)
return emitOpError() << "'" << fnAttr.getValue()
<< "' does not reference a valid function";
// Verify that the operand and result types match the callee.
auto fnType = fn.getFunctionType();
if (fnType.getNumInputs() != getNumOperands())
return emitOpError("incorrect number of operands for callee");
for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
if (getOperand(i).getType() != fnType.getInput(i))
return emitOpError("operand type mismatch: expected operand type ")
<< fnType.getInput(i) << ", but provided "
<< getOperand(i).getType() << " for operand number " << i;
if (fnType.getNumResults() != getNumResults())
return emitOpError("incorrect number of results for callee");
for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
if (getResult(i).getType() != fnType.getResult(i)) {
auto diag = emitOpError("result type mismatch at index ") << i;
diag.attachNote() << " op result types: " << getResultTypes();
diag.attachNote() << "function result types: " << fnType.getResults();
return diag;
}
return success();
}
FunctionType CallOp::getCalleeType() {
return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
}
//===----------------------------------------------------------------------===//
// CallIndirectOp
//===----------------------------------------------------------------------===//
/// Fold indirect calls that have a constant function as the callee operand.
LogicalResult CallIndirectOp::canonicalize(CallIndirectOp indirectCall,
PatternRewriter &rewriter) {
// Check that the callee is a constant callee.
SymbolRefAttr calledFn;
if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn)))
return failure();
// Replace with a direct call.
rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn,
indirectCall.getResultTypes(),
indirectCall.getArgOperands());
return success();
}
//===----------------------------------------------------------------------===//
// ConstantOp
//===----------------------------------------------------------------------===//
LogicalResult ConstantOp::verify() {
StringRef fnName = getValue();
Type type = getType();
// Try to find the referenced function.
auto fn = (*this)->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(fnName);
if (!fn)
return emitOpError() << "reference to undefined function '" << fnName
<< "'";
// Check that the referenced function has the correct type.
if (fn.getFunctionType() != type)
return emitOpError("reference to function with mismatched type");
return success();
}
OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
assert(operands.empty() && "constant has no operands");
return getValueAttr();
}
void ConstantOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "f");
}
bool ConstantOp::isBuildableWith(Attribute value, Type type) {
return value.isa<FlatSymbolRefAttr>() && type.isa<FunctionType>();
}
//===----------------------------------------------------------------------===//
// FuncOp
//===----------------------------------------------------------------------===//
FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
ArrayRef<NamedAttribute> attrs) {
OpBuilder builder(location->getContext());
OperationState state(location, getOperationName());
FuncOp::build(builder, state, name, type, attrs);
return cast<FuncOp>(Operation::create(state));
}
FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
Operation::dialect_attr_range attrs) {
SmallVector<NamedAttribute, 8> attrRef(attrs);
return create(location, name, type, llvm::makeArrayRef(attrRef));
}
FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
ArrayRef<NamedAttribute> attrs,
ArrayRef<DictionaryAttr> argAttrs) {
FuncOp func = create(location, name, type, attrs);
func.setAllArgAttrs(argAttrs);
return func;
}
void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
FunctionType type, ArrayRef<NamedAttribute> attrs,
ArrayRef<DictionaryAttr> argAttrs) {
state.addAttribute(SymbolTable::getSymbolAttrName(),
builder.getStringAttr(name));
state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
state.attributes.append(attrs.begin(), attrs.end());
state.addRegion();
if (argAttrs.empty())
return;
assert(type.getNumInputs() == argAttrs.size());
function_interface_impl::addArgAndResultAttrs(
builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
}
ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
auto buildFuncType =
[](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
function_interface_impl::VariadicFlag,
std::string &) { return builder.getFunctionType(argTypes, results); };
return function_interface_impl::parseFunctionOp(
parser, result, /*allowVariadic=*/false,
getFunctionTypeAttrName(result.name), buildFuncType,
getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
}
void FuncOp::print(OpAsmPrinter &p) {
function_interface_impl::printFunctionOp(
p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
getArgAttrsAttrName(), getResAttrsAttrName());
}
/// Clone the internal blocks from this function into dest and all attributes
/// from this function to dest.
void FuncOp::cloneInto(FuncOp dest, BlockAndValueMapping &mapper) {
// Add the attributes of this function to dest.
llvm::MapVector<StringAttr, Attribute> newAttrMap;
for (const auto &attr : dest->getAttrs())
newAttrMap.insert({attr.getName(), attr.getValue()});
for (const auto &attr : (*this)->getAttrs())
newAttrMap.insert({attr.getName(), attr.getValue()});
auto newAttrs = llvm::to_vector(llvm::map_range(
newAttrMap, [](std::pair<StringAttr, Attribute> attrPair) {
return NamedAttribute(attrPair.first, attrPair.second);
}));
dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs));
// Clone the body.
getBody().cloneInto(&dest.getBody(), mapper);
}
/// Create a deep copy of this function and all of its blocks, remapping
/// any operands that use values outside of the function using the map that is
/// provided (leaving them alone if no entry is present). Replaces references
/// to cloned sub-values with the corresponding value that is copied, and adds
/// those mappings to the mapper.
FuncOp FuncOp::clone(BlockAndValueMapping &mapper) {
// Create the new function.
FuncOp newFunc = cast<FuncOp>(getOperation()->cloneWithoutRegions());
// If the function has a body, then the user might be deleting arguments to
// the function by specifying them in the mapper. If so, we don't add the
// argument to the input type vector.
if (!isExternal()) {
FunctionType oldType = getFunctionType();
unsigned oldNumArgs = oldType.getNumInputs();
SmallVector<Type, 4> newInputs;
newInputs.reserve(oldNumArgs);
for (unsigned i = 0; i != oldNumArgs; ++i)
if (!mapper.contains(getArgument(i)))
newInputs.push_back(oldType.getInput(i));
/// If any of the arguments were dropped, update the type and drop any
/// necessary argument attributes.
if (newInputs.size() != oldNumArgs) {
newFunc.setType(FunctionType::get(oldType.getContext(), newInputs,
oldType.getResults()));
if (ArrayAttr argAttrs = getAllArgAttrs()) {
SmallVector<Attribute> newArgAttrs;
newArgAttrs.reserve(newInputs.size());
for (unsigned i = 0; i != oldNumArgs; ++i)
if (!mapper.contains(getArgument(i)))
newArgAttrs.push_back(argAttrs[i]);
newFunc.setAllArgAttrs(newArgAttrs);
}
}
}
/// Clone the current function into the new one and return it.
cloneInto(newFunc, mapper);
return newFunc;
}
FuncOp FuncOp::clone() {
BlockAndValueMapping mapper;
return clone(mapper);
}
//===----------------------------------------------------------------------===//
// ReturnOp
//===----------------------------------------------------------------------===//
LogicalResult ReturnOp::verify() {
auto function = cast<FuncOp>((*this)->getParentOp());
// The operand number and types must match the function signature.
const auto &results = function.getFunctionType().getResults();
if (getNumOperands() != results.size())
return emitOpError("has ")
<< getNumOperands() << " operands, but enclosing function (@"
<< function.getName() << ") returns " << results.size();
for (unsigned i = 0, e = results.size(); i != e; ++i)
if (getOperand(i).getType() != results[i])
return emitError() << "type of return operand " << i << " ("
<< getOperand(i).getType()
<< ") doesn't match function result type ("
<< results[i] << ")"
<< " in function @" << function.getName();
return success();
}
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "mlir/Dialect/Func/IR/FuncOps.cpp.inc"