[mlir] canonicalizer: shape_cast(poison) -> poison (#133988)
Based on the ShapeCastConstantFolder, this pattern replaces %0 = ub.poison : vector<2x3xf32> %1 = vector.shape_cast %0 vector<2x3xf32> to vector<6xf32> with %1 = ub.poison : vector<6xf32> --------- Signed-off-by: James Newling <james.newling@gmail.com>
This commit is contained in:
@@ -42,6 +42,7 @@
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringSet.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <cstdint>
|
||||
@@ -5611,18 +5612,20 @@ LogicalResult ShapeCastOp::verify() {
|
||||
}
|
||||
|
||||
OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
|
||||
|
||||
// No-op shape cast.
|
||||
if (getSource().getType() == getResult().getType())
|
||||
if (getSource().getType() == getType())
|
||||
return getSource();
|
||||
|
||||
VectorType resultType = getType();
|
||||
|
||||
// Canceling shape casts.
|
||||
if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
|
||||
if (getResult().getType() == otherOp.getSource().getType())
|
||||
return otherOp.getSource();
|
||||
|
||||
// Only allows valid transitive folding.
|
||||
VectorType srcType = llvm::cast<VectorType>(otherOp.getSource().getType());
|
||||
VectorType resultType = llvm::cast<VectorType>(getResult().getType());
|
||||
// Only allows valid transitive folding (expand/collapse dimensions).
|
||||
VectorType srcType = otherOp.getSource().getType();
|
||||
if (resultType == srcType)
|
||||
return otherOp.getSource();
|
||||
if (srcType.getRank() < resultType.getRank()) {
|
||||
if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
|
||||
return {};
|
||||
@@ -5632,43 +5635,32 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
|
||||
} else {
|
||||
return {};
|
||||
}
|
||||
|
||||
setOperand(otherOp.getSource());
|
||||
return getResult();
|
||||
}
|
||||
|
||||
// Cancelling broadcast and shape cast ops.
|
||||
if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
|
||||
if (bcastOp.getSourceType() == getType())
|
||||
if (bcastOp.getSourceType() == resultType)
|
||||
return bcastOp.getSource();
|
||||
}
|
||||
|
||||
// shape_cast(constant) -> constant
|
||||
if (auto splatAttr =
|
||||
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource())) {
|
||||
return DenseElementsAttr::get(resultType,
|
||||
splatAttr.getSplatValue<Attribute>());
|
||||
}
|
||||
|
||||
// shape_cast(poison) -> poison
|
||||
if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource())) {
|
||||
return ub::PoisonAttr::get(getContext());
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
namespace {
|
||||
// Pattern to rewrite a ShapeCast(splat ConstantOp) -> ConstantOp.
|
||||
class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto constantOp =
|
||||
shapeCastOp.getSource().getDefiningOp<arith::ConstantOp>();
|
||||
if (!constantOp)
|
||||
return failure();
|
||||
// Only handle splat for now.
|
||||
auto dense = llvm::dyn_cast<SplatElementsAttr>(constantOp.getValue());
|
||||
if (!dense)
|
||||
return failure();
|
||||
auto newAttr =
|
||||
DenseElementsAttr::get(llvm::cast<VectorType>(shapeCastOp.getType()),
|
||||
dense.getSplatValue<Attribute>());
|
||||
rewriter.replaceOpWithNewOp<arith::ConstantOp>(shapeCastOp, newAttr);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Helper function that computes a new vector type based on the input vector
|
||||
/// type by removing the trailing one dims:
|
||||
@@ -5828,8 +5820,9 @@ public:
|
||||
|
||||
void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
MLIRContext *context) {
|
||||
results.add<ShapeCastConstantFolder, ShapeCastCreateMaskFolderTrailingOneDim,
|
||||
ShapeCastBroadcastFolder>(context);
|
||||
results
|
||||
.add<ShapeCastCreateMaskFolderTrailingOneDim, ShapeCastBroadcastFolder>(
|
||||
context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -1167,6 +1167,20 @@ func.func @shape_cast_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: shape_cast_poison
|
||||
// CHECK-DAG: %[[CST1:.*]] = ub.poison : vector<3x4x2xi32>
|
||||
// CHECK-DAG: %[[CST0:.*]] = ub.poison : vector<20x2xf32>
|
||||
// CHECK: return %[[CST0]], %[[CST1]] : vector<20x2xf32>, vector<3x4x2xi32>
|
||||
func.func @shape_cast_poison() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
|
||||
%poison = ub.poison : vector<5x4x2xf32>
|
||||
%poison_1 = ub.poison : vector<12x2xi32>
|
||||
%0 = vector.shape_cast %poison : vector<5x4x2xf32> to vector<20x2xf32>
|
||||
%1 = vector.shape_cast %poison_1 : vector<12x2xi32> to vector<3x4x2xi32>
|
||||
return %0, %1 : vector<20x2xf32>, vector<3x4x2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: extract_strided_constant
|
||||
// CHECK-DAG: %[[CST1:.*]] = arith.constant dense<1> : vector<2x13x3xi32>
|
||||
// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<2.000000e+00> : vector<12x2xf32>
|
||||
|
||||
Reference in New Issue
Block a user