Files
clang-p2996/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp
2022-04-13 21:38:14 -07:00

108 lines
4.3 KiB
C++

//===- MLProgramOps.cpp - MLProgram dialect ops implementation ------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/FunctionImplementation.h"
using namespace mlir;
using namespace mlir::ml_program;
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "mlir/Dialect/MLProgram/IR/MLProgramOps.cpp.inc"
//===----------------------------------------------------------------------===//
// FuncOp
//===----------------------------------------------------------------------===//
ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
auto buildFuncType =
[](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
function_interface_impl::VariadicFlag,
std::string &) { return builder.getFunctionType(argTypes, results); };
return function_interface_impl::parseFunctionOp(
parser, result, /*allowVariadic=*/false, buildFuncType);
}
void FuncOp::print(OpAsmPrinter &p) {
function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
}
//===----------------------------------------------------------------------===//
// SubgraphOp
//===----------------------------------------------------------------------===//
ParseResult SubgraphOp::parse(OpAsmParser &parser, OperationState &result) {
auto buildFuncType =
[](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
function_interface_impl::VariadicFlag,
std::string &) { return builder.getFunctionType(argTypes, results); };
return function_interface_impl::parseFunctionOp(
parser, result, /*allowVariadic=*/false, buildFuncType);
}
void SubgraphOp::print(OpAsmPrinter &p) {
function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
}
//===----------------------------------------------------------------------===//
// OutputOp
//===----------------------------------------------------------------------===//
LogicalResult OutputOp::verify() {
auto function = cast<SubgraphOp>((*this)->getParentOp());
// The operand number and types must match the function signature.
const auto &results = function.getFunctionType().getResults();
if (getNumOperands() != results.size())
return emitOpError("has ")
<< getNumOperands() << " operands, but enclosing function (@"
<< function.getName() << ") outputs " << results.size();
for (unsigned i = 0, e = results.size(); i != e; ++i)
if (getOperand(i).getType() != results[i])
return emitError() << "type of output operand " << i << " ("
<< getOperand(i).getType()
<< ") doesn't match function result type ("
<< results[i] << ")"
<< " in function @" << function.getName();
return success();
}
//===----------------------------------------------------------------------===//
// ReturnOp
//===----------------------------------------------------------------------===//
LogicalResult ReturnOp::verify() {
auto function = cast<FuncOp>((*this)->getParentOp());
// The operand number and types must match the function signature.
const auto &results = function.getFunctionType().getResults();
if (getNumOperands() != results.size())
return emitOpError("has ")
<< getNumOperands() << " operands, but enclosing function (@"
<< function.getName() << ") returns " << results.size();
for (unsigned i = 0, e = results.size(); i != e; ++i)
if (getOperand(i).getType() != results[i])
return emitError() << "type of return operand " << i << " ("
<< getOperand(i).getType()
<< ") doesn't match function result type ("
<< results[i] << ")"
<< " in function @" << function.getName();
return success();
}