There is currently an op interface for subset insertion ops (`SubsetInsertionOpInterface`), but not for subset extraction ops. This commit adds `SubsetExtractionOpInterface` to `mlir/Interfaces`, as well as a common dependent op interface: `SubsetOpInterface`. - `SubsetOpInterface` is for ops that operate on tensor subsets. It provides interface methods to check if two subset ops operate on equivalent or disjoint subsets. Ops that implement this interface must implement either `SubsetExtractionOpInterface` or `SubsetInsertionOpInterface`. - `SubsetExtractionOpInterface` is for ops that extract from a tensor at a subset. E.g., `tensor.extract_slice`, `tensor.gather`, `vector.transfer_read`. Current implemented only on `tensor.extract_slice`. - `SubsetInsertionOpInterface` is for ops that insert into a destination tensor at a subset. E.g., `tensor.insert_slice`, `tensor.parallel_insert_slice`, `tensor.scatter`, `vector.transfer_write`. Currently only implemented on `tensor.insert_slice`, `tensor.parallel_insert_slice`. Other changes: - Rename `SubsetInsertionOpInterface.td` to `SubsetOpInterface.td`. - Add helper functions to `ValueBoundsOpInterface.cpp` for checking whether two slices are disjoint. The new interfaces will be utilized by a new "loop-invariant subset hoisting" transformation. (This new transform is roughly what `Linalg/Transforms/SubsetHoisting.cpp` is doing, but in a generic and interface-driven way.)
217 lines
8.1 KiB
C++
217 lines
8.1 KiB
C++
//===- EmptyTensorElimination.cpp - tensor.empty op elimination -----------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
|
|
|
|
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
|
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
|
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
|
|
#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
|
|
#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/IR/Dominance.h"
|
|
#include "mlir/Interfaces/SubsetOpInterface.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
|
|
namespace mlir {
|
|
namespace bufferization {
|
|
#define GEN_PASS_DEF_EMPTYTENSORELIMINATION
|
|
#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
|
|
} // namespace bufferization
|
|
} // namespace mlir
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::bufferization;
|
|
|
|
/// Return true if all `neededValues` are in scope at the given
|
|
/// `insertionPoint`.
|
|
static bool
|
|
neededValuesDominateInsertionPoint(const DominanceInfo &domInfo,
|
|
Operation *insertionPoint,
|
|
const SmallVector<Value> &neededValues) {
|
|
for (Value val : neededValues) {
|
|
if (auto bbArg = dyn_cast<BlockArgument>(val)) {
|
|
Block *owner = bbArg.getOwner();
|
|
if (!owner->findAncestorOpInBlock(*insertionPoint))
|
|
return false;
|
|
} else {
|
|
auto opResult = cast<OpResult>(val);
|
|
if (!domInfo.properlyDominates(opResult.getOwner(), insertionPoint))
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
/// Return true if the given `insertionPoint` dominates all uses of
|
|
/// `emptyTensorOp`.
|
|
static bool insertionPointDominatesUses(const DominanceInfo &domInfo,
|
|
Operation *insertionPoint,
|
|
Operation *emptyTensorOp) {
|
|
for (Operation *user : emptyTensorOp->getUsers())
|
|
if (!domInfo.dominates(insertionPoint, user))
|
|
return false;
|
|
return true;
|
|
}
|
|
|
|
/// Find a valid insertion point for a replacement of `emptyTensorOp`, assuming
|
|
/// that the replacement may use any value from `neededValues`.
|
|
static Operation *
|
|
findValidInsertionPoint(Operation *emptyTensorOp,
|
|
const SmallVector<Value> &neededValues) {
|
|
DominanceInfo domInfo;
|
|
|
|
// Gather all possible insertion points: the location of `emptyTensorOp` and
|
|
// right after the definition of each value in `neededValues`.
|
|
SmallVector<Operation *> insertionPointCandidates;
|
|
insertionPointCandidates.push_back(emptyTensorOp);
|
|
for (Value val : neededValues) {
|
|
// Note: The anchor op is using all of `neededValues`, so:
|
|
// * in case of a block argument: There must be at least one op in the block
|
|
// (the anchor op or one of its parents).
|
|
// * in case of an OpResult: There must be at least one op right after the
|
|
// defining op (the anchor op or one of its
|
|
// parents).
|
|
if (auto bbArg = dyn_cast<BlockArgument>(val)) {
|
|
insertionPointCandidates.push_back(
|
|
&bbArg.getOwner()->getOperations().front());
|
|
} else {
|
|
insertionPointCandidates.push_back(val.getDefiningOp()->getNextNode());
|
|
}
|
|
}
|
|
|
|
// Select first matching insertion point.
|
|
for (Operation *insertionPoint : insertionPointCandidates) {
|
|
// Check if all needed values are in scope.
|
|
if (!neededValuesDominateInsertionPoint(domInfo, insertionPoint,
|
|
neededValues))
|
|
continue;
|
|
// Check if the insertion point is before all uses.
|
|
if (!insertionPointDominatesUses(domInfo, insertionPoint, emptyTensorOp))
|
|
continue;
|
|
return insertionPoint;
|
|
}
|
|
|
|
// No suitable insertion point was found.
|
|
return nullptr;
|
|
}
|
|
|
|
LogicalResult mlir::bufferization::eliminateEmptyTensors(
|
|
RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) {
|
|
OpBuilder::InsertionGuard g(rewriter);
|
|
|
|
op->walk([&](SubsetInsertionOpInterface op) {
|
|
OpOperand &source = op.getSourceOperand();
|
|
// Skip operands that do not bufferize inplace. "tensor.empty" could still
|
|
// be replaced, but the transformation may not be beneficial.
|
|
if (!state.isInPlace(source))
|
|
return WalkResult::skip();
|
|
|
|
// All values that are needed to create the replacement op.
|
|
SmallVector<Value> neededValues =
|
|
op.getValuesNeededToBuildSubsetExtraction();
|
|
|
|
// Find tensor.empty ops on the reverse SSA use-def chain. Only follow
|
|
// equivalent tensors. I.e., stop when there are ops such as extract_slice
|
|
// on the path.
|
|
TraversalConfig config;
|
|
config.followEquivalentOnly = true;
|
|
config.alwaysIncludeLeaves = false;
|
|
// Replace only if the types match or are static <-> dynamic casts. We do
|
|
// not support slices or reshapes.
|
|
// TODO: This could be extended to support IR such as:
|
|
// %0 = tensor.empty() : tensor<128xf32>
|
|
// %1 = "some_op"(%0) : (tensor<128xf32>) -> (tensor<128xf32>)
|
|
// %2 = tensor.expand_shape %1 ...
|
|
// %3 = tensor.insert_slice %2 into ...
|
|
config.followSameTypeOrCastsOnly = true;
|
|
SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
|
|
source.get(), /*condition=*/
|
|
[&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); },
|
|
config);
|
|
|
|
for (Value v : emptyTensors) {
|
|
Operation *emptyTensorOp = v.getDefiningOp();
|
|
|
|
// Find a suitable insertion point. If no suitable insertion point for
|
|
// the replacement can be found, skip this replacement.
|
|
Operation *insertionPoint =
|
|
findValidInsertionPoint(emptyTensorOp, neededValues);
|
|
if (!insertionPoint)
|
|
continue;
|
|
|
|
rewriter.setInsertionPoint(insertionPoint);
|
|
Value replacement =
|
|
op.buildSubsetExtraction(rewriter, emptyTensorOp->getLoc());
|
|
if (!replacement)
|
|
continue;
|
|
if (emptyTensorOp == replacement.getDefiningOp())
|
|
continue;
|
|
if (replacement.getType() != v.getType()) {
|
|
rewriter.setInsertionPointAfterValue(replacement);
|
|
replacement = rewriter.create<tensor::CastOp>(v.getLoc(), v.getType(),
|
|
replacement);
|
|
}
|
|
// Replace the tensor::EmptyOp.
|
|
rewriter.replaceOp(emptyTensorOp, replacement);
|
|
state.resetCache();
|
|
}
|
|
|
|
return WalkResult::advance();
|
|
});
|
|
|
|
return success();
|
|
}
|
|
|
|
namespace {
|
|
struct EmptyTensorElimination
|
|
: public bufferization::impl::EmptyTensorEliminationBase<
|
|
EmptyTensorElimination> {
|
|
EmptyTensorElimination() = default;
|
|
|
|
void runOnOperation() override;
|
|
|
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
|
registry
|
|
.insert<bufferization::BufferizationDialect, tensor::TensorDialect>();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
LogicalResult mlir::bufferization::eliminateEmptyTensors(RewriterBase &rewriter,
|
|
Operation *op) {
|
|
auto moduleOp = dyn_cast<ModuleOp>(op);
|
|
OneShotBufferizationOptions options;
|
|
options.allowReturnAllocsFromLoops = true;
|
|
if (moduleOp)
|
|
options.bufferizeFunctionBoundaries = true;
|
|
OneShotAnalysisState state(op, options);
|
|
if (moduleOp) {
|
|
// Module analysis takes into account function boundaries.
|
|
if (failed(analyzeModuleOp(moduleOp, state)))
|
|
return failure();
|
|
} else {
|
|
// Regular One-Shot Bufferize ignores func.func block arguments, func.call,
|
|
// func.return.
|
|
if (failed(analyzeOp(op, state)))
|
|
return failure();
|
|
}
|
|
|
|
return bufferization::eliminateEmptyTensors(rewriter, op, state);
|
|
}
|
|
|
|
void EmptyTensorElimination::runOnOperation() {
|
|
IRRewriter rewriter(getOperation()->getContext());
|
|
if (failed(bufferization::eliminateEmptyTensors(rewriter, getOperation())))
|
|
signalPassFailure();
|
|
}
|
|
|
|
std::unique_ptr<Pass> mlir::bufferization::createEmptyTensorEliminationPass() {
|
|
return std::make_unique<EmptyTensorElimination>();
|
|
}
|