Files
clang-p2996/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp
Jeff Niu 53406427cd [mlir] FunctionOpInterface: turn required attributes into interface methods (Reland)
Reland D139447, D139471 With flang actually working

- FunctionOpInterface: make get/setFunctionType interface methods

This patch removes the concept of a `function_type`-named type attribute
as a requirement for implementors of FunctionOpInterface. Instead, this
type should be provided through two interface methods, `getFunctionType`
and `setFunctionTypeAttr` (*Attr because functions may use different
concrete function types), which should be automatically implemented by
ODS for ops that define a `$function_type` attribute.

This also allows FunctionOpInterface to materialize function types if
they don't carry them in an attribute, for example.

Importantly, all the function "helper" still accept an attribute name to
use in parsing and printing functions, for example.

- FunctionOpInterface: arg and result attrs dispatch to interface

This patch removes the `arg_attrs` and `res_attrs` named attributes as a
requirement for FunctionOpInterface and replaces them with interface
methods for the getters, setters, and removers of the relevent
attributes. This allows operations to use their own storage for the
argument and result attributes.

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D139736
2022-12-10 15:17:09 -08:00

380 lines
13 KiB
C++

//===- MLProgramOps.cpp - MLProgram dialect ops implementation ------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/FunctionImplementation.h"
using namespace mlir;
using namespace mlir::ml_program;
//===----------------------------------------------------------------------===//
// Custom asm helpers
//===----------------------------------------------------------------------===//
/// Parse and print an ordering clause for a variadic of consuming tokens
/// and an producing token.
///
/// Syntax:
/// ordering(%0, %1 -> !ml_program.token)
/// ordering(() -> !ml_program.token)
///
/// If both the consuming and producing token are not present on the op, then
/// the clause prints nothing.
static ParseResult parseTokenOrdering(
OpAsmParser &parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &consumeTokens,
Type &produceTokenType) {
if (failed(parser.parseOptionalKeyword("ordering")) ||
failed(parser.parseLParen()))
return success();
// Parse consuming token list. If there are no consuming tokens, the
// '()' null list represents this.
if (succeeded(parser.parseOptionalLParen())) {
if (failed(parser.parseRParen()))
return failure();
} else {
if (failed(parser.parseOperandList(consumeTokens,
/*requiredOperandCount=*/-1)))
return failure();
}
// Parse producer token.
if (failed(parser.parseArrow()))
return failure();
if (failed(parser.parseType(produceTokenType)))
return failure();
if (failed(parser.parseRParen()))
return failure();
return success();
}
static void printTokenOrdering(OpAsmPrinter &p, Operation *op,
OperandRange consumeTokens,
Type produceTokenType) {
if (consumeTokens.empty() && !produceTokenType)
return;
p << " ordering(";
if (consumeTokens.empty())
p << "()";
else
p.printOperands(consumeTokens);
if (produceTokenType) {
p << " -> ";
p.printType(produceTokenType);
}
p << ")";
}
/// some.op custom<TypeOrAttr>($type, $attr)
///
/// Uninitialized:
/// some.op : tensor<3xi32>
/// Initialized to narrower type than op:
/// some.op (dense<0> : tensor<3xi32>) : tensor<?xi32>
static ParseResult parseTypedInitialValue(OpAsmParser &parser,
TypeAttr &typeAttr, Attribute &attr) {
if (succeeded(parser.parseOptionalLParen())) {
if (failed(parser.parseAttribute(attr)))
return failure();
if (failed(parser.parseRParen()))
return failure();
}
Type type;
if (failed(parser.parseColonType(type)))
return failure();
typeAttr = TypeAttr::get(type);
return success();
}
static void printTypedInitialValue(OpAsmPrinter &p, Operation *op,
TypeAttr type, Attribute attr) {
if (attr) {
p << "(";
p.printAttribute(attr);
p << ")";
}
p << " : ";
p.printAttribute(type);
}
/// some.op custom<SymbolVisibility>($sym_visibility) $sym_name
/// ->
/// some.op public @foo
/// some.op private @foo
static ParseResult parseSymbolVisibility(OpAsmParser &parser,
StringAttr &symVisibilityAttr) {
StringRef symVisibility;
(void)parser.parseOptionalKeyword(&symVisibility,
{"public", "private", "nested"});
if (symVisibility.empty())
return parser.emitError(parser.getCurrentLocation())
<< "expected 'public', 'private', or 'nested'";
if (!symVisibility.empty())
symVisibilityAttr = parser.getBuilder().getStringAttr(symVisibility);
return success();
}
static void printSymbolVisibility(OpAsmPrinter &p, Operation *op,
StringAttr symVisibilityAttr) {
if (!symVisibilityAttr)
p << "public";
else
p << symVisibilityAttr.getValue();
}
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "mlir/Dialect/MLProgram/IR/MLProgramOps.cpp.inc"
//===----------------------------------------------------------------------===//
// FuncOp
//===----------------------------------------------------------------------===//
ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
auto buildFuncType =
[](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
function_interface_impl::VariadicFlag,
std::string &) { return builder.getFunctionType(argTypes, results); };
return function_interface_impl::parseFunctionOp(
parser, result, /*allowVariadic=*/false,
getFunctionTypeAttrName(result.name), buildFuncType,
getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
}
void FuncOp::print(OpAsmPrinter &p) {
function_interface_impl::printFunctionOp(
p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
getArgAttrsAttrName(), getResAttrsAttrName());
}
//===----------------------------------------------------------------------===//
// GlobalOp
//===----------------------------------------------------------------------===//
LogicalResult GlobalOp::verify() {
if (!getIsMutable() && !getValue())
return emitOpError() << "immutable global must have an initial value";
return success();
}
//===----------------------------------------------------------------------===//
// GlobalLoadOp
//===----------------------------------------------------------------------===//
GlobalOp GlobalLoadOp::getGlobalOp(SymbolTableCollection &symbolTable) {
return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
getOperation()->getParentOp(), getGlobalAttr());
}
LogicalResult
GlobalLoadOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
GlobalOp referrent = getGlobalOp(symbolTable);
if (!referrent)
return emitOpError() << "undefined global: " << getGlobal();
if (referrent.getType() != getResult().getType()) {
return emitOpError() << "cannot load from global typed "
<< referrent.getType() << " as "
<< getResult().getType();
}
return success();
}
//===----------------------------------------------------------------------===//
// GlobalLoadConstOp
//===----------------------------------------------------------------------===//
GlobalOp GlobalLoadConstOp::getGlobalOp(SymbolTableCollection &symbolTable) {
return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
getOperation()->getParentOp(), getGlobalAttr());
}
LogicalResult
GlobalLoadConstOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
GlobalOp referrent = getGlobalOp(symbolTable);
if (!referrent)
return emitOpError() << "undefined global: " << getGlobal();
if (referrent.getIsMutable())
return emitOpError() << "cannot load as const from mutable global "
<< getGlobal();
if (referrent.getType() != getResult().getType())
return emitOpError() << "cannot load from global typed "
<< referrent.getType() << " as "
<< getResult().getType();
return success();
}
//===----------------------------------------------------------------------===//
// GlobalLoadGraphOp
//===----------------------------------------------------------------------===//
GlobalOp GlobalLoadGraphOp::getGlobalOp(SymbolTableCollection &symbolTable) {
return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
getOperation()->getParentOp(), getGlobalAttr());
}
LogicalResult
GlobalLoadGraphOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
GlobalOp referrent = getGlobalOp(symbolTable);
if (!referrent)
return emitOpError() << "undefined global: " << getGlobal();
if (referrent.getType() != getResult().getType()) {
return emitOpError() << "cannot load from global typed "
<< referrent.getType() << " as "
<< getResult().getType();
}
return success();
}
//===----------------------------------------------------------------------===//
// GlobalStoreOp
//===----------------------------------------------------------------------===//
GlobalOp GlobalStoreOp::getGlobalOp(SymbolTableCollection &symbolTable) {
return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
getOperation()->getParentOp(), getGlobalAttr());
}
LogicalResult
GlobalStoreOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
GlobalOp referrent = getGlobalOp(symbolTable);
if (!referrent)
return emitOpError() << "undefined global: " << getGlobal();
if (!referrent.getIsMutable()) {
return emitOpError() << "cannot store to an immutable global "
<< getGlobal();
}
if (referrent.getType() != getValue().getType()) {
return emitOpError() << "cannot store to a global typed "
<< referrent.getType() << " from "
<< getValue().getType();
}
return success();
}
//===----------------------------------------------------------------------===//
// GlobalStoreGraphOp
//===----------------------------------------------------------------------===//
GlobalOp GlobalStoreGraphOp::getGlobalOp(SymbolTableCollection &symbolTable) {
return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
getOperation()->getParentOp(), getGlobalAttr());
}
LogicalResult
GlobalStoreGraphOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
GlobalOp referrent = getGlobalOp(symbolTable);
if (!referrent)
return emitOpError() << "undefined global: " << getGlobal();
if (!referrent.getIsMutable()) {
return emitOpError() << "cannot store to an immutable global "
<< getGlobal();
}
if (referrent.getType() != getValue().getType()) {
return emitOpError() << "cannot store to a global typed "
<< referrent.getType() << " from "
<< getValue().getType();
}
return success();
}
//===----------------------------------------------------------------------===//
// SubgraphOp
//===----------------------------------------------------------------------===//
ParseResult SubgraphOp::parse(OpAsmParser &parser, OperationState &result) {
auto buildFuncType =
[](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
function_interface_impl::VariadicFlag,
std::string &) { return builder.getFunctionType(argTypes, results); };
return function_interface_impl::parseFunctionOp(
parser, result, /*allowVariadic=*/false,
getFunctionTypeAttrName(result.name), buildFuncType,
getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
}
void SubgraphOp::print(OpAsmPrinter &p) {
function_interface_impl::printFunctionOp(
p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
getArgAttrsAttrName(), getResAttrsAttrName());
}
//===----------------------------------------------------------------------===//
// OutputOp
//===----------------------------------------------------------------------===//
LogicalResult OutputOp::verify() {
auto function = cast<SubgraphOp>((*this)->getParentOp());
// The operand number and types must match the function signature.
const auto &results = function.getFunctionType().getResults();
if (getNumOperands() != results.size())
return emitOpError("has ")
<< getNumOperands() << " operands, but enclosing function (@"
<< function.getName() << ") outputs " << results.size();
for (unsigned i = 0, e = results.size(); i != e; ++i)
if (getOperand(i).getType() != results[i])
return emitError() << "type of output operand " << i << " ("
<< getOperand(i).getType()
<< ") doesn't match function result type ("
<< results[i] << ")"
<< " in function @" << function.getName();
return success();
}
//===----------------------------------------------------------------------===//
// ReturnOp
//===----------------------------------------------------------------------===//
LogicalResult ReturnOp::verify() {
auto function = cast<FuncOp>((*this)->getParentOp());
// The operand number and types must match the function signature.
const auto &results = function.getFunctionType().getResults();
if (getNumOperands() != results.size())
return emitOpError("has ")
<< getNumOperands() << " operands, but enclosing function (@"
<< function.getName() << ") returns " << results.size();
for (unsigned i = 0, e = results.size(); i != e; ++i)
if (getOperand(i).getType() != results[i])
return emitError() << "type of return operand " << i << " ("
<< getOperand(i).getType()
<< ") doesn't match function result type ("
<< results[i] << ")"
<< " in function @" << function.getName();
return success();
}