Files
clang-p2996/mlir/lib/Dialect/IRDL/IR/IRDL.cpp
Daniil Dudkin d4bde6968e [mlir][irdl] Introduce a way to define regions
This patch introduces new operations:
`irdl.region` and `irdl.regions`.
The former lets us to specify characteristics of a region,
such as the arguments for the entry block and the number of blocks.
The latter accepts all results of the former operations
to define the set of the regions for the operation.

Example:

```
    irdl.dialect @example {
      irdl.operation @op_with_regions {
          %r0 = irdl.region
          %r1 = irdl.region()
          %v0 = irdl.is i32
          %v1 = irdl.is i64
          %r2 = irdl.region(%v0, %v1)
          %r3 = irdl.region with size 3

          irdl.regions(%r0, %r1, %r2, %r3)
      }
    }
```

The above snippet demonstrates an operation named `@op_with_regions`,
which is constrained to have four regions.

* Region `%r0` doesn't have any constraints on the arguments or the number of blocks.
* Region `%r1` should have an empty set of arguments.
* Region `%r2` should have two arguments of types `i32` and `i64`.
* Region `%r3` should contain exactly three blocks.

In the future the block count constraint may be expanded to support range of possible number of blocks.

Reviewed By: math-fehr, Mogball

Differential Revision: https://reviews.llvm.org/D155112
2023-08-23 17:55:10 +03:00

248 lines
8.2 KiB
C++

//===- IRDL.cpp - IRDL dialect ----------------------------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/IRDL/IR/IRDL.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/ExtensibleDialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/Operation.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/IR/Metadata.h"
#include "llvm/Support/Casting.h"
using namespace mlir;
using namespace mlir::irdl;
//===----------------------------------------------------------------------===//
// IRDL dialect.
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/IRDL/IR/IRDL.cpp.inc"
#include "mlir/Dialect/IRDL/IR/IRDLDialect.cpp.inc"
void IRDLDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/IRDL/IR/IRDLOps.cpp.inc"
>();
addTypes<
#define GET_TYPEDEF_LIST
#include "mlir/Dialect/IRDL/IR/IRDLTypesGen.cpp.inc"
>();
addAttributes<
#define GET_ATTRDEF_LIST
#include "mlir/Dialect/IRDL/IR/IRDLAttributes.cpp.inc"
>();
}
//===----------------------------------------------------------------------===//
// Parsing/Printing
//===----------------------------------------------------------------------===//
/// Parse a region, and add a single block if the region is empty.
/// If no region is parsed, create a new region with a single empty block.
static ParseResult parseSingleBlockRegion(OpAsmParser &p, Region &region) {
auto regionParseRes = p.parseOptionalRegion(region);
if (regionParseRes.has_value() && failed(regionParseRes.value()))
return failure();
// If the region is empty, add a single empty block.
if (region.empty())
region.push_back(new Block());
return success();
}
static void printSingleBlockRegion(OpAsmPrinter &p, Operation *op,
Region &region) {
if (!region.getBlocks().front().empty())
p.printRegion(region);
}
LogicalResult DialectOp::verify() {
if (!Dialect::isValidNamespace(getName()))
return emitOpError("invalid dialect name");
return success();
}
LogicalResult OperandsOp::verify() {
size_t numVariadicities = getVariadicity().size();
size_t numOperands = getNumOperands();
if (numOperands != numVariadicities)
return emitOpError()
<< "the number of operands and their variadicities must be "
"the same, but got "
<< numOperands << " and " << numVariadicities << " respectively";
return success();
}
LogicalResult ResultsOp::verify() {
size_t numVariadicities = getVariadicity().size();
size_t numOperands = this->getNumOperands();
if (numOperands != numVariadicities)
return emitOpError()
<< "the number of operands and their variadicities must be "
"the same, but got "
<< numOperands << " and " << numVariadicities << " respectively";
return success();
}
LogicalResult AttributesOp::verify() {
size_t namesSize = getAttributeValueNames().size();
size_t valuesSize = getAttributeValues().size();
if (namesSize != valuesSize)
return emitOpError()
<< "the number of attribute names and their constraints must be "
"the same but got "
<< namesSize << " and " << valuesSize << " respectively";
return success();
}
/// Parse a value with its variadicity first. By default, the variadicity is
/// single.
///
/// value-with-variadicity ::= ("single" | "optional" | "variadic")? ssa-value
static ParseResult
parseValueWithVariadicity(OpAsmParser &p,
OpAsmParser::UnresolvedOperand &operand,
VariadicityAttr &variadicityAttr) {
MLIRContext *ctx = p.getBuilder().getContext();
// Parse the variadicity, if present
if (p.parseOptionalKeyword("single").succeeded()) {
variadicityAttr = VariadicityAttr::get(ctx, Variadicity::single);
} else if (p.parseOptionalKeyword("optional").succeeded()) {
variadicityAttr = VariadicityAttr::get(ctx, Variadicity::optional);
} else if (p.parseOptionalKeyword("variadic").succeeded()) {
variadicityAttr = VariadicityAttr::get(ctx, Variadicity::variadic);
} else {
variadicityAttr = VariadicityAttr::get(ctx, Variadicity::single);
}
// Parse the value
if (p.parseOperand(operand))
return failure();
return success();
}
/// Parse a list of values with their variadicities first. By default, the
/// variadicity is single.
///
/// values-with-variadicity ::=
/// `(` (value-with-variadicity (`,` value-with-variadicity)*)? `)`
/// value-with-variadicity ::= ("single" | "optional" | "variadic")? ssa-value
static ParseResult parseValuesWithVariadicity(
OpAsmParser &p, SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
VariadicityArrayAttr &variadicityAttr) {
Builder &builder = p.getBuilder();
MLIRContext *ctx = builder.getContext();
SmallVector<VariadicityAttr> variadicities;
// Parse a single value with its variadicity
auto parseOne = [&] {
OpAsmParser::UnresolvedOperand operand;
VariadicityAttr variadicity;
if (parseValueWithVariadicity(p, operand, variadicity))
return failure();
operands.push_back(operand);
variadicities.push_back(variadicity);
return success();
};
if (p.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren, parseOne))
return failure();
variadicityAttr = VariadicityArrayAttr::get(ctx, variadicities);
return success();
}
/// Print a list of values with their variadicities first. By default, the
/// variadicity is single.
///
/// values-with-variadicity ::=
/// `(` (value-with-variadicity (`,` value-with-variadicity)*)? `)`
/// value-with-variadicity ::= ("single" | "optional" | "variadic")? ssa-value
static void printValuesWithVariadicity(OpAsmPrinter &p, Operation *op,
OperandRange operands,
VariadicityArrayAttr variadicityAttr) {
p << "(";
interleaveComma(llvm::seq<int>(0, operands.size()), p, [&](int i) {
Variadicity variadicity = variadicityAttr[i].getValue();
if (variadicity != Variadicity::single) {
p << stringifyVariadicity(variadicity) << " ";
}
p << operands[i];
});
p << ")";
}
static ParseResult
parseAttributesOp(OpAsmParser &p,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &attrOperands,
ArrayAttr &attrNamesAttr) {
Builder &builder = p.getBuilder();
SmallVector<Attribute> attrNames;
if (succeeded(p.parseOptionalLBrace())) {
auto parseOperands = [&]() {
if (p.parseAttribute(attrNames.emplace_back()) || p.parseEqual() ||
p.parseOperand(attrOperands.emplace_back()))
return failure();
return success();
};
if (p.parseCommaSeparatedList(parseOperands) || p.parseRBrace())
return failure();
}
attrNamesAttr = builder.getArrayAttr(attrNames);
return success();
}
static void printAttributesOp(OpAsmPrinter &p, AttributesOp op,
OperandRange attrArgs, ArrayAttr attrNames) {
if (attrNames.empty())
return;
p << "{";
interleaveComma(llvm::seq<int>(0, attrNames.size()), p,
[&](int i) { p << attrNames[i] << " = " << attrArgs[i]; });
p << '}';
}
LogicalResult RegionOp::verify() {
if (IntegerAttr numberOfBlocks = getNumberOfBlocksAttr())
if (int64_t number = numberOfBlocks.getInt(); number <= 0) {
return emitOpError("the number of blocks is expected to be >= 1 but got ")
<< number;
}
return success();
}
#include "mlir/Dialect/IRDL/IR/IRDLInterfaces.cpp.inc"
#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/IRDL/IR/IRDLTypesGen.cpp.inc"
#include "mlir/Dialect/IRDL/IR/IRDLEnums.cpp.inc"
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/IRDL/IR/IRDLAttributes.cpp.inc"
#define GET_OP_CLASSES
#include "mlir/Dialect/IRDL/IR/IRDLOps.cpp.inc"