As described in issue llvm/llvm-project#91518, a previous PR llvm/llvm-project#78484 introduced the `defaultMemorySpaceFn` into bufferization options, allowing one to inform OneShotBufferize that it should use a specified function to derive the memory space attribute from the encoding attribute attached to tensor types. However, introducing this feature exposed unhandled edge cases, examples of which are introduced by this change in the new test under `test/Dialect/Bufferization/Transforms/one-shot-bufferize-encodings.mlir`. Fixing the inconsistencies introduced by `defaultMemorySpaceFn` is pretty simple. This change: - Updates the `bufferization.to_memref` and `bufferization.to_tensor` operations to explicitly include operand and destination types, whereas previously they relied on type inference to deduce the tensor types. Since the type inference cannot recover the correct tensor encoding/memory space, the operand and result types must be explicitly included. This is a small assembly format change, but it touches a large number of test files. - Makes minor updates to other bufferization functions to handle the changes in building the above ops. - Updates bufferization of `tensor.from_elements` to handle memory space. Integration/upgrade guide: In downstream projects, if you have tests or MLIR files that explicitly use `bufferization.to_tensor` or `bufferization.to_memref`, then update them to the new assembly format as follows: ``` %1 = bufferization.to_memref %0 : memref<10xf32> %2 = bufferization.to_tensor %1 : memref<10xf32> ``` becomes ``` %1 = bufferization.to_memref %0 : tensor<10xf32> to memref<10xf32> %2 = bufferization.to_tensor %0 : memref<10xf32> to tensor<10xf32> ```
1375 lines
57 KiB
C++
1375 lines
57 KiB
C++
//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h"
|
|
|
|
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
|
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
|
#include "mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h"
|
|
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
|
|
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/Dialect/Utils/StaticValueUtils.h"
|
|
#include "mlir/IR/Dialect.h"
|
|
#include "mlir/IR/Operation.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::bufferization;
|
|
using namespace mlir::scf;
|
|
|
|
namespace mlir {
|
|
namespace scf {
|
|
namespace {
|
|
|
|
/// Helper function for loop bufferization. Cast the given buffer to the given
|
|
/// memref type.
|
|
static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
|
|
assert(isa<BaseMemRefType>(type) && "expected BaseMemRefType");
|
|
assert(isa<BaseMemRefType>(buffer.getType()) && "expected BaseMemRefType");
|
|
// If the buffer already has the correct type, no cast is needed.
|
|
if (buffer.getType() == type)
|
|
return buffer;
|
|
// TODO: In case `type` has a layout map that is not the fully dynamic
|
|
// one, we may not be able to cast the buffer. In that case, the loop
|
|
// iter_arg's layout map must be changed (see uses of `castBuffer`).
|
|
assert(memref::CastOp::areCastCompatible(buffer.getType(), type) &&
|
|
"scf.while op bufferization: cast incompatible");
|
|
return b.create<memref::CastOp>(buffer.getLoc(), type, buffer).getResult();
|
|
}
|
|
|
|
/// Helper function for loop bufferization. Return "true" if the given value
|
|
/// is guaranteed to not alias with an external tensor apart from values in
|
|
/// `exceptions`. A value is external if it is defined outside of the given
|
|
/// region or if it is an entry block argument of the region.
|
|
static bool doesNotAliasExternalValue(Value value, Region *region,
|
|
ValueRange exceptions,
|
|
const OneShotAnalysisState &state) {
|
|
assert(region->getBlocks().size() == 1 &&
|
|
"expected region with single block");
|
|
bool result = true;
|
|
state.applyOnAliases(value, [&](Value alias) {
|
|
if (llvm::is_contained(exceptions, alias))
|
|
return;
|
|
Region *aliasRegion = alias.getParentRegion();
|
|
if (isa<BlockArgument>(alias) && !region->isProperAncestor(aliasRegion))
|
|
result = false;
|
|
if (isa<OpResult>(alias) && !region->isAncestor(aliasRegion))
|
|
result = false;
|
|
});
|
|
return result;
|
|
}
|
|
|
|
/// Bufferization of scf.condition.
|
|
struct ConditionOpInterface
|
|
: public BufferizableOpInterface::ExternalModel<ConditionOpInterface,
|
|
scf::ConditionOp> {
|
|
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
|
const AnalysisState &state) const {
|
|
return true;
|
|
}
|
|
|
|
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
|
const AnalysisState &state) const {
|
|
return false;
|
|
}
|
|
|
|
AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
|
|
const AnalysisState &state) const {
|
|
return {};
|
|
}
|
|
|
|
bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
|
|
const AnalysisState &state) const {
|
|
// Condition operands always bufferize inplace. Otherwise, an alloc + copy
|
|
// may be generated inside the block. We should not return/yield allocations
|
|
// when possible.
|
|
return true;
|
|
}
|
|
|
|
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
|
const BufferizationOptions &options) const {
|
|
auto conditionOp = cast<scf::ConditionOp>(op);
|
|
auto whileOp = cast<scf::WhileOp>(conditionOp->getParentOp());
|
|
|
|
SmallVector<Value> newArgs;
|
|
for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
|
|
Value value = it.value();
|
|
if (isa<TensorType>(value.getType())) {
|
|
FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options);
|
|
if (failed(maybeBuffer))
|
|
return failure();
|
|
FailureOr<BaseMemRefType> resultType = bufferization::getBufferType(
|
|
whileOp.getAfterArguments()[it.index()], options);
|
|
if (failed(resultType))
|
|
return failure();
|
|
Value buffer = castBuffer(rewriter, *maybeBuffer, *resultType);
|
|
newArgs.push_back(buffer);
|
|
} else {
|
|
newArgs.push_back(value);
|
|
}
|
|
}
|
|
|
|
replaceOpWithNewBufferizedOp<scf::ConditionOp>(
|
|
rewriter, op, conditionOp.getCondition(), newArgs);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Return the unique scf.yield op. If there are multiple or no scf.yield ops,
|
|
/// return an empty op.
|
|
static scf::YieldOp getUniqueYieldOp(scf::ExecuteRegionOp executeRegionOp) {
|
|
scf::YieldOp result;
|
|
for (Block &block : executeRegionOp.getRegion()) {
|
|
if (auto yieldOp = dyn_cast<scf::YieldOp>(block.getTerminator())) {
|
|
if (result)
|
|
return {};
|
|
result = yieldOp;
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
|
|
/// Bufferization of scf.execute_region. Can be analyzed, but bufferization not
|
|
/// fully implemented at the moment.
|
|
struct ExecuteRegionOpInterface
|
|
: public OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel<
|
|
ExecuteRegionOpInterface, scf::ExecuteRegionOp> {
|
|
|
|
static bool supportsUnstructuredControlFlow() { return true; }
|
|
|
|
bool isWritable(Operation *op, Value value,
|
|
const AnalysisState &state) const {
|
|
return true;
|
|
}
|
|
|
|
LogicalResult verifyAnalysis(Operation *op,
|
|
const AnalysisState &state) const {
|
|
auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
|
|
// TODO: scf.execute_region with multiple yields are not supported.
|
|
if (!getUniqueYieldOp(executeRegionOp))
|
|
return op->emitOpError("op without unique scf.yield is not supported");
|
|
return success();
|
|
}
|
|
|
|
AliasingOpOperandList
|
|
getAliasingOpOperands(Operation *op, Value value,
|
|
const AnalysisState &state) const {
|
|
if (auto bbArg = dyn_cast<BlockArgument>(value))
|
|
return getAliasingBranchOpOperands(op, bbArg, state);
|
|
|
|
// ExecuteRegionOps do not have tensor OpOperands. The yielded value can be
|
|
// any SSA value that is in scope. To allow for use-def chain traversal
|
|
// through ExecuteRegionOps in the analysis, the corresponding yield value
|
|
// is considered to be aliasing with the result.
|
|
auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
|
|
auto it = llvm::find(op->getOpResults(), value);
|
|
assert(it != op->getOpResults().end() && "invalid value");
|
|
size_t resultNum = std::distance(op->getOpResults().begin(), it);
|
|
auto yieldOp = getUniqueYieldOp(executeRegionOp);
|
|
// Note: If there is no unique scf.yield op, `verifyAnalysis` will fail.
|
|
if (!yieldOp)
|
|
return {};
|
|
return {{&yieldOp->getOpOperand(resultNum), BufferRelation::Equivalent}};
|
|
}
|
|
|
|
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
|
const BufferizationOptions &options) const {
|
|
auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
|
|
auto yieldOp = getUniqueYieldOp(executeRegionOp);
|
|
TypeRange newResultTypes(yieldOp.getResults());
|
|
|
|
// Create new op and move over region.
|
|
auto newOp =
|
|
rewriter.create<scf::ExecuteRegionOp>(op->getLoc(), newResultTypes);
|
|
newOp.getRegion().takeBody(executeRegionOp.getRegion());
|
|
|
|
// Bufferize every block.
|
|
for (Block &block : newOp.getRegion())
|
|
if (failed(bufferization::bufferizeBlockSignature(&block, rewriter,
|
|
options)))
|
|
return failure();
|
|
|
|
// Update all uses of the old op.
|
|
rewriter.setInsertionPointAfter(newOp);
|
|
SmallVector<Value> newResults;
|
|
for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) {
|
|
if (isa<TensorType>(it.value())) {
|
|
newResults.push_back(rewriter.create<bufferization::ToTensorOp>(
|
|
executeRegionOp.getLoc(), it.value(),
|
|
newOp->getResult(it.index())));
|
|
} else {
|
|
newResults.push_back(newOp->getResult(it.index()));
|
|
}
|
|
}
|
|
|
|
// Replace old op.
|
|
rewriter.replaceOp(executeRegionOp, newResults);
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Bufferization of scf.if. Replace with a new scf.if that yields memrefs.
|
|
struct IfOpInterface
|
|
: public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> {
|
|
AliasingOpOperandList
|
|
getAliasingOpOperands(Operation *op, Value value,
|
|
const AnalysisState &state) const {
|
|
// IfOps do not have tensor OpOperands. The yielded value can be any SSA
|
|
// value that is in scope. To allow for use-def chain traversal through
|
|
// IfOps in the analysis, both corresponding yield values from the then/else
|
|
// branches are considered to be aliasing with the result.
|
|
auto ifOp = cast<scf::IfOp>(op);
|
|
size_t resultNum = std::distance(op->getOpResults().begin(),
|
|
llvm::find(op->getOpResults(), value));
|
|
OpOperand *thenOperand = &ifOp.thenYield()->getOpOperand(resultNum);
|
|
OpOperand *elseOperand = &ifOp.elseYield()->getOpOperand(resultNum);
|
|
return {{thenOperand, BufferRelation::Equivalent, /*isDefinite=*/false},
|
|
{elseOperand, BufferRelation::Equivalent, /*isDefinite=*/false}};
|
|
}
|
|
|
|
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
|
const BufferizationOptions &options) const {
|
|
OpBuilder::InsertionGuard g(rewriter);
|
|
auto ifOp = cast<scf::IfOp>(op);
|
|
|
|
// Compute bufferized result types.
|
|
SmallVector<Type> newTypes;
|
|
for (Value result : ifOp.getResults()) {
|
|
if (!isa<TensorType>(result.getType())) {
|
|
newTypes.push_back(result.getType());
|
|
continue;
|
|
}
|
|
auto bufferType = bufferization::getBufferType(result, options);
|
|
if (failed(bufferType))
|
|
return failure();
|
|
newTypes.push_back(*bufferType);
|
|
}
|
|
|
|
// Create new op.
|
|
rewriter.setInsertionPoint(ifOp);
|
|
auto newIfOp =
|
|
rewriter.create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.getCondition(),
|
|
/*withElseRegion=*/true);
|
|
|
|
// Move over then/else blocks.
|
|
rewriter.mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock());
|
|
rewriter.mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock());
|
|
|
|
// Replace op results.
|
|
replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults());
|
|
|
|
return success();
|
|
}
|
|
|
|
FailureOr<BaseMemRefType>
|
|
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
|
|
SmallVector<Value> &invocationStack) const {
|
|
auto ifOp = cast<scf::IfOp>(op);
|
|
auto thenYieldOp = cast<scf::YieldOp>(ifOp.thenBlock()->getTerminator());
|
|
auto elseYieldOp = cast<scf::YieldOp>(ifOp.elseBlock()->getTerminator());
|
|
assert(value.getDefiningOp() == op && "invalid valid");
|
|
|
|
// Determine buffer types of the true/false branches.
|
|
auto opResult = cast<OpResult>(value);
|
|
auto thenValue = thenYieldOp.getOperand(opResult.getResultNumber());
|
|
auto elseValue = elseYieldOp.getOperand(opResult.getResultNumber());
|
|
BaseMemRefType thenBufferType, elseBufferType;
|
|
if (isa<BaseMemRefType>(thenValue.getType())) {
|
|
// True branch was already bufferized.
|
|
thenBufferType = cast<BaseMemRefType>(thenValue.getType());
|
|
} else {
|
|
auto maybeBufferType =
|
|
bufferization::getBufferType(thenValue, options, invocationStack);
|
|
if (failed(maybeBufferType))
|
|
return failure();
|
|
thenBufferType = *maybeBufferType;
|
|
}
|
|
if (isa<BaseMemRefType>(elseValue.getType())) {
|
|
// False branch was already bufferized.
|
|
elseBufferType = cast<BaseMemRefType>(elseValue.getType());
|
|
} else {
|
|
auto maybeBufferType =
|
|
bufferization::getBufferType(elseValue, options, invocationStack);
|
|
if (failed(maybeBufferType))
|
|
return failure();
|
|
elseBufferType = *maybeBufferType;
|
|
}
|
|
|
|
// Best case: Both branches have the exact same buffer type.
|
|
if (thenBufferType == elseBufferType)
|
|
return thenBufferType;
|
|
|
|
// Memory space mismatch.
|
|
if (thenBufferType.getMemorySpace() != elseBufferType.getMemorySpace())
|
|
return op->emitError("inconsistent memory space on then/else branches");
|
|
|
|
// Layout maps are different: Promote to fully dynamic layout map.
|
|
return getMemRefTypeWithFullyDynamicLayout(
|
|
cast<TensorType>(opResult.getType()), thenBufferType.getMemorySpace());
|
|
}
|
|
};
|
|
|
|
/// Bufferization of scf.index_switch. Replace with a new scf.index_switch that
|
|
/// yields memrefs.
|
|
struct IndexSwitchOpInterface
|
|
: public BufferizableOpInterface::ExternalModel<IndexSwitchOpInterface,
|
|
scf::IndexSwitchOp> {
|
|
AliasingOpOperandList
|
|
getAliasingOpOperands(Operation *op, Value value,
|
|
const AnalysisState &state) const {
|
|
// IndexSwitchOps do not have tensor OpOperands. The yielded value can be
|
|
// any SSA. This is similar to IfOps.
|
|
auto switchOp = cast<scf::IndexSwitchOp>(op);
|
|
int64_t resultNum = cast<OpResult>(value).getResultNumber();
|
|
AliasingOpOperandList result;
|
|
for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) {
|
|
auto yieldOp =
|
|
cast<scf::YieldOp>(switchOp.getCaseBlock(i).getTerminator());
|
|
result.addAlias(AliasingOpOperand(&yieldOp->getOpOperand(resultNum),
|
|
BufferRelation::Equivalent,
|
|
/*isDefinite=*/false));
|
|
}
|
|
auto defaultYieldOp =
|
|
cast<scf::YieldOp>(switchOp.getDefaultBlock().getTerminator());
|
|
result.addAlias(AliasingOpOperand(&defaultYieldOp->getOpOperand(resultNum),
|
|
BufferRelation::Equivalent,
|
|
/*isDefinite=*/false));
|
|
return result;
|
|
}
|
|
|
|
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
|
const BufferizationOptions &options) const {
|
|
OpBuilder::InsertionGuard g(rewriter);
|
|
auto switchOp = cast<scf::IndexSwitchOp>(op);
|
|
|
|
// Compute bufferized result types.
|
|
SmallVector<Type> newTypes;
|
|
for (Value result : switchOp.getResults()) {
|
|
if (!isa<TensorType>(result.getType())) {
|
|
newTypes.push_back(result.getType());
|
|
continue;
|
|
}
|
|
auto bufferType = bufferization::getBufferType(result, options);
|
|
if (failed(bufferType))
|
|
return failure();
|
|
newTypes.push_back(*bufferType);
|
|
}
|
|
|
|
// Create new op.
|
|
rewriter.setInsertionPoint(switchOp);
|
|
auto newSwitchOp = rewriter.create<scf::IndexSwitchOp>(
|
|
switchOp.getLoc(), newTypes, switchOp.getArg(), switchOp.getCases(),
|
|
switchOp.getCases().size());
|
|
|
|
// Move over blocks.
|
|
for (auto [src, dest] :
|
|
llvm::zip(switchOp.getCaseRegions(), newSwitchOp.getCaseRegions()))
|
|
rewriter.inlineRegionBefore(src, dest, dest.begin());
|
|
rewriter.inlineRegionBefore(switchOp.getDefaultRegion(),
|
|
newSwitchOp.getDefaultRegion(),
|
|
newSwitchOp.getDefaultRegion().begin());
|
|
|
|
// Replace op results.
|
|
replaceOpWithBufferizedValues(rewriter, op, newSwitchOp->getResults());
|
|
|
|
return success();
|
|
}
|
|
|
|
FailureOr<BaseMemRefType>
|
|
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
|
|
SmallVector<Value> &invocationStack) const {
|
|
auto switchOp = cast<scf::IndexSwitchOp>(op);
|
|
assert(value.getDefiningOp() == op && "invalid value");
|
|
int64_t resultNum = cast<OpResult>(value).getResultNumber();
|
|
|
|
// Helper function to get buffer type of a case.
|
|
SmallVector<BaseMemRefType> yieldedTypes;
|
|
auto getYieldedBufferType = [&](Block &b) -> FailureOr<BaseMemRefType> {
|
|
auto yieldOp = cast<scf::YieldOp>(b.getTerminator());
|
|
Value yieldedValue = yieldOp->getOperand(resultNum);
|
|
if (auto bufferType = dyn_cast<BaseMemRefType>(yieldedValue.getType()))
|
|
return bufferType;
|
|
auto maybeBufferType =
|
|
bufferization::getBufferType(yieldedValue, options, invocationStack);
|
|
if (failed(maybeBufferType))
|
|
return failure();
|
|
return maybeBufferType;
|
|
};
|
|
|
|
// Compute buffer type of the default case.
|
|
auto maybeBufferType = getYieldedBufferType(switchOp.getDefaultBlock());
|
|
if (failed(maybeBufferType))
|
|
return failure();
|
|
BaseMemRefType bufferType = *maybeBufferType;
|
|
|
|
// Compute buffer types of all other cases.
|
|
for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) {
|
|
auto yieldedBufferType = getYieldedBufferType(switchOp.getCaseBlock(i));
|
|
if (failed(yieldedBufferType))
|
|
return failure();
|
|
|
|
// Best case: Both branches have the exact same buffer type.
|
|
if (bufferType == *yieldedBufferType)
|
|
continue;
|
|
|
|
// Memory space mismatch.
|
|
if (bufferType.getMemorySpace() != yieldedBufferType->getMemorySpace())
|
|
return op->emitError("inconsistent memory space on switch cases");
|
|
|
|
// Layout maps are different: Promote to fully dynamic layout map.
|
|
bufferType = getMemRefTypeWithFullyDynamicLayout(
|
|
cast<TensorType>(value.getType()), bufferType.getMemorySpace());
|
|
}
|
|
|
|
return bufferType;
|
|
}
|
|
};
|
|
|
|
/// Helper function for loop bufferization. Return the indices of all values
|
|
/// that have a tensor type.
|
|
static DenseSet<int64_t> getTensorIndices(ValueRange values) {
|
|
DenseSet<int64_t> result;
|
|
for (const auto &it : llvm::enumerate(values))
|
|
if (isa<TensorType>(it.value().getType()))
|
|
result.insert(it.index());
|
|
return result;
|
|
}
|
|
|
|
/// Helper function for loop bufferization. Return the indices of all
|
|
/// bbArg/yielded value pairs who's buffer relation is "Equivalent".
|
|
DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs,
|
|
ValueRange yieldedValues,
|
|
const AnalysisState &state) {
|
|
unsigned int minSize = std::min(bbArgs.size(), yieldedValues.size());
|
|
DenseSet<int64_t> result;
|
|
for (unsigned int i = 0; i < minSize; ++i) {
|
|
if (!isa<TensorType>(bbArgs[i].getType()) ||
|
|
!isa<TensorType>(yieldedValues[i].getType()))
|
|
continue;
|
|
if (state.areEquivalentBufferizedValues(bbArgs[i], yieldedValues[i]))
|
|
result.insert(i);
|
|
}
|
|
return result;
|
|
}
|
|
|
|
/// Helper function for loop bufferization. Return the bufferized values of the
|
|
/// given OpOperands. If an operand is not a tensor, return the original value.
|
|
static FailureOr<SmallVector<Value>>
|
|
getBuffers(RewriterBase &rewriter, const MutableOperandRange &operands,
|
|
const BufferizationOptions &options) {
|
|
SmallVector<Value> result;
|
|
for (OpOperand &opOperand : operands) {
|
|
if (isa<TensorType>(opOperand.get().getType())) {
|
|
FailureOr<Value> resultBuffer =
|
|
getBuffer(rewriter, opOperand.get(), options);
|
|
if (failed(resultBuffer))
|
|
return failure();
|
|
result.push_back(*resultBuffer);
|
|
} else {
|
|
result.push_back(opOperand.get());
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
|
|
/// Helper function for loop bufferization. Given a list of bbArgs of the new
|
|
/// (bufferized) loop op, wrap the bufferized tensor args (now memrefs) into
|
|
/// ToTensorOps, so that the block body can be moved over to the new op.
|
|
static SmallVector<Value>
|
|
getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
|
|
Block::BlockArgListType oldBbArgs,
|
|
const DenseSet<int64_t> &tensorIndices) {
|
|
SmallVector<Value> result;
|
|
for (const auto &it : llvm::enumerate(bbArgs)) {
|
|
size_t idx = it.index();
|
|
Value val = it.value();
|
|
if (tensorIndices.contains(idx)) {
|
|
result.push_back(rewriter
|
|
.create<bufferization::ToTensorOp>(
|
|
val.getLoc(), oldBbArgs[idx].getType(), val)
|
|
.getResult());
|
|
} else {
|
|
result.push_back(val);
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
|
|
/// Compute the bufferized type of a loop iter_arg. This type must be equal to
|
|
/// the bufferized type of the corresponding init_arg and the bufferized type
|
|
/// of the corresponding yielded value.
|
|
///
|
|
/// This function uses bufferization::getBufferType to compute the bufferized
|
|
/// type of the init_arg and of the yielded value. (The computation of the
|
|
/// bufferized yielded value type usually requires computing the bufferized type
|
|
/// of the iter_arg again; the implementation of getBufferType traces back the
|
|
/// use-def chain of the given value and computes a buffer type along the way.)
|
|
/// If both buffer types are equal, no casts are needed the computed buffer type
|
|
/// can be used directly. Otherwise, the buffer types can only differ in their
|
|
/// layout map and a cast must be inserted.
|
|
static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
|
|
Operation *loopOp, BlockArgument iterArg, Value initArg, Value yieldedValue,
|
|
const BufferizationOptions &options, SmallVector<Value> &invocationStack) {
|
|
// Determine the buffer type of the init_arg.
|
|
auto initArgBufferType =
|
|
bufferization::getBufferType(initArg, options, invocationStack);
|
|
if (failed(initArgBufferType))
|
|
return failure();
|
|
|
|
if (llvm::count(invocationStack, iterArg) >= 2) {
|
|
// If the iter_arg is already twice on the invocation stack, just take the
|
|
// type of the init_arg. This is to avoid infinite loops when calculating
|
|
// the buffer type. This will most likely result in computing a memref type
|
|
// with a fully dynamic layout map.
|
|
|
|
// Note: For more precise layout map computation, a fixpoint iteration could
|
|
// be done (i.e., re-computing the yielded buffer type until the bufferized
|
|
// iter_arg type no longer changes). This current implementation immediately
|
|
// switches to a fully dynamic layout map when a mismatch between bufferized
|
|
// init_arg type and bufferized yield value type is detected.
|
|
return *initArgBufferType;
|
|
}
|
|
|
|
// Compute the buffer type of the yielded value.
|
|
BaseMemRefType yieldedValueBufferType;
|
|
if (isa<BaseMemRefType>(yieldedValue.getType())) {
|
|
// scf.yield was already bufferized.
|
|
yieldedValueBufferType = cast<BaseMemRefType>(yieldedValue.getType());
|
|
} else {
|
|
// Note: This typically triggers a recursive call for the buffer type of
|
|
// the iter_arg.
|
|
auto maybeBufferType =
|
|
bufferization::getBufferType(yieldedValue, options, invocationStack);
|
|
if (failed(maybeBufferType))
|
|
return failure();
|
|
yieldedValueBufferType = *maybeBufferType;
|
|
}
|
|
|
|
// If yielded type and init_arg type are the same, use that type directly.
|
|
if (*initArgBufferType == yieldedValueBufferType)
|
|
return yieldedValueBufferType;
|
|
|
|
// If there is a mismatch between the yielded buffer type and the init_arg
|
|
// buffer type, the buffer type must be promoted to a fully dynamic layout
|
|
// map.
|
|
auto yieldedBufferType = cast<BaseMemRefType>(yieldedValueBufferType);
|
|
auto iterTensorType = cast<TensorType>(iterArg.getType());
|
|
auto initBufferType = llvm::cast<BaseMemRefType>(*initArgBufferType);
|
|
if (initBufferType.getMemorySpace() != yieldedBufferType.getMemorySpace())
|
|
return loopOp->emitOpError(
|
|
"init_arg and yielded value bufferize to inconsistent memory spaces");
|
|
#ifndef NDEBUG
|
|
if (auto yieldedRankedBufferType = dyn_cast<MemRefType>(yieldedBufferType)) {
|
|
assert(
|
|
llvm::all_equal({yieldedRankedBufferType.getShape(),
|
|
cast<MemRefType>(initBufferType).getShape(),
|
|
cast<RankedTensorType>(iterTensorType).getShape()}) &&
|
|
"expected same shape");
|
|
}
|
|
#endif // NDEBUG
|
|
return getMemRefTypeWithFullyDynamicLayout(
|
|
iterTensorType, yieldedBufferType.getMemorySpace());
|
|
}
|
|
|
|
/// Return `true` if the given loop may have 0 iterations.
|
|
bool mayHaveZeroIterations(scf::ForOp forOp) {
|
|
std::optional<int64_t> lb = getConstantIntValue(forOp.getLowerBound());
|
|
std::optional<int64_t> ub = getConstantIntValue(forOp.getUpperBound());
|
|
if (!lb.has_value() || !ub.has_value())
|
|
return true;
|
|
return *ub <= *lb;
|
|
}
|
|
|
|
/// Bufferization of scf.for. Replace with a new scf.for that operates on
|
|
/// memrefs.
|
|
struct ForOpInterface
|
|
: public BufferizableOpInterface::ExternalModel<ForOpInterface,
|
|
scf::ForOp> {
|
|
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
|
const AnalysisState &state) const {
|
|
auto forOp = cast<scf::ForOp>(op);
|
|
|
|
// If the loop has zero iterations, the results of the op are their
|
|
// corresponding init_args, meaning that the init_args bufferize to a read.
|
|
if (mayHaveZeroIterations(forOp))
|
|
return true;
|
|
|
|
// scf::ForOp alone doesn't bufferize to a memory read, one of the uses of
|
|
// its matching bbArg may.
|
|
return state.isValueRead(forOp.getTiedLoopRegionIterArg(&opOperand));
|
|
}
|
|
|
|
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
|
const AnalysisState &state) const {
|
|
// Tensor iter_args of scf::ForOps are always considered as a write.
|
|
return true;
|
|
}
|
|
|
|
AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
|
|
const AnalysisState &state) const {
|
|
auto forOp = cast<scf::ForOp>(op);
|
|
OpResult opResult = forOp.getTiedLoopResult(&opOperand);
|
|
BufferRelation relation = bufferRelation(op, opResult, state);
|
|
return {{opResult, relation,
|
|
/*isDefinite=*/relation == BufferRelation::Equivalent}};
|
|
}
|
|
|
|
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
|
const AnalysisState &state) const {
|
|
// ForOp results are equivalent to their corresponding init_args if the
|
|
// corresponding iter_args and yield values are equivalent.
|
|
auto forOp = cast<scf::ForOp>(op);
|
|
BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
|
|
bool equivalentYield = state.areEquivalentBufferizedValues(
|
|
bbArg, forOp.getTiedLoopYieldedValue(bbArg)->get());
|
|
return equivalentYield ? BufferRelation::Equivalent
|
|
: BufferRelation::Unknown;
|
|
}
|
|
|
|
bool isWritable(Operation *op, Value value,
|
|
const AnalysisState &state) const {
|
|
// Interestingly, scf::ForOp's bbArg can **always** be viewed
|
|
// inplace from the perspective of ops nested under:
|
|
// 1. Either the matching iter operand is not bufferized inplace and an
|
|
// alloc + optional copy makes the bbArg itself inplaceable.
|
|
// 2. Or the matching iter operand is bufferized inplace and bbArg just
|
|
// bufferizes to that too.
|
|
return true;
|
|
}
|
|
|
|
LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
|
|
const AnalysisState &state) const {
|
|
auto bufferizableOp = cast<BufferizableOpInterface>(op);
|
|
if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
|
|
return failure();
|
|
|
|
if (!state.getOptions().enforceAliasingInvariants ||
|
|
state.getOptions().copyBeforeWrite)
|
|
return success();
|
|
|
|
// According to the `getAliasing...` implementations, a bufferized OpResult
|
|
// may alias only with the corresponding bufferized init_arg (or with a
|
|
// newly allocated buffer) and not with other buffers defined outside of the
|
|
// loop. I.e., the i-th OpResult may alias with the i-th init_arg;
|
|
// but not with any other OpOperand.
|
|
auto forOp = cast<scf::ForOp>(op);
|
|
auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
|
|
OpBuilder::InsertionGuard g(rewriter);
|
|
rewriter.setInsertionPoint(yieldOp);
|
|
|
|
// Indices of all iter_args that have tensor type. These are the ones that
|
|
// are bufferized.
|
|
DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs());
|
|
// For every yielded value, does it alias with something defined outside of
|
|
// the loop?
|
|
SmallVector<Value> yieldValues;
|
|
for (const auto it : llvm::enumerate(yieldOp.getResults())) {
|
|
// Note: `state` is guaranteed to be a `OneShotAnalysisState`, but this
|
|
// type cannot be used in the signature of `resolveConflicts` because the
|
|
// op interface is in the "IR" build unit and the `OneShotAnalysisState`
|
|
// is defined in the "Transforms" build unit.
|
|
if (!indices.contains(it.index()) ||
|
|
doesNotAliasExternalValue(
|
|
it.value(), &forOp.getRegion(),
|
|
/*exceptions=*/forOp.getRegionIterArg(it.index()),
|
|
static_cast<const OneShotAnalysisState &>(state))) {
|
|
yieldValues.push_back(it.value());
|
|
continue;
|
|
}
|
|
FailureOr<Value> alloc = allocateTensorForShapedValue(
|
|
rewriter, yieldOp.getLoc(), it.value(), state.getOptions());
|
|
if (failed(alloc))
|
|
return failure();
|
|
yieldValues.push_back(*alloc);
|
|
}
|
|
|
|
rewriter.modifyOpInPlace(
|
|
yieldOp, [&]() { yieldOp.getResultsMutable().assign(yieldValues); });
|
|
return success();
|
|
}
|
|
|
|
FailureOr<BaseMemRefType>
|
|
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
|
|
SmallVector<Value> &invocationStack) const {
|
|
auto forOp = cast<scf::ForOp>(op);
|
|
assert(getOwnerOfValue(value) == op && "invalid value");
|
|
assert(isa<TensorType>(value.getType()) && "expected tensor type");
|
|
|
|
if (auto opResult = dyn_cast<OpResult>(value)) {
|
|
// The type of an OpResult must match the corresponding iter_arg type.
|
|
BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
|
|
return bufferization::getBufferType(bbArg, options, invocationStack);
|
|
}
|
|
|
|
// Compute result/argument number.
|
|
BlockArgument bbArg = cast<BlockArgument>(value);
|
|
unsigned resultNum = forOp.getTiedLoopResult(bbArg).getResultNumber();
|
|
|
|
// Compute the bufferized type.
|
|
auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
|
|
Value yieldedValue = yieldOp.getOperand(resultNum);
|
|
BlockArgument iterArg = forOp.getRegionIterArgs()[resultNum];
|
|
Value initArg = forOp.getInitArgs()[resultNum];
|
|
return computeLoopRegionIterArgBufferType(
|
|
op, iterArg, initArg, yieldedValue, options, invocationStack);
|
|
}
|
|
|
|
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
|
const BufferizationOptions &options) const {
|
|
auto forOp = cast<scf::ForOp>(op);
|
|
Block *oldLoopBody = forOp.getBody();
|
|
|
|
// Indices of all iter_args that have tensor type. These are the ones that
|
|
// are bufferized.
|
|
DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs());
|
|
|
|
// The new memref init_args of the loop.
|
|
FailureOr<SmallVector<Value>> maybeInitArgs =
|
|
getBuffers(rewriter, forOp.getInitArgsMutable(), options);
|
|
if (failed(maybeInitArgs))
|
|
return failure();
|
|
SmallVector<Value> initArgs = *maybeInitArgs;
|
|
|
|
// Cast init_args if necessary.
|
|
SmallVector<Value> castedInitArgs;
|
|
for (const auto &it : llvm::enumerate(initArgs)) {
|
|
Value initArg = it.value();
|
|
Value result = forOp->getResult(it.index());
|
|
// If the type is not a tensor, bufferization doesn't need to touch it.
|
|
if (!isa<TensorType>(result.getType())) {
|
|
castedInitArgs.push_back(initArg);
|
|
continue;
|
|
}
|
|
auto targetType = bufferization::getBufferType(result, options);
|
|
if (failed(targetType))
|
|
return failure();
|
|
castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
|
|
}
|
|
|
|
// Construct a new scf.for op with memref instead of tensor values.
|
|
auto newForOp = rewriter.create<scf::ForOp>(
|
|
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
|
|
forOp.getStep(), castedInitArgs);
|
|
newForOp->setAttrs(forOp->getAttrs());
|
|
Block *loopBody = newForOp.getBody();
|
|
|
|
// Set up new iter_args. The loop body uses tensors, so wrap the (memref)
|
|
// iter_args of the new loop in ToTensorOps.
|
|
rewriter.setInsertionPointToStart(loopBody);
|
|
SmallVector<Value> iterArgs =
|
|
getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(),
|
|
forOp.getRegionIterArgs(), indices);
|
|
iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar());
|
|
|
|
// Move loop body to new loop.
|
|
rewriter.mergeBlocks(oldLoopBody, loopBody, iterArgs);
|
|
|
|
// Replace loop results.
|
|
replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults());
|
|
|
|
return success();
|
|
}
|
|
|
|
/// Assert that yielded values of an scf.for op are equivalent to their
|
|
/// corresponding bbArgs. In that case, the buffer relations of the
|
|
/// corresponding OpResults are "Equivalent".
|
|
///
|
|
/// If this is not the case, an allocs+copies are inserted and yielded from
|
|
/// the loop. This could be a performance problem, so it must be explicitly
|
|
/// activated with `alloc-return-allocs`.
|
|
LogicalResult verifyAnalysis(Operation *op,
|
|
const AnalysisState &state) const {
|
|
const auto &options =
|
|
static_cast<const OneShotBufferizationOptions &>(state.getOptions());
|
|
if (options.allowReturnAllocsFromLoops)
|
|
return success();
|
|
|
|
auto forOp = cast<scf::ForOp>(op);
|
|
auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
|
|
for (OpResult opResult : op->getOpResults()) {
|
|
if (!isa<TensorType>(opResult.getType()))
|
|
continue;
|
|
|
|
// Note: This is overly strict. We should check for aliasing bufferized
|
|
// values. But we don't have a "must-alias" analysis yet.
|
|
if (bufferRelation(op, opResult, state) != BufferRelation::Equivalent)
|
|
return yieldOp->emitError()
|
|
<< "Yield operand #" << opResult.getResultNumber()
|
|
<< " is not equivalent to the corresponding iter bbArg";
|
|
}
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Bufferization of scf.while. Replace with a new scf.while that operates on
|
|
/// memrefs.
|
|
struct WhileOpInterface
|
|
: public BufferizableOpInterface::ExternalModel<WhileOpInterface,
|
|
scf::WhileOp> {
|
|
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
|
const AnalysisState &state) const {
|
|
// Tensor iter_args of scf::WhileOps are always considered as a read.
|
|
return true;
|
|
}
|
|
|
|
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
|
const AnalysisState &state) const {
|
|
// Tensor iter_args of scf::WhileOps are always considered as a write.
|
|
return true;
|
|
}
|
|
|
|
AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
|
|
const AnalysisState &state) const {
|
|
auto whileOp = cast<scf::WhileOp>(op);
|
|
unsigned int idx = opOperand.getOperandNumber();
|
|
|
|
// The OpResults and OpOperands may not match. They may not even have the
|
|
// same type. The number of OpResults and OpOperands can also differ.
|
|
if (idx >= op->getNumResults() ||
|
|
opOperand.get().getType() != op->getResult(idx).getType())
|
|
return {};
|
|
|
|
// The only aliasing OpResult may be the one at the same index.
|
|
OpResult opResult = whileOp->getResult(idx);
|
|
BufferRelation relation = bufferRelation(op, opResult, state);
|
|
return {{opResult, relation,
|
|
/*isDefinite=*/relation == BufferRelation::Equivalent}};
|
|
}
|
|
|
|
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
|
const AnalysisState &state) const {
|
|
// WhileOp results are equivalent to their corresponding init_args if the
|
|
// corresponding iter_args and yield values are equivalent (for both the
|
|
// "before" and the "after" block).
|
|
unsigned int resultNumber = opResult.getResultNumber();
|
|
auto whileOp = cast<scf::WhileOp>(op);
|
|
|
|
// The "before" region bbArgs and the OpResults may not match.
|
|
if (resultNumber >= whileOp.getBeforeArguments().size())
|
|
return BufferRelation::Unknown;
|
|
if (opResult.getType() !=
|
|
whileOp.getBeforeArguments()[resultNumber].getType())
|
|
return BufferRelation::Unknown;
|
|
|
|
auto conditionOp = whileOp.getConditionOp();
|
|
BlockArgument conditionBbArg = whileOp.getBeforeArguments()[resultNumber];
|
|
Value conditionOperand = conditionOp.getArgs()[resultNumber];
|
|
bool equivCondition =
|
|
state.areEquivalentBufferizedValues(conditionBbArg, conditionOperand);
|
|
|
|
auto yieldOp = whileOp.getYieldOp();
|
|
BlockArgument bodyBbArg = whileOp.getAfterArguments()[resultNumber];
|
|
Value yieldOperand = yieldOp.getOperand(resultNumber);
|
|
bool equivYield =
|
|
state.areEquivalentBufferizedValues(bodyBbArg, yieldOperand);
|
|
|
|
return equivCondition && equivYield ? BufferRelation::Equivalent
|
|
: BufferRelation::Unknown;
|
|
}
|
|
|
|
bool isWritable(Operation *op, Value value,
|
|
const AnalysisState &state) const {
|
|
// Interestingly, scf::WhileOp's bbArg can **always** be viewed
|
|
// inplace from the perspective of ops nested under:
|
|
// 1. Either the matching iter operand is not bufferized inplace and an
|
|
// alloc + optional copy makes the bbArg itself inplaceable.
|
|
// 2. Or the matching iter operand is bufferized inplace and bbArg just
|
|
// bufferizes to that too.
|
|
return true;
|
|
}
|
|
|
|
LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
|
|
const AnalysisState &state) const {
|
|
auto bufferizableOp = cast<BufferizableOpInterface>(op);
|
|
if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
|
|
return failure();
|
|
|
|
if (!state.getOptions().enforceAliasingInvariants ||
|
|
state.getOptions().copyBeforeWrite)
|
|
return success();
|
|
|
|
// According to the `getAliasing...` implementations, a bufferized OpResult
|
|
// may alias only with the corresponding bufferized init_arg and with no
|
|
// other buffers. I.e., the i-th OpResult may alias with the i-th init_arg;
|
|
// but not with any other OpOperand. If a corresponding OpResult/init_arg
|
|
// pair bufferizes to equivalent buffers, this aliasing requirement is
|
|
// satisfied. Otherwise, we cannot be sure and must yield a new buffer copy.
|
|
// (New buffer copies do not alias with any buffer.)
|
|
OpBuilder::InsertionGuard g(rewriter);
|
|
auto whileOp = cast<scf::WhileOp>(op);
|
|
auto conditionOp = whileOp.getConditionOp();
|
|
|
|
// For every yielded value, is the value equivalent to its corresponding
|
|
// bbArg?
|
|
DenseSet<int64_t> equivalentYieldsBefore = getEquivalentBuffers(
|
|
whileOp.getBeforeArguments(), conditionOp.getArgs(), state);
|
|
DenseSet<int64_t> equivalentYieldsAfter = getEquivalentBuffers(
|
|
whileOp.getAfterArguments(), whileOp.getYieldOp().getResults(), state);
|
|
|
|
// Update "before" region.
|
|
rewriter.setInsertionPoint(conditionOp);
|
|
SmallVector<Value> beforeYieldValues;
|
|
for (int64_t idx = 0;
|
|
idx < static_cast<int64_t>(conditionOp.getArgs().size()); ++idx) {
|
|
Value value = conditionOp.getArgs()[idx];
|
|
if (!isa<TensorType>(value.getType()) ||
|
|
(equivalentYieldsAfter.contains(idx) &&
|
|
equivalentYieldsBefore.contains(idx))) {
|
|
beforeYieldValues.push_back(value);
|
|
continue;
|
|
}
|
|
FailureOr<Value> alloc = allocateTensorForShapedValue(
|
|
rewriter, conditionOp.getLoc(), value, state.getOptions());
|
|
if (failed(alloc))
|
|
return failure();
|
|
beforeYieldValues.push_back(*alloc);
|
|
}
|
|
rewriter.modifyOpInPlace(conditionOp, [&]() {
|
|
conditionOp.getArgsMutable().assign(beforeYieldValues);
|
|
});
|
|
|
|
return success();
|
|
}
|
|
|
|
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
|
const BufferizationOptions &options) const {
|
|
auto whileOp = cast<scf::WhileOp>(op);
|
|
|
|
// Indices of all bbArgs that have tensor type. These are the ones that
|
|
// are bufferized. The "before" and "after" regions may have different args.
|
|
DenseSet<int64_t> indicesBefore = getTensorIndices(whileOp.getInits());
|
|
DenseSet<int64_t> indicesAfter =
|
|
getTensorIndices(whileOp.getAfterArguments());
|
|
|
|
// The new memref init_args of the loop.
|
|
FailureOr<SmallVector<Value>> maybeInitArgs =
|
|
getBuffers(rewriter, whileOp.getInitsMutable(), options);
|
|
if (failed(maybeInitArgs))
|
|
return failure();
|
|
SmallVector<Value> initArgs = *maybeInitArgs;
|
|
|
|
// Cast init_args if necessary.
|
|
SmallVector<Value> castedInitArgs;
|
|
for (const auto &it : llvm::enumerate(initArgs)) {
|
|
Value initArg = it.value();
|
|
Value beforeArg = whileOp.getBeforeArguments()[it.index()];
|
|
// If the type is not a tensor, bufferization doesn't need to touch it.
|
|
if (!isa<TensorType>(beforeArg.getType())) {
|
|
castedInitArgs.push_back(initArg);
|
|
continue;
|
|
}
|
|
auto targetType = bufferization::getBufferType(beforeArg, options);
|
|
if (failed(targetType))
|
|
return failure();
|
|
castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
|
|
}
|
|
|
|
// The result types of a WhileOp are the same as the "after" bbArg types.
|
|
SmallVector<Type> argsTypesAfter = llvm::to_vector(
|
|
llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) {
|
|
if (!isa<TensorType>(bbArg.getType()))
|
|
return bbArg.getType();
|
|
// TODO: error handling
|
|
return llvm::cast<Type>(
|
|
*bufferization::getBufferType(bbArg, options));
|
|
}));
|
|
|
|
// Construct a new scf.while op with memref instead of tensor values.
|
|
ValueRange argsRangeBefore(castedInitArgs);
|
|
TypeRange argsTypesBefore(argsRangeBefore);
|
|
auto newWhileOp = rewriter.create<scf::WhileOp>(
|
|
whileOp.getLoc(), argsTypesAfter, castedInitArgs);
|
|
|
|
// Add before/after regions to the new op.
|
|
SmallVector<Location> bbArgLocsBefore(castedInitArgs.size(),
|
|
whileOp.getLoc());
|
|
SmallVector<Location> bbArgLocsAfter(argsTypesAfter.size(),
|
|
whileOp.getLoc());
|
|
Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock();
|
|
newWhileOp.getBefore().addArguments(argsTypesBefore, bbArgLocsBefore);
|
|
Block *newAfterBody = &newWhileOp.getAfter().emplaceBlock();
|
|
newWhileOp.getAfter().addArguments(argsTypesAfter, bbArgLocsAfter);
|
|
|
|
// Set up new iter_args and move the loop condition block to the new op.
|
|
// The old block uses tensors, so wrap the (memref) bbArgs of the new block
|
|
// in ToTensorOps.
|
|
rewriter.setInsertionPointToStart(newBeforeBody);
|
|
SmallVector<Value> newBeforeArgs =
|
|
getBbArgReplacements(rewriter, newWhileOp.getBeforeArguments(),
|
|
whileOp.getBeforeArguments(), indicesBefore);
|
|
rewriter.mergeBlocks(whileOp.getBeforeBody(), newBeforeBody, newBeforeArgs);
|
|
|
|
// Set up new iter_args and move the loop body block to the new op.
|
|
// The old block uses tensors, so wrap the (memref) bbArgs of the new block
|
|
// in ToTensorOps.
|
|
rewriter.setInsertionPointToStart(newAfterBody);
|
|
SmallVector<Value> newAfterArgs =
|
|
getBbArgReplacements(rewriter, newWhileOp.getAfterArguments(),
|
|
whileOp.getAfterArguments(), indicesAfter);
|
|
rewriter.mergeBlocks(whileOp.getAfterBody(), newAfterBody, newAfterArgs);
|
|
|
|
// Replace loop results.
|
|
replaceOpWithBufferizedValues(rewriter, op, newWhileOp->getResults());
|
|
|
|
return success();
|
|
}
|
|
|
|
FailureOr<BaseMemRefType>
|
|
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
|
|
SmallVector<Value> &invocationStack) const {
|
|
auto whileOp = cast<scf::WhileOp>(op);
|
|
assert(getOwnerOfValue(value) == op && "invalid value");
|
|
assert(isa<TensorType>(value.getType()) && "expected tensor type");
|
|
|
|
// Case 1: Block argument of the "before" region.
|
|
if (auto bbArg = dyn_cast<BlockArgument>(value)) {
|
|
if (bbArg.getOwner()->getParent() == &whileOp.getBefore()) {
|
|
Value initArg = whileOp.getInits()[bbArg.getArgNumber()];
|
|
auto yieldOp = whileOp.getYieldOp();
|
|
Value yieldedValue = yieldOp.getOperand(bbArg.getArgNumber());
|
|
return computeLoopRegionIterArgBufferType(
|
|
op, bbArg, initArg, yieldedValue, options, invocationStack);
|
|
}
|
|
}
|
|
|
|
// Case 2: OpResult of the loop or block argument of the "after" region.
|
|
// The bufferized "after" bbArg type can be directly computed from the
|
|
// bufferized "before" bbArg type.
|
|
unsigned resultNum;
|
|
if (auto opResult = dyn_cast<OpResult>(value)) {
|
|
resultNum = opResult.getResultNumber();
|
|
} else if (cast<BlockArgument>(value).getOwner()->getParent() ==
|
|
&whileOp.getAfter()) {
|
|
resultNum = cast<BlockArgument>(value).getArgNumber();
|
|
} else {
|
|
llvm_unreachable("invalid value");
|
|
}
|
|
Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum];
|
|
if (!isa<TensorType>(conditionYieldedVal.getType())) {
|
|
// scf.condition was already bufferized.
|
|
return cast<BaseMemRefType>(conditionYieldedVal.getType());
|
|
}
|
|
return bufferization::getBufferType(conditionYieldedVal, options,
|
|
invocationStack);
|
|
}
|
|
|
|
/// Assert that yielded values of an scf.while op are equivalent to their
|
|
/// corresponding bbArgs. In that case, the buffer relations of the
|
|
/// corresponding OpResults are "Equivalent".
|
|
///
|
|
/// If this is not the case, allocs+copies are inserted and yielded from
|
|
/// the loop. This could be a performance problem, so it must be explicitly
|
|
/// activated with `allow-return-allocs`.
|
|
///
|
|
/// Not: In contrast to scf::ForOp, scf::WhileOp has two regions and the
|
|
/// equivalence condition must be checked for both.
|
|
LogicalResult verifyAnalysis(Operation *op,
|
|
const AnalysisState &state) const {
|
|
auto whileOp = cast<scf::WhileOp>(op);
|
|
const auto &options =
|
|
static_cast<const OneShotBufferizationOptions &>(state.getOptions());
|
|
if (options.allowReturnAllocsFromLoops)
|
|
return success();
|
|
|
|
auto conditionOp = whileOp.getConditionOp();
|
|
for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
|
|
Block *block = conditionOp->getBlock();
|
|
if (!isa<TensorType>(it.value().getType()))
|
|
continue;
|
|
if (it.index() >= block->getNumArguments() ||
|
|
!state.areEquivalentBufferizedValues(it.value(),
|
|
block->getArgument(it.index())))
|
|
return conditionOp->emitError()
|
|
<< "Condition arg #" << it.index()
|
|
<< " is not equivalent to the corresponding iter bbArg";
|
|
}
|
|
|
|
auto yieldOp = whileOp.getYieldOp();
|
|
for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
|
|
Block *block = yieldOp->getBlock();
|
|
if (!isa<TensorType>(it.value().getType()))
|
|
continue;
|
|
if (it.index() >= block->getNumArguments() ||
|
|
!state.areEquivalentBufferizedValues(it.value(),
|
|
block->getArgument(it.index())))
|
|
return yieldOp->emitError()
|
|
<< "Yield operand #" << it.index()
|
|
<< " is not equivalent to the corresponding iter bbArg";
|
|
}
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so
|
|
/// this is for analysis only.
|
|
struct YieldOpInterface
|
|
: public BufferizableOpInterface::ExternalModel<YieldOpInterface,
|
|
scf::YieldOp> {
|
|
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
|
const AnalysisState &state) const {
|
|
return true;
|
|
}
|
|
|
|
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
|
const AnalysisState &state) const {
|
|
return false;
|
|
}
|
|
|
|
AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
|
|
const AnalysisState &state) const {
|
|
if (auto ifOp = dyn_cast<scf::IfOp>(op->getParentOp())) {
|
|
return {{op->getParentOp()->getResult(opOperand.getOperandNumber()),
|
|
BufferRelation::Equivalent, /*isDefinite=*/false}};
|
|
}
|
|
if (isa<scf::ExecuteRegionOp>(op->getParentOp()))
|
|
return {{op->getParentOp()->getResult(opOperand.getOperandNumber()),
|
|
BufferRelation::Equivalent}};
|
|
return {};
|
|
}
|
|
|
|
bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
|
|
const AnalysisState &state) const {
|
|
// Yield operands always bufferize inplace. Otherwise, an alloc + copy
|
|
// may be generated inside the block. We should not return/yield allocations
|
|
// when possible.
|
|
return true;
|
|
}
|
|
|
|
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
|
const BufferizationOptions &options) const {
|
|
auto yieldOp = cast<scf::YieldOp>(op);
|
|
if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::IndexSwitchOp, scf::ForOp,
|
|
scf::WhileOp>(yieldOp->getParentOp()))
|
|
return yieldOp->emitError("unsupported scf::YieldOp parent");
|
|
|
|
SmallVector<Value> newResults;
|
|
for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
|
|
Value value = it.value();
|
|
if (isa<TensorType>(value.getType())) {
|
|
FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options);
|
|
if (failed(maybeBuffer))
|
|
return failure();
|
|
Value buffer = *maybeBuffer;
|
|
// We may have to cast the value before yielding it.
|
|
if (isa<scf::ForOp, scf::IfOp, scf::IndexSwitchOp>(
|
|
yieldOp->getParentOp())) {
|
|
FailureOr<BaseMemRefType> resultType = bufferization::getBufferType(
|
|
yieldOp->getParentOp()->getResult(it.index()), options);
|
|
if (failed(resultType))
|
|
return failure();
|
|
buffer = castBuffer(rewriter, buffer, *resultType);
|
|
} else if (auto whileOp =
|
|
dyn_cast<scf::WhileOp>(yieldOp->getParentOp())) {
|
|
FailureOr<BaseMemRefType> resultType = bufferization::getBufferType(
|
|
whileOp.getBeforeArguments()[it.index()], options);
|
|
if (failed(resultType))
|
|
return failure();
|
|
buffer = castBuffer(rewriter, buffer, *resultType);
|
|
}
|
|
newResults.push_back(buffer);
|
|
} else {
|
|
newResults.push_back(value);
|
|
}
|
|
}
|
|
|
|
replaceOpWithNewBufferizedOp<scf::YieldOp>(rewriter, op, newResults);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Return `true` if the given loop may have 0 iterations.
|
|
bool mayHaveZeroIterations(scf::ForallOp forallOp) {
|
|
for (auto [lb, ub] : llvm::zip(forallOp.getMixedLowerBound(),
|
|
forallOp.getMixedUpperBound())) {
|
|
std::optional<int64_t> lbConst = getConstantIntValue(lb);
|
|
std::optional<int64_t> ubConst = getConstantIntValue(ub);
|
|
if (!lbConst.has_value() || !ubConst.has_value() || *lbConst >= *ubConst)
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
/// Bufferization of ForallOp. This also bufferizes the terminator of the
|
|
/// region. There are op interfaces for the terminators (InParallelOp
|
|
/// and ParallelInsertSliceOp), but these are only used during analysis. Not
|
|
/// for bufferization.
|
|
struct ForallOpInterface
|
|
: public BufferizableOpInterface::ExternalModel<ForallOpInterface,
|
|
ForallOp> {
|
|
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
|
const AnalysisState &state) const {
|
|
auto forallOp = cast<ForallOp>(op);
|
|
|
|
// If the loop has zero iterations, the results of the op are their
|
|
// corresponding shared_outs, meaning that the shared_outs bufferize to a
|
|
// read.
|
|
if (mayHaveZeroIterations(forallOp))
|
|
return true;
|
|
|
|
// scf::ForallOp alone doesn't bufferize to a memory read, one of the
|
|
// uses of its matching bbArg may.
|
|
return state.isValueRead(forallOp.getTiedBlockArgument(&opOperand));
|
|
}
|
|
|
|
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
|
const AnalysisState &state) const {
|
|
// Outputs of scf::ForallOps are always considered as a write.
|
|
return true;
|
|
}
|
|
|
|
AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
|
|
const AnalysisState &state) const {
|
|
auto forallOp = cast<ForallOp>(op);
|
|
return {
|
|
{{forallOp.getTiedOpResult(&opOperand), BufferRelation::Equivalent}}};
|
|
}
|
|
|
|
bool isWritable(Operation *op, Value value,
|
|
const AnalysisState &state) const {
|
|
return true;
|
|
}
|
|
|
|
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
|
const BufferizationOptions &options) const {
|
|
OpBuilder::InsertionGuard guard(rewriter);
|
|
auto forallOp = cast<ForallOp>(op);
|
|
int64_t rank = forallOp.getRank();
|
|
|
|
// Get buffers for all output operands.
|
|
SmallVector<Value> buffers;
|
|
for (Value out : forallOp.getOutputs()) {
|
|
FailureOr<Value> buffer = getBuffer(rewriter, out, options);
|
|
if (failed(buffer))
|
|
return failure();
|
|
buffers.push_back(*buffer);
|
|
}
|
|
|
|
// Use buffers instead of block arguments.
|
|
rewriter.setInsertionPointToStart(forallOp.getBody());
|
|
for (const auto &it : llvm::zip(
|
|
forallOp.getBody()->getArguments().drop_front(rank), buffers)) {
|
|
BlockArgument bbArg = std::get<0>(it);
|
|
Value buffer = std::get<1>(it);
|
|
Value bufferAsTensor = rewriter.create<ToTensorOp>(
|
|
forallOp.getLoc(), bbArg.getType(), buffer);
|
|
bbArg.replaceAllUsesWith(bufferAsTensor);
|
|
}
|
|
|
|
// Create new ForallOp without any results and drop the automatically
|
|
// introduced terminator.
|
|
rewriter.setInsertionPoint(forallOp);
|
|
ForallOp newForallOp;
|
|
newForallOp = rewriter.create<ForallOp>(
|
|
forallOp.getLoc(), forallOp.getMixedLowerBound(),
|
|
forallOp.getMixedUpperBound(), forallOp.getMixedStep(),
|
|
/*outputs=*/ValueRange(), forallOp.getMapping());
|
|
|
|
// Keep discardable attributes from the original op.
|
|
newForallOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
|
|
|
|
rewriter.eraseOp(newForallOp.getBody()->getTerminator());
|
|
|
|
// Move over block contents of the old op.
|
|
SmallVector<Value> replacementBbArgs;
|
|
replacementBbArgs.append(newForallOp.getBody()->getArguments().begin(),
|
|
newForallOp.getBody()->getArguments().end());
|
|
replacementBbArgs.append(forallOp.getOutputs().size(), Value());
|
|
rewriter.mergeBlocks(forallOp.getBody(), newForallOp.getBody(),
|
|
replacementBbArgs);
|
|
|
|
// Remove the old op and replace all of its uses.
|
|
replaceOpWithBufferizedValues(rewriter, op, buffers);
|
|
|
|
return success();
|
|
}
|
|
|
|
FailureOr<BaseMemRefType>
|
|
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
|
|
SmallVector<Value> &invocationStack) const {
|
|
auto forallOp = cast<ForallOp>(op);
|
|
|
|
if (auto bbArg = dyn_cast<BlockArgument>(value))
|
|
// A tensor block argument has the same bufferized type as the
|
|
// corresponding output operand.
|
|
return bufferization::getBufferType(
|
|
forallOp.getTiedOpOperand(bbArg)->get(), options, invocationStack);
|
|
|
|
// The bufferized result type is the same as the bufferized type of the
|
|
// corresponding output operand.
|
|
return bufferization::getBufferType(
|
|
forallOp.getOutputs()[cast<OpResult>(value).getResultNumber()], options,
|
|
invocationStack);
|
|
}
|
|
|
|
bool isRepetitiveRegion(Operation *op, unsigned index) const {
|
|
auto forallOp = cast<ForallOp>(op);
|
|
|
|
// This op is repetitive if it has 1 or more steps.
|
|
// If the control variables are dynamic, it is also considered so.
|
|
for (auto [lb, ub, step] :
|
|
llvm::zip(forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
|
|
forallOp.getMixedStep())) {
|
|
std::optional<int64_t> lbConstant = getConstantIntValue(lb);
|
|
if (!lbConstant)
|
|
return true;
|
|
|
|
std::optional<int64_t> ubConstant = getConstantIntValue(ub);
|
|
if (!ubConstant)
|
|
return true;
|
|
|
|
std::optional<int64_t> stepConstant = getConstantIntValue(step);
|
|
if (!stepConstant)
|
|
return true;
|
|
|
|
if (*lbConstant + *stepConstant < *ubConstant)
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool isParallelRegion(Operation *op, unsigned index) const {
|
|
return isRepetitiveRegion(op, index);
|
|
}
|
|
};
|
|
|
|
/// Nothing to do for InParallelOp.
|
|
struct InParallelOpInterface
|
|
: public BufferizableOpInterface::ExternalModel<InParallelOpInterface,
|
|
InParallelOp> {
|
|
LogicalResult bufferize(Operation *op, RewriterBase &b,
|
|
const BufferizationOptions &options) const {
|
|
llvm_unreachable("op does not have any tensor OpOperands / OpResults");
|
|
return failure();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
} // namespace scf
|
|
} // namespace mlir
|
|
|
|
void mlir::scf::registerBufferizableOpInterfaceExternalModels(
|
|
DialectRegistry ®istry) {
|
|
registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) {
|
|
ConditionOp::attachInterface<ConditionOpInterface>(*ctx);
|
|
ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx);
|
|
ForOp::attachInterface<ForOpInterface>(*ctx);
|
|
IfOp::attachInterface<IfOpInterface>(*ctx);
|
|
IndexSwitchOp::attachInterface<IndexSwitchOpInterface>(*ctx);
|
|
ForallOp::attachInterface<ForallOpInterface>(*ctx);
|
|
InParallelOp::attachInterface<InParallelOpInterface>(*ctx);
|
|
WhileOp::attachInterface<WhileOpInterface>(*ctx);
|
|
YieldOp::attachInterface<YieldOpInterface>(*ctx);
|
|
});
|
|
}
|