This commit refactors the expected form of native constraint and rewrite
functions, and greatly reduces the necessary user complexity required when
defining a native function. Namely, this commit adds in automatic processing
of the necessary PDLValue glue code, and allows for users to define
constraint/rewrite functions using the C++ types that they actually want to
use.
As an example, lets see a simple example rewrite defined today:
```
static void rewriteFn(PatternRewriter &rewriter, PDLResultList &results,
ArrayRef<PDLValue> args) {
ValueRange operandValues = args[0].cast<ValueRange>();
TypeRange typeValues = args[1].cast<TypeRange>();
...
// Create an operation at some point and pass it back to PDL.
Operation *op = rewriter.create<SomeOp>(...);
results.push_back(op);
}
```
After this commit, that same rewrite could be defined as:
```
static Operation *rewriteFn(PatternRewriter &rewriter ValueRange operandValues,
TypeRange typeValues) {
...
// Create an operation at some point and pass it back to PDL.
return rewriter.create<SomeOp>(...);
}
```
Differential Revision: https://reviews.llvm.org/D122086
215 lines
8.5 KiB
C++
215 lines
8.5 KiB
C++
//===- CPPGen.cpp ---------------------------------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This files contains a PDLL generator that outputs C++ code that defines PDLL
|
|
// patterns as individual C++ PDLPatternModules for direct use in native code,
|
|
// and also defines any native constraints whose bodies were defined in PDLL.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Tools/PDLL/CodeGen/CPPGen.h"
|
|
#include "mlir/Dialect/PDL/IR/PDL.h"
|
|
#include "mlir/Dialect/PDL/IR/PDLOps.h"
|
|
#include "mlir/IR/BuiltinOps.h"
|
|
#include "mlir/Tools/PDLL/AST/Nodes.h"
|
|
#include "llvm/ADT/SmallString.h"
|
|
#include "llvm/ADT/StringExtras.h"
|
|
#include "llvm/ADT/StringSet.h"
|
|
#include "llvm/ADT/TypeSwitch.h"
|
|
#include "llvm/Support/ErrorHandling.h"
|
|
#include "llvm/Support/FormatVariadic.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::pdll;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// CodeGen
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
class CodeGen {
|
|
public:
|
|
CodeGen(raw_ostream &os) : os(os) {}
|
|
|
|
/// Generate C++ code for the given PDL pattern module.
|
|
void generate(const ast::Module &astModule, ModuleOp module);
|
|
|
|
private:
|
|
void generate(pdl::PatternOp pattern, StringRef patternName,
|
|
StringSet<> &nativeFunctions);
|
|
|
|
/// Generate C++ code for all user defined constraints and rewrites with
|
|
/// native code.
|
|
void generateConstraintAndRewrites(const ast::Module &astModule,
|
|
ModuleOp module,
|
|
StringSet<> &nativeFunctions);
|
|
void generate(const ast::UserConstraintDecl *decl,
|
|
StringSet<> &nativeFunctions);
|
|
void generate(const ast::UserRewriteDecl *decl, StringSet<> &nativeFunctions);
|
|
void generateConstraintOrRewrite(StringRef name, bool isConstraint,
|
|
ArrayRef<ast::VariableDecl *> inputs,
|
|
StringRef codeBlock,
|
|
StringSet<> &nativeFunctions);
|
|
|
|
/// The stream to output to.
|
|
raw_ostream &os;
|
|
};
|
|
} // namespace
|
|
|
|
void CodeGen::generate(const ast::Module &astModule, ModuleOp module) {
|
|
SetVector<std::string, SmallVector<std::string>, StringSet<>> patternNames;
|
|
StringSet<> nativeFunctions;
|
|
|
|
// Generate code for any native functions within the module.
|
|
generateConstraintAndRewrites(astModule, module, nativeFunctions);
|
|
|
|
os << "namespace {\n";
|
|
std::string basePatternName = "GeneratedPDLLPattern";
|
|
int patternIndex = 0;
|
|
for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) {
|
|
// If the pattern has a name, use that. Otherwise, generate a unique name.
|
|
if (Optional<StringRef> patternName = pattern.sym_name()) {
|
|
patternNames.insert(patternName->str());
|
|
} else {
|
|
std::string name;
|
|
do {
|
|
name = (basePatternName + Twine(patternIndex++)).str();
|
|
} while (!patternNames.insert(name));
|
|
}
|
|
|
|
generate(pattern, patternNames.back(), nativeFunctions);
|
|
}
|
|
os << "} // end namespace\n\n";
|
|
|
|
// Emit function to add the generated matchers to the pattern list.
|
|
os << "static void LLVM_ATTRIBUTE_UNUSED populateGeneratedPDLLPatterns("
|
|
"::mlir::RewritePatternSet &patterns) {\n";
|
|
for (const auto &name : patternNames)
|
|
os << " patterns.add<" << name << ">(patterns.getContext());\n";
|
|
os << "}\n";
|
|
}
|
|
|
|
void CodeGen::generate(pdl::PatternOp pattern, StringRef patternName,
|
|
StringSet<> &nativeFunctions) {
|
|
const char *patternClassStartStr = R"(
|
|
struct {0} : ::mlir::PDLPatternModule {{
|
|
{0}(::mlir::MLIRContext *context)
|
|
: ::mlir::PDLPatternModule(::mlir::parseSourceString<::mlir::ModuleOp>(
|
|
)";
|
|
os << llvm::formatv(patternClassStartStr, patternName);
|
|
|
|
os << "R\"mlir(";
|
|
pattern->print(os, OpPrintingFlags().enableDebugInfo());
|
|
os << "\n )mlir\", context)) {\n";
|
|
|
|
// Register any native functions used within the pattern.
|
|
StringSet<> registeredNativeFunctions;
|
|
auto checkRegisterNativeFn = [&](StringRef fnName, StringRef fnType) {
|
|
if (!nativeFunctions.count(fnName) ||
|
|
!registeredNativeFunctions.insert(fnName).second)
|
|
return;
|
|
os << " register" << fnType << "Function(\"" << fnName << "\", "
|
|
<< fnName << "PDLFn);\n";
|
|
};
|
|
pattern.walk([&](Operation *op) {
|
|
if (auto constraintOp = dyn_cast<pdl::ApplyNativeConstraintOp>(op))
|
|
checkRegisterNativeFn(constraintOp.name(), "Constraint");
|
|
else if (auto rewriteOp = dyn_cast<pdl::ApplyNativeRewriteOp>(op))
|
|
checkRegisterNativeFn(rewriteOp.name(), "Rewrite");
|
|
});
|
|
os << " }\n};\n\n";
|
|
}
|
|
|
|
void CodeGen::generateConstraintAndRewrites(const ast::Module &astModule,
|
|
ModuleOp module,
|
|
StringSet<> &nativeFunctions) {
|
|
// First check to see which constraints and rewrites are actually referenced
|
|
// in the module.
|
|
StringSet<> usedFns;
|
|
module.walk([&](Operation *op) {
|
|
TypeSwitch<Operation *>(op)
|
|
.Case<pdl::ApplyNativeConstraintOp, pdl::ApplyNativeRewriteOp>(
|
|
[&](auto op) { usedFns.insert(op.name()); });
|
|
});
|
|
|
|
for (const ast::Decl *decl : astModule.getChildren()) {
|
|
TypeSwitch<const ast::Decl *>(decl)
|
|
.Case<ast::UserConstraintDecl, ast::UserRewriteDecl>(
|
|
[&](const auto *decl) {
|
|
// We only generate code for inline native decls that have been
|
|
// referenced.
|
|
if (decl->getCodeBlock() &&
|
|
usedFns.contains(decl->getName().getName()))
|
|
this->generate(decl, nativeFunctions);
|
|
});
|
|
}
|
|
}
|
|
|
|
void CodeGen::generate(const ast::UserConstraintDecl *decl,
|
|
StringSet<> &nativeFunctions) {
|
|
return generateConstraintOrRewrite(decl->getName().getName(),
|
|
/*isConstraint=*/true, decl->getInputs(),
|
|
*decl->getCodeBlock(), nativeFunctions);
|
|
}
|
|
|
|
void CodeGen::generate(const ast::UserRewriteDecl *decl,
|
|
StringSet<> &nativeFunctions) {
|
|
return generateConstraintOrRewrite(decl->getName().getName(),
|
|
/*isConstraint=*/false, decl->getInputs(),
|
|
*decl->getCodeBlock(), nativeFunctions);
|
|
}
|
|
|
|
void CodeGen::generateConstraintOrRewrite(StringRef name, bool isConstraint,
|
|
ArrayRef<ast::VariableDecl *> inputs,
|
|
StringRef codeBlock,
|
|
StringSet<> &nativeFunctions) {
|
|
nativeFunctions.insert(name);
|
|
|
|
// TODO: Should there be something explicit for handling optionality?
|
|
auto getCppType = [&](ast::Type type) -> StringRef {
|
|
return llvm::TypeSwitch<ast::Type, StringRef>(type)
|
|
.Case([&](ast::AttributeType) { return "::mlir::Attribute"; })
|
|
.Case([&](ast::OperationType) {
|
|
// TODO: Allow using the derived Op class when possible.
|
|
return "::mlir::Operation *";
|
|
})
|
|
.Case([&](ast::TypeType) { return "::mlir::Type"; })
|
|
.Case([&](ast::ValueType) { return "::mlir::Value"; })
|
|
.Case([&](ast::TypeRangeType) { return "::mlir::TypeRange"; })
|
|
.Case([&](ast::ValueRangeType) { return "::mlir::ValueRange"; });
|
|
};
|
|
os << "static " << (isConstraint ? "::mlir::LogicalResult " : "void ") << name
|
|
<< "PDLFn(::mlir::PatternRewriter &rewriter, "
|
|
<< (isConstraint ? "" : "::mlir::PDLResultList &results, ")
|
|
<< "::llvm::ArrayRef<::mlir::PDLValue> values) {\n";
|
|
|
|
const char *argumentInitStr = R"(
|
|
{0} {1} = {{};
|
|
if (values[{2}])
|
|
{1} = values[{2}].cast<{0}>();
|
|
(void){1};
|
|
)";
|
|
for (const auto &it : llvm::enumerate(inputs)) {
|
|
const ast::VariableDecl *input = it.value();
|
|
os << llvm::formatv(argumentInitStr, getCppType(input->getType()),
|
|
input->getName().getName(), it.index());
|
|
}
|
|
|
|
os << " " << codeBlock.trim() << "\n}\n";
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// CPPGen
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void mlir::pdll::codegenPDLLToCPP(const ast::Module &astModule, ModuleOp module,
|
|
raw_ostream &os) {
|
|
CodeGen codegen(os);
|
|
codegen.generate(astModule, module);
|
|
}
|