[MLIR][Shape] Support >2 args in shape.broadcast folder (#126808)

Hi!

As the title says, this PR adds support for >2 arguments in
`shape.broadcast` folder by sequentially calling `getBroadcastedShape`.
This commit is contained in:
Mateusz Sokół
2025-04-14 19:50:52 +02:00
committed by GitHub
parent da17ced11b
commit df84aa8e06
3 changed files with 33 additions and 15 deletions

View File

@@ -649,24 +649,29 @@ OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
return getShapes().front();
}
// TODO: Support folding with more than 2 input shapes
if (getShapes().size() > 2)
if (!adaptor.getShapes().front())
return nullptr;
if (!adaptor.getShapes()[0] || !adaptor.getShapes()[1])
return nullptr;
auto lhsShape = llvm::to_vector<6>(
llvm::cast<DenseIntElementsAttr>(adaptor.getShapes()[0])
SmallVector<int64_t, 6> resultShape(
llvm::cast<DenseIntElementsAttr>(adaptor.getShapes().front())
.getValues<int64_t>());
auto rhsShape = llvm::to_vector<6>(
llvm::cast<DenseIntElementsAttr>(adaptor.getShapes()[1])
.getValues<int64_t>());
SmallVector<int64_t, 6> resultShape;
// If the shapes are not compatible, we can't fold it.
// TODO: Fold to an "error".
if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
return nullptr;
for (auto next : adaptor.getShapes().drop_front()) {
if (!next)
return nullptr;
auto nextShape = llvm::to_vector<6>(
llvm::cast<DenseIntElementsAttr>(next).getValues<int64_t>());
SmallVector<int64_t, 6> tmpShape;
// If the shapes are not compatible, we can't fold it.
// TODO: Fold to an "error".
if (!OpTrait::util::getBroadcastedShape(resultShape, nextShape, tmpShape))
return nullptr;
resultShape.clear();
std::copy(tmpShape.begin(), tmpShape.end(),
std::back_inserter(resultShape));
}
Builder builder(getContext());
return builder.getIndexTensorAttr(resultShape);

View File

@@ -84,7 +84,7 @@ bool OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1,
if (ShapedType::isDynamic(*i1) || ShapedType::isDynamic(*i2)) {
// One or both dimensions is unknown. Follow TensorFlow behavior:
// - If either dimension is greater than 1, we assume that the program is
// correct, and the other dimension will be broadcast to match it.
// correct, and the other dimension will be broadcasted to match it.
// - If either dimension is 1, the other dimension is the output.
if (*i1 > 1) {
*iR = *i1;

View File

@@ -86,6 +86,19 @@ func.func @broadcast() -> !shape.shape {
// -----
// Variadic case including extent tensors.
// CHECK-LABEL: @broadcast_variadic
func.func @broadcast_variadic() -> !shape.shape {
// CHECK: shape.const_shape [7, 2, 10] : !shape.shape
%0 = shape.const_shape [2, 1] : tensor<2xindex>
%1 = shape.const_shape [7, 2, 1] : tensor<3xindex>
%2 = shape.const_shape [1, 10] : tensor<2xindex>
%3 = shape.broadcast %0, %1, %2 : tensor<2xindex>, tensor<3xindex>, tensor<2xindex> -> !shape.shape
return %3 : !shape.shape
}
// -----
// Rhs is a scalar.
// CHECK-LABEL: func @f
func.func @f(%arg0 : !shape.shape) -> !shape.shape {