[mlir] Do not bufferize parallel_insert_slice dest to read for full slices (#112761)
In the insert_slice bufferization interface implementation, the destination tensor is not considered read if the full tensor is overwritten by the slice. This PR adds the same check for tensor.parallel_insert_slice. Adds two new StaticValueUtils: - `isAllConstantIntValue` checks if an array of `OpFoldResult` are all equal to a passed `int64_t` value. - `areConstantIntValues` checks if an array of `OpFoldResult` are all equal to a passed array of `int64_t` values. fixes https://github.com/llvm/llvm-project/issues/112435 --------- Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
This commit is contained in:
@@ -92,6 +92,12 @@ getConstantIntValues(ArrayRef<OpFoldResult> ofrs);
|
|||||||
|
|
||||||
/// Return true if `ofr` is constant integer equal to `value`.
|
/// Return true if `ofr` is constant integer equal to `value`.
|
||||||
bool isConstantIntValue(OpFoldResult ofr, int64_t value);
|
bool isConstantIntValue(OpFoldResult ofr, int64_t value);
|
||||||
|
/// Return true if all of `ofrs` are constant integers equal to `value`.
|
||||||
|
bool areAllConstantIntValue(ArrayRef<OpFoldResult> ofrs, int64_t value);
|
||||||
|
/// Return true if all of `ofrs` are constant integers equal to the
|
||||||
|
/// corresponding value in `values`.
|
||||||
|
bool areConstantIntValues(ArrayRef<OpFoldResult> ofrs,
|
||||||
|
ArrayRef<int64_t> values);
|
||||||
|
|
||||||
/// Return true if ofr1 and ofr2 are the same integer constant attribute
|
/// Return true if ofr1 and ofr2 are the same integer constant attribute
|
||||||
/// values or the same SSA value. Ignore integer bitwitdh and type mismatch
|
/// values or the same SSA value. Ignore integer bitwitdh and type mismatch
|
||||||
|
|||||||
@@ -19,6 +19,7 @@
|
|||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h"
|
#include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h"
|
||||||
#include "mlir/Dialect/Utils/StaticValueUtils.h"
|
#include "mlir/Dialect/Utils/StaticValueUtils.h"
|
||||||
|
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||||
#include "mlir/IR/Dialect.h"
|
#include "mlir/IR/Dialect.h"
|
||||||
#include "mlir/IR/Operation.h"
|
#include "mlir/IR/Operation.h"
|
||||||
|
|
||||||
@@ -636,6 +637,28 @@ struct InsertOpInterface
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename InsertOpTy>
|
||||||
|
static bool insertSliceOpRequiresRead(InsertOpTy insertSliceOp,
|
||||||
|
OpOperand &opOperand) {
|
||||||
|
// The source is always read.
|
||||||
|
if (opOperand == insertSliceOp.getSourceMutable())
|
||||||
|
return true;
|
||||||
|
|
||||||
|
// For the destination, it depends...
|
||||||
|
assert(opOperand == insertSliceOp.getDestMutable() && "expected dest");
|
||||||
|
|
||||||
|
// Dest is not read if it is entirely overwritten. E.g.:
|
||||||
|
// tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32>
|
||||||
|
bool allOffsetsZero =
|
||||||
|
llvm::all_of(insertSliceOp.getMixedOffsets(), isZeroIndex);
|
||||||
|
RankedTensorType destType = insertSliceOp.getDestType();
|
||||||
|
bool sizesMatchDestSizes =
|
||||||
|
areConstantIntValues(insertSliceOp.getMixedSizes(), destType.getShape());
|
||||||
|
bool allStridesOne =
|
||||||
|
areAllConstantIntValue(insertSliceOp.getMixedStrides(), 1);
|
||||||
|
return !(allOffsetsZero && sizesMatchDestSizes && allStridesOne);
|
||||||
|
}
|
||||||
|
|
||||||
/// Bufferization of tensor.insert_slice. Replace with a memory copy. Under
|
/// Bufferization of tensor.insert_slice. Replace with a memory copy. Under
|
||||||
/// certain circumstances, this op can also be a no-op.
|
/// certain circumstances, this op can also be a no-op.
|
||||||
///
|
///
|
||||||
@@ -646,32 +669,8 @@ struct InsertSliceOpInterface
|
|||||||
tensor::InsertSliceOp> {
|
tensor::InsertSliceOp> {
|
||||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
||||||
const AnalysisState &state) const {
|
const AnalysisState &state) const {
|
||||||
auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
|
return insertSliceOpRequiresRead(cast<tensor::InsertSliceOp>(op),
|
||||||
RankedTensorType destType = insertSliceOp.getDestType();
|
opOperand);
|
||||||
|
|
||||||
// The source is always read.
|
|
||||||
if (opOperand == insertSliceOp.getSourceMutable())
|
|
||||||
return true;
|
|
||||||
|
|
||||||
// For the destination, it depends...
|
|
||||||
assert(opOperand == insertSliceOp.getDestMutable() && "expected dest");
|
|
||||||
|
|
||||||
// Dest is not read if it is entirely overwritten. E.g.:
|
|
||||||
// tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32>
|
|
||||||
bool allOffsetsZero =
|
|
||||||
llvm::all_of(insertSliceOp.getMixedOffsets(), [](OpFoldResult ofr) {
|
|
||||||
return isConstantIntValue(ofr, 0);
|
|
||||||
});
|
|
||||||
bool sizesMatchDestSizes = llvm::all_of(
|
|
||||||
llvm::enumerate(insertSliceOp.getMixedSizes()), [&](const auto &it) {
|
|
||||||
return getConstantIntValue(it.value()) ==
|
|
||||||
destType.getDimSize(it.index());
|
|
||||||
});
|
|
||||||
bool allStridesOne =
|
|
||||||
llvm::all_of(insertSliceOp.getMixedStrides(), [](OpFoldResult ofr) {
|
|
||||||
return isConstantIntValue(ofr, 1);
|
|
||||||
});
|
|
||||||
return !(allOffsetsZero && sizesMatchDestSizes && allStridesOne);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||||
@@ -931,7 +930,8 @@ struct ParallelInsertSliceOpInterface
|
|||||||
|
|
||||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
||||||
const AnalysisState &state) const {
|
const AnalysisState &state) const {
|
||||||
return true;
|
return insertSliceOpRequiresRead(cast<tensor::ParallelInsertSliceOp>(op),
|
||||||
|
opOperand);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
||||||
|
|||||||
@@ -16,11 +16,6 @@ namespace mlir {
|
|||||||
namespace tensor {
|
namespace tensor {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
static bool areAllConstantIntValue(ArrayRef<OpFoldResult> ofrs, int64_t value) {
|
|
||||||
return llvm::all_of(
|
|
||||||
ofrs, [&](OpFoldResult ofr) { return isConstantIntValue(ofr, value); });
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns the number of shape sizes that is either dynamic or greater than 1.
|
/// Returns the number of shape sizes that is either dynamic or greater than 1.
|
||||||
static int64_t getNumGtOneDims(ArrayRef<int64_t> shape) {
|
static int64_t getNumGtOneDims(ArrayRef<int64_t> shape) {
|
||||||
return llvm::count_if(
|
return llvm::count_if(
|
||||||
|
|||||||
@@ -10,6 +10,7 @@
|
|||||||
#include "mlir/IR/Matchers.h"
|
#include "mlir/IR/Matchers.h"
|
||||||
#include "mlir/Support/LLVM.h"
|
#include "mlir/Support/LLVM.h"
|
||||||
#include "llvm/ADT/APSInt.h"
|
#include "llvm/ADT/APSInt.h"
|
||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
#include "llvm/Support/MathExtras.h"
|
#include "llvm/Support/MathExtras.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
@@ -131,12 +132,24 @@ getConstantIntValues(ArrayRef<OpFoldResult> ofrs) {
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Return true if `ofr` is constant integer equal to `value`.
|
|
||||||
bool isConstantIntValue(OpFoldResult ofr, int64_t value) {
|
bool isConstantIntValue(OpFoldResult ofr, int64_t value) {
|
||||||
auto val = getConstantIntValue(ofr);
|
auto val = getConstantIntValue(ofr);
|
||||||
return val && *val == value;
|
return val && *val == value;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool areAllConstantIntValue(ArrayRef<OpFoldResult> ofrs, int64_t value) {
|
||||||
|
return llvm::all_of(
|
||||||
|
ofrs, [&](OpFoldResult ofr) { return isConstantIntValue(ofr, value); });
|
||||||
|
}
|
||||||
|
|
||||||
|
bool areConstantIntValues(ArrayRef<OpFoldResult> ofrs,
|
||||||
|
ArrayRef<int64_t> values) {
|
||||||
|
if (ofrs.size() != values.size())
|
||||||
|
return false;
|
||||||
|
std::optional<SmallVector<int64_t>> constOfrs = getConstantIntValues(ofrs);
|
||||||
|
return constOfrs && llvm::equal(constOfrs.value(), values);
|
||||||
|
}
|
||||||
|
|
||||||
/// Return true if ofr1 and ofr2 are the same integer constant attribute values
|
/// Return true if ofr1 and ofr2 are the same integer constant attribute values
|
||||||
/// or the same SSA value.
|
/// or the same SSA value.
|
||||||
/// Ignore integer bitwidth and type mismatch that come from the fact there is
|
/// Ignore integer bitwidth and type mismatch that come from the fact there is
|
||||||
|
|||||||
@@ -213,6 +213,21 @@ func.func @rank_reducing_parallel_insert_slice(%in: tensor<100xf32>, %out: tenso
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @parallel_insert_full_slice_in_place
|
||||||
|
// CHECK-NOT: memref.alloc()
|
||||||
|
func.func @parallel_insert_full_slice_in_place(%2: tensor<2xf32>) -> tensor<2xf32> {
|
||||||
|
%cst = arith.constant 0.000000e+00 : f32
|
||||||
|
%3 = scf.forall (%arg0) in (1) shared_outs(%arg2 = %2) -> (tensor<2xf32>) {
|
||||||
|
%fill = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<2xf32>) -> tensor<2xf32>
|
||||||
|
scf.forall.in_parallel {
|
||||||
|
tensor.parallel_insert_slice %fill into %arg2[0] [2] [1] : tensor<2xf32> into tensor<2xf32>
|
||||||
|
}
|
||||||
|
} {mapping = [#gpu.thread<linear_dim_0>]}
|
||||||
|
return %3 : tensor<2xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// This test case could bufferize in-place with a better analysis. However, it
|
// This test case could bufferize in-place with a better analysis. However, it
|
||||||
// is simpler to let the canonicalizer fold away the tensor.insert_slice.
|
// is simpler to let the canonicalizer fold away the tensor.insert_slice.
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user