Files
clang-p2996/mlir/lib/Dialect/MemRef/Transforms/ExpandRealloc.cpp
Martin Erhart 8037deb7af [mlir][memref] Add pass to expand realloc operations, simplify lowering to LLVM
There are two motivations for this change:
1. It considerably simplifies adding support for the realloc operation to the
   new buffer deallocation pass by lowering the realloc such that no
   deallocation operation is inserted and the deallocation pass itself can
   insert that dealloc
2. The lowering is expressed on a higher level and thus easier to understand,
   and the lowerings of the memref operations it is composed of don't have to
   be duplicated in the MemRefToLLVM lowering (also see discussion in
   https://reviews.llvm.org/D133424)

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D159430
2023-09-05 08:58:40 +00:00

176 lines
7.3 KiB
C++

//===- ExpandRealloc.cpp - Expand memref.realloc ops into it's components -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
namespace memref {
#define GEN_PASS_DEF_EXPANDREALLOC
#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
} // namespace memref
} // namespace mlir
using namespace mlir;
namespace {
/// The `realloc` operation performs a conditional allocation and copy to
/// increase the size of a buffer if necessary. This pattern converts the
/// `realloc` operation into this sequence of simpler operations.
/// Example of an expansion:
/// ```mlir
/// %realloc = memref.realloc %alloc (%size) : memref<?xf32> to memref<?xf32>
/// ```
/// is expanded to
/// ```mlir
/// %c0 = arith.constant 0 : index
/// %dim = memref.dim %alloc, %c0 : memref<?xf32>
/// %is_old_smaller = arith.cmpi ult, %dim, %arg1
/// %realloc = scf.if %is_old_smaller -> (memref<?xf32>) {
/// %new_alloc = memref.alloc(%size) : memref<?xf32>
/// %subview = memref.subview %new_alloc[0] [%dim] [1]
/// memref.copy %alloc, %subview
/// memref.dealloc %alloc
/// scf.yield %alloc_0 : memref<?xf32>
/// } else {
/// %reinterpret_cast = memref.reinterpret_cast %alloc to
/// offset: [0], sizes: [%size], strides: [1]
/// scf.yield %reinterpret_cast : memref<?xf32>
/// }
/// ```
struct ExpandReallocOpPattern : public OpRewritePattern<memref::ReallocOp> {
ExpandReallocOpPattern(MLIRContext *ctx, bool emitDeallocs)
: OpRewritePattern(ctx), emitDeallocs(emitDeallocs) {}
LogicalResult matchAndRewrite(memref::ReallocOp op,
PatternRewriter &rewriter) const final {
Location loc = op.getLoc();
assert(op.getType().getRank() == 1 &&
"result MemRef must have exactly one rank");
assert(op.getSource().getType().getRank() == 1 &&
"source MemRef must have exactly one rank");
assert(op.getType().getLayout().isIdentity() &&
"result MemRef must have identity layout (or none)");
assert(op.getSource().getType().getLayout().isIdentity() &&
"source MemRef must have identity layout (or none)");
// Get the size of the original buffer.
int64_t inputSize =
op.getSource().getType().cast<BaseMemRefType>().getDimSize(0);
OpFoldResult currSize = rewriter.getIndexAttr(inputSize);
if (ShapedType::isDynamic(inputSize)) {
Value dimZero = getValueOrCreateConstantIndexOp(rewriter, loc,
rewriter.getIndexAttr(0));
currSize = rewriter.create<memref::DimOp>(loc, op.getSource(), dimZero)
.getResult();
}
// Get the requested size that the new buffer should have.
int64_t outputSize =
op.getResult().getType().cast<BaseMemRefType>().getDimSize(0);
OpFoldResult targetSize = ShapedType::isDynamic(outputSize)
? OpFoldResult{op.getDynamicResultSize()}
: rewriter.getIndexAttr(outputSize);
// Only allocate a new buffer and copy over the values in the old buffer if
// the old buffer is smaller than the requested size.
Value lhs = getValueOrCreateConstantIndexOp(rewriter, loc, currSize);
Value rhs = getValueOrCreateConstantIndexOp(rewriter, loc, targetSize);
Value cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
lhs, rhs);
auto ifOp = rewriter.create<scf::IfOp>(
loc, cond,
[&](OpBuilder &builder, Location loc) {
// Allocate the new buffer. If it is a dynamic memref we need to pass
// an additional operand for the size at runtime, otherwise the static
// size is encoded in the result type.
SmallVector<Value> dynamicSizeOperands;
if (op.getDynamicResultSize())
dynamicSizeOperands.push_back(op.getDynamicResultSize());
Value newAlloc = builder.create<memref::AllocOp>(
loc, op.getResult().getType(), dynamicSizeOperands,
op.getAlignmentAttr());
// Take a subview of the new (bigger) buffer such that we can copy the
// old values over (the copy operation requires both operands to have
// the same shape).
Value subview = builder.create<memref::SubViewOp>(
loc, newAlloc, ArrayRef<OpFoldResult>{rewriter.getIndexAttr(0)},
ArrayRef<OpFoldResult>{currSize},
ArrayRef<OpFoldResult>{rewriter.getIndexAttr(1)});
builder.create<memref::CopyOp>(loc, op.getSource(), subview);
// Insert the deallocation of the old buffer only if requested
// (enabled by default).
if (emitDeallocs)
builder.create<memref::DeallocOp>(loc, op.getSource());
builder.create<scf::YieldOp>(loc, newAlloc);
},
[&](OpBuilder &builder, Location loc) {
// We need to reinterpret-cast here because either the input or output
// type might be static, which means we need to cast from static to
// dynamic or vice-versa. If both are static and the original buffer
// is already bigger than the requested size, the cast represents a
// subview operation.
Value casted = builder.create<memref::ReinterpretCastOp>(
loc, op.getResult().getType().cast<MemRefType>(), op.getSource(),
rewriter.getIndexAttr(0), ArrayRef<OpFoldResult>{targetSize},
ArrayRef<OpFoldResult>{rewriter.getIndexAttr(1)});
builder.create<scf::YieldOp>(loc, casted);
});
rewriter.replaceOp(op, ifOp.getResult(0));
return success();
}
private:
const bool emitDeallocs;
};
struct ExpandReallocPass
: public memref::impl::ExpandReallocBase<ExpandReallocPass> {
ExpandReallocPass(bool emitDeallocs)
: memref::impl::ExpandReallocBase<ExpandReallocPass>() {
this->emitDeallocs.setValue(emitDeallocs);
}
void runOnOperation() override {
MLIRContext &ctx = getContext();
RewritePatternSet patterns(&ctx);
memref::populateExpandReallocPatterns(patterns, emitDeallocs.getValue());
ConversionTarget target(ctx);
target.addLegalDialect<arith::ArithDialect, scf::SCFDialect,
memref::MemRefDialect>();
target.addIllegalOp<memref::ReallocOp>();
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
}
};
} // namespace
void mlir::memref::populateExpandReallocPatterns(RewritePatternSet &patterns,
bool emitDeallocs) {
patterns.add<ExpandReallocOpPattern>(patterns.getContext(), emitDeallocs);
}
std::unique_ptr<Pass> mlir::memref::createExpandReallocPass(bool emitDeallocs) {
return std::make_unique<ExpandReallocPass>(emitDeallocs);
}