//===- Ops.cpp - Loop MLIR Operations -------------------------------------===// // // Copyright 2019 The MLIR Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= #include "mlir/Dialect/LoopOps/LoopOps.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/Module.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/StandardTypes.h" #include "mlir/IR/Value.h" #include "mlir/StandardOps/Ops.h" #include "mlir/Support/MathExtras.h" #include "mlir/Support/STLExtras.h" using namespace mlir; using namespace mlir::loop; //===----------------------------------------------------------------------===// // LoopOpsDialect //===----------------------------------------------------------------------===// LoopOpsDialect::LoopOpsDialect(MLIRContext *context) : Dialect(getDialectNamespace(), context) { addOperations< #define GET_OP_LIST #include "mlir/Dialect/LoopOps/LoopOps.cpp.inc" >(); } //===----------------------------------------------------------------------===// // ForOp //===----------------------------------------------------------------------===// void ForOp::build(Builder *builder, OperationState *result, Value *lb, Value *ub, Value *step) { result->addOperands({lb, ub, step}); Region *bodyRegion = result->addRegion(); ForOp::ensureTerminator(*bodyRegion, *builder, result->location); bodyRegion->front().addArgument(builder->getIndexType()); } LogicalResult verify(ForOp op) { if (auto cst = dyn_cast_or_null(op.step()->getDefiningOp())) if (cst.getValue() <= 0) return op.emitOpError("constant step operand must be nonnegative"); // Check that the body defines as single block argument for the induction // variable. auto *body = op.getBody(); if (body->getNumArguments() != 1 || !body->getArgument(0)->getType().isIndex()) return op.emitOpError("expected body to have a single index argument for " "the induction variable"); return success(); } static void print(OpAsmPrinter *p, ForOp op) { *p << op.getOperationName() << " " << *op.getInductionVar() << " = " << *op.lowerBound() << " to " << *op.upperBound() << " step " << *op.step(); p->printRegion(op.region(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/false); p->printOptionalAttrDict(op.getAttrs()); } static ParseResult parseForOp(OpAsmParser *parser, OperationState *result) { auto &builder = parser->getBuilder(); OpAsmParser::OperandType inductionVariable, lb, ub, step; // Parse the induction variable followed by '='. if (parser->parseRegionArgument(inductionVariable) || parser->parseEqual()) return failure(); // Parse loop bounds. Type indexType = builder.getIndexType(); if (parser->parseOperand(lb) || parser->resolveOperand(lb, indexType, result->operands) || parser->parseKeyword("to") || parser->parseOperand(ub) || parser->resolveOperand(ub, indexType, result->operands) || parser->parseKeyword("step") || parser->parseOperand(step) || parser->resolveOperand(step, indexType, result->operands)) return failure(); // Parse the body region. Region *body = result->addRegion(); if (parser->parseRegion(*body, inductionVariable, indexType)) return failure(); ForOp::ensureTerminator(*body, builder, result->location); // Parse the optional attribute list. if (parser->parseOptionalAttributeDict(result->attributes)) return failure(); return success(); } ForOp mlir::loop::getForInductionVarOwner(Value *val) { auto *ivArg = dyn_cast(val); if (!ivArg) return ForOp(); assert(ivArg->getOwner() && "unlinked block argument"); auto *containingInst = ivArg->getOwner()->getParentOp(); return dyn_cast_or_null(containingInst); } //===----------------------------------------------------------------------===// // IfOp //===----------------------------------------------------------------------===// void IfOp::build(Builder *builder, OperationState *result, Value *cond, bool withElseRegion) { result->addOperands(cond); Region *thenRegion = result->addRegion(); Region *elseRegion = result->addRegion(); IfOp::ensureTerminator(*thenRegion, *builder, result->location); if (withElseRegion) IfOp::ensureTerminator(*elseRegion, *builder, result->location); } static LogicalResult verify(IfOp op) { // Verify that the entry of each child region does not have arguments. for (auto ®ion : op.getOperation()->getRegions()) { if (region.empty()) continue; for (auto &b : region) if (b.getNumArguments() != 0) return op.emitOpError( "requires that child entry blocks have no arguments"); } return success(); } static ParseResult parseIfOp(OpAsmParser *parser, OperationState *result) { // Create the regions for 'then'. result->regions.reserve(2); Region *thenRegion = result->addRegion(); Region *elseRegion = result->addRegion(); auto &builder = parser->getBuilder(); OpAsmParser::OperandType cond; Type i1Type = builder.getIntegerType(1); if (parser->parseOperand(cond) || parser->resolveOperand(cond, i1Type, result->operands)) return failure(); // Parse the 'then' region. if (parser->parseRegion(*thenRegion, {}, {})) return failure(); IfOp::ensureTerminator(*thenRegion, parser->getBuilder(), result->location); // If we find an 'else' keyword then parse the 'else' region. if (!parser->parseOptionalKeyword("else")) { if (parser->parseRegion(*elseRegion, {}, {})) return failure(); IfOp::ensureTerminator(*elseRegion, parser->getBuilder(), result->location); } // Parse the optional attribute list. if (parser->parseOptionalAttributeDict(result->attributes)) return failure(); return success(); } static void print(OpAsmPrinter *p, IfOp op) { *p << IfOp::getOperationName() << " " << *op.condition(); p->printRegion(op.thenRegion(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/false); // Print the 'else' regions if it exists and has a block. auto &elseRegion = op.elseRegion(); if (!elseRegion.empty()) { *p << " else"; p->printRegion(elseRegion, /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/false); } p->printOptionalAttrDict(op.getAttrs()); } //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// #define GET_OP_CLASSES #include "mlir/Dialect/LoopOps/LoopOps.cpp.inc"