[mlir][Vector] Canonicalize empty vector.mask into arith.select (#140976)
This PR adds a missing canonicalization for empty `vector.mask` ops with
a passthru value.
```
%0 = vector.mask %mask, %passthru { vector.yield %a : vector<8xf32> } :
vector<8xi1> -> vector<8xf32>
becomes:
%0 = arith.select %mask, %a, %passthru : vector<8xf32>
```
This commit is contained in:
@@ -2559,6 +2559,7 @@ def Vector_MaskOp : Vector_Op<"mask", [
|
||||
Location loc);
|
||||
}];
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
let hasFolder = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let hasVerifier = 1;
|
||||
|
||||
@@ -6661,6 +6661,9 @@ LogicalResult MaskOp::verify() {
|
||||
///
|
||||
/// %0 = user_op %a : vector<8xf32>
|
||||
///
|
||||
/// Empty `vector.mask` with passthru operand are handled by the canonicalizer
|
||||
/// as it requires creating new operations.
|
||||
|
||||
static LogicalResult foldEmptyMaskOp(MaskOp maskOp, MaskOp::FoldAdaptor adaptor,
|
||||
SmallVectorImpl<OpFoldResult> &results) {
|
||||
if (!maskOp.isEmpty() || maskOp.hasPassthru())
|
||||
@@ -6696,6 +6699,47 @@ LogicalResult MaskOp::fold(FoldAdaptor adaptor,
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Canonialize empty `vector.mask` operations that can't be handled in
|
||||
/// `VectorMask::fold` as they require creating new operations.
|
||||
///
|
||||
/// Example 1: Empty `vector.mask` with passthru operand.
|
||||
///
|
||||
/// %0 = vector.mask %mask, %passthru { vector.yield %a : vector<8xf32> } :
|
||||
/// vector<8xi1> -> vector<8xf32>
|
||||
///
|
||||
/// becomes:
|
||||
///
|
||||
/// %0 = arith.select %mask, %a, %passthru : vector<8xf32>
|
||||
///
|
||||
class CanonializeEmptyMaskOp : public OpRewritePattern<MaskOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(MaskOp maskOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (!maskOp.isEmpty())
|
||||
return failure();
|
||||
|
||||
if (!maskOp.hasPassthru())
|
||||
return failure();
|
||||
|
||||
Block *block = maskOp.getMaskBlock();
|
||||
auto terminator = cast<vector::YieldOp>(block->front());
|
||||
assert(terminator.getNumOperands() == 1 &&
|
||||
"expected one result when passthru is provided");
|
||||
|
||||
rewriter.replaceOpWithNewOp<arith::SelectOp>(
|
||||
maskOp, maskOp.getResultTypes(), maskOp.getMask(),
|
||||
terminator.getOperand(0), maskOp.getPassthru());
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
MLIRContext *context) {
|
||||
results.add<CanonializeEmptyMaskOp>(context);
|
||||
}
|
||||
|
||||
// MaskingOpInterface definitions.
|
||||
|
||||
/// Returns the operation masked by this 'vector.mask'.
|
||||
|
||||
@@ -719,7 +719,7 @@ func.func @fold_extract_transpose(
|
||||
// CHECK-LABEL: fold_extract_broadcast_same_input_output_scalar
|
||||
// CHECK-SAME: %[[A:.*]]: f32
|
||||
// CHECK: return %[[A]] : f32
|
||||
func.func @fold_extract_broadcast_same_input_output_scalar(%a : f32,
|
||||
func.func @fold_extract_broadcast_same_input_output_scalar(%a : f32,
|
||||
%idx0 : index, %idx1 : index, %idx2 : index) -> f32 {
|
||||
%b = vector.broadcast %a : f32 to vector<1x2x4xf32>
|
||||
%r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32>
|
||||
@@ -731,7 +731,7 @@ func.func @fold_extract_broadcast_same_input_output_scalar(%a : f32,
|
||||
// CHECK-LABEL: fold_extract_broadcast_same_input_output_vec
|
||||
// CHECK-SAME: %[[A:.*]]: vector<4xf32>
|
||||
// CHECK: return %[[A]] : vector<4xf32>
|
||||
func.func @fold_extract_broadcast_same_input_output_vec(%a : vector<4xf32>,
|
||||
func.func @fold_extract_broadcast_same_input_output_vec(%a : vector<4xf32>,
|
||||
%idx0 : index, %idx1 : index) -> vector<4xf32> {
|
||||
%b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32>
|
||||
%r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
|
||||
@@ -744,7 +744,7 @@ func.func @fold_extract_broadcast_same_input_output_vec(%a : vector<4xf32>,
|
||||
// CHECK-SAME: %[[A:.*]]: vector<f32>
|
||||
// CHECK: %[[B:.+]] = vector.extract %[[A]][] : f32 from vector<f32>
|
||||
// CHECK: return %[[B]] : f32
|
||||
func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>,
|
||||
func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>,
|
||||
%idx0 : index, %idx1 : index, %idx2: index) -> f32 {
|
||||
%b = vector.broadcast %a : vector<f32> to vector<1x2x4xf32>
|
||||
%r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32>
|
||||
@@ -780,7 +780,7 @@ func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : in
|
||||
// CHECK-SAME: %[[IDX:.*]]: index, %[[IDX1:.*]]: index, %[[IDX2:.*]]: index
|
||||
// CHECK: %[[R:.*]] = vector.extract %[[A]][%[[IDX1]], 0] : f32 from vector<2x1xf32>
|
||||
// CHECK: return %[[R]] : f32
|
||||
func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>,
|
||||
func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>,
|
||||
%idx : index, %idx1 : index, %idx2 : index) -> f32 {
|
||||
%b = vector.broadcast %a : vector<2x1xf32> to vector<1x2x4xf32>
|
||||
%r = vector.extract %b[%idx, %idx1, %idx2] : f32 from vector<1x2x4xf32>
|
||||
@@ -795,7 +795,7 @@ func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>,
|
||||
// CHECK: %[[B:.+]] = vector.extract %[[A]][%[[IDX1]]] : vector<4xf32> from vector<2x4xf32>
|
||||
// CHECK: return %[[B]] : vector<4xf32>
|
||||
// rank(extract_output) < rank(broadcast_input)
|
||||
func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>,
|
||||
func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>,
|
||||
%idx0 : index, %idx1 : index) -> vector<4xf32> {
|
||||
%b = vector.broadcast %a : vector<2x4xf32> to vector<1x2x4xf32>
|
||||
%r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
|
||||
@@ -808,7 +808,7 @@ func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>,
|
||||
// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32>
|
||||
// CHECK: return %[[B]] : vector<4xf32>
|
||||
// rank(extract_output) > rank(broadcast_input)
|
||||
func.func @fold_extract_broadcast_to_higher_rank(%a : f32, %idx0 : index, %idx1 : index)
|
||||
func.func @fold_extract_broadcast_to_higher_rank(%a : f32, %idx0 : index, %idx1 : index)
|
||||
-> vector<4xf32> {
|
||||
%b = vector.broadcast %a : f32 to vector<1x2x4xf32>
|
||||
%r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
|
||||
@@ -822,7 +822,7 @@ func.func @fold_extract_broadcast_to_higher_rank(%a : f32, %idx0 : index, %idx1
|
||||
// CHECK: %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<8xf32>
|
||||
// CHECK: return %[[R]] : vector<8xf32>
|
||||
// rank(extract_output) == rank(broadcast_input)
|
||||
func.func @fold_extract_broadcast_to_equal_rank(%a : vector<1xf32>, %idx0 : index)
|
||||
func.func @fold_extract_broadcast_to_equal_rank(%a : vector<1xf32>, %idx0 : index)
|
||||
-> vector<8xf32> {
|
||||
%b = vector.broadcast %a : vector<1xf32> to vector<1x8xf32>
|
||||
%r = vector.extract %b[%idx0] : vector<8xf32> from vector<1x8xf32>
|
||||
@@ -1169,7 +1169,7 @@ func.func @broadcast_poison() -> vector<4x6xi8> {
|
||||
return %broadcast : vector<4x6xi8>
|
||||
}
|
||||
|
||||
// -----
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: broadcast_splat_constant
|
||||
// CHECK: %[[CONST:.*]] = arith.constant dense<1> : vector<4x6xi8>
|
||||
@@ -2756,6 +2756,19 @@ func.func @empty_vector_mask_with_return(%a : vector<8xf32>, %mask : vector<8xi1
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @empty_vector_mask_with_passthru
|
||||
// CHECK-SAME: %[[IN:.*]]: vector<8xf32>, %[[MASK:.*]]: vector<8xi1>, %[[PASSTHRU:.*]]: vector<8xf32>
|
||||
func.func @empty_vector_mask_with_passthru(%a : vector<8xf32>, %mask : vector<8xi1>,
|
||||
%passthru : vector<8xf32>) -> vector<8xf32> {
|
||||
// CHECK-NOT: vector.mask
|
||||
// CHECK: %[[SEL:.*]] = arith.select %[[MASK]], %[[IN]], %[[PASSTHRU]] : vector<8xi1>, vector<8xf32>
|
||||
// CHECK: return %[[SEL]] : vector<8xf32>
|
||||
%0 = vector.mask %mask, %passthru { vector.yield %a : vector<8xf32> } : vector<8xi1> -> vector<8xf32>
|
||||
return %0 : vector<8xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @all_true_vector_mask
|
||||
// CHECK-SAME: %[[IN:.*]]: tensor<3x4xf32>
|
||||
func.func @all_true_vector_mask(%ta : tensor<3x4xf32>) -> vector<3x4xf32> {
|
||||
|
||||
Reference in New Issue
Block a user