Files
clang-p2996/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
Christopher Bate 2646c36a86 [mlir][bufferization] Change OneShotModuleBufferize to not analyze or bufferize nested symbol tables (#127726)
The existing OneShotModuleBufferize will analyze and bufferize
operations which are in nested symbol tables (e.g. nested
`builtin.module`, `gpu.module`, or similar operations). This
behavior is untested and likely unintentional given other
limitations of OneShotModuleBufferize (`func.call` can't call
into nested symbol tables). This change reverses the existing
behavior so that the operations considered by the analysis and
bufferization exclude any operations in nested symbol table
scopes. Users who desire to bufferize nested modules can still do
so by applying the transformation in a pass pipeline or in a
custom pass. This further enables controlling the order in which
modules are bufferized as well as allowing use of different
options for different kinds of modules.
2025-02-25 14:23:11 -07:00

80 lines
2.9 KiB
C++

//===- TensorCopyInsertion.cpp - Resolve Bufferization Conflicts w/ Copies ===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
namespace mlir {
namespace bufferization {
#define GEN_PASS_DEF_TENSORCOPYINSERTION
#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
} // namespace bufferization
} // namespace mlir
using namespace mlir;
using namespace mlir::bufferization;
LogicalResult mlir::bufferization::insertTensorCopies(
Operation *op, const OneShotBufferizationOptions &options,
BufferizationStatistics *statistics) {
OneShotAnalysisState state(op, options);
// Run normal One-Shot Bufferize analysis or One-Shot Module Bufferize
// analysis depending on whether function boundary bufferization is enabled or
// not.
if (options.bufferizeFunctionBoundaries) {
if (failed(analyzeModuleOp(cast<ModuleOp>(op), state, statistics)))
return failure();
} else {
if (failed(analyzeOp(op, state, statistics)))
return failure();
}
if (options.testAnalysisOnly)
return success();
return insertTensorCopies(op, state);
}
LogicalResult
mlir::bufferization::insertTensorCopies(Operation *op,
const AnalysisState &state) {
IRRewriter rewriter(op->getContext());
// It may be more efficient to walk in pre-order here, but the current
// implementation visits regions of ops even if they are not allowed or
// bufferizable, and existing tests rely on this behavior.
// For now, only exclude nested operations if they are in a different symbol
// table scope.
WalkResult result = op->walk([&](Operation *nestedOp) {
if (op->hasTrait<OpTrait::SymbolTable>() &&
nestedOp->getParentWithTrait<OpTrait::SymbolTable>() != op)
return WalkResult::skip();
auto bufferizableOp = state.getOptions().dynCastBufferizableOp(nestedOp);
if (!bufferizableOp)
return WalkResult::skip();
// Find inplacability conflicts and resolve them. (Typically with explicit
// tensor copies in the form of AllocTensorOps.)
rewriter.setInsertionPoint(nestedOp);
if (failed(bufferizableOp.resolveConflicts(rewriter, state)))
return WalkResult::interrupt();
return WalkResult::advance();
});
return failure(result.wasInterrupted());
}