Reland D139447, D139471 With flang actually working - FunctionOpInterface: make get/setFunctionType interface methods This patch removes the concept of a `function_type`-named type attribute as a requirement for implementors of FunctionOpInterface. Instead, this type should be provided through two interface methods, `getFunctionType` and `setFunctionTypeAttr` (*Attr because functions may use different concrete function types), which should be automatically implemented by ODS for ops that define a `$function_type` attribute. This also allows FunctionOpInterface to materialize function types if they don't carry them in an attribute, for example. Importantly, all the function "helper" still accept an attribute name to use in parsing and printing functions, for example. - FunctionOpInterface: arg and result attrs dispatch to interface This patch removes the `arg_attrs` and `res_attrs` named attributes as a requirement for FunctionOpInterface and replaces them with interface methods for the getters, setters, and removers of the relevent attributes. This allows operations to use their own storage for the argument and result attributes. Reviewed By: jpienaar Differential Revision: https://reviews.llvm.org/D139736
380 lines
13 KiB
C++
380 lines
13 KiB
C++
//===- MLProgramOps.cpp - MLProgram dialect ops implementation ------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/FunctionImplementation.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::ml_program;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Custom asm helpers
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Parse and print an ordering clause for a variadic of consuming tokens
|
|
/// and an producing token.
|
|
///
|
|
/// Syntax:
|
|
/// ordering(%0, %1 -> !ml_program.token)
|
|
/// ordering(() -> !ml_program.token)
|
|
///
|
|
/// If both the consuming and producing token are not present on the op, then
|
|
/// the clause prints nothing.
|
|
static ParseResult parseTokenOrdering(
|
|
OpAsmParser &parser,
|
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &consumeTokens,
|
|
Type &produceTokenType) {
|
|
if (failed(parser.parseOptionalKeyword("ordering")) ||
|
|
failed(parser.parseLParen()))
|
|
return success();
|
|
|
|
// Parse consuming token list. If there are no consuming tokens, the
|
|
// '()' null list represents this.
|
|
if (succeeded(parser.parseOptionalLParen())) {
|
|
if (failed(parser.parseRParen()))
|
|
return failure();
|
|
} else {
|
|
if (failed(parser.parseOperandList(consumeTokens,
|
|
/*requiredOperandCount=*/-1)))
|
|
return failure();
|
|
}
|
|
|
|
// Parse producer token.
|
|
if (failed(parser.parseArrow()))
|
|
return failure();
|
|
if (failed(parser.parseType(produceTokenType)))
|
|
return failure();
|
|
|
|
if (failed(parser.parseRParen()))
|
|
return failure();
|
|
|
|
return success();
|
|
}
|
|
|
|
static void printTokenOrdering(OpAsmPrinter &p, Operation *op,
|
|
OperandRange consumeTokens,
|
|
Type produceTokenType) {
|
|
if (consumeTokens.empty() && !produceTokenType)
|
|
return;
|
|
|
|
p << " ordering(";
|
|
if (consumeTokens.empty())
|
|
p << "()";
|
|
else
|
|
p.printOperands(consumeTokens);
|
|
if (produceTokenType) {
|
|
p << " -> ";
|
|
p.printType(produceTokenType);
|
|
}
|
|
p << ")";
|
|
}
|
|
|
|
/// some.op custom<TypeOrAttr>($type, $attr)
|
|
///
|
|
/// Uninitialized:
|
|
/// some.op : tensor<3xi32>
|
|
/// Initialized to narrower type than op:
|
|
/// some.op (dense<0> : tensor<3xi32>) : tensor<?xi32>
|
|
static ParseResult parseTypedInitialValue(OpAsmParser &parser,
|
|
TypeAttr &typeAttr, Attribute &attr) {
|
|
if (succeeded(parser.parseOptionalLParen())) {
|
|
if (failed(parser.parseAttribute(attr)))
|
|
return failure();
|
|
if (failed(parser.parseRParen()))
|
|
return failure();
|
|
}
|
|
|
|
Type type;
|
|
if (failed(parser.parseColonType(type)))
|
|
return failure();
|
|
typeAttr = TypeAttr::get(type);
|
|
return success();
|
|
}
|
|
|
|
static void printTypedInitialValue(OpAsmPrinter &p, Operation *op,
|
|
TypeAttr type, Attribute attr) {
|
|
if (attr) {
|
|
p << "(";
|
|
p.printAttribute(attr);
|
|
p << ")";
|
|
}
|
|
|
|
p << " : ";
|
|
p.printAttribute(type);
|
|
}
|
|
|
|
/// some.op custom<SymbolVisibility>($sym_visibility) $sym_name
|
|
/// ->
|
|
/// some.op public @foo
|
|
/// some.op private @foo
|
|
static ParseResult parseSymbolVisibility(OpAsmParser &parser,
|
|
StringAttr &symVisibilityAttr) {
|
|
StringRef symVisibility;
|
|
(void)parser.parseOptionalKeyword(&symVisibility,
|
|
{"public", "private", "nested"});
|
|
if (symVisibility.empty())
|
|
return parser.emitError(parser.getCurrentLocation())
|
|
<< "expected 'public', 'private', or 'nested'";
|
|
if (!symVisibility.empty())
|
|
symVisibilityAttr = parser.getBuilder().getStringAttr(symVisibility);
|
|
return success();
|
|
}
|
|
|
|
static void printSymbolVisibility(OpAsmPrinter &p, Operation *op,
|
|
StringAttr symVisibilityAttr) {
|
|
if (!symVisibilityAttr)
|
|
p << "public";
|
|
else
|
|
p << symVisibilityAttr.getValue();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TableGen'd op method definitions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#define GET_OP_CLASSES
|
|
#include "mlir/Dialect/MLProgram/IR/MLProgramOps.cpp.inc"
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// FuncOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
auto buildFuncType =
|
|
[](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
|
|
function_interface_impl::VariadicFlag,
|
|
std::string &) { return builder.getFunctionType(argTypes, results); };
|
|
|
|
return function_interface_impl::parseFunctionOp(
|
|
parser, result, /*allowVariadic=*/false,
|
|
getFunctionTypeAttrName(result.name), buildFuncType,
|
|
getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
|
|
}
|
|
|
|
void FuncOp::print(OpAsmPrinter &p) {
|
|
function_interface_impl::printFunctionOp(
|
|
p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
|
|
getArgAttrsAttrName(), getResAttrsAttrName());
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// GlobalOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult GlobalOp::verify() {
|
|
if (!getIsMutable() && !getValue())
|
|
return emitOpError() << "immutable global must have an initial value";
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// GlobalLoadOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
GlobalOp GlobalLoadOp::getGlobalOp(SymbolTableCollection &symbolTable) {
|
|
return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
|
|
getOperation()->getParentOp(), getGlobalAttr());
|
|
}
|
|
|
|
LogicalResult
|
|
GlobalLoadOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
|
GlobalOp referrent = getGlobalOp(symbolTable);
|
|
if (!referrent)
|
|
return emitOpError() << "undefined global: " << getGlobal();
|
|
|
|
if (referrent.getType() != getResult().getType()) {
|
|
return emitOpError() << "cannot load from global typed "
|
|
<< referrent.getType() << " as "
|
|
<< getResult().getType();
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// GlobalLoadConstOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
GlobalOp GlobalLoadConstOp::getGlobalOp(SymbolTableCollection &symbolTable) {
|
|
return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
|
|
getOperation()->getParentOp(), getGlobalAttr());
|
|
}
|
|
|
|
LogicalResult
|
|
GlobalLoadConstOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
|
GlobalOp referrent = getGlobalOp(symbolTable);
|
|
if (!referrent)
|
|
return emitOpError() << "undefined global: " << getGlobal();
|
|
|
|
if (referrent.getIsMutable())
|
|
return emitOpError() << "cannot load as const from mutable global "
|
|
<< getGlobal();
|
|
|
|
if (referrent.getType() != getResult().getType())
|
|
return emitOpError() << "cannot load from global typed "
|
|
<< referrent.getType() << " as "
|
|
<< getResult().getType();
|
|
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// GlobalLoadGraphOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
GlobalOp GlobalLoadGraphOp::getGlobalOp(SymbolTableCollection &symbolTable) {
|
|
return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
|
|
getOperation()->getParentOp(), getGlobalAttr());
|
|
}
|
|
|
|
LogicalResult
|
|
GlobalLoadGraphOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
|
GlobalOp referrent = getGlobalOp(symbolTable);
|
|
if (!referrent)
|
|
return emitOpError() << "undefined global: " << getGlobal();
|
|
|
|
if (referrent.getType() != getResult().getType()) {
|
|
return emitOpError() << "cannot load from global typed "
|
|
<< referrent.getType() << " as "
|
|
<< getResult().getType();
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// GlobalStoreOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
GlobalOp GlobalStoreOp::getGlobalOp(SymbolTableCollection &symbolTable) {
|
|
return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
|
|
getOperation()->getParentOp(), getGlobalAttr());
|
|
}
|
|
|
|
LogicalResult
|
|
GlobalStoreOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
|
GlobalOp referrent = getGlobalOp(symbolTable);
|
|
if (!referrent)
|
|
return emitOpError() << "undefined global: " << getGlobal();
|
|
|
|
if (!referrent.getIsMutable()) {
|
|
return emitOpError() << "cannot store to an immutable global "
|
|
<< getGlobal();
|
|
}
|
|
|
|
if (referrent.getType() != getValue().getType()) {
|
|
return emitOpError() << "cannot store to a global typed "
|
|
<< referrent.getType() << " from "
|
|
<< getValue().getType();
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// GlobalStoreGraphOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
GlobalOp GlobalStoreGraphOp::getGlobalOp(SymbolTableCollection &symbolTable) {
|
|
return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
|
|
getOperation()->getParentOp(), getGlobalAttr());
|
|
}
|
|
|
|
LogicalResult
|
|
GlobalStoreGraphOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
|
GlobalOp referrent = getGlobalOp(symbolTable);
|
|
if (!referrent)
|
|
return emitOpError() << "undefined global: " << getGlobal();
|
|
|
|
if (!referrent.getIsMutable()) {
|
|
return emitOpError() << "cannot store to an immutable global "
|
|
<< getGlobal();
|
|
}
|
|
|
|
if (referrent.getType() != getValue().getType()) {
|
|
return emitOpError() << "cannot store to a global typed "
|
|
<< referrent.getType() << " from "
|
|
<< getValue().getType();
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// SubgraphOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
ParseResult SubgraphOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
auto buildFuncType =
|
|
[](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
|
|
function_interface_impl::VariadicFlag,
|
|
std::string &) { return builder.getFunctionType(argTypes, results); };
|
|
|
|
return function_interface_impl::parseFunctionOp(
|
|
parser, result, /*allowVariadic=*/false,
|
|
getFunctionTypeAttrName(result.name), buildFuncType,
|
|
getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
|
|
}
|
|
|
|
void SubgraphOp::print(OpAsmPrinter &p) {
|
|
function_interface_impl::printFunctionOp(
|
|
p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
|
|
getArgAttrsAttrName(), getResAttrsAttrName());
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// OutputOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult OutputOp::verify() {
|
|
auto function = cast<SubgraphOp>((*this)->getParentOp());
|
|
|
|
// The operand number and types must match the function signature.
|
|
const auto &results = function.getFunctionType().getResults();
|
|
if (getNumOperands() != results.size())
|
|
return emitOpError("has ")
|
|
<< getNumOperands() << " operands, but enclosing function (@"
|
|
<< function.getName() << ") outputs " << results.size();
|
|
|
|
for (unsigned i = 0, e = results.size(); i != e; ++i)
|
|
if (getOperand(i).getType() != results[i])
|
|
return emitError() << "type of output operand " << i << " ("
|
|
<< getOperand(i).getType()
|
|
<< ") doesn't match function result type ("
|
|
<< results[i] << ")"
|
|
<< " in function @" << function.getName();
|
|
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ReturnOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult ReturnOp::verify() {
|
|
auto function = cast<FuncOp>((*this)->getParentOp());
|
|
|
|
// The operand number and types must match the function signature.
|
|
const auto &results = function.getFunctionType().getResults();
|
|
if (getNumOperands() != results.size())
|
|
return emitOpError("has ")
|
|
<< getNumOperands() << " operands, but enclosing function (@"
|
|
<< function.getName() << ") returns " << results.size();
|
|
|
|
for (unsigned i = 0, e = results.size(); i != e; ++i)
|
|
if (getOperand(i).getType() != results[i])
|
|
return emitError() << "type of return operand " << i << " ("
|
|
<< getOperand(i).getType()
|
|
<< ") doesn't match function result type ("
|
|
<< results[i] << ")"
|
|
<< " in function @" << function.getName();
|
|
|
|
return success();
|
|
}
|