[mlir][Vector] Add pattern to reorder elementwise and broadcast ops
The new pattern will replace elementwise(broadcast) with
broadcast(elementwise) when safe.
This change affects tests for vectorising nD-extract. In one case
("vectorize_nd_tensor_extract_with_tensor_extract") I just trimmed the
test and only preserved the key parts (scalar and contiguous load from
the original Op). We could do the same with some other tests if that
helps maintainability.
Differential Revision: https://reviews.llvm.org/D152812
This commit is contained in:
committed by
Andrzej Warzynski
parent
e9d77cd9b2
commit
4d339ec91e
@@ -137,6 +137,10 @@ void populateVectorTransferFullPartialPatterns(
|
||||
void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
|
||||
RewritePatternSet &patterns, PatternBenefit benefit = 1);
|
||||
|
||||
/// Patterns that remove redundant vector broadcasts.
|
||||
void populateSinkVectorBroadcastPatterns(RewritePatternSet &patterns,
|
||||
PatternBenefit benefit = 1);
|
||||
|
||||
/// Populate `patterns` with the following patterns.
|
||||
///
|
||||
/// [DecomposeDifferentRankInsertStridedSlice]
|
||||
|
||||
@@ -3066,6 +3066,8 @@ transform::VectorizeOp::applyToOne(Operation *target,
|
||||
if (!getDisableMultiReductionToContractPatterns())
|
||||
vector::populateVectorReductionToContractPatterns(patterns);
|
||||
|
||||
vector::populateSinkVectorBroadcastPatterns(patterns);
|
||||
|
||||
patterns.add<linalg::LinalgCopyVTRForwardingPattern,
|
||||
linalg::LinalgCopyVTWForwardingPattern>(ctx,
|
||||
/*benefit=*/2);
|
||||
|
||||
@@ -885,6 +885,66 @@ private:
|
||||
std::function<bool(BitCastOp)> controlFn;
|
||||
};
|
||||
|
||||
/// Reorders elementwise(broadcast) to broadcast(elementwise). Ex:
|
||||
/// ```
|
||||
/// %a = vector.broadcast %arg1 : index to vector<1x4xindex>
|
||||
/// %b = vector.broadcast %arg2 : index to vector<1x4xindex>
|
||||
/// %r = arith.addi %a, %b : vector<1x4xindex>
|
||||
/// ```
|
||||
/// Gets converted to:
|
||||
/// ```
|
||||
/// %r = arith.addi %arg0, %arg1 : index
|
||||
/// %b = vector.broadcast %r : index to vector<1x4xindex>
|
||||
/// ```
|
||||
struct ReorderElementwiseOpsOnBroadcast final
|
||||
: public OpTraitRewritePattern<OpTrait::Elementwise> {
|
||||
using OpTraitRewritePattern::OpTraitRewritePattern;
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (op->getNumResults() != 1)
|
||||
return failure();
|
||||
if (!llvm::isa<ShapedType>(op->getResults()[0].getType()))
|
||||
return failure();
|
||||
if (!OpTrait::hasElementwiseMappableTraits(op))
|
||||
return failure();
|
||||
|
||||
// Get the type of the first operand
|
||||
auto firstBcast = op->getOperand(0).getDefiningOp<vector::BroadcastOp>();
|
||||
if (!firstBcast)
|
||||
return failure();
|
||||
auto firstOpType = firstBcast.getOperand().getType();
|
||||
|
||||
// Make sure that operands are "broadcast"ed from identical (scalar or
|
||||
// vector) types. That indicates that it's safe to skip the broadcasting of
|
||||
// operands.
|
||||
if (!llvm::all_of(op->getOperands(), [&firstOpType](Value val) {
|
||||
auto bcast = val.getDefiningOp<vector::BroadcastOp>();
|
||||
return (bcast && (bcast.getOperand().getType() == firstOpType));
|
||||
})) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Collect the source values
|
||||
SmallVector<Value> srcValues;
|
||||
srcValues.reserve(op->getNumOperands());
|
||||
|
||||
for (Value operand : op->getOperands()) {
|
||||
srcValues.push_back(
|
||||
operand.getDefiningOp<vector::BroadcastOp>().getOperand());
|
||||
}
|
||||
|
||||
Operation *elementwiseOp =
|
||||
rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
|
||||
firstOpType, op->getAttrs());
|
||||
|
||||
auto vectorType = op->getResultTypes()[0];
|
||||
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
|
||||
op, vectorType, elementwiseOp->getResults());
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Helper that returns a vector comparison that constructs a mask:
|
||||
// mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
|
||||
//
|
||||
@@ -1311,6 +1371,12 @@ void mlir::vector::
|
||||
patterns.add<DropInnerMostUnitDims>(patterns.getContext(), benefit);
|
||||
}
|
||||
|
||||
void mlir::vector::populateSinkVectorBroadcastPatterns(
|
||||
RewritePatternSet &patterns, PatternBenefit benefit) {
|
||||
patterns.add<ReorderElementwiseOpsOnBroadcast>(patterns.getContext(),
|
||||
benefit);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TableGen'd enum attribute definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -130,27 +130,29 @@ func.func @vectorize_nd_tensor_extract_transfer_read_complex(%6: tensor<45x80x16
|
||||
return %25 : tensor<1x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_transfer_read_complex
|
||||
|
||||
// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_transfer_read_complex(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: tensor<45x80x16xf32>,
|
||||
// CHECK-SAME: {{.*}}: index,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: index, %[[VAL_4:.*]]: index,
|
||||
// CHECK-SAME: %[[VAL_5:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> {
|
||||
// CHECK: %[[VAL_6:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
|
||||
// CHECK: %[[VAL_7:.*]] = arith.constant 0 : i32
|
||||
// CHECK: %[[VAL_8:.*]] = arith.constant 0.000000e+00 : f32
|
||||
// CHECK: %[[VAL_9:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[VAL_10:.*]] = arith.constant 79 : index
|
||||
// CHECK: %[[VAL_11:.*]] = vector.broadcast %{{.*}} : index to vector<1x4xindex>
|
||||
// CHECK: %[[VAL_12:.*]] = vector.broadcast %{{.*}} : index to vector<1x4xindex>
|
||||
// CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_11]], %[[VAL_12]] : vector<1x4xindex>
|
||||
// CHECK: %[[VAL_14:.*]] = vector.broadcast %{{.*}} : index to vector<4xindex>
|
||||
// CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_14]], %[[VAL_6]] : vector<4xindex>
|
||||
// CHECK: %[[VAL_16:.*]] = vector.broadcast %{{.*}} : index to vector<4xindex>
|
||||
// CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_15]], %[[VAL_16]] : vector<4xindex>
|
||||
// CHECK: %[[VAL_18:.*]] = vector.shape_cast %[[VAL_13]] : vector<1x4xindex> to vector<4xindex>
|
||||
// CHECK: %[[VAL_19:.*]] = vector.extractelement %[[VAL_18]]{{\[}}%[[VAL_7]] : i32] : vector<4xindex>
|
||||
// CHECK: %[[VAL_20:.*]] = vector.extractelement %[[VAL_17]]{{\[}}%[[VAL_7]] : i32] : vector<4xindex>
|
||||
// CHECK: %[[VAL_21:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_19]], %[[VAL_10]], %[[VAL_20]]], %[[VAL_8]] {in_bounds = [true, true]} : tensor<45x80x16xf32>, vector<1x4xf32>
|
||||
// CHECK: %[[VAL_22:.*]] = vector.transfer_write %[[VAL_21]], %[[VAL_5]]{{\[}}%[[VAL_9]], %[[VAL_9]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
|
||||
// CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_1]], %[[VAL_2]] : index
|
||||
// CHECK: %[[VAL_12:.*]] = vector.broadcast %[[VAL_11]] : index to vector<1x4xindex>
|
||||
// CHECK: %[[VAL_13:.*]] = vector.broadcast %[[VAL_3]] : index to vector<4xindex>
|
||||
// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_13]], %[[VAL_6]] : vector<4xindex>
|
||||
// CHECK: %[[VAL_15:.*]] = vector.broadcast %[[VAL_4]] : index to vector<4xindex>
|
||||
// CHECK: %[[VAL_16:.*]] = arith.addi %[[VAL_14]], %[[VAL_15]] : vector<4xindex>
|
||||
// CHECK: %[[VAL_17:.*]] = vector.shape_cast %[[VAL_12]] : vector<1x4xindex> to vector<4xindex>
|
||||
// CHECK: %[[VAL_18:.*]] = vector.extractelement %[[VAL_17]]{{\[}}%[[VAL_7]] : i32] : vector<4xindex>
|
||||
// CHECK: %[[VAL_19:.*]] = vector.extractelement %[[VAL_16]]{{\[}}%[[VAL_7]] : i32] : vector<4xindex>
|
||||
// CHECK: %[[VAL_20:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_18]], %[[VAL_10]], %[[VAL_19]]], %[[VAL_8]] {in_bounds = [true, true]} : tensor<45x80x16xf32>, vector<1x4xf32>
|
||||
// CHECK: %[[VAL_21:.*]] = vector.transfer_write %[[VAL_20]], %[[VAL_5]]{{\[}}%[[VAL_9]], %[[VAL_9]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
|
||||
// CHECK: return %[[VAL_21]] : tensor<1x4xf32>
|
||||
// CHECK: }
|
||||
|
||||
transform.sequence failures(propagate) {
|
||||
^bb1(%arg1: !transform.any_op):
|
||||
@@ -317,43 +319,16 @@ func.func @vectorize_nd_tensor_extract_with_tensor_extract(%input_1: tensor<1x20
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_with_tensor_extract(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x20xi32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: tensor<257x24xf32>,
|
||||
// CHECK-SAME: %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: index, %[[VAL_4:.*]]: index, %[[VAL_5:.*]]: index) -> tensor<1x1x4xf32> {
|
||||
// CHECK-DAG: %[[VAL_6:.*]] = arith.constant dense<0> : vector<1x1x4xindex>
|
||||
// CHECK-DAG: %[[VAL_7:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
|
||||
// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 0 : i32
|
||||
// CHECK-DAG: %[[VAL_9:.*]] = arith.constant dense<256> : vector<1x1x4xindex>
|
||||
// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f32
|
||||
// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[VAL_12:.*]] = tensor.empty() : tensor<1x1x4xf32>
|
||||
// CHECK: %[[VAL_13:.*]] = vector.broadcast %[[VAL_2]] : index to vector<1x1x4xindex>
|
||||
// CHECK: %[[VAL_14:.*]] = vector.broadcast %[[VAL_4]] : index to vector<1x1x4xindex>
|
||||
// CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_13]], %[[VAL_14]] : vector<1x1x4xindex>
|
||||
// CHECK: %[[VAL_16:.*]] = vector.broadcast %[[VAL_3]] : index to vector<1x1x4xindex>
|
||||
// CHECK: %[[VAL_17:.*]] = vector.broadcast %[[VAL_7]] : vector<4xindex> to vector<1x1x4xindex>
|
||||
// CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_16]], %[[VAL_17]] : vector<1x1x4xindex>
|
||||
// CHECK: %[[VAL_19:.*]] = vector.broadcast %[[VAL_5]] : index to vector<1x1x4xindex>
|
||||
// CHECK: %[[VAL_20:.*]] = arith.addi %[[VAL_18]], %[[VAL_19]] : vector<1x1x4xindex>
|
||||
// CHECK: %[[VAL_21:.*]] = vector.shape_cast %[[VAL_15]] : vector<1x1x4xindex> to vector<4xindex>
|
||||
// CHECK: %[[VAL_22:.*]] = vector.extractelement %[[VAL_21]][%[[VAL_8]] : i32] : vector<4xindex>
|
||||
// CHECK-SAME: %[[INPUT_1:.*]]: tensor<1x20xi32>,
|
||||
// CHECK-SAME: %[[INPUT_2:.*]]: tensor<257x24xf32>,
|
||||
// CHECK: %[[EXTRACTED_0_IDX_0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[EXTRACTED_0_IDX_1:.*]] = vector.extractelement %{{.*}}[%{{.*}} : i32] : vector<4xindex>
|
||||
// First `tensor.extract` from the generic Op - loop invariant scalar load.
|
||||
// CHECK: %[[VAL_23:.*]] = tensor.extract %[[VAL_0]][%[[VAL_11]], %[[VAL_22]]] : tensor<1x20xi32>
|
||||
// CHECK: %[[VAL_24:.*]] = arith.index_cast %[[VAL_23]] : i32 to index
|
||||
// CHECK: %[[VAL_25:.*]] = vector.broadcast %[[VAL_24]] : index to vector<1x1x4xindex>
|
||||
// CHECK: %[[VAL_26:.*]] = arith.maxsi %[[VAL_25]], %[[VAL_6]] : vector<1x1x4xindex>
|
||||
// CHECK: %[[VAL_27:.*]] = arith.minsi %[[VAL_26]], %[[VAL_9]] : vector<1x1x4xindex>
|
||||
// CHECK: %[[VAL_28:.*]] = vector.shape_cast %[[VAL_27]] : vector<1x1x4xindex> to vector<4xindex>
|
||||
// CHECK: %[[VAL_29:.*]] = vector.extractelement %[[VAL_28]][%[[VAL_8]] : i32] : vector<4xindex>
|
||||
// CHECK: %[[VAL_30:.*]] = vector.shape_cast %[[VAL_20]] : vector<1x1x4xindex> to vector<4xindex>
|
||||
// CHECK: %[[VAL_31:.*]] = vector.extractelement %[[VAL_30]][%[[VAL_8]] : i32] : vector<4xindex>
|
||||
// CHECK: tensor.extract %[[INPUT_1]][%[[EXTRACTED_0_IDX_0]], %[[EXTRACTED_0_IDX_1]]] : tensor<1x20xi32>
|
||||
// The following `tensor.extract` from the generic Op s a contiguous load (all Ops used
|
||||
// for address calculation also satisfy the required conditions).
|
||||
// CHECK: %[[VAL_32:.*]] = vector.transfer_read %[[VAL_1]][%[[VAL_29]], %[[VAL_31]]], %[[VAL_10]] {in_bounds = [true, true]} : tensor<257x24xf32>, vector<1x4xf32>
|
||||
// CHECK: %[[VAL_33:.*]] = vector.broadcast %[[VAL_32]] : vector<1x4xf32> to vector<1x1x4xf32>
|
||||
// CHECK: %[[VAL_34:.*]] = vector.transfer_write %[[VAL_33]], %[[VAL_12]][%[[VAL_11]], %[[VAL_11]], %[[VAL_11]]] {in_bounds = [true, true, true]} : vector<1x1x4xf32>, tensor<1x1x4xf32>
|
||||
// CHECK: return %[[VAL_34]] : tensor<1x1x4xf32>
|
||||
// CHECK: }
|
||||
// CHECK: vector.transfer_read %[[INPUT_2]][%{{.*}}, %{{.*}}, %{{.*}} {in_bounds = [true, true]} : tensor<257x24xf32>, vector<1x4xf32>
|
||||
|
||||
|
||||
transform.sequence failures(propagate) {
|
||||
^bb1(%arg1: !transform.any_op):
|
||||
|
||||
78
mlir/test/Dialect/Vector/sink-vector-broadcast.mlir
Normal file
78
mlir/test/Dialect/Vector/sink-vector-broadcast.mlir
Normal file
@@ -0,0 +1,78 @@
|
||||
// RUN: mlir-opt %s -test-sink-vector-broadcast -split-input-file | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func.func @broadcast_scalar(
|
||||
// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index) -> vector<1x4xindex> {
|
||||
// CHECK: %[[ADD:.*]] = arith.addi %[[ARG_0]], %[[ARG_1]] : index
|
||||
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex>
|
||||
// CHECK: return %[[BCAST]] : vector<1x4xindex>
|
||||
// CHECK: }
|
||||
|
||||
func.func @broadcast_scalar( %arg1: index, %arg2: index) -> vector<1x4xindex> {
|
||||
%0 = vector.broadcast %arg1 : index to vector<1x4xindex>
|
||||
%1 = vector.broadcast %arg2 : index to vector<1x4xindex>
|
||||
%2 = arith.addi %0, %1 : vector<1x4xindex>
|
||||
return %2 : vector<1x4xindex>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @broadcast_vector(
|
||||
// CHECK-SAME: %[[ARG_0:.*]]: vector<4xf32>,
|
||||
// CHECK-SAME: %[[ARG_1:.*]]: vector<4xf32>) -> vector<3x4xf32> {
|
||||
// CHECK: %[[ADDF:.*]] = arith.addf %[[ARG_0]], %[[ARG_1]] : vector<4xf32>
|
||||
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADDF]] : vector<4xf32> to vector<3x4xf32>
|
||||
// CHECK: return %[[BCAST]] : vector<3x4xf32>
|
||||
// CHECK: }
|
||||
|
||||
func.func @broadcast_vector( %arg1: vector<4xf32>, %arg2: vector<4xf32>) -> vector<3x4xf32> {
|
||||
%arg1_bcast = vector.broadcast %arg1 : vector<4xf32> to vector<3x4xf32>
|
||||
%arg2_bcast = vector.broadcast %arg2 : vector<4xf32> to vector<3x4xf32>
|
||||
%2 = arith.addf %arg1_bcast, %arg2_bcast : vector<3x4xf32>
|
||||
return %2 : vector<3x4xf32>
|
||||
}
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @broadcast_vector_and_scalar(
|
||||
// CHECK-SAME: %[[ARG_0:.*]]: i32,
|
||||
// CHECK-SAME: %[[ARG_1:.*]]: vector<4xi32>) -> vector<4xi32> {
|
||||
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ARG_0]] : i32 to vector<4xi32>
|
||||
// CHECK: %[[ADD:.*]] = arith.addi %[[BCAST]], %[[ARG_1]] : vector<4xi32>
|
||||
// CHECK: return %[[ADD]] : vector<4xi32>
|
||||
// CHECK: }
|
||||
|
||||
func.func @broadcast_vector_and_scalar( %arg1: i32, %arg2: vector<4xi32>) -> vector<4xi32> {
|
||||
%arg1_bcast = vector.broadcast %arg1 : i32 to vector<4xi32>
|
||||
%2 = arith.addi %arg1_bcast, %arg2 : vector<4xi32>
|
||||
return %2 : vector<4xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#matmat_accesses = [
|
||||
affine_map<(i, j, k) -> (i, k)>,
|
||||
affine_map<(i, j, k) -> (k, j)>,
|
||||
affine_map<(i, j, k) -> (i, j)>
|
||||
]
|
||||
#matmat_trait = {
|
||||
indexing_maps = #matmat_accesses,
|
||||
iterator_types = ["parallel", "parallel", "reduction"]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @broadcast_not_elementwise() -> vector<2x2xf32> {
|
||||
// CHECK-DAG: %[[VAL_0:.*]] = arith.constant dense<1.000000e+00> : vector<2x2xf32>
|
||||
// CHECK-DAG: %[[VAL_1:.*]] = arith.constant dense<2.000000e+00> : vector<2x2xf32>
|
||||
// CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<3.000000e+00> : vector<2x2xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
||||
func.func @broadcast_not_elementwise() -> vector<2x2xf32> {
|
||||
%f1 = arith.constant 1.0: f32
|
||||
%f2 = arith.constant 2.0: f32
|
||||
%f3 = arith.constant 3.0: f32
|
||||
|
||||
%A = vector.broadcast %f1 : f32 to vector<2x2xf32>
|
||||
%B = vector.broadcast %f2 : f32 to vector<2x2xf32>
|
||||
%C = vector.broadcast %f3 : f32 to vector<2x2xf32>
|
||||
%mm1 = vector.contract #matmat_trait %A, %B, %C
|
||||
: vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
||||
|
||||
return %mm1 : vector<2x2xf32>
|
||||
}
|
||||
@@ -374,6 +374,31 @@ struct TestVectorTransferCollapseInnerMostContiguousDims
|
||||
}
|
||||
};
|
||||
|
||||
struct TestSinkVectorBroadcast
|
||||
: public PassWrapper<TestSinkVectorBroadcast, OperationPass<func::FuncOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSinkVectorBroadcast)
|
||||
|
||||
TestSinkVectorBroadcast() = default;
|
||||
TestSinkVectorBroadcast(const TestSinkVectorBroadcast &pass) = default;
|
||||
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<memref::MemRefDialect, affine::AffineDialect>();
|
||||
}
|
||||
|
||||
StringRef getArgument() const final { return "test-sink-vector-broadcast"; }
|
||||
|
||||
StringRef getDescription() const final {
|
||||
return "Test lowering patterns that eliminate redundant brodacast "
|
||||
"operations.";
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
RewritePatternSet patterns(&getContext());
|
||||
populateSinkVectorBroadcastPatterns(patterns);
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
}
|
||||
};
|
||||
|
||||
struct TestVectorReduceToContractPatternsPatterns
|
||||
: public PassWrapper<TestVectorReduceToContractPatternsPatterns,
|
||||
OperationPass<func::FuncOp>> {
|
||||
@@ -735,6 +760,8 @@ void registerTestVectorLowerings() {
|
||||
|
||||
PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>();
|
||||
|
||||
PassRegistration<TestSinkVectorBroadcast>();
|
||||
|
||||
PassRegistration<TestVectorReduceToContractPatternsPatterns>();
|
||||
|
||||
PassRegistration<TestFlattenVectorTransferPatterns>();
|
||||
|
||||
Reference in New Issue
Block a user