108 lines
4.3 KiB
C++
108 lines
4.3 KiB
C++
//===- MLProgramOps.cpp - MLProgram dialect ops implementation ------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/FunctionImplementation.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::ml_program;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TableGen'd op method definitions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#define GET_OP_CLASSES
|
|
#include "mlir/Dialect/MLProgram/IR/MLProgramOps.cpp.inc"
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// FuncOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
auto buildFuncType =
|
|
[](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
|
|
function_interface_impl::VariadicFlag,
|
|
std::string &) { return builder.getFunctionType(argTypes, results); };
|
|
|
|
return function_interface_impl::parseFunctionOp(
|
|
parser, result, /*allowVariadic=*/false, buildFuncType);
|
|
}
|
|
|
|
void FuncOp::print(OpAsmPrinter &p) {
|
|
function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// SubgraphOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
ParseResult SubgraphOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
auto buildFuncType =
|
|
[](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
|
|
function_interface_impl::VariadicFlag,
|
|
std::string &) { return builder.getFunctionType(argTypes, results); };
|
|
|
|
return function_interface_impl::parseFunctionOp(
|
|
parser, result, /*allowVariadic=*/false, buildFuncType);
|
|
}
|
|
|
|
void SubgraphOp::print(OpAsmPrinter &p) {
|
|
function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// OutputOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult OutputOp::verify() {
|
|
auto function = cast<SubgraphOp>((*this)->getParentOp());
|
|
|
|
// The operand number and types must match the function signature.
|
|
const auto &results = function.getFunctionType().getResults();
|
|
if (getNumOperands() != results.size())
|
|
return emitOpError("has ")
|
|
<< getNumOperands() << " operands, but enclosing function (@"
|
|
<< function.getName() << ") outputs " << results.size();
|
|
|
|
for (unsigned i = 0, e = results.size(); i != e; ++i)
|
|
if (getOperand(i).getType() != results[i])
|
|
return emitError() << "type of output operand " << i << " ("
|
|
<< getOperand(i).getType()
|
|
<< ") doesn't match function result type ("
|
|
<< results[i] << ")"
|
|
<< " in function @" << function.getName();
|
|
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ReturnOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult ReturnOp::verify() {
|
|
auto function = cast<FuncOp>((*this)->getParentOp());
|
|
|
|
// The operand number and types must match the function signature.
|
|
const auto &results = function.getFunctionType().getResults();
|
|
if (getNumOperands() != results.size())
|
|
return emitOpError("has ")
|
|
<< getNumOperands() << " operands, but enclosing function (@"
|
|
<< function.getName() << ") returns " << results.size();
|
|
|
|
for (unsigned i = 0, e = results.size(); i != e; ++i)
|
|
if (getOperand(i).getType() != results[i])
|
|
return emitError() << "type of return operand " << i << " ("
|
|
<< getOperand(i).getType()
|
|
<< ") doesn't match function result type ("
|
|
<< results[i] << ")"
|
|
<< " in function @" << function.getName();
|
|
|
|
return success();
|
|
}
|