Files
clang-p2996/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
River Riddle abfd1a8b3b [mlir][PDL] Add support for PDL bytecode and expose PDL support to OwningRewritePatternList
PDL patterns are now supported via a new `PDLPatternModule` class. This class contains a ModuleOp with the pdl::PatternOp operations representing the patterns, as well as a collection of registered C++ functions for native constraints/creations/rewrites/etc. that may be invoked via the pdl patterns. Instances of this class are added to an OwningRewritePatternList in the same fashion as C++ RewritePatterns, i.e. via the `insert` method.

The PDL bytecode is an in-memory representation of the PDL interpreter dialect that can be efficiently interpreted/executed. The representation of the bytecode boils down to a code array(for opcodes/memory locations/etc) and a memory buffer(for storing attributes/operations/values/any other data necessary). The bytecode operations are effectively a 1-1 mapping to the PDLInterp dialect operations, with a few exceptions in cases where the in-memory representation of the bytecode can be more efficient than the MLIR representation. For example, a generic `AreEqual` bytecode op can be used to represent AreEqualOp, CheckAttributeOp, and CheckTypeOp.

The execution of the bytecode is split into two phases: matching and rewriting. When matching, all of the matched patterns are collected to avoid the overhead of re-running parts of the matcher. These matched patterns are then considered alongside the native C++ patterns, which rewrite immediately in-place via `RewritePattern::matchAndRewrite`,  for the given root operation. When a PDL pattern is matched and has the highest benefit, it is passed back to the bytecode to execute its rewriter.

Differential Revision: https://reviews.llvm.org/D89107
2020-12-01 15:05:50 -08:00

86 lines
3.5 KiB
C++

//===- TestPDLByteCode.cpp - Test rewriter bytecode functionality ---------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
/// Custom constraint invoked from PDL.
static LogicalResult customSingleEntityConstraint(PDLValue value,
ArrayAttr constantParams,
PatternRewriter &rewriter) {
Operation *rootOp = value.cast<Operation *>();
return success(rootOp->getName().getStringRef() == "test.op");
}
static LogicalResult customMultiEntityConstraint(ArrayRef<PDLValue> values,
ArrayAttr constantParams,
PatternRewriter &rewriter) {
return customSingleEntityConstraint(values[1], constantParams, rewriter);
}
// Custom creator invoked from PDL.
static PDLValue customCreate(ArrayRef<PDLValue> args, ArrayAttr constantParams,
PatternRewriter &rewriter) {
return rewriter.createOperation(
OperationState(args[0].cast<Operation *>()->getLoc(), "test.success"));
}
/// Custom rewriter invoked from PDL.
static void customRewriter(Operation *root, ArrayRef<PDLValue> args,
ArrayAttr constantParams,
PatternRewriter &rewriter) {
OperationState successOpState(root->getLoc(), "test.success");
successOpState.addOperands(args[0].cast<Value>());
successOpState.addAttribute("constantParams", constantParams);
rewriter.createOperation(successOpState);
rewriter.eraseOp(root);
}
namespace {
struct TestPDLByteCodePass
: public PassWrapper<TestPDLByteCodePass, OperationPass<ModuleOp>> {
void runOnOperation() final {
ModuleOp module = getOperation();
// The test cases are encompassed via two modules, one containing the
// patterns and one containing the operations to rewrite.
ModuleOp patternModule = module.lookupSymbol<ModuleOp>("patterns");
ModuleOp irModule = module.lookupSymbol<ModuleOp>("ir");
if (!patternModule || !irModule)
return;
// Process the pattern module.
patternModule.getOperation()->remove();
PDLPatternModule pdlPattern(patternModule);
pdlPattern.registerConstraintFunction("multi_entity_constraint",
customMultiEntityConstraint);
pdlPattern.registerConstraintFunction("single_entity_constraint",
customSingleEntityConstraint);
pdlPattern.registerCreateFunction("creator", customCreate);
pdlPattern.registerRewriteFunction("rewriter", customRewriter);
OwningRewritePatternList patternList(std::move(pdlPattern));
// Invoke the pattern driver with the provided patterns.
(void)applyPatternsAndFoldGreedily(irModule.getBodyRegion(),
std::move(patternList));
}
};
} // end anonymous namespace
namespace mlir {
namespace test {
void registerTestPDLByteCodePass() {
PassRegistration<TestPDLByteCodePass>("test-pdl-bytecode-pass",
"Test PDL ByteCode functionality");
}
} // namespace test
} // namespace mlir