This patch adds more precise side effects to the current ops with memory effects, allowing us to determine which OpOperand/OpResult/BlockArgument the operation reads or writes, rather than just recording the reading and writing of values. This allows for convenient use of precise side effects to achieve analysis and optimization. Related discussions: https://discourse.llvm.org/t/rfc-add-operandindex-to-sideeffect-instance/79243
1190 lines
45 KiB
C++
1190 lines
45 KiB
C++
//===----------------------------------------------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
|
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/IR/Matchers.h"
|
|
#include <optional>
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::bufferization;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Helper functions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
FailureOr<Value> mlir::bufferization::castOrReallocMemRefValue(
|
|
OpBuilder &b, Value value, MemRefType destType,
|
|
const BufferizationOptions &options) {
|
|
auto srcType = llvm::cast<MemRefType>(value.getType());
|
|
|
|
// Element type, rank and memory space must match.
|
|
if (srcType.getElementType() != destType.getElementType())
|
|
return failure();
|
|
if (srcType.getMemorySpace() != destType.getMemorySpace())
|
|
return failure();
|
|
if (srcType.getRank() != destType.getRank())
|
|
return failure();
|
|
|
|
// In case the affine maps are different, we may need to use a copy if we go
|
|
// from dynamic to static offset or stride (the canonicalization cannot know
|
|
// at this point that it is really cast compatible).
|
|
auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) {
|
|
int64_t sourceOffset, targetOffset;
|
|
SmallVector<int64_t, 4> sourceStrides, targetStrides;
|
|
if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) ||
|
|
failed(getStridesAndOffset(target, targetStrides, targetOffset)))
|
|
return false;
|
|
auto dynamicToStatic = [](int64_t a, int64_t b) {
|
|
return ShapedType::isDynamic(a) && !ShapedType::isDynamic(b);
|
|
};
|
|
if (dynamicToStatic(sourceOffset, targetOffset))
|
|
return false;
|
|
for (auto it : zip(sourceStrides, targetStrides))
|
|
if (dynamicToStatic(std::get<0>(it), std::get<1>(it)))
|
|
return false;
|
|
return true;
|
|
};
|
|
|
|
// Note: If `areCastCompatible`, a cast is valid, but may fail at runtime. To
|
|
// ensure that we only generate casts that always succeed at runtime, we check
|
|
// a fix extra conditions in `isGuaranteedCastCompatible`.
|
|
if (memref::CastOp::areCastCompatible(srcType, destType) &&
|
|
isGuaranteedCastCompatible(srcType, destType)) {
|
|
Value casted = b.create<memref::CastOp>(value.getLoc(), destType, value);
|
|
return casted;
|
|
}
|
|
|
|
auto loc = value.getLoc();
|
|
SmallVector<Value, 4> dynamicOperands;
|
|
for (int i = 0; i < destType.getRank(); ++i) {
|
|
if (destType.getShape()[i] != ShapedType::kDynamic)
|
|
continue;
|
|
Value size = b.create<memref::DimOp>(loc, value, i);
|
|
dynamicOperands.push_back(size);
|
|
}
|
|
|
|
FailureOr<Value> copy =
|
|
options.createAlloc(b, loc, destType, dynamicOperands);
|
|
if (failed(copy))
|
|
return failure();
|
|
if (failed(options.createMemCpy(b, loc, value, *copy)))
|
|
return failure();
|
|
return copy;
|
|
}
|
|
|
|
/// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the
|
|
/// to_memref op are different, a memref.cast is needed.
|
|
LogicalResult mlir::bufferization::foldToMemrefToTensorPair(
|
|
RewriterBase &rewriter, ToMemrefOp toMemref,
|
|
const BufferizationOptions &options) {
|
|
auto memrefToTensor = toMemref.getTensor().getDefiningOp<ToTensorOp>();
|
|
if (!memrefToTensor)
|
|
return failure();
|
|
|
|
Type srcType = memrefToTensor.getMemref().getType();
|
|
Type destType = toMemref.getType();
|
|
|
|
// Directly rewrite if the type did not change.
|
|
if (srcType == destType) {
|
|
rewriter.replaceOp(toMemref, memrefToTensor.getMemref());
|
|
return success();
|
|
}
|
|
|
|
auto rankedSrcType = llvm::dyn_cast<MemRefType>(srcType);
|
|
auto rankedDestType = llvm::dyn_cast<MemRefType>(destType);
|
|
auto unrankedSrcType = llvm::dyn_cast<UnrankedMemRefType>(srcType);
|
|
|
|
// Ranked memref -> Ranked memref cast.
|
|
if (rankedSrcType && rankedDestType) {
|
|
FailureOr<Value> replacement = castOrReallocMemRefValue(
|
|
rewriter, memrefToTensor.getMemref(), rankedDestType, options);
|
|
if (failed(replacement))
|
|
return failure();
|
|
|
|
rewriter.replaceOp(toMemref, *replacement);
|
|
return success();
|
|
}
|
|
|
|
// Unranked memref -> Ranked memref cast: May require a copy.
|
|
// TODO: Not implemented at the moment.
|
|
if (unrankedSrcType && rankedDestType)
|
|
return failure();
|
|
|
|
// Unranked memref -> unranked memref cast
|
|
// Ranked memref -> unranked memref cast: No copy needed.
|
|
assert(memref::CastOp::areCastCompatible(srcType, destType) &&
|
|
"expected that types are cast compatible");
|
|
rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, destType,
|
|
memrefToTensor.getMemref());
|
|
return success();
|
|
}
|
|
|
|
void mlir::bufferization::populateDynamicDimSizes(
|
|
OpBuilder &b, Location loc, Value shapedValue,
|
|
SmallVector<Value> &dynamicDims) {
|
|
auto shapedType = llvm::cast<ShapedType>(shapedValue.getType());
|
|
for (int64_t i = 0; i < shapedType.getRank(); ++i) {
|
|
if (shapedType.isDynamicDim(i)) {
|
|
if (llvm::isa<MemRefType>(shapedType)) {
|
|
dynamicDims.push_back(b.create<memref::DimOp>(loc, shapedValue, i));
|
|
} else {
|
|
assert(llvm::isa<RankedTensorType>(shapedType) && "expected tensor");
|
|
dynamicDims.push_back(b.create<tensor::DimOp>(loc, shapedValue, i));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// AllocTensorOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
|
|
const BufferizationOptions &options) {
|
|
OpBuilder::InsertionGuard g(rewriter);
|
|
Location loc = getLoc();
|
|
|
|
// Nothing to do for dead AllocTensorOps.
|
|
if (getOperation()->getUses().empty()) {
|
|
rewriter.eraseOp(getOperation());
|
|
return success();
|
|
}
|
|
|
|
// Get "copy" buffer.
|
|
Value copyBuffer;
|
|
if (getCopy()) {
|
|
FailureOr<Value> maybeCopyBuffer = getBuffer(rewriter, getCopy(), options);
|
|
if (failed(maybeCopyBuffer))
|
|
return failure();
|
|
copyBuffer = *maybeCopyBuffer;
|
|
}
|
|
|
|
// Create memory allocation.
|
|
auto allocType = bufferization::getBufferType(getResult(), options);
|
|
if (failed(allocType))
|
|
return failure();
|
|
SmallVector<Value> dynamicDims = getDynamicSizes();
|
|
if (getCopy()) {
|
|
assert(dynamicDims.empty() && "expected either `copy` or `dynamicDims`");
|
|
populateDynamicDimSizes(rewriter, loc, copyBuffer, dynamicDims);
|
|
}
|
|
FailureOr<Value> alloc = options.createAlloc(
|
|
rewriter, loc, llvm::cast<MemRefType>(*allocType), dynamicDims);
|
|
if (failed(alloc))
|
|
return failure();
|
|
|
|
// Create memory copy (if any).
|
|
if (getCopy()) {
|
|
if (failed(options.createMemCpy(rewriter, loc, copyBuffer, *alloc)))
|
|
return failure();
|
|
}
|
|
|
|
// Replace op.
|
|
replaceOpWithBufferizedValues(rewriter, getOperation(), *alloc);
|
|
|
|
return success();
|
|
}
|
|
|
|
bool AllocTensorOp::resultBufferizesToMemoryWrite(OpResult opResult,
|
|
const AnalysisState &state) {
|
|
// AllocTensorOps do not write unless they have a `copy` value.
|
|
return static_cast<bool>(getCopy());
|
|
}
|
|
|
|
bool AllocTensorOp::bufferizesToMemoryRead(OpOperand &opOperand,
|
|
const AnalysisState &state) {
|
|
assert(opOperand.getOperandNumber() == getNumOperands() - 1 &&
|
|
"expected copy operand");
|
|
return true;
|
|
}
|
|
|
|
bool AllocTensorOp::bufferizesToMemoryWrite(OpOperand &opOperand,
|
|
const AnalysisState &state) {
|
|
assert(opOperand.getOperandNumber() == getNumOperands() - 1 &&
|
|
"expected copy operand");
|
|
return false;
|
|
}
|
|
|
|
AliasingValueList AllocTensorOp::getAliasingValues(OpOperand &opOperand,
|
|
const AnalysisState &state) {
|
|
// This is a new allocation. It does not alias with any other buffer.
|
|
return {};
|
|
}
|
|
|
|
FailureOr<BaseMemRefType>
|
|
AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
|
|
SmallVector<Value> &invocationStack) {
|
|
assert(value == getResult() && "invalid value");
|
|
|
|
// Compute memory space of this allocation.
|
|
Attribute memorySpace;
|
|
if (getMemorySpace().has_value()) {
|
|
memorySpace = *getMemorySpace();
|
|
} else if (getCopy()) {
|
|
auto copyBufferType =
|
|
bufferization::getBufferType(getCopy(), options, invocationStack);
|
|
if (failed(copyBufferType))
|
|
return failure();
|
|
memorySpace = copyBufferType->getMemorySpace();
|
|
} else if (auto ms = options.defaultMemorySpaceFn(getType())) {
|
|
memorySpace = *ms;
|
|
} else {
|
|
return getOperation()->emitError("could not infer memory space");
|
|
}
|
|
|
|
return getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace);
|
|
}
|
|
|
|
LogicalResult AllocTensorOp::verify() {
|
|
if (getCopy() && !getDynamicSizes().empty())
|
|
return emitError("dynamic sizes not needed when copying a tensor");
|
|
if (!getCopy() && getType().getNumDynamicDims() !=
|
|
static_cast<int64_t>(getDynamicSizes().size()))
|
|
return emitError("expected ")
|
|
<< getType().getNumDynamicDims() << " dynamic sizes";
|
|
if (getCopy() && getCopy().getType() != getType())
|
|
return emitError("expected that `copy` and return type match");
|
|
return success();
|
|
}
|
|
|
|
void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
|
|
RankedTensorType type, ValueRange dynamicSizes) {
|
|
build(builder, result, type, dynamicSizes, /*copy=*/Value(),
|
|
/*size_hint=*/Value(),
|
|
/*memory_space=*/IntegerAttr());
|
|
}
|
|
|
|
void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
|
|
RankedTensorType type, ValueRange dynamicSizes,
|
|
Value copy) {
|
|
build(builder, result, type, dynamicSizes, copy, /*size_hint=*/Value(),
|
|
/*memory_space=*/IntegerAttr());
|
|
}
|
|
|
|
void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
|
|
TensorType type, ValueRange dynamicSizes, Value copy,
|
|
IntegerAttr memorySpace) {
|
|
build(builder, result, type, dynamicSizes, copy, /*size_hint=*/Value(),
|
|
memorySpace);
|
|
}
|
|
|
|
namespace {
|
|
/// Change the type of the result of a `bufferization.alloc_tensor` by making
|
|
/// the result type statically sized along dimension that in the original
|
|
/// operation where defined as dynamic, but the size was defined using a
|
|
/// `constant` op. For example:
|
|
///
|
|
/// %c5 = arith.constant 5: index
|
|
/// %0 = bufferization.alloc_tensor(%arg0, %c5) : tensor<?x?xf32>
|
|
///
|
|
/// to
|
|
///
|
|
/// %0 = bufferization.alloc_tensor(%arg0) : tensor<?x5xf32>
|
|
struct ReplaceStaticShapeDims : OpRewritePattern<AllocTensorOp> {
|
|
using OpRewritePattern<AllocTensorOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AllocTensorOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
if (op.getCopy())
|
|
return failure();
|
|
SmallVector<int64_t> newShape = llvm::to_vector(op.getType().getShape());
|
|
SmallVector<Value> newDynamicSizes;
|
|
unsigned int dynValCounter = 0;
|
|
for (int64_t i = 0; i < op.getType().getRank(); ++i) {
|
|
if (!op.isDynamicDim(i))
|
|
continue;
|
|
Value value = op.getDynamicSizes()[dynValCounter++];
|
|
APInt intVal;
|
|
if (matchPattern(value, m_ConstantInt(&intVal))) {
|
|
int64_t dim = intVal.getSExtValue();
|
|
if (dim >= 0)
|
|
newShape[i] = intVal.getSExtValue();
|
|
else
|
|
newDynamicSizes.push_back(value);
|
|
} else {
|
|
newDynamicSizes.push_back(value);
|
|
}
|
|
}
|
|
RankedTensorType newType = RankedTensorType::get(
|
|
newShape, op.getType().getElementType(), op.getType().getEncoding());
|
|
if (newType == op.getType())
|
|
return failure();
|
|
auto newOp = rewriter.create<AllocTensorOp>(
|
|
op.getLoc(), newType, newDynamicSizes, /*copy=*/Value());
|
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct FoldDimOfAllocTensorOp : public OpRewritePattern<tensor::DimOp> {
|
|
using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tensor::DimOp dimOp,
|
|
PatternRewriter &rewriter) const override {
|
|
std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
|
|
auto allocTensorOp = dimOp.getSource().getDefiningOp<AllocTensorOp>();
|
|
if (!allocTensorOp || !maybeConstantIndex)
|
|
return failure();
|
|
if (*maybeConstantIndex < 0 ||
|
|
*maybeConstantIndex >= allocTensorOp.getType().getRank())
|
|
return failure();
|
|
if (!allocTensorOp.getType().isDynamicDim(*maybeConstantIndex))
|
|
return failure();
|
|
rewriter.replaceOp(
|
|
dimOp, allocTensorOp.getDynamicSize(rewriter, *maybeConstantIndex));
|
|
return success();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
void AllocTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|
MLIRContext *ctx) {
|
|
results.add<FoldDimOfAllocTensorOp, ReplaceStaticShapeDims>(ctx);
|
|
}
|
|
|
|
LogicalResult AllocTensorOp::reifyResultShapes(
|
|
OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
|
|
auto shapes = llvm::to_vector<4>(
|
|
llvm::map_range(llvm::seq<int64_t>(0, getType().getRank()),
|
|
[&](int64_t dim) -> OpFoldResult {
|
|
if (isDynamicDim(dim))
|
|
return getDynamicSize(builder, dim);
|
|
return builder.getIndexAttr(getStaticSize(dim));
|
|
}));
|
|
reifiedReturnShapes.emplace_back(std::move(shapes));
|
|
return success();
|
|
}
|
|
|
|
ParseResult AllocTensorOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizesOperands;
|
|
if (parser.parseLParen() || parser.parseOperandList(dynamicSizesOperands) ||
|
|
parser.parseRParen())
|
|
return failure();
|
|
ParseResult copyKeyword = parser.parseOptionalKeyword("copy");
|
|
OpAsmParser::UnresolvedOperand copyOperand;
|
|
if (copyKeyword.succeeded())
|
|
if (parser.parseLParen() || parser.parseOperand(copyOperand) ||
|
|
parser.parseRParen())
|
|
return failure();
|
|
ParseResult sizeHintKeyword = parser.parseOptionalKeyword("size_hint");
|
|
OpAsmParser::UnresolvedOperand sizeHintOperand;
|
|
if (sizeHintKeyword.succeeded())
|
|
if (parser.parseEqual() || parser.parseOperand(sizeHintOperand))
|
|
return failure();
|
|
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon())
|
|
return failure();
|
|
|
|
TensorType type;
|
|
if (parser.parseCustomTypeWithFallback(type))
|
|
return failure();
|
|
result.addTypes(type);
|
|
|
|
Type indexType = parser.getBuilder().getIndexType();
|
|
if (parser.resolveOperands(dynamicSizesOperands, indexType, result.operands))
|
|
return failure();
|
|
if (copyKeyword.succeeded())
|
|
if (parser.resolveOperand(copyOperand, type, result.operands))
|
|
return failure();
|
|
if (sizeHintKeyword.succeeded())
|
|
if (parser.resolveOperand(sizeHintOperand, indexType, result.operands))
|
|
return failure();
|
|
result.addAttribute(AllocTensorOp::getOperandSegmentSizeAttr(),
|
|
parser.getBuilder().getDenseI32ArrayAttr(
|
|
{static_cast<int32_t>(dynamicSizesOperands.size()),
|
|
static_cast<int32_t>(copyKeyword.succeeded()),
|
|
static_cast<int32_t>(sizeHintKeyword.succeeded())}));
|
|
return success();
|
|
}
|
|
|
|
void AllocTensorOp::print(OpAsmPrinter &p) {
|
|
p << "(" << getDynamicSizes() << ")";
|
|
if (getCopy())
|
|
p << " copy(" << getCopy() << ")";
|
|
if (getSizeHint())
|
|
p << " size_hint=" << getSizeHint();
|
|
p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
|
|
AllocTensorOp::getOperandSegmentSizeAttr()});
|
|
p << " : ";
|
|
auto type = getResult().getType();
|
|
if (auto validType = llvm::dyn_cast<::mlir::TensorType>(type))
|
|
p.printStrippedAttrOrType(validType);
|
|
else
|
|
p << type;
|
|
}
|
|
|
|
Value AllocTensorOp::getDynamicSize(OpBuilder &b, unsigned idx) {
|
|
assert(isDynamicDim(idx) && "expected dynamic dim");
|
|
if (getCopy())
|
|
return b.create<tensor::DimOp>(getLoc(), getCopy(), idx);
|
|
return getOperand(getIndexOfDynamicSize(idx));
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// CloneOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult CloneOp::fold(FoldAdaptor adaptor) {
|
|
return succeeded(memref::foldMemRefCast(*this)) ? getResult() : Value();
|
|
}
|
|
|
|
namespace {
|
|
|
|
/// Merge the clone and its source (by converting the clone to a cast) when
|
|
/// possible.
|
|
struct SimplifyClones : public OpRewritePattern<CloneOp> {
|
|
using OpRewritePattern<CloneOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(CloneOp cloneOp,
|
|
PatternRewriter &rewriter) const override {
|
|
if (cloneOp.use_empty()) {
|
|
rewriter.eraseOp(cloneOp);
|
|
return success();
|
|
}
|
|
|
|
Value source = cloneOp.getInput();
|
|
if (source.getType() != cloneOp.getType() &&
|
|
!memref::CastOp::areCastCompatible({source.getType()},
|
|
{cloneOp.getType()}))
|
|
return failure();
|
|
|
|
// Aims to find the dealloc op for the canonical source
|
|
// which otherwise could prevent removal of unnecessary allocs.
|
|
Value canonicalSource = source;
|
|
while (auto iface = dyn_cast_or_null<ViewLikeOpInterface>(
|
|
canonicalSource.getDefiningOp()))
|
|
canonicalSource = iface.getViewSource();
|
|
|
|
std::optional<Operation *> maybeCloneDeallocOp =
|
|
memref::findDealloc(cloneOp.getOutput());
|
|
// Skip if either of them has > 1 deallocate operations.
|
|
if (!maybeCloneDeallocOp.has_value())
|
|
return failure();
|
|
std::optional<Operation *> maybeSourceDeallocOp =
|
|
memref::findDealloc(canonicalSource);
|
|
if (!maybeSourceDeallocOp.has_value())
|
|
return failure();
|
|
Operation *cloneDeallocOp = *maybeCloneDeallocOp;
|
|
Operation *sourceDeallocOp = *maybeSourceDeallocOp;
|
|
|
|
// If both are deallocated in the same block, their in-block lifetimes
|
|
// might not fully overlap, so we cannot decide which one to drop.
|
|
if (cloneDeallocOp && sourceDeallocOp &&
|
|
cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock())
|
|
return failure();
|
|
|
|
Block *currentBlock = cloneOp->getBlock();
|
|
Operation *redundantDealloc = nullptr;
|
|
if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) {
|
|
redundantDealloc = cloneDeallocOp;
|
|
} else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) {
|
|
redundantDealloc = sourceDeallocOp;
|
|
}
|
|
|
|
if (!redundantDealloc)
|
|
return failure();
|
|
|
|
// Safety check that there are no other deallocations inbetween
|
|
// cloneOp and redundantDealloc, as otherwise we might deallocate an alias
|
|
// of source before the uses of the clone. With alias information, we could
|
|
// restrict this to only fail of the dealloc's operand is an alias
|
|
// of the source.
|
|
for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc;
|
|
pos = pos->getNextNode()) {
|
|
// Bail if we run out of operations while looking for a deallocation op.
|
|
if (!pos)
|
|
return failure();
|
|
auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos);
|
|
if (!effectInterface)
|
|
continue;
|
|
if (effectInterface.hasEffect<MemoryEffects::Free>())
|
|
return failure();
|
|
}
|
|
|
|
if (source.getType() != cloneOp.getType())
|
|
source = rewriter.create<memref::CastOp>(cloneOp.getLoc(),
|
|
cloneOp.getType(), source);
|
|
rewriter.replaceOp(cloneOp, source);
|
|
rewriter.eraseOp(redundantDealloc);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|
MLIRContext *context) {
|
|
results.add<SimplifyClones>(context);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// DeallocTensorOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter,
|
|
const BufferizationOptions &options) {
|
|
FailureOr<Value> buffer = getBuffer(rewriter, getTensor(), options);
|
|
if (failed(buffer))
|
|
return failure();
|
|
rewriter.create<memref::DeallocOp>(getLoc(), *buffer);
|
|
rewriter.eraseOp(getOperation());
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MaterializeInDestinationOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
bool MaterializeInDestinationOp::bufferizesToMemoryRead(
|
|
OpOperand &opOperand, const AnalysisState &state) {
|
|
return opOperand == getSourceMutable();
|
|
}
|
|
|
|
bool MaterializeInDestinationOp::bufferizesToMemoryWrite(
|
|
OpOperand &opOperand, const AnalysisState &state) {
|
|
if (opOperand == getDestMutable()) {
|
|
assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool MaterializeInDestinationOp::mustBufferizeInPlace(
|
|
OpOperand &opOperand, const AnalysisState &state) {
|
|
// The source is only read and not written, so it always bufferizes in-place
|
|
// by default. The destination is written and is forced to bufferize in-place
|
|
// (if it is a tensor).
|
|
return true;
|
|
}
|
|
|
|
AliasingValueList
|
|
MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand,
|
|
const AnalysisState &state) {
|
|
if (opOperand == getDestMutable()) {
|
|
assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
|
|
return {{getOperation()->getResult(0), BufferRelation::Equivalent}};
|
|
}
|
|
return {};
|
|
}
|
|
|
|
LogicalResult
|
|
MaterializeInDestinationOp::bufferize(RewriterBase &rewriter,
|
|
const BufferizationOptions &options) {
|
|
bool tensorDest = isa<TensorType>(getDest().getType());
|
|
Value buffer;
|
|
if (tensorDest) {
|
|
FailureOr<Value> maybeBuffer = getBuffer(rewriter, getDest(), options);
|
|
if (failed(maybeBuffer))
|
|
return failure();
|
|
buffer = *maybeBuffer;
|
|
} else {
|
|
assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type");
|
|
buffer = getDest();
|
|
}
|
|
auto srcBuffer = getBuffer(rewriter, getSource(), options);
|
|
if (failed(srcBuffer))
|
|
return failure();
|
|
if (failed(options.createMemCpy(rewriter, getLoc(), *srcBuffer, buffer)))
|
|
return failure();
|
|
replaceOpWithBufferizedValues(rewriter, getOperation(),
|
|
tensorDest ? ValueRange(buffer) : ValueRange());
|
|
return success();
|
|
}
|
|
|
|
bool MaterializeInDestinationOp::bufferizesToElementwiseAccess(
|
|
const AnalysisState &state, ArrayRef<OpOperand *> opOperands) {
|
|
// As elements are copied from the "source" buffer to the "dest" buffer,
|
|
// already copied elements are not read a second time.
|
|
return true;
|
|
}
|
|
|
|
LogicalResult MaterializeInDestinationOp::reifyResultShapes(
|
|
OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
|
|
if (getOperation()->getNumResults() == 1) {
|
|
assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
|
|
reifiedReturnShapes.resize(1,
|
|
SmallVector<OpFoldResult>(getType().getRank()));
|
|
reifiedReturnShapes[0] =
|
|
tensor::getMixedSizes(builder, getLoc(), getDest());
|
|
}
|
|
return success();
|
|
}
|
|
|
|
Value MaterializeInDestinationOp::buildSubsetExtraction(OpBuilder &builder,
|
|
Location loc) {
|
|
if (isa<TensorType>(getDest().getType())) {
|
|
// The subset is the entire destination tensor.
|
|
return getDest();
|
|
}
|
|
|
|
// The "restrict" attribute is transferred from this op to the newly created
|
|
// to_tensor op. If this op does not the "restrict" attribute, the subset
|
|
// extraction cannot be built because there is no guarantee that there is no
|
|
// pre-existing "restrict" to_tensor op with the same/an aliasing destination.
|
|
if (!getRestrict())
|
|
return {};
|
|
|
|
// Build a bufferization.to_tensor op.
|
|
assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type");
|
|
assert(getRestrict() &&
|
|
"expected that ops with memrefs dest have 'restrict'");
|
|
setRestrict(false);
|
|
return builder.create<ToTensorOp>(loc, getDest(), /*restrict=*/true,
|
|
getWritable());
|
|
}
|
|
|
|
bool MaterializeInDestinationOp::isEquivalentSubset(
|
|
Value candidate, function_ref<bool(Value, Value)> equivalenceFn) {
|
|
return equivalenceFn(getDest(), candidate);
|
|
}
|
|
|
|
SmallVector<Value>
|
|
MaterializeInDestinationOp::getValuesNeededToBuildSubsetExtraction() {
|
|
return {getDest()};
|
|
}
|
|
|
|
OpOperand &MaterializeInDestinationOp::getSourceOperand() {
|
|
return getOperation()->getOpOperand(0) /*source*/;
|
|
}
|
|
|
|
bool MaterializeInDestinationOp::operatesOnEquivalentSubset(
|
|
SubsetOpInterface subsetOp,
|
|
function_ref<bool(Value, Value)> equivalenceFn) {
|
|
return false;
|
|
}
|
|
|
|
bool MaterializeInDestinationOp::operatesOnDisjointSubset(
|
|
SubsetOpInterface subsetOp,
|
|
function_ref<bool(Value, Value)> equivalenceFn) {
|
|
return false;
|
|
}
|
|
|
|
LogicalResult MaterializeInDestinationOp::verify() {
|
|
if (!isa<TensorType, BaseMemRefType>(getDest().getType()))
|
|
return emitOpError("'dest' must be a tensor or a memref");
|
|
if (auto destType = dyn_cast<TensorType>(getDest().getType())) {
|
|
if (getOperation()->getNumResults() != 1)
|
|
return emitOpError("tensor 'dest' implies exactly one tensor result");
|
|
if (destType != getResult().getType())
|
|
return emitOpError("result and 'dest' types must match");
|
|
}
|
|
if (isa<BaseMemRefType>(getDest().getType()) &&
|
|
getOperation()->getNumResults() != 0)
|
|
return emitOpError("memref 'dest' implies zero results");
|
|
if (getRestrict() && !isa<BaseMemRefType>(getDest().getType()))
|
|
return emitOpError("'restrict' is valid only for memref destinations");
|
|
if (getWritable() != isa<BaseMemRefType>(getDest().getType()))
|
|
return emitOpError("'writable' must be specified if and only if the "
|
|
"destination is of memref type");
|
|
TensorType srcType = getSource().getType();
|
|
ShapedType destType = cast<ShapedType>(getDest().getType());
|
|
if (srcType.hasRank() != destType.hasRank())
|
|
return emitOpError("source/destination shapes are incompatible");
|
|
if (srcType.hasRank()) {
|
|
if (srcType.getRank() != destType.getRank())
|
|
return emitOpError("rank mismatch between source and destination shape");
|
|
for (auto [src, dest] :
|
|
llvm::zip(srcType.getShape(), destType.getShape())) {
|
|
if (src == ShapedType::kDynamic || dest == ShapedType::kDynamic) {
|
|
// Cannot verify dynamic dimension size. Assume that that they match at
|
|
// runtime.
|
|
continue;
|
|
}
|
|
if (src != dest)
|
|
return emitOpError("source/destination shapes are incompatible");
|
|
}
|
|
}
|
|
return success();
|
|
}
|
|
|
|
void MaterializeInDestinationOp::build(OpBuilder &builder,
|
|
OperationState &state, Value source,
|
|
Value dest) {
|
|
auto destTensorType = dyn_cast<TensorType>(dest.getType());
|
|
build(builder, state, /*result=*/destTensorType ? destTensorType : Type(),
|
|
source, dest);
|
|
}
|
|
|
|
bool MaterializeInDestinationOp::isWritable(Value value,
|
|
const AnalysisState &state) {
|
|
return isa<TensorType>(getDest().getType()) ? true : getWritable();
|
|
}
|
|
|
|
MutableOperandRange MaterializeInDestinationOp::getDpsInitsMutable() {
|
|
return getDestMutable();
|
|
}
|
|
|
|
void MaterializeInDestinationOp::getEffects(
|
|
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
|
&effects) {
|
|
if (isa<BaseMemRefType>(getDest().getType()))
|
|
effects.emplace_back(MemoryEffects::Write::get(), &getDestMutable(),
|
|
SideEffects::DefaultResource::get());
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ToTensorOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
bool ToTensorOp::isWritable(Value value, const AnalysisState &state) {
|
|
return getWritable();
|
|
}
|
|
|
|
OpFoldResult ToTensorOp::fold(FoldAdaptor) {
|
|
if (auto toMemref = getMemref().getDefiningOp<ToMemrefOp>())
|
|
// Approximate alias analysis by conservatively folding only when no there
|
|
// is no interleaved operation.
|
|
if (toMemref->getBlock() == this->getOperation()->getBlock() &&
|
|
toMemref->getNextNode() == this->getOperation())
|
|
return toMemref.getTensor();
|
|
return {};
|
|
}
|
|
|
|
namespace {
|
|
struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> {
|
|
using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tensor::DimOp dimOp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto memrefToTensorOp = dimOp.getSource().getDefiningOp<ToTensorOp>();
|
|
if (!memrefToTensorOp)
|
|
return failure();
|
|
|
|
rewriter.replaceOpWithNewOp<memref::DimOp>(
|
|
dimOp, memrefToTensorOp.getMemref(), dimOp.getIndex());
|
|
return success();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|
MLIRContext *context) {
|
|
results.add<DimOfToTensorFolder>(context);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ToMemrefOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult ToMemrefOp::fold(FoldAdaptor) {
|
|
if (auto memrefToTensor = getTensor().getDefiningOp<ToTensorOp>())
|
|
if (memrefToTensor.getMemref().getType() == getType())
|
|
return memrefToTensor.getMemref();
|
|
return {};
|
|
}
|
|
|
|
namespace {
|
|
|
|
/// Replace tensor.cast + to_memref by to_memref + memref.cast.
|
|
struct ToMemrefOfCast : public OpRewritePattern<ToMemrefOp> {
|
|
using OpRewritePattern<ToMemrefOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(ToMemrefOp toMemref,
|
|
PatternRewriter &rewriter) const final {
|
|
auto tensorCastOperand =
|
|
toMemref.getOperand().getDefiningOp<tensor::CastOp>();
|
|
if (!tensorCastOperand)
|
|
return failure();
|
|
auto srcTensorType = llvm::dyn_cast<RankedTensorType>(
|
|
tensorCastOperand.getOperand().getType());
|
|
if (!srcTensorType)
|
|
return failure();
|
|
auto memrefType = MemRefType::get(srcTensorType.getShape(),
|
|
srcTensorType.getElementType());
|
|
Value memref = rewriter.create<ToMemrefOp>(toMemref.getLoc(), memrefType,
|
|
tensorCastOperand.getOperand());
|
|
rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, toMemref.getType(),
|
|
memref);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Canonicalize bufferization.to_tensor + bufferization.to_memref. Insert a
|
|
/// cast if necessary.
|
|
struct ToMemrefToTensorFolding : public OpRewritePattern<ToMemrefOp> {
|
|
using OpRewritePattern<ToMemrefOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(ToMemrefOp toMemref,
|
|
PatternRewriter &rewriter) const final {
|
|
BufferizationOptions options;
|
|
options.bufferAlignment = 0;
|
|
return foldToMemrefToTensorPair(rewriter, toMemref, options);
|
|
}
|
|
};
|
|
|
|
/// Fold a load on a to_memref operation into an tensor.extract on the
|
|
/// corresponding tensor.
|
|
struct LoadOfToMemref : public OpRewritePattern<memref::LoadOp> {
|
|
using OpRewritePattern<memref::LoadOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(memref::LoadOp load,
|
|
PatternRewriter &rewriter) const override {
|
|
auto toMemref = load.getMemref().getDefiningOp<ToMemrefOp>();
|
|
if (!toMemref)
|
|
return failure();
|
|
|
|
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, toMemref.getTensor(),
|
|
load.getIndices());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Fold dim of a to_memref into the dim of the tensor.
|
|
struct DimOfCastOp : public OpRewritePattern<memref::DimOp> {
|
|
using OpRewritePattern<memref::DimOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(memref::DimOp dimOp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto castOp = dimOp.getSource().getDefiningOp<ToMemrefOp>();
|
|
if (!castOp)
|
|
return failure();
|
|
Value newSource = castOp.getOperand();
|
|
rewriter.replaceOpWithNewOp<tensor::DimOp>(dimOp, newSource,
|
|
dimOp.getIndex());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|
MLIRContext *context) {
|
|
results.add<DimOfCastOp, LoadOfToMemref, ToMemrefOfCast,
|
|
ToMemrefToTensorFolding>(context);
|
|
}
|
|
|
|
LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter,
|
|
const BufferizationOptions &options) {
|
|
// Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary.
|
|
(void)foldToMemrefToTensorPair(rewriter, *this, options);
|
|
// Note: The return value of `bufferize` indicates whether there was an error
|
|
// or not. (And not whether the pattern matched or not.)
|
|
return success();
|
|
}
|
|
|
|
std::optional<Operation *> CloneOp::buildDealloc(OpBuilder &builder,
|
|
Value alloc) {
|
|
return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc)
|
|
.getOperation();
|
|
}
|
|
|
|
std::optional<Value> CloneOp::buildClone(OpBuilder &builder, Value alloc) {
|
|
return builder.create<CloneOp>(alloc.getLoc(), alloc).getResult();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// DeallocOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult DeallocOp::inferReturnTypes(
|
|
MLIRContext *context, std::optional<::mlir::Location> location,
|
|
ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties,
|
|
RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) {
|
|
DeallocOpAdaptor adaptor(operands, attributes, properties, regions);
|
|
inferredReturnTypes = SmallVector<Type>(adaptor.getRetained().size(),
|
|
IntegerType::get(context, 1));
|
|
return success();
|
|
}
|
|
|
|
LogicalResult DeallocOp::verify() {
|
|
if (getMemrefs().size() != getConditions().size())
|
|
return emitOpError(
|
|
"must have the same number of conditions as memrefs to deallocate");
|
|
if (getRetained().size() != getUpdatedConditions().size())
|
|
return emitOpError("must have the same number of updated conditions "
|
|
"(results) as retained operands");
|
|
return success();
|
|
}
|
|
|
|
static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp,
|
|
ValueRange memrefs,
|
|
ValueRange conditions,
|
|
PatternRewriter &rewriter) {
|
|
if (deallocOp.getMemrefs() == memrefs &&
|
|
deallocOp.getConditions() == conditions)
|
|
return failure();
|
|
|
|
rewriter.modifyOpInPlace(deallocOp, [&]() {
|
|
deallocOp.getMemrefsMutable().assign(memrefs);
|
|
deallocOp.getConditionsMutable().assign(conditions);
|
|
});
|
|
return success();
|
|
}
|
|
|
|
namespace {
|
|
|
|
/// Remove duplicate values in the list of memrefs to be deallocated. We need to
|
|
/// make sure the corresponding condition value is updated accordingly since
|
|
/// their two conditions might not cover the same set of cases. In that case, we
|
|
/// have to combine them (by computing the disjunction of them).
|
|
/// Example:
|
|
/// ```mlir
|
|
/// bufferization.dealloc (%arg0, %arg0 : ...) if (%arg1, %arg2)
|
|
/// ```
|
|
/// is canonicalized to
|
|
/// ```mlir
|
|
/// %0 = arith.ori %arg1, %arg2 : i1
|
|
/// bufferization.dealloc (%arg0 : memref<2xi32>) if (%0)
|
|
/// ```
|
|
struct DeallocRemoveDuplicateDeallocMemrefs
|
|
: public OpRewritePattern<DeallocOp> {
|
|
using OpRewritePattern<DeallocOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(DeallocOp deallocOp,
|
|
PatternRewriter &rewriter) const override {
|
|
// Unique memrefs to be deallocated.
|
|
DenseMap<Value, unsigned> memrefToCondition;
|
|
SmallVector<Value> newMemrefs, newConditions;
|
|
for (auto [i, memref, cond] :
|
|
llvm::enumerate(deallocOp.getMemrefs(), deallocOp.getConditions())) {
|
|
if (memrefToCondition.count(memref)) {
|
|
// If the dealloc conditions don't match, we need to make sure that the
|
|
// dealloc happens on the union of cases.
|
|
Value &newCond = newConditions[memrefToCondition[memref]];
|
|
if (newCond != cond)
|
|
newCond =
|
|
rewriter.create<arith::OrIOp>(deallocOp.getLoc(), newCond, cond);
|
|
} else {
|
|
memrefToCondition.insert({memref, newConditions.size()});
|
|
newMemrefs.push_back(memref);
|
|
newConditions.push_back(cond);
|
|
}
|
|
}
|
|
|
|
// Return failure if we don't change anything such that we don't run into an
|
|
// infinite loop of pattern applications.
|
|
return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
|
|
rewriter);
|
|
}
|
|
};
|
|
|
|
/// Remove duplicate values in the list of retained memrefs. We need to make
|
|
/// sure the corresponding result condition value is replaced properly.
|
|
/// Example:
|
|
/// ```mlir
|
|
/// %0:2 = bufferization.dealloc retain (%arg3, %arg3 : ...)
|
|
/// ```
|
|
/// is canonicalized to
|
|
/// ```mlir
|
|
/// %0 = bufferization.dealloc retain (%arg3 : memref<2xi32>)
|
|
/// ```
|
|
struct DeallocRemoveDuplicateRetainedMemrefs
|
|
: public OpRewritePattern<DeallocOp> {
|
|
using OpRewritePattern<DeallocOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(DeallocOp deallocOp,
|
|
PatternRewriter &rewriter) const override {
|
|
// Unique retained values
|
|
DenseMap<Value, unsigned> seen;
|
|
SmallVector<Value> newRetained;
|
|
SmallVector<unsigned> resultReplacementIdx;
|
|
unsigned i = 0;
|
|
for (auto retained : deallocOp.getRetained()) {
|
|
if (seen.count(retained)) {
|
|
resultReplacementIdx.push_back(seen[retained]);
|
|
continue;
|
|
}
|
|
|
|
seen[retained] = i;
|
|
newRetained.push_back(retained);
|
|
resultReplacementIdx.push_back(i++);
|
|
}
|
|
|
|
// Return failure if we don't change anything such that we don't run into an
|
|
// infinite loop of pattern applications.
|
|
if (newRetained.size() == deallocOp.getRetained().size())
|
|
return failure();
|
|
|
|
// We need to create a new op because the number of results is always the
|
|
// same as the number of condition operands.
|
|
auto newDeallocOp =
|
|
rewriter.create<DeallocOp>(deallocOp.getLoc(), deallocOp.getMemrefs(),
|
|
deallocOp.getConditions(), newRetained);
|
|
SmallVector<Value> replacements(
|
|
llvm::map_range(resultReplacementIdx, [&](unsigned idx) {
|
|
return newDeallocOp.getUpdatedConditions()[idx];
|
|
}));
|
|
rewriter.replaceOp(deallocOp, replacements);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Erase deallocation operations where the variadic list of memrefs to
|
|
/// deallocate is empty. Example:
|
|
/// ```mlir
|
|
/// %0 = bufferization.dealloc retain (%arg0: memref<2xi32>)
|
|
/// ```
|
|
struct EraseEmptyDealloc : public OpRewritePattern<DeallocOp> {
|
|
using OpRewritePattern<DeallocOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(DeallocOp deallocOp,
|
|
PatternRewriter &rewriter) const override {
|
|
if (deallocOp.getMemrefs().empty()) {
|
|
Value constFalse = rewriter.create<arith::ConstantOp>(
|
|
deallocOp.getLoc(), rewriter.getBoolAttr(false));
|
|
rewriter.replaceOp(
|
|
deallocOp, SmallVector<Value>(deallocOp.getUpdatedConditions().size(),
|
|
constFalse));
|
|
return success();
|
|
}
|
|
return failure();
|
|
}
|
|
};
|
|
|
|
/// Removes memrefs from the deallocation list if their associated condition is
|
|
/// always 'false'.
|
|
///
|
|
/// Example:
|
|
/// ```
|
|
/// bufferization.dealloc (%arg0, %arg1 : memref<2xi32>, memref<2xi32>)
|
|
/// if (%arg2, %false)
|
|
/// ```
|
|
/// becomes
|
|
/// ```
|
|
/// bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg2)
|
|
/// ```
|
|
struct EraseAlwaysFalseDealloc : public OpRewritePattern<DeallocOp> {
|
|
using OpRewritePattern<DeallocOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(DeallocOp deallocOp,
|
|
PatternRewriter &rewriter) const override {
|
|
SmallVector<Value> newMemrefs, newConditions;
|
|
for (auto [memref, cond] :
|
|
llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
|
|
if (!matchPattern(cond, m_Zero())) {
|
|
newMemrefs.push_back(memref);
|
|
newConditions.push_back(cond);
|
|
}
|
|
}
|
|
|
|
return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
|
|
rewriter);
|
|
}
|
|
};
|
|
|
|
/// The `memref.extract_strided_metadata` is often inserted to get the base
|
|
/// memref if the operand is not already guaranteed to be the result of a memref
|
|
/// allocation operation. This canonicalization pattern removes this extraction
|
|
/// operation if the operand is now produced by an allocation operation (e.g.,
|
|
/// due to other canonicalizations simplifying the IR).
|
|
///
|
|
/// Example:
|
|
/// ```mlir
|
|
/// %alloc = memref.alloc() : memref<2xi32>
|
|
/// %base_memref, %offset, %size, %stride = memref.extract_strided_metadata
|
|
/// %alloc : memref<2xi32> -> memref<i32>, index, index, index
|
|
/// bufferization.dealloc (%base_memref : memref<i32>) if (%cond)
|
|
/// ```
|
|
/// is canonicalized to
|
|
/// ```mlir
|
|
/// %alloc = memref.alloc() : memref<2xi32>
|
|
/// bufferization.dealloc (%alloc : memref<2xi32>) if (%cond)
|
|
/// ```
|
|
struct SkipExtractMetadataOfAlloc : public OpRewritePattern<DeallocOp> {
|
|
using OpRewritePattern<DeallocOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(DeallocOp deallocOp,
|
|
PatternRewriter &rewriter) const override {
|
|
SmallVector<Value> newMemrefs(
|
|
llvm::map_range(deallocOp.getMemrefs(), [&](Value memref) {
|
|
auto extractStridedOp =
|
|
memref.getDefiningOp<memref::ExtractStridedMetadataOp>();
|
|
if (!extractStridedOp)
|
|
return memref;
|
|
Value allocMemref = extractStridedOp.getOperand();
|
|
auto allocOp = allocMemref.getDefiningOp<MemoryEffectOpInterface>();
|
|
if (!allocOp)
|
|
return memref;
|
|
if (allocOp.getEffectOnValue<MemoryEffects::Allocate>(allocMemref))
|
|
return allocMemref;
|
|
return memref;
|
|
}));
|
|
|
|
return updateDeallocIfChanged(deallocOp, newMemrefs,
|
|
deallocOp.getConditions(), rewriter);
|
|
}
|
|
};
|
|
|
|
/// Removes pairs of `bufferization.dealloc` and alloc operations if there is no
|
|
/// other user of the allocated value and the allocating operation can be safely
|
|
/// removed. If the same value is present multiple times, this pattern relies on
|
|
/// other canonicalization patterns to remove the duplicate first.
|
|
///
|
|
/// Example:
|
|
/// ```mlir
|
|
/// %alloc = memref.alloc() : memref<2xi32>
|
|
/// bufferization.dealloc (%alloc, %arg0, : ...) if (%true, %true)
|
|
/// ```
|
|
/// is canonicalized to
|
|
/// ```mlir
|
|
/// bufferization.dealloc (%arg0 : ...) if (%true)
|
|
/// ```
|
|
struct RemoveAllocDeallocPairWhenNoOtherUsers
|
|
: public OpRewritePattern<DeallocOp> {
|
|
using OpRewritePattern<DeallocOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(DeallocOp deallocOp,
|
|
PatternRewriter &rewriter) const override {
|
|
SmallVector<Value> newMemrefs, newConditions;
|
|
SmallVector<Operation *> toDelete;
|
|
for (auto [memref, cond] :
|
|
llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
|
|
if (auto allocOp = memref.getDefiningOp<MemoryEffectOpInterface>()) {
|
|
// Check that it is indeed an allocate effect, that the op has no other
|
|
// side effects (which would not allow us to remove the op), and that
|
|
// there are no other users.
|
|
if (allocOp.getEffectOnValue<MemoryEffects::Allocate>(memref) &&
|
|
hasSingleEffect<MemoryEffects::Allocate>(allocOp, memref) &&
|
|
memref.hasOneUse()) {
|
|
toDelete.push_back(allocOp);
|
|
continue;
|
|
}
|
|
}
|
|
|
|
newMemrefs.push_back(memref);
|
|
newConditions.push_back(cond);
|
|
}
|
|
|
|
if (failed(updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
|
|
rewriter)))
|
|
return failure();
|
|
|
|
for (Operation *op : toDelete)
|
|
rewriter.eraseOp(op);
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // anonymous namespace
|
|
|
|
void DeallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|
MLIRContext *context) {
|
|
populateDeallocOpCanonicalizationPatterns(results, context);
|
|
}
|
|
|
|
void bufferization::populateDeallocOpCanonicalizationPatterns(
|
|
RewritePatternSet &patterns, MLIRContext *context) {
|
|
patterns.add<DeallocRemoveDuplicateDeallocMemrefs,
|
|
DeallocRemoveDuplicateRetainedMemrefs, EraseEmptyDealloc,
|
|
EraseAlwaysFalseDealloc, SkipExtractMetadataOfAlloc,
|
|
RemoveAllocDeallocPairWhenNoOtherUsers>(context);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TableGen'd op method definitions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#define GET_OP_CLASSES
|
|
#include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc"
|