Files
clang-p2996/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
Chris Lattner 41d4aa7de6 [SymbolRefAttr] Revise SymbolRefAttr to hold a StringAttr.
SymbolRefAttr is fundamentally a base string plus a sequence
of nested references.  Instead of storing the string data as
a copies StringRef, store it as an already-uniqued StringAttr.

This makes a lot of things simpler and more efficient because:
1) references to the symbol are already stored as StringAttr's:
   there is no need to copy the string data into MLIRContext
   multiple times.
2) This allows pointer comparisons instead of string
   comparisons (or redundant uniquing) within SymbolTable.cpp.
3) This allows SymbolTable to hold a DenseMap instead of a
   StringMap (which again copies the string data and slows
   lookup).

This is a moderately invasive patch, so I kept a lot of
compatibility APIs around.  It would be nice to explore changing
getName() to return a StringAttr for example (right now you have
to use getNameAttr()), and eliminate things like the StringRef
version of getSymbol.

Differential Revision: https://reviews.llvm.org/D108899
2021-08-29 21:54:47 -07:00

119 lines
5.2 KiB
C++

//===- TestPDLByteCode.cpp - Test rewriter bytecode functionality ---------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
/// Custom constraint invoked from PDL.
static LogicalResult customSingleEntityConstraint(PDLValue value,
ArrayAttr constantParams,
PatternRewriter &rewriter) {
Operation *rootOp = value.cast<Operation *>();
return success(rootOp->getName().getStringRef() == "test.op");
}
static LogicalResult customMultiEntityConstraint(ArrayRef<PDLValue> values,
ArrayAttr constantParams,
PatternRewriter &rewriter) {
return customSingleEntityConstraint(values[1], constantParams, rewriter);
}
static LogicalResult
customMultiEntityVariadicConstraint(ArrayRef<PDLValue> values,
ArrayAttr constantParams,
PatternRewriter &rewriter) {
if (llvm::any_of(values, [](const PDLValue &value) { return !value; }))
return failure();
ValueRange operandValues = values[0].cast<ValueRange>();
TypeRange typeValues = values[1].cast<TypeRange>();
if (operandValues.size() != 2 || typeValues.size() != 2)
return failure();
return success();
}
// Custom creator invoked from PDL.
static void customCreate(ArrayRef<PDLValue> args, ArrayAttr constantParams,
PatternRewriter &rewriter, PDLResultList &results) {
results.push_back(rewriter.createOperation(
OperationState(args[0].cast<Operation *>()->getLoc(), "test.success")));
}
static void customVariadicResultCreate(ArrayRef<PDLValue> args,
ArrayAttr constantParams,
PatternRewriter &rewriter,
PDLResultList &results) {
Operation *root = args[0].cast<Operation *>();
results.push_back(root->getOperands());
results.push_back(root->getOperands().getTypes());
}
static void customCreateType(ArrayRef<PDLValue> args, ArrayAttr constantParams,
PatternRewriter &rewriter,
PDLResultList &results) {
results.push_back(rewriter.getF32Type());
}
/// Custom rewriter invoked from PDL.
static void customRewriter(ArrayRef<PDLValue> args, ArrayAttr constantParams,
PatternRewriter &rewriter, PDLResultList &results) {
Operation *root = args[0].cast<Operation *>();
OperationState successOpState(root->getLoc(), "test.success");
successOpState.addOperands(args[1].cast<Value>());
successOpState.addAttribute("constantParams", constantParams);
rewriter.createOperation(successOpState);
rewriter.eraseOp(root);
}
namespace {
struct TestPDLByteCodePass
: public PassWrapper<TestPDLByteCodePass, OperationPass<ModuleOp>> {
StringRef getArgument() const final { return "test-pdl-bytecode-pass"; }
StringRef getDescription() const final {
return "Test PDL ByteCode functionality";
}
void runOnOperation() final {
ModuleOp module = getOperation();
// The test cases are encompassed via two modules, one containing the
// patterns and one containing the operations to rewrite.
ModuleOp patternModule = module.lookupSymbol<ModuleOp>(
StringAttr::get(module->getContext(), "patterns"));
ModuleOp irModule = module.lookupSymbol<ModuleOp>(
StringAttr::get(module->getContext(), "ir"));
if (!patternModule || !irModule)
return;
// Process the pattern module.
patternModule.getOperation()->remove();
PDLPatternModule pdlPattern(patternModule);
pdlPattern.registerConstraintFunction("multi_entity_constraint",
customMultiEntityConstraint);
pdlPattern.registerConstraintFunction("single_entity_constraint",
customSingleEntityConstraint);
pdlPattern.registerConstraintFunction("multi_entity_var_constraint",
customMultiEntityVariadicConstraint);
pdlPattern.registerRewriteFunction("creator", customCreate);
pdlPattern.registerRewriteFunction("var_creator",
customVariadicResultCreate);
pdlPattern.registerRewriteFunction("type_creator", customCreateType);
pdlPattern.registerRewriteFunction("rewriter", customRewriter);
RewritePatternSet patternList(std::move(pdlPattern));
// Invoke the pattern driver with the provided patterns.
(void)applyPatternsAndFoldGreedily(irModule.getBodyRegion(),
std::move(patternList));
}
};
} // end anonymous namespace
namespace mlir {
namespace test {
void registerTestPDLByteCodePass() { PassRegistration<TestPDLByteCodePass>(); }
} // namespace test
} // namespace mlir