[mlir][Vector] Canonicalize empty vector.mask into arith.select (#140976)

This PR adds a missing canonicalization for empty `vector.mask` ops with
a passthru value.

```
   %0 = vector.mask %mask, %passthru { vector.yield %a : vector<8xf32> } :
     vector<8xi1> -> vector<8xf32>

 becomes:

   %0 = arith.select %mask, %a, %passthru : vector<8xf32>
```
This commit is contained in:
Diego Caballero
2025-05-23 08:29:57 -07:00
committed by GitHub
parent 1bdec97799
commit 204eb70af8
3 changed files with 66 additions and 8 deletions

View File

@@ -2559,6 +2559,7 @@ def Vector_MaskOp : Vector_Op<"mask", [
Location loc);
}];
let hasCanonicalizer = 1;
let hasFolder = 1;
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;

View File

@@ -6661,6 +6661,9 @@ LogicalResult MaskOp::verify() {
///
/// %0 = user_op %a : vector<8xf32>
///
/// Empty `vector.mask` with passthru operand are handled by the canonicalizer
/// as it requires creating new operations.
static LogicalResult foldEmptyMaskOp(MaskOp maskOp, MaskOp::FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) {
if (!maskOp.isEmpty() || maskOp.hasPassthru())
@@ -6696,6 +6699,47 @@ LogicalResult MaskOp::fold(FoldAdaptor adaptor,
return success();
}
/// Canonialize empty `vector.mask` operations that can't be handled in
/// `VectorMask::fold` as they require creating new operations.
///
/// Example 1: Empty `vector.mask` with passthru operand.
///
/// %0 = vector.mask %mask, %passthru { vector.yield %a : vector<8xf32> } :
/// vector<8xi1> -> vector<8xf32>
///
/// becomes:
///
/// %0 = arith.select %mask, %a, %passthru : vector<8xf32>
///
class CanonializeEmptyMaskOp : public OpRewritePattern<MaskOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(MaskOp maskOp,
PatternRewriter &rewriter) const override {
if (!maskOp.isEmpty())
return failure();
if (!maskOp.hasPassthru())
return failure();
Block *block = maskOp.getMaskBlock();
auto terminator = cast<vector::YieldOp>(block->front());
assert(terminator.getNumOperands() == 1 &&
"expected one result when passthru is provided");
rewriter.replaceOpWithNewOp<arith::SelectOp>(
maskOp, maskOp.getResultTypes(), maskOp.getMask(),
terminator.getOperand(0), maskOp.getPassthru());
return success();
}
};
void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<CanonializeEmptyMaskOp>(context);
}
// MaskingOpInterface definitions.
/// Returns the operation masked by this 'vector.mask'.

View File

@@ -719,7 +719,7 @@ func.func @fold_extract_transpose(
// CHECK-LABEL: fold_extract_broadcast_same_input_output_scalar
// CHECK-SAME: %[[A:.*]]: f32
// CHECK: return %[[A]] : f32
func.func @fold_extract_broadcast_same_input_output_scalar(%a : f32,
func.func @fold_extract_broadcast_same_input_output_scalar(%a : f32,
%idx0 : index, %idx1 : index, %idx2 : index) -> f32 {
%b = vector.broadcast %a : f32 to vector<1x2x4xf32>
%r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32>
@@ -731,7 +731,7 @@ func.func @fold_extract_broadcast_same_input_output_scalar(%a : f32,
// CHECK-LABEL: fold_extract_broadcast_same_input_output_vec
// CHECK-SAME: %[[A:.*]]: vector<4xf32>
// CHECK: return %[[A]] : vector<4xf32>
func.func @fold_extract_broadcast_same_input_output_vec(%a : vector<4xf32>,
func.func @fold_extract_broadcast_same_input_output_vec(%a : vector<4xf32>,
%idx0 : index, %idx1 : index) -> vector<4xf32> {
%b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32>
%r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
@@ -744,7 +744,7 @@ func.func @fold_extract_broadcast_same_input_output_vec(%a : vector<4xf32>,
// CHECK-SAME: %[[A:.*]]: vector<f32>
// CHECK: %[[B:.+]] = vector.extract %[[A]][] : f32 from vector<f32>
// CHECK: return %[[B]] : f32
func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>,
func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>,
%idx0 : index, %idx1 : index, %idx2: index) -> f32 {
%b = vector.broadcast %a : vector<f32> to vector<1x2x4xf32>
%r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32>
@@ -780,7 +780,7 @@ func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : in
// CHECK-SAME: %[[IDX:.*]]: index, %[[IDX1:.*]]: index, %[[IDX2:.*]]: index
// CHECK: %[[R:.*]] = vector.extract %[[A]][%[[IDX1]], 0] : f32 from vector<2x1xf32>
// CHECK: return %[[R]] : f32
func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>,
func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>,
%idx : index, %idx1 : index, %idx2 : index) -> f32 {
%b = vector.broadcast %a : vector<2x1xf32> to vector<1x2x4xf32>
%r = vector.extract %b[%idx, %idx1, %idx2] : f32 from vector<1x2x4xf32>
@@ -795,7 +795,7 @@ func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>,
// CHECK: %[[B:.+]] = vector.extract %[[A]][%[[IDX1]]] : vector<4xf32> from vector<2x4xf32>
// CHECK: return %[[B]] : vector<4xf32>
// rank(extract_output) < rank(broadcast_input)
func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>,
func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>,
%idx0 : index, %idx1 : index) -> vector<4xf32> {
%b = vector.broadcast %a : vector<2x4xf32> to vector<1x2x4xf32>
%r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
@@ -808,7 +808,7 @@ func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>,
// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32>
// CHECK: return %[[B]] : vector<4xf32>
// rank(extract_output) > rank(broadcast_input)
func.func @fold_extract_broadcast_to_higher_rank(%a : f32, %idx0 : index, %idx1 : index)
func.func @fold_extract_broadcast_to_higher_rank(%a : f32, %idx0 : index, %idx1 : index)
-> vector<4xf32> {
%b = vector.broadcast %a : f32 to vector<1x2x4xf32>
%r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
@@ -822,7 +822,7 @@ func.func @fold_extract_broadcast_to_higher_rank(%a : f32, %idx0 : index, %idx1
// CHECK: %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<8xf32>
// CHECK: return %[[R]] : vector<8xf32>
// rank(extract_output) == rank(broadcast_input)
func.func @fold_extract_broadcast_to_equal_rank(%a : vector<1xf32>, %idx0 : index)
func.func @fold_extract_broadcast_to_equal_rank(%a : vector<1xf32>, %idx0 : index)
-> vector<8xf32> {
%b = vector.broadcast %a : vector<1xf32> to vector<1x8xf32>
%r = vector.extract %b[%idx0] : vector<8xf32> from vector<1x8xf32>
@@ -1169,7 +1169,7 @@ func.func @broadcast_poison() -> vector<4x6xi8> {
return %broadcast : vector<4x6xi8>
}
// -----
// -----
// CHECK-LABEL: broadcast_splat_constant
// CHECK: %[[CONST:.*]] = arith.constant dense<1> : vector<4x6xi8>
@@ -2756,6 +2756,19 @@ func.func @empty_vector_mask_with_return(%a : vector<8xf32>, %mask : vector<8xi1
// -----
// CHECK-LABEL: func @empty_vector_mask_with_passthru
// CHECK-SAME: %[[IN:.*]]: vector<8xf32>, %[[MASK:.*]]: vector<8xi1>, %[[PASSTHRU:.*]]: vector<8xf32>
func.func @empty_vector_mask_with_passthru(%a : vector<8xf32>, %mask : vector<8xi1>,
%passthru : vector<8xf32>) -> vector<8xf32> {
// CHECK-NOT: vector.mask
// CHECK: %[[SEL:.*]] = arith.select %[[MASK]], %[[IN]], %[[PASSTHRU]] : vector<8xi1>, vector<8xf32>
// CHECK: return %[[SEL]] : vector<8xf32>
%0 = vector.mask %mask, %passthru { vector.yield %a : vector<8xf32> } : vector<8xi1> -> vector<8xf32>
return %0 : vector<8xf32>
}
// -----
// CHECK-LABEL: func @all_true_vector_mask
// CHECK-SAME: %[[IN:.*]]: tensor<3x4xf32>
func.func @all_true_vector_mask(%ta : tensor<3x4xf32>) -> vector<3x4xf32> {