[MLIR][Shape] Support >2 args in shape.broadcast folder (#126808)
Hi! As the title says, this PR adds support for >2 arguments in `shape.broadcast` folder by sequentially calling `getBroadcastedShape`.
This commit is contained in:
@@ -649,24 +649,29 @@ OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
|
||||
return getShapes().front();
|
||||
}
|
||||
|
||||
// TODO: Support folding with more than 2 input shapes
|
||||
if (getShapes().size() > 2)
|
||||
if (!adaptor.getShapes().front())
|
||||
return nullptr;
|
||||
|
||||
if (!adaptor.getShapes()[0] || !adaptor.getShapes()[1])
|
||||
return nullptr;
|
||||
auto lhsShape = llvm::to_vector<6>(
|
||||
llvm::cast<DenseIntElementsAttr>(adaptor.getShapes()[0])
|
||||
SmallVector<int64_t, 6> resultShape(
|
||||
llvm::cast<DenseIntElementsAttr>(adaptor.getShapes().front())
|
||||
.getValues<int64_t>());
|
||||
auto rhsShape = llvm::to_vector<6>(
|
||||
llvm::cast<DenseIntElementsAttr>(adaptor.getShapes()[1])
|
||||
.getValues<int64_t>());
|
||||
SmallVector<int64_t, 6> resultShape;
|
||||
|
||||
// If the shapes are not compatible, we can't fold it.
|
||||
// TODO: Fold to an "error".
|
||||
if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
|
||||
return nullptr;
|
||||
for (auto next : adaptor.getShapes().drop_front()) {
|
||||
if (!next)
|
||||
return nullptr;
|
||||
auto nextShape = llvm::to_vector<6>(
|
||||
llvm::cast<DenseIntElementsAttr>(next).getValues<int64_t>());
|
||||
|
||||
SmallVector<int64_t, 6> tmpShape;
|
||||
// If the shapes are not compatible, we can't fold it.
|
||||
// TODO: Fold to an "error".
|
||||
if (!OpTrait::util::getBroadcastedShape(resultShape, nextShape, tmpShape))
|
||||
return nullptr;
|
||||
|
||||
resultShape.clear();
|
||||
std::copy(tmpShape.begin(), tmpShape.end(),
|
||||
std::back_inserter(resultShape));
|
||||
}
|
||||
|
||||
Builder builder(getContext());
|
||||
return builder.getIndexTensorAttr(resultShape);
|
||||
|
||||
@@ -84,7 +84,7 @@ bool OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1,
|
||||
if (ShapedType::isDynamic(*i1) || ShapedType::isDynamic(*i2)) {
|
||||
// One or both dimensions is unknown. Follow TensorFlow behavior:
|
||||
// - If either dimension is greater than 1, we assume that the program is
|
||||
// correct, and the other dimension will be broadcast to match it.
|
||||
// correct, and the other dimension will be broadcasted to match it.
|
||||
// - If either dimension is 1, the other dimension is the output.
|
||||
if (*i1 > 1) {
|
||||
*iR = *i1;
|
||||
|
||||
@@ -86,6 +86,19 @@ func.func @broadcast() -> !shape.shape {
|
||||
|
||||
// -----
|
||||
|
||||
// Variadic case including extent tensors.
|
||||
// CHECK-LABEL: @broadcast_variadic
|
||||
func.func @broadcast_variadic() -> !shape.shape {
|
||||
// CHECK: shape.const_shape [7, 2, 10] : !shape.shape
|
||||
%0 = shape.const_shape [2, 1] : tensor<2xindex>
|
||||
%1 = shape.const_shape [7, 2, 1] : tensor<3xindex>
|
||||
%2 = shape.const_shape [1, 10] : tensor<2xindex>
|
||||
%3 = shape.broadcast %0, %1, %2 : tensor<2xindex>, tensor<3xindex>, tensor<2xindex> -> !shape.shape
|
||||
return %3 : !shape.shape
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Rhs is a scalar.
|
||||
// CHECK-LABEL: func @f
|
||||
func.func @f(%arg0 : !shape.shape) -> !shape.shape {
|
||||
|
||||
Reference in New Issue
Block a user