Files
clang-p2996/mlir/test/lib/TestDialect/TestDialect.cpp
River Riddle a20d96e436 Update the Inliner pass to work on SCCs of the CallGraph.
This allows for the inliner to work on arbitrary call operations. The updated inliner will also work bottom-up through the callgraph enabling support for multiple levels of inlining.

PiperOrigin-RevId: 272813876
2019-10-03 23:05:21 -07:00

236 lines
8.8 KiB
C++

//===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "TestDialect.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/InliningUtils.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
// TestDialect Interfaces
//===----------------------------------------------------------------------===//
namespace {
struct TestOpFolderDialectInterface : public OpFolderDialectInterface {
using OpFolderDialectInterface::OpFolderDialectInterface;
/// Registered hook to check if the given region, which is attached to an
/// operation that is *not* isolated from above, should be used when
/// materializing constants.
bool shouldMaterializeInto(Region *region) const final {
// If this is a one region operation, then insert into it.
return isa<OneRegionOp>(region->getParentOp());
}
};
/// This class defines the interface for handling inlining with standard
/// operations.
struct TestInlinerInterface : public DialectInlinerInterface {
using DialectInlinerInterface::DialectInlinerInterface;
//===--------------------------------------------------------------------===//
// Analysis Hooks
//===--------------------------------------------------------------------===//
bool isLegalToInline(Region *, Region *, BlockAndValueMapping &) const final {
// Inlining into test dialect regions is legal.
return true;
}
bool isLegalToInline(Operation *, Region *,
BlockAndValueMapping &) const final {
return true;
}
bool shouldAnalyzeRecursively(Operation *op) const override {
// Analyze recursively if this is not a functional region operation, it
// froms a separate functional scope.
return !isa<FunctionalRegionOp>(op);
}
//===--------------------------------------------------------------------===//
// Transformation Hooks
//===--------------------------------------------------------------------===//
/// Handle the given inlined terminator by replacing it with a new operation
/// as necessary.
void handleTerminator(Operation *op,
ArrayRef<Value *> valuesToRepl) const final {
// Only handle "test.return" here.
auto returnOp = dyn_cast<TestReturnOp>(op);
if (!returnOp)
return;
// Replace the values directly with the return operands.
assert(returnOp.getNumOperands() == valuesToRepl.size());
for (const auto &it : llvm::enumerate(returnOp.getOperands()))
valuesToRepl[it.index()]->replaceAllUsesWith(it.value());
}
};
} // end anonymous namespace
//===----------------------------------------------------------------------===//
// TestDialect
//===----------------------------------------------------------------------===//
TestDialect::TestDialect(MLIRContext *context)
: Dialect(getDialectName(), context) {
addOperations<
#define GET_OP_LIST
#include "TestOps.cpp.inc"
>();
addInterfaces<TestOpFolderDialectInterface, TestInlinerInterface>();
allowUnknownOperations();
}
//===----------------------------------------------------------------------===//
// Test IsolatedRegionOp - parse passthrough region arguments.
//===----------------------------------------------------------------------===//
static ParseResult parseIsolatedRegionOp(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::OperandType argInfo;
Type argType = parser.getBuilder().getIndexType();
// Parse the input operand.
if (parser.parseOperand(argInfo) ||
parser.resolveOperand(argInfo, argType, result.operands))
return failure();
// Parse the body region, and reuse the operand info as the argument info.
Region *body = result.addRegion();
return parser.parseRegion(*body, argInfo, argType,
/*enableNameShadowing=*/true);
}
static void print(OpAsmPrinter &p, IsolatedRegionOp op) {
p << "test.isolated_region ";
p.printOperand(op.getOperand());
p.shadowRegionArgs(op.region(), op.getOperand());
p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
}
//===----------------------------------------------------------------------===//
// Test parser.
//===----------------------------------------------------------------------===//
static ParseResult parseWrappedKeywordOp(OpAsmParser &parser,
OperationState &result) {
StringRef keyword;
if (parser.parseKeyword(&keyword))
return failure();
result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword));
return success();
}
static void print(OpAsmPrinter &p, WrappedKeywordOp op) {
p << WrappedKeywordOp::getOperationName() << " " << op.keyword();
}
//===----------------------------------------------------------------------===//
// Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
static ParseResult parseWrappingRegionOp(OpAsmParser &parser,
OperationState &result) {
if (parser.parseKeyword("wraps"))
return failure();
// Parse the wrapped op in a region
Region &body = *result.addRegion();
body.push_back(new Block);
Block &block = body.back();
Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin());
if (!wrapped_op)
return failure();
// Create a return terminator in the inner region, pass as operand to the
// terminator the returned values from the wrapped operation.
SmallVector<Value *, 8> return_operands(wrapped_op->getResults());
OpBuilder builder(parser.getBuilder().getContext());
builder.setInsertionPointToEnd(&block);
builder.create<TestReturnOp>(result.location, return_operands);
// Get the results type for the wrapping op from the terminator operands.
Operation &return_op = body.back().back();
result.types.append(return_op.operand_type_begin(),
return_op.operand_type_end());
return success();
}
static void print(OpAsmPrinter &p, WrappingRegionOp op) {
p << op.getOperationName() << " wraps ";
p.printGenericOp(&op.region().front().front());
}
//===----------------------------------------------------------------------===//
// Test PolyForOp - parse list of region arguments.
//===----------------------------------------------------------------------===//
static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::OperandType, 4> ivsInfo;
// Parse list of region arguments without a delimiter.
if (parser.parseRegionArgumentList(ivsInfo))
return failure();
// Parse the body region.
Region *body = result.addRegion();
auto &builder = parser.getBuilder();
SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType());
return parser.parseRegion(*body, ivsInfo, argTypes);
}
//===----------------------------------------------------------------------===//
// Test removing op with inner ops.
//===----------------------------------------------------------------------===//
namespace {
struct TestRemoveOpWithInnerOps
: public OpRewritePattern<TestOpWithRegionPattern> {
using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
PatternMatchResult matchAndRewrite(TestOpWithRegionPattern op,
PatternRewriter &rewriter) const override {
rewriter.replaceOp(op, llvm::None);
return matchSuccess();
}
};
} // end anonymous namespace
void TestOpWithRegionPattern::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<TestRemoveOpWithInnerOps>(context);
}
OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
return operand();
}
SmallVector<Type, 2> mlir::OpWithInferTypeInterfaceOp::inferReturnTypes(
llvm::Optional<Location> location, ArrayRef<Value *> operands,
ArrayRef<NamedAttribute> attributes, ArrayRef<Region> regions) {
if (location)
mlir::emitError(*location) << "expected to fail";
return SmallVector<Type, 2>{nullptr};
}
// Static initialization for Test dialect registration.
static mlir::DialectRegistration<mlir::TestDialect> testDialect;
#define GET_OP_CLASSES
#include "TestOps.cpp.inc"