[mlir] canonicalizer: shape_cast(poison) -> poison (#133988)

Based on the ShapeCastConstantFolder, this pattern replaces

%0 = ub.poison : vector<2x3xf32>
%1 = vector.shape_cast %0 vector<2x3xf32> to vector<6xf32>

with 

%1 = ub.poison : vector<6xf32>

---------

Signed-off-by: James Newling <james.newling@gmail.com>
This commit is contained in:
James Newling
2025-04-11 07:13:03 -07:00
committed by GitHub
parent a9225251c4
commit cd85f5dbdf
2 changed files with 39 additions and 32 deletions

View File

@@ -42,6 +42,7 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Casting.h"
#include <cassert>
#include <cstdint>
@@ -5611,18 +5612,20 @@ LogicalResult ShapeCastOp::verify() {
}
OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
// No-op shape cast.
if (getSource().getType() == getResult().getType())
if (getSource().getType() == getType())
return getSource();
VectorType resultType = getType();
// Canceling shape casts.
if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
if (getResult().getType() == otherOp.getSource().getType())
return otherOp.getSource();
// Only allows valid transitive folding.
VectorType srcType = llvm::cast<VectorType>(otherOp.getSource().getType());
VectorType resultType = llvm::cast<VectorType>(getResult().getType());
// Only allows valid transitive folding (expand/collapse dimensions).
VectorType srcType = otherOp.getSource().getType();
if (resultType == srcType)
return otherOp.getSource();
if (srcType.getRank() < resultType.getRank()) {
if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
return {};
@@ -5632,43 +5635,32 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
} else {
return {};
}
setOperand(otherOp.getSource());
return getResult();
}
// Cancelling broadcast and shape cast ops.
if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
if (bcastOp.getSourceType() == getType())
if (bcastOp.getSourceType() == resultType)
return bcastOp.getSource();
}
// shape_cast(constant) -> constant
if (auto splatAttr =
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource())) {
return DenseElementsAttr::get(resultType,
splatAttr.getSplatValue<Attribute>());
}
// shape_cast(poison) -> poison
if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource())) {
return ub::PoisonAttr::get(getContext());
}
return {};
}
namespace {
// Pattern to rewrite a ShapeCast(splat ConstantOp) -> ConstantOp.
class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
PatternRewriter &rewriter) const override {
auto constantOp =
shapeCastOp.getSource().getDefiningOp<arith::ConstantOp>();
if (!constantOp)
return failure();
// Only handle splat for now.
auto dense = llvm::dyn_cast<SplatElementsAttr>(constantOp.getValue());
if (!dense)
return failure();
auto newAttr =
DenseElementsAttr::get(llvm::cast<VectorType>(shapeCastOp.getType()),
dense.getSplatValue<Attribute>());
rewriter.replaceOpWithNewOp<arith::ConstantOp>(shapeCastOp, newAttr);
return success();
}
};
/// Helper function that computes a new vector type based on the input vector
/// type by removing the trailing one dims:
@@ -5828,8 +5820,9 @@ public:
void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ShapeCastConstantFolder, ShapeCastCreateMaskFolderTrailingOneDim,
ShapeCastBroadcastFolder>(context);
results
.add<ShapeCastCreateMaskFolderTrailingOneDim, ShapeCastBroadcastFolder>(
context);
}
//===----------------------------------------------------------------------===//

View File

@@ -1167,6 +1167,20 @@ func.func @shape_cast_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
// -----
// CHECK-LABEL: shape_cast_poison
// CHECK-DAG: %[[CST1:.*]] = ub.poison : vector<3x4x2xi32>
// CHECK-DAG: %[[CST0:.*]] = ub.poison : vector<20x2xf32>
// CHECK: return %[[CST0]], %[[CST1]] : vector<20x2xf32>, vector<3x4x2xi32>
func.func @shape_cast_poison() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
%poison = ub.poison : vector<5x4x2xf32>
%poison_1 = ub.poison : vector<12x2xi32>
%0 = vector.shape_cast %poison : vector<5x4x2xf32> to vector<20x2xf32>
%1 = vector.shape_cast %poison_1 : vector<12x2xi32> to vector<3x4x2xi32>
return %0, %1 : vector<20x2xf32>, vector<3x4x2xi32>
}
// -----
// CHECK-LABEL: extract_strided_constant
// CHECK-DAG: %[[CST1:.*]] = arith.constant dense<1> : vector<2x13x3xi32>
// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<2.000000e+00> : vector<12x2xf32>