Handling parallel region RaW conflicts should usually be the responsibility of the source program, rather than bufferization analysis. However, to preserve current functionality, checks on parallel regions is put behind a bufferization in this PR, which is on by default. Default functionality will not change, but this PR enables the option to leave parallelism checks out of the bufferization analysis.
636 lines
23 KiB
C++
636 lines
23 KiB
C++
//===- Bufferize.cpp - Bufferization utilities ----------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
|
|
|
|
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
|
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
|
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
|
|
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
|
|
#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
|
|
#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/IR/Diagnostics.h"
|
|
#include "mlir/IR/Operation.h"
|
|
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
|
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
|
#include "mlir/Pass/PassManager.h"
|
|
#include "mlir/Transforms/Passes.h"
|
|
#include <optional>
|
|
|
|
namespace mlir {
|
|
namespace bufferization {
|
|
#define GEN_PASS_DEF_FINALIZINGBUFFERIZE
|
|
#define GEN_PASS_DEF_BUFFERIZATIONBUFFERIZE
|
|
#define GEN_PASS_DEF_ONESHOTBUFFERIZE
|
|
#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
|
|
} // namespace bufferization
|
|
} // namespace mlir
|
|
|
|
#define DEBUG_TYPE "bufferize"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::bufferization;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// BufferizeTypeConverter
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static Value materializeToTensor(OpBuilder &builder, TensorType type,
|
|
ValueRange inputs, Location loc) {
|
|
assert(inputs.size() == 1);
|
|
assert(isa<BaseMemRefType>(inputs[0].getType()));
|
|
return builder.create<bufferization::ToTensorOp>(loc, type, inputs[0]);
|
|
}
|
|
|
|
/// Registers conversions into BufferizeTypeConverter
|
|
BufferizeTypeConverter::BufferizeTypeConverter() {
|
|
// Keep all types unchanged.
|
|
addConversion([](Type type) { return type; });
|
|
// Convert RankedTensorType to MemRefType.
|
|
addConversion([](RankedTensorType type) -> Type {
|
|
return MemRefType::get(type.getShape(), type.getElementType());
|
|
});
|
|
// Convert UnrankedTensorType to UnrankedMemRefType.
|
|
addConversion([](UnrankedTensorType type) -> Type {
|
|
return UnrankedMemRefType::get(type.getElementType(), 0);
|
|
});
|
|
addArgumentMaterialization(materializeToTensor);
|
|
addSourceMaterialization(materializeToTensor);
|
|
addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type,
|
|
ValueRange inputs, Location loc) -> Value {
|
|
assert(inputs.size() == 1 && "expected exactly one input");
|
|
|
|
if (auto inputType = dyn_cast<MemRefType>(inputs[0].getType())) {
|
|
// MemRef to MemRef cast.
|
|
assert(inputType != type && "expected different types");
|
|
// Unranked to ranked and ranked to unranked casts must be explicit.
|
|
auto rankedDestType = dyn_cast<MemRefType>(type);
|
|
if (!rankedDestType)
|
|
return nullptr;
|
|
BufferizationOptions options;
|
|
options.bufferAlignment = 0;
|
|
FailureOr<Value> replacement =
|
|
castOrReallocMemRefValue(builder, inputs[0], rankedDestType, options);
|
|
if (failed(replacement))
|
|
return nullptr;
|
|
return *replacement;
|
|
}
|
|
|
|
if (isa<TensorType>(inputs[0].getType())) {
|
|
// Tensor to MemRef cast.
|
|
return builder.create<bufferization::ToMemrefOp>(loc, type, inputs[0]);
|
|
}
|
|
|
|
llvm_unreachable("only tensor/memref input types supported");
|
|
});
|
|
}
|
|
|
|
void mlir::bufferization::populateBufferizeMaterializationLegality(
|
|
ConversionTarget &target) {
|
|
target.addLegalOp<bufferization::ToTensorOp, bufferization::ToMemrefOp>();
|
|
}
|
|
|
|
namespace {
|
|
// In a finalizing bufferize conversion, we know that all tensors have been
|
|
// converted to memrefs, thus, this op becomes an identity.
|
|
class BufferizeToTensorOp
|
|
: public OpConversionPattern<bufferization::ToTensorOp> {
|
|
public:
|
|
using OpConversionPattern::OpConversionPattern;
|
|
LogicalResult
|
|
matchAndRewrite(bufferization::ToTensorOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
rewriter.replaceOp(op, adaptor.getMemref());
|
|
return success();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
namespace {
|
|
// In a finalizing bufferize conversion, we know that all tensors have been
|
|
// converted to memrefs, thus, this op becomes an identity.
|
|
class BufferizeToMemrefOp
|
|
: public OpConversionPattern<bufferization::ToMemrefOp> {
|
|
public:
|
|
using OpConversionPattern::OpConversionPattern;
|
|
LogicalResult
|
|
matchAndRewrite(bufferization::ToMemrefOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
rewriter.replaceOp(op, adaptor.getTensor());
|
|
return success();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
void mlir::bufferization::populateEliminateBufferizeMaterializationsPatterns(
|
|
BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) {
|
|
patterns.add<BufferizeToTensorOp, BufferizeToMemrefOp>(typeConverter,
|
|
patterns.getContext());
|
|
}
|
|
|
|
namespace {
|
|
struct FinalizingBufferizePass
|
|
: public bufferization::impl::FinalizingBufferizeBase<
|
|
FinalizingBufferizePass> {
|
|
using FinalizingBufferizeBase<
|
|
FinalizingBufferizePass>::FinalizingBufferizeBase;
|
|
|
|
void runOnOperation() override {
|
|
auto func = getOperation();
|
|
auto *context = &getContext();
|
|
|
|
BufferizeTypeConverter typeConverter;
|
|
RewritePatternSet patterns(context);
|
|
ConversionTarget target(*context);
|
|
|
|
populateEliminateBufferizeMaterializationsPatterns(typeConverter, patterns);
|
|
|
|
// If all result types are legal, and all block arguments are legal (ensured
|
|
// by func conversion above), then all types in the program are legal.
|
|
//
|
|
// We also check that the operand types are legal to avoid creating invalid
|
|
// IR. For example, this prevents
|
|
// populateEliminateBufferizeMaterializationsPatterns from updating the
|
|
// types of the operands to a return op without updating the enclosing
|
|
// function.
|
|
target.markUnknownOpDynamicallyLegal(
|
|
[&](Operation *op) { return typeConverter.isLegal(op); });
|
|
|
|
if (failed(applyFullConversion(func, target, std::move(patterns))))
|
|
signalPassFailure();
|
|
}
|
|
};
|
|
|
|
static LayoutMapOption parseLayoutMapOption(const std::string &s) {
|
|
if (s == "fully-dynamic-layout-map")
|
|
return LayoutMapOption::FullyDynamicLayoutMap;
|
|
if (s == "identity-layout-map")
|
|
return LayoutMapOption::IdentityLayoutMap;
|
|
if (s == "infer-layout-map")
|
|
return LayoutMapOption::InferLayoutMap;
|
|
llvm_unreachable("invalid layout map option");
|
|
}
|
|
|
|
static OneShotBufferizationOptions::AnalysisHeuristic
|
|
parseHeuristicOption(const std::string &s) {
|
|
if (s == "bottom-up")
|
|
return OneShotBufferizationOptions::AnalysisHeuristic::BottomUp;
|
|
if (s == "top-down")
|
|
return OneShotBufferizationOptions::AnalysisHeuristic::TopDown;
|
|
if (s == "bottom-up-from-terminators")
|
|
return OneShotBufferizationOptions::AnalysisHeuristic::
|
|
BottomUpFromTerminators;
|
|
if (s == "fuzzer")
|
|
return OneShotBufferizationOptions::AnalysisHeuristic::Fuzzer;
|
|
llvm_unreachable("invalid analysisheuristic option");
|
|
}
|
|
|
|
struct OneShotBufferizePass
|
|
: public bufferization::impl::OneShotBufferizeBase<OneShotBufferizePass> {
|
|
OneShotBufferizePass() = default;
|
|
|
|
explicit OneShotBufferizePass(const OneShotBufferizationOptions &options)
|
|
: options(options) {}
|
|
|
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
|
registry
|
|
.insert<bufferization::BufferizationDialect, memref::MemRefDialect>();
|
|
}
|
|
|
|
void runOnOperation() override {
|
|
OneShotBufferizationOptions opt;
|
|
if (!options) {
|
|
// Make new bufferization options if none were provided when creating the
|
|
// pass.
|
|
opt.allowReturnAllocsFromLoops = allowReturnAllocsFromLoops;
|
|
opt.allowUnknownOps = allowUnknownOps;
|
|
opt.analysisFuzzerSeed = analysisFuzzerSeed;
|
|
opt.analysisHeuristic = parseHeuristicOption(analysisHeuristic);
|
|
opt.copyBeforeWrite = copyBeforeWrite;
|
|
opt.dumpAliasSets = dumpAliasSets;
|
|
opt.setFunctionBoundaryTypeConversion(
|
|
parseLayoutMapOption(functionBoundaryTypeConversion));
|
|
if (mustInferMemorySpace) {
|
|
opt.defaultMemorySpaceFn =
|
|
[](TensorType t) -> std::optional<Attribute> {
|
|
return std::nullopt;
|
|
};
|
|
}
|
|
opt.printConflicts = printConflicts;
|
|
opt.testAnalysisOnly = testAnalysisOnly;
|
|
opt.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;
|
|
opt.checkParallelRegions = checkParallelRegions;
|
|
opt.noAnalysisFuncFilter = noAnalysisFuncFilter;
|
|
|
|
// Configure type converter.
|
|
LayoutMapOption unknownTypeConversionOption =
|
|
parseLayoutMapOption(unknownTypeConversion);
|
|
if (unknownTypeConversionOption == LayoutMapOption::InferLayoutMap) {
|
|
emitError(UnknownLoc::get(&getContext()),
|
|
"Invalid option: 'infer-layout-map' is not a valid value for "
|
|
"'unknown-type-conversion'");
|
|
return signalPassFailure();
|
|
}
|
|
opt.unknownTypeConverterFn = [=](Value value, Attribute memorySpace,
|
|
const BufferizationOptions &options) {
|
|
auto tensorType = cast<TensorType>(value.getType());
|
|
if (unknownTypeConversionOption == LayoutMapOption::IdentityLayoutMap)
|
|
return bufferization::getMemRefTypeWithStaticIdentityLayout(
|
|
tensorType, memorySpace);
|
|
assert(unknownTypeConversionOption ==
|
|
LayoutMapOption::FullyDynamicLayoutMap &&
|
|
"invalid layout map option");
|
|
return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType,
|
|
memorySpace);
|
|
};
|
|
|
|
// Configure op filter.
|
|
OpFilter::Entry::FilterFn filterFn = [&](Operation *op) {
|
|
// Filter may be specified via options.
|
|
if (this->dialectFilter.hasValue())
|
|
return llvm::is_contained(this->dialectFilter,
|
|
op->getDialect()->getNamespace());
|
|
// No filter specified: All other ops are allowed.
|
|
return true;
|
|
};
|
|
opt.opFilter.allowOperation(filterFn);
|
|
} else {
|
|
opt = *options;
|
|
}
|
|
|
|
if (opt.copyBeforeWrite && opt.testAnalysisOnly) {
|
|
// These two flags do not make sense together: "copy-before-write"
|
|
// indicates that copies should be inserted before every memory write,
|
|
// but "test-analysis-only" indicates that only the analysis should be
|
|
// tested. (I.e., no IR is bufferized.)
|
|
emitError(UnknownLoc::get(&getContext()),
|
|
"Invalid option: 'copy-before-write' cannot be used with "
|
|
"'test-analysis-only'");
|
|
return signalPassFailure();
|
|
}
|
|
|
|
if (opt.printConflicts && !opt.testAnalysisOnly) {
|
|
emitError(
|
|
UnknownLoc::get(&getContext()),
|
|
"Invalid option: 'print-conflicts' requires 'test-analysis-only'");
|
|
return signalPassFailure();
|
|
}
|
|
|
|
if (opt.dumpAliasSets && !opt.testAnalysisOnly) {
|
|
emitError(
|
|
UnknownLoc::get(&getContext()),
|
|
"Invalid option: 'dump-alias-sets' requires 'test-analysis-only'");
|
|
return signalPassFailure();
|
|
}
|
|
|
|
BufferizationStatistics statistics;
|
|
ModuleOp moduleOp = getOperation();
|
|
if (opt.bufferizeFunctionBoundaries) {
|
|
if (failed(runOneShotModuleBufferize(moduleOp, opt, &statistics))) {
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
} else {
|
|
if (!opt.noAnalysisFuncFilter.empty()) {
|
|
emitError(UnknownLoc::get(&getContext()),
|
|
"Invalid option: 'no-analysis-func-filter' requires "
|
|
"'bufferize-function-boundaries'");
|
|
return signalPassFailure();
|
|
}
|
|
if (failed(runOneShotBufferize(moduleOp, opt, &statistics))) {
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
}
|
|
|
|
// Set pass statistics.
|
|
this->numBufferAlloc = statistics.numBufferAlloc;
|
|
this->numTensorInPlace = statistics.numTensorInPlace;
|
|
this->numTensorOutOfPlace = statistics.numTensorOutOfPlace;
|
|
}
|
|
|
|
private:
|
|
std::optional<OneShotBufferizationOptions> options;
|
|
};
|
|
} // namespace
|
|
|
|
std::unique_ptr<Pass> mlir::bufferization::createOneShotBufferizePass() {
|
|
return std::make_unique<OneShotBufferizePass>();
|
|
}
|
|
|
|
std::unique_ptr<Pass> mlir::bufferization::createOneShotBufferizePass(
|
|
const OneShotBufferizationOptions &options) {
|
|
return std::make_unique<OneShotBufferizePass>(options);
|
|
}
|
|
|
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
|
mlir::bufferization::createFinalizingBufferizePass() {
|
|
return std::make_unique<FinalizingBufferizePass>();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// BufferizableOpInterface-based Bufferization
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
/// A rewriter that keeps track of extra information during bufferization.
|
|
class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener {
|
|
public:
|
|
BufferizationRewriter(MLIRContext *ctx, DenseSet<Operation *> &erasedOps,
|
|
DenseSet<Operation *> &toMemrefOps,
|
|
SmallVector<Operation *> &worklist,
|
|
const BufferizationOptions &options,
|
|
BufferizationStatistics *statistics)
|
|
: IRRewriter(ctx), erasedOps(erasedOps), toMemrefOps(toMemrefOps),
|
|
worklist(worklist), analysisState(options), statistics(statistics) {
|
|
setListener(this);
|
|
}
|
|
|
|
protected:
|
|
void notifyOperationErased(Operation *op) override {
|
|
erasedOps.insert(op);
|
|
// Erase if present.
|
|
toMemrefOps.erase(op);
|
|
}
|
|
|
|
void notifyOperationInserted(Operation *op, InsertPoint previous) override {
|
|
// We only care about newly created ops.
|
|
if (previous.isSet())
|
|
return;
|
|
|
|
erasedOps.erase(op);
|
|
|
|
// Gather statistics about allocs.
|
|
if (statistics) {
|
|
if (auto sideEffectingOp = dyn_cast<MemoryEffectOpInterface>(op))
|
|
statistics->numBufferAlloc += static_cast<int64_t>(
|
|
sideEffectingOp.hasEffect<MemoryEffects::Allocate>());
|
|
}
|
|
|
|
// Keep track of to_memref ops.
|
|
if (isa<ToMemrefOp>(op)) {
|
|
toMemrefOps.insert(op);
|
|
return;
|
|
}
|
|
|
|
// Skip to_tensor ops.
|
|
if (isa<ToTensorOp>(op))
|
|
return;
|
|
|
|
// Skip non-tensor ops.
|
|
if (!hasTensorSemantics(op))
|
|
return;
|
|
|
|
// Skip ops that are not allowed to be bufferized.
|
|
auto const &options = analysisState.getOptions();
|
|
if (!options.isOpAllowed(op))
|
|
return;
|
|
|
|
// Add op to worklist.
|
|
worklist.push_back(op);
|
|
}
|
|
|
|
private:
|
|
/// A set of all erased ops.
|
|
DenseSet<Operation *> &erasedOps;
|
|
|
|
/// A set of all to_memref ops.
|
|
DenseSet<Operation *> &toMemrefOps;
|
|
|
|
/// The worklist of ops to be bufferized.
|
|
SmallVector<Operation *> &worklist;
|
|
|
|
/// The analysis state. Used for debug assertions and access to the
|
|
/// bufferization options.
|
|
const AnalysisState analysisState;
|
|
|
|
/// Bufferization statistics for debugging.
|
|
BufferizationStatistics *statistics;
|
|
};
|
|
} // namespace
|
|
|
|
LogicalResult bufferization::bufferizeOp(Operation *op,
|
|
const BufferizationOptions &options,
|
|
BufferizationStatistics *statistics) {
|
|
if (options.copyBeforeWrite) {
|
|
AnalysisState state(options);
|
|
if (failed(insertTensorCopies(op, state)))
|
|
return failure();
|
|
}
|
|
|
|
// Keep track of to_memref ops.
|
|
DenseSet<Operation *> toMemrefOps;
|
|
op->walk([&](ToMemrefOp toMemrefOp) { toMemrefOps.insert(toMemrefOp); });
|
|
|
|
// Gather all bufferizable ops in top-to-bottom order.
|
|
//
|
|
// We should ideally know the exact memref type of all operands when
|
|
// bufferizing an op. (This is the case when bufferizing top-to-bottom.)
|
|
// Otherwise, we have to use a memref type with a fully dynamic layout map to
|
|
// avoid copies. We are currently missing patterns for layout maps to
|
|
// canonicalize away (or canonicalize to more precise layouts).
|
|
SmallVector<Operation *> worklist;
|
|
op->walk<WalkOrder::PostOrder>([&](Operation *op) {
|
|
if (options.isOpAllowed(op) && hasTensorSemantics(op))
|
|
worklist.push_back(op);
|
|
});
|
|
|
|
// Keep track of all erased ops.
|
|
DenseSet<Operation *> erasedOps;
|
|
|
|
// Bufferize all ops.
|
|
BufferizationRewriter rewriter(op->getContext(), erasedOps, toMemrefOps,
|
|
worklist, options, statistics);
|
|
for (unsigned i = 0; i < worklist.size(); ++i) {
|
|
Operation *nextOp = worklist[i];
|
|
// Skip ops that were erased.
|
|
if (erasedOps.contains(nextOp))
|
|
continue;
|
|
// Skip ops that are not bufferizable or not allowed.
|
|
auto bufferizableOp = options.dynCastBufferizableOp(nextOp);
|
|
if (!bufferizableOp)
|
|
continue;
|
|
// Skip ops that no longer have tensor semantics.
|
|
if (!hasTensorSemantics(nextOp))
|
|
continue;
|
|
// Check for unsupported unstructured control flow.
|
|
if (!bufferizableOp.supportsUnstructuredControlFlow())
|
|
for (Region &r : nextOp->getRegions())
|
|
if (r.getBlocks().size() > 1)
|
|
return nextOp->emitOpError(
|
|
"op or BufferizableOpInterface implementation does not support "
|
|
"unstructured control flow, but at least one region has multiple "
|
|
"blocks");
|
|
|
|
// Bufferize the op.
|
|
LLVM_DEBUG(llvm::dbgs()
|
|
<< "//===-------------------------------------------===//\n"
|
|
<< "IR after bufferizing: " << nextOp->getName() << "\n");
|
|
rewriter.setInsertionPoint(nextOp);
|
|
if (failed(bufferizableOp.bufferize(rewriter, options))) {
|
|
LLVM_DEBUG(llvm::dbgs()
|
|
<< "failed to bufferize\n"
|
|
<< "//===-------------------------------------------===//\n");
|
|
return nextOp->emitError("failed to bufferize op");
|
|
}
|
|
LLVM_DEBUG(llvm::dbgs()
|
|
<< *op
|
|
<< "\n//===-------------------------------------------===//\n");
|
|
}
|
|
|
|
// Return early if the top-level op is entirely gone.
|
|
if (erasedOps.contains(op))
|
|
return success();
|
|
|
|
// Fold all to_memref(to_tensor(x)) pairs.
|
|
for (Operation *op : toMemrefOps) {
|
|
rewriter.setInsertionPoint(op);
|
|
(void)bufferization::foldToMemrefToTensorPair(
|
|
rewriter, cast<ToMemrefOp>(op), options);
|
|
}
|
|
|
|
// Remove all dead to_tensor ops.
|
|
op->walk<WalkOrder::PostOrder>([&](ToTensorOp toTensorOp) {
|
|
if (toTensorOp->getUses().empty()) {
|
|
rewriter.eraseOp(toTensorOp);
|
|
return WalkResult::skip();
|
|
}
|
|
return WalkResult::advance();
|
|
});
|
|
|
|
/// Check the result of bufferization. Return an error if an op was not
|
|
/// bufferized, unless partial bufferization is allowed.
|
|
if (options.allowUnknownOps)
|
|
return success();
|
|
|
|
for (Operation *op : worklist) {
|
|
// Skip ops that are entirely gone.
|
|
if (erasedOps.contains(op))
|
|
continue;
|
|
// Ops that no longer have tensor semantics (because they were updated
|
|
// in-place) are allowed.
|
|
if (!hasTensorSemantics(op))
|
|
continue;
|
|
// Continue ops that are not allowed.
|
|
if (!options.isOpAllowed(op))
|
|
continue;
|
|
// Ops without any uses and no side effects will fold away.
|
|
if (op->getUses().empty() && isMemoryEffectFree(op))
|
|
continue;
|
|
// ToTensorOps/ToMemrefOps are allowed in the output.
|
|
if (isa<ToTensorOp, ToMemrefOp>(op))
|
|
continue;
|
|
return op->emitError("op was not bufferized");
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
LogicalResult
|
|
bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter,
|
|
const BufferizationOptions &options) {
|
|
OpBuilder::InsertionGuard g(rewriter);
|
|
auto bufferizableOp = options.dynCastBufferizableOp(block->getParentOp());
|
|
if (!bufferizableOp)
|
|
return failure();
|
|
|
|
// Compute the new signature.
|
|
SmallVector<Type> newTypes;
|
|
for (BlockArgument &bbArg : block->getArguments()) {
|
|
auto tensorType = dyn_cast<TensorType>(bbArg.getType());
|
|
if (!tensorType) {
|
|
newTypes.push_back(bbArg.getType());
|
|
continue;
|
|
}
|
|
|
|
FailureOr<BaseMemRefType> memrefType =
|
|
bufferization::getBufferType(bbArg, options);
|
|
if (failed(memrefType))
|
|
return failure();
|
|
newTypes.push_back(*memrefType);
|
|
}
|
|
|
|
// Change the type of all block arguments.
|
|
for (auto [bbArg, type] : llvm::zip(block->getArguments(), newTypes)) {
|
|
if (bbArg.getType() == type)
|
|
continue;
|
|
|
|
// Collect all uses of the bbArg.
|
|
SmallVector<OpOperand *> bbArgUses;
|
|
for (OpOperand &use : bbArg.getUses())
|
|
bbArgUses.push_back(&use);
|
|
|
|
// Change the bbArg type to memref.
|
|
bbArg.setType(type);
|
|
|
|
// Replace all uses of the original tensor bbArg.
|
|
rewriter.setInsertionPointToStart(block);
|
|
if (!bbArgUses.empty()) {
|
|
Value toTensorOp =
|
|
rewriter.create<bufferization::ToTensorOp>(bbArg.getLoc(), bbArg);
|
|
for (OpOperand *use : bbArgUses)
|
|
use->set(toTensorOp);
|
|
}
|
|
}
|
|
|
|
// Bufferize callers of the block.
|
|
for (Operation *op : block->getUsers()) {
|
|
auto branchOp = dyn_cast<BranchOpInterface>(op);
|
|
if (!branchOp)
|
|
return op->emitOpError("cannot bufferize ops with block references that "
|
|
"do not implement BranchOpInterface");
|
|
|
|
auto it = llvm::find(op->getSuccessors(), block);
|
|
assert(it != op->getSuccessors().end() && "could find successor");
|
|
int64_t successorIdx = std::distance(op->getSuccessors().begin(), it);
|
|
|
|
SuccessorOperands operands = branchOp.getSuccessorOperands(successorIdx);
|
|
SmallVector<Value> newOperands;
|
|
for (auto [operand, type] :
|
|
llvm::zip(operands.getForwardedOperands(), newTypes)) {
|
|
if (operand.getType() == type) {
|
|
// Not a tensor type. Nothing to do for this operand.
|
|
newOperands.push_back(operand);
|
|
continue;
|
|
}
|
|
FailureOr<BaseMemRefType> operandBufferType =
|
|
bufferization::getBufferType(operand, options);
|
|
if (failed(operandBufferType))
|
|
return failure();
|
|
rewriter.setInsertionPointAfterValue(operand);
|
|
Value bufferizedOperand = rewriter.create<bufferization::ToMemrefOp>(
|
|
operand.getLoc(), *operandBufferType, operand);
|
|
// A cast is needed if the operand and the block argument have different
|
|
// bufferized types.
|
|
if (type != *operandBufferType)
|
|
bufferizedOperand = rewriter.create<memref::CastOp>(
|
|
operand.getLoc(), type, bufferizedOperand);
|
|
newOperands.push_back(bufferizedOperand);
|
|
}
|
|
operands.getMutableForwardedOperands().assign(newOperands);
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
BufferizationOptions bufferization::getPartialBufferizationOptions() {
|
|
BufferizationOptions options;
|
|
options.allowUnknownOps = true;
|
|
options.copyBeforeWrite = true;
|
|
options.enforceAliasingInvariants = false;
|
|
options.unknownTypeConverterFn = [](Value value, Attribute memorySpace,
|
|
const BufferizationOptions &options) {
|
|
return getMemRefTypeWithStaticIdentityLayout(
|
|
cast<TensorType>(value.getType()), memorySpace);
|
|
};
|
|
options.opFilter.allowDialect<BufferizationDialect>();
|
|
return options;
|
|
}
|