Files
clang-p2996/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
River Riddle b54c724be0 [mlir:OpConversionPattern] Add overloads for taking an Adaptor instead of ArrayRef
This has been a TODO for a long time, and it brings about many advantages (namely nice accessors, and less fragile code). The existing overloads that accept ArrayRef are now treated as deprecated and will be removed in a followup (after a small grace period). Most of the upstream MLIR usages have been fixed by this commit, the rest will be handled in a followup.

Differential Revision: https://reviews.llvm.org/D110293
2021-09-24 17:51:41 +00:00

93 lines
3.4 KiB
C++

//===- Bufferize.cpp - Bufferization for std ops --------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements bufferization of std ops.
//
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/Bufferize.h"
#include "PassDetail.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
namespace {
class BufferizeIndexCastOp : public OpConversionPattern<IndexCastOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(IndexCastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto tensorType = op.getType().cast<RankedTensorType>();
rewriter.replaceOpWithNewOp<IndexCastOp>(
op, adaptor.in(),
MemRefType::get(tensorType.getShape(), tensorType.getElementType()));
return success();
}
};
class BufferizeSelectOp : public OpConversionPattern<SelectOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(SelectOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!op.condition().getType().isa<IntegerType>())
return rewriter.notifyMatchFailure(op, "requires scalar condition");
rewriter.replaceOpWithNewOp<SelectOp>(
op, adaptor.condition(), adaptor.true_value(), adaptor.false_value());
return success();
}
};
} // namespace
void mlir::populateStdBufferizePatterns(BufferizeTypeConverter &typeConverter,
RewritePatternSet &patterns) {
patterns.add<BufferizeSelectOp, BufferizeIndexCastOp>(typeConverter,
patterns.getContext());
}
namespace {
struct StdBufferizePass : public StdBufferizeBase<StdBufferizePass> {
void runOnFunction() override {
auto *context = &getContext();
BufferizeTypeConverter typeConverter;
RewritePatternSet patterns(context);
ConversionTarget target(*context);
target.addLegalDialect<scf::SCFDialect, StandardOpsDialect,
memref::MemRefDialect>();
populateStdBufferizePatterns(typeConverter, patterns);
// We only bufferize the case of tensor selected type and scalar condition,
// as that boils down to a select over memref descriptors (don't need to
// touch the data).
target.addDynamicallyLegalOp<IndexCastOp>(
[&](IndexCastOp op) { return typeConverter.isLegal(op.getType()); });
target.addDynamicallyLegalOp<SelectOp>([&](SelectOp op) {
return typeConverter.isLegal(op.getType()) ||
!op.condition().getType().isa<IntegerType>();
});
if (failed(
applyPartialConversion(getFunction(), target, std::move(patterns))))
signalPassFailure();
}
};
} // namespace
std::unique_ptr<Pass> mlir::createStdBufferizePass() {
return std::make_unique<StdBufferizePass>();
}