[mlir][Vector] Add pattern to reorder elementwise and broadcast ops

The new pattern will replace elementwise(broadcast) with
broadcast(elementwise) when safe.

This change affects tests for vectorising nD-extract. In one case
("vectorize_nd_tensor_extract_with_tensor_extract") I just trimmed the
test and only preserved the key parts (scalar and contiguous load from
the original Op). We could do the same with some other tests if that
helps maintainability.

Differential Revision: https://reviews.llvm.org/D152812
This commit is contained in:
Andrzej Warzynski
2023-06-02 15:32:12 +01:00
committed by Andrzej Warzynski
parent e9d77cd9b2
commit 4d339ec91e
6 changed files with 200 additions and 48 deletions

View File

@@ -137,6 +137,10 @@ void populateVectorTransferFullPartialPatterns(
void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
RewritePatternSet &patterns, PatternBenefit benefit = 1);
/// Patterns that remove redundant vector broadcasts.
void populateSinkVectorBroadcastPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
/// Populate `patterns` with the following patterns.
///
/// [DecomposeDifferentRankInsertStridedSlice]

View File

@@ -3066,6 +3066,8 @@ transform::VectorizeOp::applyToOne(Operation *target,
if (!getDisableMultiReductionToContractPatterns())
vector::populateVectorReductionToContractPatterns(patterns);
vector::populateSinkVectorBroadcastPatterns(patterns);
patterns.add<linalg::LinalgCopyVTRForwardingPattern,
linalg::LinalgCopyVTWForwardingPattern>(ctx,
/*benefit=*/2);

View File

@@ -885,6 +885,66 @@ private:
std::function<bool(BitCastOp)> controlFn;
};
/// Reorders elementwise(broadcast) to broadcast(elementwise). Ex:
/// ```
/// %a = vector.broadcast %arg1 : index to vector<1x4xindex>
/// %b = vector.broadcast %arg2 : index to vector<1x4xindex>
/// %r = arith.addi %a, %b : vector<1x4xindex>
/// ```
/// Gets converted to:
/// ```
/// %r = arith.addi %arg0, %arg1 : index
/// %b = vector.broadcast %r : index to vector<1x4xindex>
/// ```
struct ReorderElementwiseOpsOnBroadcast final
: public OpTraitRewritePattern<OpTrait::Elementwise> {
using OpTraitRewritePattern::OpTraitRewritePattern;
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
if (op->getNumResults() != 1)
return failure();
if (!llvm::isa<ShapedType>(op->getResults()[0].getType()))
return failure();
if (!OpTrait::hasElementwiseMappableTraits(op))
return failure();
// Get the type of the first operand
auto firstBcast = op->getOperand(0).getDefiningOp<vector::BroadcastOp>();
if (!firstBcast)
return failure();
auto firstOpType = firstBcast.getOperand().getType();
// Make sure that operands are "broadcast"ed from identical (scalar or
// vector) types. That indicates that it's safe to skip the broadcasting of
// operands.
if (!llvm::all_of(op->getOperands(), [&firstOpType](Value val) {
auto bcast = val.getDefiningOp<vector::BroadcastOp>();
return (bcast && (bcast.getOperand().getType() == firstOpType));
})) {
return failure();
}
// Collect the source values
SmallVector<Value> srcValues;
srcValues.reserve(op->getNumOperands());
for (Value operand : op->getOperands()) {
srcValues.push_back(
operand.getDefiningOp<vector::BroadcastOp>().getOperand());
}
Operation *elementwiseOp =
rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
firstOpType, op->getAttrs());
auto vectorType = op->getResultTypes()[0];
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
op, vectorType, elementwiseOp->getResults());
return success();
}
};
// Helper that returns a vector comparison that constructs a mask:
// mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
//
@@ -1311,6 +1371,12 @@ void mlir::vector::
patterns.add<DropInnerMostUnitDims>(patterns.getContext(), benefit);
}
void mlir::vector::populateSinkVectorBroadcastPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<ReorderElementwiseOpsOnBroadcast>(patterns.getContext(),
benefit);
}
//===----------------------------------------------------------------------===//
// TableGen'd enum attribute definitions
//===----------------------------------------------------------------------===//

View File

@@ -130,27 +130,29 @@ func.func @vectorize_nd_tensor_extract_transfer_read_complex(%6: tensor<45x80x16
return %25 : tensor<1x4xf32>
}
// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_transfer_read_complex
// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_transfer_read_complex(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<45x80x16xf32>,
// CHECK-SAME: {{.*}}: index,
// CHECK-SAME: %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: index, %[[VAL_4:.*]]: index,
// CHECK-SAME: %[[VAL_5:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> {
// CHECK: %[[VAL_6:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
// CHECK: %[[VAL_7:.*]] = arith.constant 0 : i32
// CHECK: %[[VAL_8:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[VAL_9:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_10:.*]] = arith.constant 79 : index
// CHECK: %[[VAL_11:.*]] = vector.broadcast %{{.*}} : index to vector<1x4xindex>
// CHECK: %[[VAL_12:.*]] = vector.broadcast %{{.*}} : index to vector<1x4xindex>
// CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_11]], %[[VAL_12]] : vector<1x4xindex>
// CHECK: %[[VAL_14:.*]] = vector.broadcast %{{.*}} : index to vector<4xindex>
// CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_14]], %[[VAL_6]] : vector<4xindex>
// CHECK: %[[VAL_16:.*]] = vector.broadcast %{{.*}} : index to vector<4xindex>
// CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_15]], %[[VAL_16]] : vector<4xindex>
// CHECK: %[[VAL_18:.*]] = vector.shape_cast %[[VAL_13]] : vector<1x4xindex> to vector<4xindex>
// CHECK: %[[VAL_19:.*]] = vector.extractelement %[[VAL_18]]{{\[}}%[[VAL_7]] : i32] : vector<4xindex>
// CHECK: %[[VAL_20:.*]] = vector.extractelement %[[VAL_17]]{{\[}}%[[VAL_7]] : i32] : vector<4xindex>
// CHECK: %[[VAL_21:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_19]], %[[VAL_10]], %[[VAL_20]]], %[[VAL_8]] {in_bounds = [true, true]} : tensor<45x80x16xf32>, vector<1x4xf32>
// CHECK: %[[VAL_22:.*]] = vector.transfer_write %[[VAL_21]], %[[VAL_5]]{{\[}}%[[VAL_9]], %[[VAL_9]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
// CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_1]], %[[VAL_2]] : index
// CHECK: %[[VAL_12:.*]] = vector.broadcast %[[VAL_11]] : index to vector<1x4xindex>
// CHECK: %[[VAL_13:.*]] = vector.broadcast %[[VAL_3]] : index to vector<4xindex>
// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_13]], %[[VAL_6]] : vector<4xindex>
// CHECK: %[[VAL_15:.*]] = vector.broadcast %[[VAL_4]] : index to vector<4xindex>
// CHECK: %[[VAL_16:.*]] = arith.addi %[[VAL_14]], %[[VAL_15]] : vector<4xindex>
// CHECK: %[[VAL_17:.*]] = vector.shape_cast %[[VAL_12]] : vector<1x4xindex> to vector<4xindex>
// CHECK: %[[VAL_18:.*]] = vector.extractelement %[[VAL_17]]{{\[}}%[[VAL_7]] : i32] : vector<4xindex>
// CHECK: %[[VAL_19:.*]] = vector.extractelement %[[VAL_16]]{{\[}}%[[VAL_7]] : i32] : vector<4xindex>
// CHECK: %[[VAL_20:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_18]], %[[VAL_10]], %[[VAL_19]]], %[[VAL_8]] {in_bounds = [true, true]} : tensor<45x80x16xf32>, vector<1x4xf32>
// CHECK: %[[VAL_21:.*]] = vector.transfer_write %[[VAL_20]], %[[VAL_5]]{{\[}}%[[VAL_9]], %[[VAL_9]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
// CHECK: return %[[VAL_21]] : tensor<1x4xf32>
// CHECK: }
transform.sequence failures(propagate) {
^bb1(%arg1: !transform.any_op):
@@ -317,43 +319,16 @@ func.func @vectorize_nd_tensor_extract_with_tensor_extract(%input_1: tensor<1x20
}
// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_with_tensor_extract(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x20xi32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<257x24xf32>,
// CHECK-SAME: %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: index, %[[VAL_4:.*]]: index, %[[VAL_5:.*]]: index) -> tensor<1x1x4xf32> {
// CHECK-DAG: %[[VAL_6:.*]] = arith.constant dense<0> : vector<1x1x4xindex>
// CHECK-DAG: %[[VAL_7:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 0 : i32
// CHECK-DAG: %[[VAL_9:.*]] = arith.constant dense<256> : vector<1x1x4xindex>
// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_12:.*]] = tensor.empty() : tensor<1x1x4xf32>
// CHECK: %[[VAL_13:.*]] = vector.broadcast %[[VAL_2]] : index to vector<1x1x4xindex>
// CHECK: %[[VAL_14:.*]] = vector.broadcast %[[VAL_4]] : index to vector<1x1x4xindex>
// CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_13]], %[[VAL_14]] : vector<1x1x4xindex>
// CHECK: %[[VAL_16:.*]] = vector.broadcast %[[VAL_3]] : index to vector<1x1x4xindex>
// CHECK: %[[VAL_17:.*]] = vector.broadcast %[[VAL_7]] : vector<4xindex> to vector<1x1x4xindex>
// CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_16]], %[[VAL_17]] : vector<1x1x4xindex>
// CHECK: %[[VAL_19:.*]] = vector.broadcast %[[VAL_5]] : index to vector<1x1x4xindex>
// CHECK: %[[VAL_20:.*]] = arith.addi %[[VAL_18]], %[[VAL_19]] : vector<1x1x4xindex>
// CHECK: %[[VAL_21:.*]] = vector.shape_cast %[[VAL_15]] : vector<1x1x4xindex> to vector<4xindex>
// CHECK: %[[VAL_22:.*]] = vector.extractelement %[[VAL_21]][%[[VAL_8]] : i32] : vector<4xindex>
// CHECK-SAME: %[[INPUT_1:.*]]: tensor<1x20xi32>,
// CHECK-SAME: %[[INPUT_2:.*]]: tensor<257x24xf32>,
// CHECK: %[[EXTRACTED_0_IDX_0:.*]] = arith.constant 0 : index
// CHECK: %[[EXTRACTED_0_IDX_1:.*]] = vector.extractelement %{{.*}}[%{{.*}} : i32] : vector<4xindex>
// First `tensor.extract` from the generic Op - loop invariant scalar load.
// CHECK: %[[VAL_23:.*]] = tensor.extract %[[VAL_0]][%[[VAL_11]], %[[VAL_22]]] : tensor<1x20xi32>
// CHECK: %[[VAL_24:.*]] = arith.index_cast %[[VAL_23]] : i32 to index
// CHECK: %[[VAL_25:.*]] = vector.broadcast %[[VAL_24]] : index to vector<1x1x4xindex>
// CHECK: %[[VAL_26:.*]] = arith.maxsi %[[VAL_25]], %[[VAL_6]] : vector<1x1x4xindex>
// CHECK: %[[VAL_27:.*]] = arith.minsi %[[VAL_26]], %[[VAL_9]] : vector<1x1x4xindex>
// CHECK: %[[VAL_28:.*]] = vector.shape_cast %[[VAL_27]] : vector<1x1x4xindex> to vector<4xindex>
// CHECK: %[[VAL_29:.*]] = vector.extractelement %[[VAL_28]][%[[VAL_8]] : i32] : vector<4xindex>
// CHECK: %[[VAL_30:.*]] = vector.shape_cast %[[VAL_20]] : vector<1x1x4xindex> to vector<4xindex>
// CHECK: %[[VAL_31:.*]] = vector.extractelement %[[VAL_30]][%[[VAL_8]] : i32] : vector<4xindex>
// CHECK: tensor.extract %[[INPUT_1]][%[[EXTRACTED_0_IDX_0]], %[[EXTRACTED_0_IDX_1]]] : tensor<1x20xi32>
// The following `tensor.extract` from the generic Op s a contiguous load (all Ops used
// for address calculation also satisfy the required conditions).
// CHECK: %[[VAL_32:.*]] = vector.transfer_read %[[VAL_1]][%[[VAL_29]], %[[VAL_31]]], %[[VAL_10]] {in_bounds = [true, true]} : tensor<257x24xf32>, vector<1x4xf32>
// CHECK: %[[VAL_33:.*]] = vector.broadcast %[[VAL_32]] : vector<1x4xf32> to vector<1x1x4xf32>
// CHECK: %[[VAL_34:.*]] = vector.transfer_write %[[VAL_33]], %[[VAL_12]][%[[VAL_11]], %[[VAL_11]], %[[VAL_11]]] {in_bounds = [true, true, true]} : vector<1x1x4xf32>, tensor<1x1x4xf32>
// CHECK: return %[[VAL_34]] : tensor<1x1x4xf32>
// CHECK: }
// CHECK: vector.transfer_read %[[INPUT_2]][%{{.*}}, %{{.*}}, %{{.*}} {in_bounds = [true, true]} : tensor<257x24xf32>, vector<1x4xf32>
transform.sequence failures(propagate) {
^bb1(%arg1: !transform.any_op):

View File

@@ -0,0 +1,78 @@
// RUN: mlir-opt %s -test-sink-vector-broadcast -split-input-file | FileCheck %s
// CHECK-LABEL: func.func @broadcast_scalar(
// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index) -> vector<1x4xindex> {
// CHECK: %[[ADD:.*]] = arith.addi %[[ARG_0]], %[[ARG_1]] : index
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex>
// CHECK: return %[[BCAST]] : vector<1x4xindex>
// CHECK: }
func.func @broadcast_scalar( %arg1: index, %arg2: index) -> vector<1x4xindex> {
%0 = vector.broadcast %arg1 : index to vector<1x4xindex>
%1 = vector.broadcast %arg2 : index to vector<1x4xindex>
%2 = arith.addi %0, %1 : vector<1x4xindex>
return %2 : vector<1x4xindex>
}
// -----
// CHECK-LABEL: func.func @broadcast_vector(
// CHECK-SAME: %[[ARG_0:.*]]: vector<4xf32>,
// CHECK-SAME: %[[ARG_1:.*]]: vector<4xf32>) -> vector<3x4xf32> {
// CHECK: %[[ADDF:.*]] = arith.addf %[[ARG_0]], %[[ARG_1]] : vector<4xf32>
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADDF]] : vector<4xf32> to vector<3x4xf32>
// CHECK: return %[[BCAST]] : vector<3x4xf32>
// CHECK: }
func.func @broadcast_vector( %arg1: vector<4xf32>, %arg2: vector<4xf32>) -> vector<3x4xf32> {
%arg1_bcast = vector.broadcast %arg1 : vector<4xf32> to vector<3x4xf32>
%arg2_bcast = vector.broadcast %arg2 : vector<4xf32> to vector<3x4xf32>
%2 = arith.addf %arg1_bcast, %arg2_bcast : vector<3x4xf32>
return %2 : vector<3x4xf32>
}
// -----
// CHECK-LABEL: func.func @broadcast_vector_and_scalar(
// CHECK-SAME: %[[ARG_0:.*]]: i32,
// CHECK-SAME: %[[ARG_1:.*]]: vector<4xi32>) -> vector<4xi32> {
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ARG_0]] : i32 to vector<4xi32>
// CHECK: %[[ADD:.*]] = arith.addi %[[BCAST]], %[[ARG_1]] : vector<4xi32>
// CHECK: return %[[ADD]] : vector<4xi32>
// CHECK: }
func.func @broadcast_vector_and_scalar( %arg1: i32, %arg2: vector<4xi32>) -> vector<4xi32> {
%arg1_bcast = vector.broadcast %arg1 : i32 to vector<4xi32>
%2 = arith.addi %arg1_bcast, %arg2 : vector<4xi32>
return %2 : vector<4xi32>
}
// -----
#matmat_accesses = [
affine_map<(i, j, k) -> (i, k)>,
affine_map<(i, j, k) -> (k, j)>,
affine_map<(i, j, k) -> (i, j)>
]
#matmat_trait = {
indexing_maps = #matmat_accesses,
iterator_types = ["parallel", "parallel", "reduction"]
}
// CHECK-LABEL: func.func @broadcast_not_elementwise() -> vector<2x2xf32> {
// CHECK-DAG: %[[VAL_0:.*]] = arith.constant dense<1.000000e+00> : vector<2x2xf32>
// CHECK-DAG: %[[VAL_1:.*]] = arith.constant dense<2.000000e+00> : vector<2x2xf32>
// CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<3.000000e+00> : vector<2x2xf32>
// CHECK: %[[VAL_3:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
func.func @broadcast_not_elementwise() -> vector<2x2xf32> {
%f1 = arith.constant 1.0: f32
%f2 = arith.constant 2.0: f32
%f3 = arith.constant 3.0: f32
%A = vector.broadcast %f1 : f32 to vector<2x2xf32>
%B = vector.broadcast %f2 : f32 to vector<2x2xf32>
%C = vector.broadcast %f3 : f32 to vector<2x2xf32>
%mm1 = vector.contract #matmat_trait %A, %B, %C
: vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
return %mm1 : vector<2x2xf32>
}

View File

@@ -374,6 +374,31 @@ struct TestVectorTransferCollapseInnerMostContiguousDims
}
};
struct TestSinkVectorBroadcast
: public PassWrapper<TestSinkVectorBroadcast, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSinkVectorBroadcast)
TestSinkVectorBroadcast() = default;
TestSinkVectorBroadcast(const TestSinkVectorBroadcast &pass) = default;
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<memref::MemRefDialect, affine::AffineDialect>();
}
StringRef getArgument() const final { return "test-sink-vector-broadcast"; }
StringRef getDescription() const final {
return "Test lowering patterns that eliminate redundant brodacast "
"operations.";
}
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateSinkVectorBroadcastPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
struct TestVectorReduceToContractPatternsPatterns
: public PassWrapper<TestVectorReduceToContractPatternsPatterns,
OperationPass<func::FuncOp>> {
@@ -735,6 +760,8 @@ void registerTestVectorLowerings() {
PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>();
PassRegistration<TestSinkVectorBroadcast>();
PassRegistration<TestVectorReduceToContractPatternsPatterns>();
PassRegistration<TestFlattenVectorTransferPatterns>();