This PR massively reorganizes the Test dialect's source files. It moves manually-written op hooks into `TestOpDefs.cpp`, moves format custom directive parsers and printers into `TestFormatUtils`, adds missing comment blocks, and moves around where generated source files are included for types, attributes, enums, etc. into their own source file. This will hopefully help navigate the test dialect source code, but also speeds up compile time of the test dialect by putting generated source files into separate compilation units. This also sets up the test dialect to shard its op definitions, done in the next PR.
378 lines
14 KiB
C++
378 lines
14 KiB
C++
//===- TestFormatUtils.cpp - MLIR Test Dialect Assembly Format Utilities --===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "TestFormatUtils.h"
|
|
#include "mlir/IR/Builders.h"
|
|
|
|
using namespace mlir;
|
|
using namespace test;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// CustomDirectiveOperands
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
ParseResult test::parseCustomDirectiveOperands(
|
|
OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand,
|
|
std::optional<OpAsmParser::UnresolvedOperand> &optOperand,
|
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands) {
|
|
if (parser.parseOperand(operand))
|
|
return failure();
|
|
if (succeeded(parser.parseOptionalComma())) {
|
|
optOperand.emplace();
|
|
if (parser.parseOperand(*optOperand))
|
|
return failure();
|
|
}
|
|
if (parser.parseArrow() || parser.parseLParen() ||
|
|
parser.parseOperandList(varOperands) || parser.parseRParen())
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
void test::printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *,
|
|
Value operand, Value optOperand,
|
|
OperandRange varOperands) {
|
|
printer << operand;
|
|
if (optOperand)
|
|
printer << ", " << optOperand;
|
|
printer << " -> (" << varOperands << ")";
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// CustomDirectiveResults
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
ParseResult
|
|
test::parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
|
|
Type &optOperandType,
|
|
SmallVectorImpl<Type> &varOperandTypes) {
|
|
if (parser.parseColon())
|
|
return failure();
|
|
|
|
if (parser.parseType(operandType))
|
|
return failure();
|
|
if (succeeded(parser.parseOptionalComma()))
|
|
if (parser.parseType(optOperandType))
|
|
return failure();
|
|
if (parser.parseArrow() || parser.parseLParen() ||
|
|
parser.parseTypeList(varOperandTypes) || parser.parseRParen())
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
void test::printCustomDirectiveResults(OpAsmPrinter &printer, Operation *,
|
|
Type operandType, Type optOperandType,
|
|
TypeRange varOperandTypes) {
|
|
printer << " : " << operandType;
|
|
if (optOperandType)
|
|
printer << ", " << optOperandType;
|
|
printer << " -> (" << varOperandTypes << ")";
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// CustomDirectiveWithTypeRefs
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
ParseResult test::parseCustomDirectiveWithTypeRefs(
|
|
OpAsmParser &parser, Type operandType, Type optOperandType,
|
|
const SmallVectorImpl<Type> &varOperandTypes) {
|
|
if (parser.parseKeyword("type_refs_capture"))
|
|
return failure();
|
|
|
|
Type operandType2, optOperandType2;
|
|
SmallVector<Type, 1> varOperandTypes2;
|
|
if (parseCustomDirectiveResults(parser, operandType2, optOperandType2,
|
|
varOperandTypes2))
|
|
return failure();
|
|
|
|
if (operandType != operandType2 || optOperandType != optOperandType2 ||
|
|
varOperandTypes != varOperandTypes2)
|
|
return failure();
|
|
|
|
return success();
|
|
}
|
|
|
|
void test::printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer,
|
|
Operation *op, Type operandType,
|
|
Type optOperandType,
|
|
TypeRange varOperandTypes) {
|
|
printer << " type_refs_capture ";
|
|
printCustomDirectiveResults(printer, op, operandType, optOperandType,
|
|
varOperandTypes);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// CustomDirectiveOperandsAndTypes
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
ParseResult test::parseCustomDirectiveOperandsAndTypes(
|
|
OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand,
|
|
std::optional<OpAsmParser::UnresolvedOperand> &optOperand,
|
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands,
|
|
Type &operandType, Type &optOperandType,
|
|
SmallVectorImpl<Type> &varOperandTypes) {
|
|
if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) ||
|
|
parseCustomDirectiveResults(parser, operandType, optOperandType,
|
|
varOperandTypes))
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
void test::printCustomDirectiveOperandsAndTypes(
|
|
OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand,
|
|
OperandRange varOperands, Type operandType, Type optOperandType,
|
|
TypeRange varOperandTypes) {
|
|
printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands);
|
|
printCustomDirectiveResults(printer, op, operandType, optOperandType,
|
|
varOperandTypes);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// CustomDirectiveRegions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
ParseResult test::parseCustomDirectiveRegions(
|
|
OpAsmParser &parser, Region ®ion,
|
|
SmallVectorImpl<std::unique_ptr<Region>> &varRegions) {
|
|
if (parser.parseRegion(region))
|
|
return failure();
|
|
if (failed(parser.parseOptionalComma()))
|
|
return success();
|
|
std::unique_ptr<Region> varRegion = std::make_unique<Region>();
|
|
if (parser.parseRegion(*varRegion))
|
|
return failure();
|
|
varRegions.emplace_back(std::move(varRegion));
|
|
return success();
|
|
}
|
|
|
|
void test::printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *,
|
|
Region ®ion,
|
|
MutableArrayRef<Region> varRegions) {
|
|
printer.printRegion(region);
|
|
if (!varRegions.empty()) {
|
|
printer << ", ";
|
|
for (Region ®ion : varRegions)
|
|
printer.printRegion(region);
|
|
}
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// CustomDirectiveSuccessors
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
ParseResult
|
|
test::parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
|
|
SmallVectorImpl<Block *> &varSuccessors) {
|
|
if (parser.parseSuccessor(successor))
|
|
return failure();
|
|
if (failed(parser.parseOptionalComma()))
|
|
return success();
|
|
Block *varSuccessor;
|
|
if (parser.parseSuccessor(varSuccessor))
|
|
return failure();
|
|
varSuccessors.append(2, varSuccessor);
|
|
return success();
|
|
}
|
|
|
|
void test::printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *,
|
|
Block *successor,
|
|
SuccessorRange varSuccessors) {
|
|
printer << successor;
|
|
if (!varSuccessors.empty())
|
|
printer << ", " << varSuccessors.front();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// CustomDirectiveAttributes
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
ParseResult test::parseCustomDirectiveAttributes(OpAsmParser &parser,
|
|
IntegerAttr &attr,
|
|
IntegerAttr &optAttr) {
|
|
if (parser.parseAttribute(attr))
|
|
return failure();
|
|
if (succeeded(parser.parseOptionalComma())) {
|
|
if (parser.parseAttribute(optAttr))
|
|
return failure();
|
|
}
|
|
return success();
|
|
}
|
|
|
|
void test::printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *,
|
|
Attribute attribute,
|
|
Attribute optAttribute) {
|
|
printer << attribute;
|
|
if (optAttribute)
|
|
printer << ", " << optAttribute;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// CustomDirectiveAttrDict
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
ParseResult test::parseCustomDirectiveAttrDict(OpAsmParser &parser,
|
|
NamedAttrList &attrs) {
|
|
return parser.parseOptionalAttrDict(attrs);
|
|
}
|
|
|
|
void test::printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op,
|
|
DictionaryAttr attrs) {
|
|
printer.printOptionalAttrDict(attrs.getValue());
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// CustomDirectiveOptionalOperandRef
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
ParseResult test::parseCustomDirectiveOptionalOperandRef(
|
|
OpAsmParser &parser,
|
|
std::optional<OpAsmParser::UnresolvedOperand> &optOperand) {
|
|
int64_t operandCount = 0;
|
|
if (parser.parseInteger(operandCount))
|
|
return failure();
|
|
bool expectedOptionalOperand = operandCount == 0;
|
|
return success(expectedOptionalOperand != !!optOperand);
|
|
}
|
|
|
|
void test::printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer,
|
|
Operation *op,
|
|
Value optOperand) {
|
|
printer << (optOperand ? "1" : "0");
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// CustomDirectiveOptionalOperand
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
ParseResult test::parseCustomOptionalOperand(
|
|
OpAsmParser &parser,
|
|
std::optional<OpAsmParser::UnresolvedOperand> &optOperand) {
|
|
if (succeeded(parser.parseOptionalLParen())) {
|
|
optOperand.emplace();
|
|
if (parser.parseOperand(*optOperand) || parser.parseRParen())
|
|
return failure();
|
|
}
|
|
return success();
|
|
}
|
|
|
|
void test::printCustomOptionalOperand(OpAsmPrinter &printer, Operation *,
|
|
Value optOperand) {
|
|
if (optOperand)
|
|
printer << "(" << optOperand << ") ";
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// CustomDirectiveSwitchCases
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
ParseResult
|
|
test::parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases,
|
|
SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) {
|
|
SmallVector<int64_t> caseValues;
|
|
while (succeeded(p.parseOptionalKeyword("case"))) {
|
|
int64_t value;
|
|
Region ®ion = *caseRegions.emplace_back(std::make_unique<Region>());
|
|
if (p.parseInteger(value) || p.parseRegion(region, /*arguments=*/{}))
|
|
return failure();
|
|
caseValues.push_back(value);
|
|
}
|
|
cases = p.getBuilder().getDenseI64ArrayAttr(caseValues);
|
|
return success();
|
|
}
|
|
|
|
void test::printSwitchCases(OpAsmPrinter &p, Operation *op,
|
|
DenseI64ArrayAttr cases, RegionRange caseRegions) {
|
|
for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) {
|
|
p.printNewline();
|
|
p << "case " << value << ' ';
|
|
p.printRegion(*region, /*printEntryBlockArgs=*/false);
|
|
}
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// CustomUsingPropertyInCustom
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
bool test::parseUsingPropertyInCustom(OpAsmParser &parser, int64_t value[3]) {
|
|
return parser.parseLSquare() || parser.parseInteger(value[0]) ||
|
|
parser.parseComma() || parser.parseInteger(value[1]) ||
|
|
parser.parseComma() || parser.parseInteger(value[2]) ||
|
|
parser.parseRSquare();
|
|
}
|
|
|
|
void test::printUsingPropertyInCustom(OpAsmPrinter &printer, Operation *op,
|
|
ArrayRef<int64_t> value) {
|
|
printer << '[' << value << ']';
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// CustomDirectiveIntProperty
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
bool test::parseIntProperty(OpAsmParser &parser, int64_t &value) {
|
|
return failed(parser.parseInteger(value));
|
|
}
|
|
|
|
void test::printIntProperty(OpAsmPrinter &printer, Operation *op,
|
|
int64_t value) {
|
|
printer << value;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// CustomDirectiveSumProperty
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
bool test::parseSumProperty(OpAsmParser &parser, int64_t &second,
|
|
int64_t first) {
|
|
int64_t sum;
|
|
auto loc = parser.getCurrentLocation();
|
|
if (parser.parseInteger(second) || parser.parseEqual() ||
|
|
parser.parseInteger(sum))
|
|
return true;
|
|
if (sum != second + first) {
|
|
parser.emitError(loc, "Expected sum to equal first + second");
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
void test::printSumProperty(OpAsmPrinter &printer, Operation *op,
|
|
int64_t second, int64_t first) {
|
|
printer << second << " = " << (second + first);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// CustomDirectiveOptionalCustomParser
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OptionalParseResult test::parseOptionalCustomParser(AsmParser &p,
|
|
IntegerAttr &result) {
|
|
if (succeeded(p.parseOptionalKeyword("foo")))
|
|
return p.parseAttribute(result);
|
|
return {};
|
|
}
|
|
|
|
void test::printOptionalCustomParser(AsmPrinter &p, Operation *,
|
|
IntegerAttr result) {
|
|
p << "foo ";
|
|
p.printAttribute(result);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// CustomDirectiveAttrElideType
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
ParseResult test::parseAttrElideType(AsmParser &parser, TypeAttr type,
|
|
Attribute &attr) {
|
|
return parser.parseAttribute(attr, type.getValue());
|
|
}
|
|
|
|
void test::printAttrElideType(AsmPrinter &printer, Operation *op, TypeAttr type,
|
|
Attribute attr) {
|
|
printer.printAttributeWithoutType(attr);
|
|
}
|