[mlir][vector] Update CombineContractBroadcastMask (#140050)
This patch updates `CombineContractBroadcastMask` to inherit from
`MaskableOpRewritePattern`, enabling it to handle masked
`vector.contract` operations. The pattern rewrites:
```mlir
%a = vector.broadcast %a_bc
%res vector.contract %a_bc, %b, ...
```
into:
```mlir
// Move the broadcast into vector.contract (by updating the indexing
// maps)
%res vector.contract %a, %b, ...
```
The main challenge is supporting cases where the pattern drops a leading
unit dimension. For example:
```mlir
func.func @contract_broadcast_unit_dim_reduction_masked(
%arg0 : vector<8x4xi32>,
%arg1 : vector<8x4xi32>,
%arg2 : vector<8x8xi32>,
%mask: vector<1x8x8x4xi1>) -> vector<8x8xi32> {
%0 = vector.broadcast %arg0 : vector<8x4xi32> to vector<1x8x4xi32>
%1 = vector.broadcast %arg1 : vector<8x4xi32> to vector<1x8x4xi32>
%result = vector.mask %mask {
vector.contract {
indexing_maps = [#map0, #map1, #map2],
iterator_types = ["reduction", "parallel", "parallel", "reduction"],
kind = #vector.kind<add>
} %0, %1, %arg2 : vector<1x8x4xi32>, vector<1x8x4xi32> into vector<8x8xi32>
} : vector<1x8x8x4xi1> -> vector<8x8xi32>
return %result : vector<8x8xi32>
}
```
Here, the leading unit dimension is dropped. To handle this, the mask is
cast to the correct shape using a `vector.shape_cast`:
```mlir
func.func @contract_broadcast_unit_dim_reduction_masked(
%arg0: vector<8x4xi32>,
%arg1: vector<8x4xi32>,
%arg2: vector<8x8xi32>,
%arg3: vector<1x8x8x4xi1>) -> vector<8x8xi32> {
%mask_sc = vector.shape_cast %arg3 : vector<1x8x8x4xi1> to vector<8x8x4xi1>
%res = vector.mask %mask_sc {
vector.contract {
indexing_maps = [#map, #map1, #map2],
iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>
} %arg0, %arg1, %mask_sc : vector<8x4xi32>, vector<8x4xi32> into vector<8x8xi32>
} : vector<8x8x4xi1> -> vector<8x8xi32>
return %res : vector<8x8xi32>
}
```
While this isn't ideal - since it introduces a `vector.shape_cast` that
must be cleaned up later - it reflects the best we can do once the input
reaches `CombineContractBroadcastMask`. A more robust solution may
involve simplifying the input earlier. I am leaving that as a TODO for
myself to explore this further. Posting this now to unblock downstream
work.
LIMITATIONS
Currently, this pattern assumes:
* Only leading dimensions are dropped in the mask.
* All dropped dimensions must be unit-sized.
This commit is contained in:
committed by
GitHub
parent
e3e5bd1cb1
commit
e22508ea81
@@ -264,109 +264,172 @@ struct CombineContractResultTranspose final
|
||||
/// iterator_types = ["parallel", "parallel", "reduction"],
|
||||
/// kind = add} %arg0, %arg1, %cst_f0
|
||||
/// : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
|
||||
/// ```
|
||||
struct CombineContractBroadcast
|
||||
: public OpRewritePattern<vector::ContractionOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
SmallVector<AffineMap> maps =
|
||||
llvm::to_vector<4>(contractOp.getIndexingMapsArray());
|
||||
Value lhs = contractOp.getLhs();
|
||||
Value rhs = contractOp.getRhs();
|
||||
size_t index = 0;
|
||||
bool changed = false;
|
||||
for (Value *operand : {&lhs, &rhs}) {
|
||||
AffineMap &map = maps[index++];
|
||||
auto broadcast = operand->getDefiningOp<vector::BroadcastOp>();
|
||||
if (!broadcast)
|
||||
continue;
|
||||
// contractionOp can only take vector as operands.
|
||||
auto srcType = dyn_cast<VectorType>(broadcast.getSourceType());
|
||||
if (!srcType ||
|
||||
srcType.getRank() == broadcast.getResultVectorType().getRank())
|
||||
continue;
|
||||
int64_t rankDiff =
|
||||
broadcast.getResultVectorType().getRank() - srcType.getRank();
|
||||
bool innerDimBroadcast = false;
|
||||
SmallVector<AffineExpr> originalDims;
|
||||
for (const auto &dim : llvm::enumerate(srcType.getShape())) {
|
||||
if (dim.value() != broadcast.getResultVectorType().getDimSize(
|
||||
rankDiff + dim.index())) {
|
||||
innerDimBroadcast = true;
|
||||
break;
|
||||
}
|
||||
originalDims.push_back(
|
||||
rewriter.getAffineDimExpr(dim.index() + rankDiff));
|
||||
/// ```
|
||||
///
|
||||
/// For masked vector.contract, the mask requires updating when a dimension is
|
||||
/// dropped. In such cases, the dropped dimensions must correspond to the mask's
|
||||
/// leading unit dimensions. Supporting more generic cases (e.g. non-unit dims)
|
||||
/// is not supported.
|
||||
FailureOr<Value> combineContractAndBroadcast(vector::ContractionOp contractOp,
|
||||
MaskingOpInterface maskingOp,
|
||||
PatternRewriter &rewriter) {
|
||||
SmallVector<AffineMap> maps =
|
||||
llvm::to_vector<4>(contractOp.getIndexingMapsArray());
|
||||
Value lhs = contractOp.getLhs();
|
||||
Value rhs = contractOp.getRhs();
|
||||
size_t index = 0;
|
||||
bool changed = false;
|
||||
for (Value *operand : {&lhs, &rhs}) {
|
||||
AffineMap &map = maps[index++];
|
||||
auto broadcast = operand->getDefiningOp<vector::BroadcastOp>();
|
||||
if (!broadcast)
|
||||
continue;
|
||||
// contractionOp can only take vector as operands.
|
||||
auto srcType = dyn_cast<VectorType>(broadcast.getSourceType());
|
||||
if (!srcType ||
|
||||
srcType.getRank() == broadcast.getResultVectorType().getRank())
|
||||
continue;
|
||||
int64_t rankDiff =
|
||||
broadcast.getResultVectorType().getRank() - srcType.getRank();
|
||||
bool innerDimBroadcast = false;
|
||||
SmallVector<AffineExpr> originalDims;
|
||||
for (const auto &dim : llvm::enumerate(srcType.getShape())) {
|
||||
if (dim.value() !=
|
||||
broadcast.getResultVectorType().getDimSize(rankDiff + dim.index())) {
|
||||
innerDimBroadcast = true;
|
||||
break;
|
||||
}
|
||||
// Contract doesn't support inner dimension broadcast. Once this is
|
||||
// relaxed we can remove this case.
|
||||
if (innerDimBroadcast)
|
||||
continue;
|
||||
|
||||
// It would be incorrect to fold a broadcast onto a reduction dimension
|
||||
// of non-unit size.
|
||||
bool nonUnitDimReductionBroadcast = false;
|
||||
for (int64_t i = 0; i < rankDiff; ++i) {
|
||||
if (broadcast.getResultVectorType().getDimSize(i) != 1 &&
|
||||
isReductionIterator(contractOp.getIteratorTypes()
|
||||
.getValue()[map.getDimPosition(i)])) {
|
||||
nonUnitDimReductionBroadcast = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (nonUnitDimReductionBroadcast)
|
||||
continue;
|
||||
|
||||
AffineMap broadcastMap =
|
||||
AffineMap::get(broadcast.getResultVectorType().getRank(), 0,
|
||||
originalDims, contractOp.getContext());
|
||||
map = broadcastMap.compose(map);
|
||||
*operand = broadcast.getSource();
|
||||
changed = true;
|
||||
originalDims.push_back(rewriter.getAffineDimExpr(dim.index() + rankDiff));
|
||||
}
|
||||
// Contract doesn't support inner dimension broadcast. Once this is
|
||||
// relaxed we can remove this case.
|
||||
if (innerDimBroadcast)
|
||||
continue;
|
||||
|
||||
if (!changed)
|
||||
return failure();
|
||||
|
||||
// Determine which dims are usused, now that the maps have been composed
|
||||
// with the broadcast maps.
|
||||
llvm::SmallBitVector unusedDimsBitVector = getUnusedDimsBitVector(maps);
|
||||
// Compress unused dims.
|
||||
for (auto &m : maps)
|
||||
m = compressDims(m, unusedDimsBitVector);
|
||||
// Compute the combined iterators.
|
||||
SmallVector<Attribute> iterators;
|
||||
for (unsigned i = 0; i < unusedDimsBitVector.size(); ++i) {
|
||||
if (!unusedDimsBitVector.test(i))
|
||||
iterators.push_back(contractOp.getIteratorTypes().getValue()[i]);
|
||||
}
|
||||
// Check that compressing unused dims isn't removing all reduction dimension
|
||||
// pairs. For example, if the vector.contract had only one reduction
|
||||
// iterator and that was a unit-dimension created by a broadcast,
|
||||
// then we should bail here, otherwise we would create a contract without
|
||||
// a reduction dimension pair.
|
||||
bool hasReductionIteratorApplyingOnBothSides = false;
|
||||
for (unsigned i = 0; i < iterators.size(); ++i) {
|
||||
if (!isReductionIterator(iterators[i]))
|
||||
continue;
|
||||
if (getResultIndex(maps[0], i) && getResultIndex(maps[1], i)) {
|
||||
hasReductionIteratorApplyingOnBothSides = true;
|
||||
// It would be incorrect to fold a broadcast onto a reduction dimension
|
||||
// of non-unit size.
|
||||
bool nonUnitDimReductionBroadcast = false;
|
||||
for (int64_t i = 0; i < rankDiff; ++i) {
|
||||
if (broadcast.getResultVectorType().getDimSize(i) != 1 &&
|
||||
isReductionIterator(contractOp.getIteratorTypes()
|
||||
.getValue()[map.getDimPosition(i)])) {
|
||||
nonUnitDimReductionBroadcast = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!hasReductionIteratorApplyingOnBothSides)
|
||||
return failure();
|
||||
if (nonUnitDimReductionBroadcast)
|
||||
continue;
|
||||
|
||||
// If the compressed maps have a dimension that is not used by either LHS or
|
||||
// RHS then the ContractionOp verifier would fail.
|
||||
if (getUnusedDimsBitVector({maps[0], maps[1]}).any())
|
||||
return failure();
|
||||
rewriter.replaceOpWithNewOp<vector::ContractionOp>(
|
||||
contractOp, lhs, rhs, contractOp.getAcc(),
|
||||
rewriter.getAffineMapArrayAttr(maps), rewriter.getArrayAttr(iterators));
|
||||
return success();
|
||||
AffineMap broadcastMap =
|
||||
AffineMap::get(broadcast.getResultVectorType().getRank(), 0,
|
||||
originalDims, contractOp.getContext());
|
||||
map = broadcastMap.compose(map);
|
||||
*operand = broadcast.getSource();
|
||||
changed = true;
|
||||
}
|
||||
|
||||
if (!changed)
|
||||
return failure();
|
||||
|
||||
// Determine which dims are usused, now that the maps have been composed
|
||||
// with the broadcast maps.
|
||||
llvm::SmallBitVector unusedDimsBitVector = getUnusedDimsBitVector(maps);
|
||||
// Compress unused dims.
|
||||
for (auto &m : maps)
|
||||
m = compressDims(m, unusedDimsBitVector);
|
||||
// Compute the combined iterators.
|
||||
SmallVector<Attribute> iterators;
|
||||
for (unsigned i = 0, e = unusedDimsBitVector.size(); i < e; ++i) {
|
||||
if (!unusedDimsBitVector.test(i))
|
||||
iterators.push_back(contractOp.getIteratorTypes().getValue()[i]);
|
||||
}
|
||||
|
||||
// Check whether any of the unused dims is non-unit, e.g.:
|
||||
// * vector.broadcast %arg0 : vector<8x4xi32> to vector<2x8x4xi32>
|
||||
// This is only required when collapsing a mask. If there is no mask, skip.
|
||||
VectorType oldMaskType;
|
||||
bool isAnyUnusedDimNonUnit = false;
|
||||
if (maskingOp) {
|
||||
oldMaskType = cast<VectorType>(maskingOp.getMask().getType());
|
||||
for (unsigned i = 0, e = unusedDimsBitVector.size(); i < e; ++i) {
|
||||
if (unusedDimsBitVector.test(i) && oldMaskType.getShape()[i] != 1) {
|
||||
isAnyUnusedDimNonUnit = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check that compressing unused dims isn't removing all reduction dimension
|
||||
// pairs. For example, if the vector.contract had only one reduction
|
||||
// iterator and that was a unit-dimension created by a broadcast,
|
||||
// then we should bail here, otherwise we would create a contract without
|
||||
// a reduction dimension pair.
|
||||
bool hasReductionIteratorApplyingOnBothSides = false;
|
||||
for (unsigned i = 0; i < iterators.size(); ++i) {
|
||||
if (!isReductionIterator(iterators[i]))
|
||||
continue;
|
||||
if (getResultIndex(maps[0], i) && getResultIndex(maps[1], i)) {
|
||||
hasReductionIteratorApplyingOnBothSides = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!hasReductionIteratorApplyingOnBothSides)
|
||||
return failure();
|
||||
|
||||
// If the compressed maps have a dimension that is not used by either LHS or
|
||||
// RHS then the ContractionOp verifier would fail.
|
||||
if (getUnusedDimsBitVector({maps[0], maps[1]}).any())
|
||||
return failure();
|
||||
|
||||
Operation *newOp = rewriter.create<vector::ContractionOp>(
|
||||
contractOp.getLoc(), lhs, rhs, contractOp.getAcc(),
|
||||
rewriter.getAffineMapArrayAttr(maps), rewriter.getArrayAttr(iterators));
|
||||
|
||||
// Handle the mask.
|
||||
if (maskingOp) {
|
||||
if (isAnyUnusedDimNonUnit)
|
||||
return rewriter.notifyMatchFailure(contractOp,
|
||||
"Cannont drop non-unit mask dim.");
|
||||
assert(unusedDimsBitVector.size() ==
|
||||
static_cast<size_t>(oldMaskType.getRank()) &&
|
||||
"The mask rank is incorrect!");
|
||||
|
||||
// If a dimension has been dropped, update the mask accordingly. Otherwise,
|
||||
// keep it as is.
|
||||
Value mask = maskingOp.getMask();
|
||||
if (unusedDimsBitVector.count() != 0) {
|
||||
// At this point, two assumptions are made:
|
||||
// * The unused dimensions are the leading mask dimensions
|
||||
// (vector.contract does not support inner dim broadcasting).
|
||||
// * The unused dimensions are all unit.
|
||||
// These conditions are effectively verified in the blocks preceeding this
|
||||
// one.
|
||||
auto newShape =
|
||||
oldMaskType.getShape().drop_front(unusedDimsBitVector.count());
|
||||
auto newShapeScalableDims =
|
||||
oldMaskType.getScalableDims().drop_front(unusedDimsBitVector.count());
|
||||
VectorType maskOpType =
|
||||
VectorType::get(newShape, rewriter.getI1Type(), newShapeScalableDims);
|
||||
mask = rewriter
|
||||
.create<vector::ShapeCastOp>(contractOp.getLoc(), maskOpType,
|
||||
maskingOp.getMask())
|
||||
.getResult();
|
||||
}
|
||||
|
||||
newOp = mlir::vector::maskOperation(rewriter, newOp, mask);
|
||||
}
|
||||
return newOp->getResult(0);
|
||||
}
|
||||
|
||||
struct CombineContractBroadcastMask
|
||||
: public MaskableOpRewritePattern<vector::ContractionOp> {
|
||||
using MaskableOpRewritePattern::MaskableOpRewritePattern;
|
||||
FailureOr<Value>
|
||||
|
||||
matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
|
||||
MaskingOpInterface maskingOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
return combineContractAndBroadcast(contractOp, maskingOp, rewriter);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -2237,7 +2300,7 @@ void mlir::vector::populateVectorContractCanonicalizeMatmulToMMT(
|
||||
|
||||
void mlir::vector::populateVectorReductionToContractPatterns(
|
||||
RewritePatternSet &patterns, PatternBenefit benefit) {
|
||||
patterns.add<MultiReduceToContract, CombineContractBroadcast,
|
||||
patterns.add<MultiReduceToContract, CombineContractBroadcastMask,
|
||||
CombineContractABTranspose, CombineContractResultTranspose>(
|
||||
patterns.getContext(), benefit);
|
||||
}
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
// RUN: mlir-opt %s -test-vector-reduction-to-contract-patterns -split-input-file | FileCheck %s
|
||||
|
||||
// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
|
||||
// TODO: Seperate tests for vector.multi_reduction -> vector.contract and
|
||||
// * pre-op + vector.contract -> vector.contract,
|
||||
// * vector.contract + post-op -> vector.contract.
|
||||
|
||||
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
|
||||
|
||||
// CHECK-LABEL: multidimreduction_contract
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: vector<8x32x16xf32>, %[[ARG1:.*]]: vector<8x32x16xf32>, %[[ARG2:.*]]: vector<8x16xf32>)
|
||||
// CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map0]], #[[$map1]]],
|
||||
// CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP1]]],
|
||||
// CHECK-SAME: iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>}
|
||||
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x16xf32>
|
||||
// CHECK-NEXT: return %[[R]] : vector<8x16xf32>
|
||||
@@ -13,17 +17,16 @@ func.func @multidimreduction_contract(
|
||||
%arg0: vector<8x32x16xf32>,%arg1: vector<8x32x16xf32>, %acc: vector<8x16xf32>) -> vector<8x16xf32> {
|
||||
%0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32>
|
||||
%1 = vector.multi_reduction <add>, %0, %acc [1] : vector<8x32x16xf32> to vector<8x16xf32>
|
||||
return %1 : vector<8x16xf32>
|
||||
}
|
||||
return %1 : vector<8x16xf32> }
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
|
||||
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
|
||||
|
||||
// CHECK-LABEL: multidimreduction_contract_int
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: vector<8x32x16xi32>, %[[ARG1:.*]]: vector<8x32x16xi32>, %[[ARG2:.*]]: vector<8x16xi32>)
|
||||
// CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map0]], #[[$map1]]],
|
||||
// CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP1]]],
|
||||
// CHECK-SAME: iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>}
|
||||
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<8x32x16xi32>, vector<8x32x16xi32> into vector<8x16xi32>
|
||||
// CHECK-NEXT: return %[[R]] : vector<8x16xi32>
|
||||
@@ -36,17 +39,21 @@ func.func @multidimreduction_contract_int(
|
||||
|
||||
// -----
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
// [Pattern: CombineContractABTranspose]
|
||||
//-----------------------------------------------------------------------------
|
||||
|
||||
#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
#map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
|
||||
// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
|
||||
// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
|
||||
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
|
||||
// CHECK-LABEL: contract_transpose
|
||||
// CHECK-SAME: (%[[ARG0:.+]]: vector<32x16x8xf32>,
|
||||
// CHECK-NEXT: %[[C0:.+]] = arith.constant dense<0.000000e+00> : vector<8x32xf32>
|
||||
// CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]],
|
||||
// CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]],
|
||||
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
|
||||
// CHECK-SAME: %[[ARG0]], %{{.*}}, %[[C0]] : vector<32x16x8xf32>, vector<8x32x16xf32> into vector<8x32xf32>
|
||||
// CHECK-NEXT: return %[[R]] : vector<8x32xf32>
|
||||
@@ -62,17 +69,21 @@ func.func @contract_transpose(
|
||||
|
||||
// -----
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
// [Pattern: CombineContractBroadcast]
|
||||
//-----------------------------------------------------------------------------
|
||||
|
||||
#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
#map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
|
||||
// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
|
||||
// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
|
||||
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
|
||||
// CHECK-LABEL: contract_broadcast
|
||||
// CHECK-SAME: (%[[ARG0:.+]]: vector<32x16xf32>,
|
||||
// CHECK-NEXT: %[[C0:.+]] = arith.constant dense<0.000000e+00> : vector<8x32xf32>
|
||||
// CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]],
|
||||
// CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]],
|
||||
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
|
||||
// CHECK-SAME: %[[ARG0]], %{{.*}}, %[[C0]] : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
|
||||
// CHECK-NEXT: return %[[R]] : vector<8x32xf32>
|
||||
@@ -87,6 +98,79 @@ func.func @contract_broadcast(
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Same as above, but with a mask.
|
||||
|
||||
#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
#map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
|
||||
// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
|
||||
// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
|
||||
// CHECK-LABEL: contract_broadcast_masked
|
||||
// CHECK-SAME: %[[ARG0:.*]]: vector<32x16xf32>,
|
||||
// CHECK-SAME: %[[ARG1:.*]]: vector<8x32x16xf32>,
|
||||
// CHECK-SAME: %[[MASK:.*]]: vector<8x32x16xi1>) -> vector<8x32xf32> {
|
||||
// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<8x32xf32>
|
||||
// CHECK: %[[R:.*]] = vector.mask %[[MASK]] {
|
||||
// CHECK-SAME: vector.contract {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]],
|
||||
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"],
|
||||
// CHECK-SAME: kind = #vector.kind<add>}
|
||||
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[C0]] : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
|
||||
// CHECK-SAME } : vector<8x32x16xi1> -> vector<8x32xf32>
|
||||
// CHECK: return %[[R]] : vector<8x32xf32>
|
||||
func.func @contract_broadcast_masked(
|
||||
%arg0: vector<32x16xf32>, %arg1: vector<8x32x16xf32>, %mask: vector<8x32x16xi1>) -> vector<8x32xf32> {
|
||||
%cst = arith.constant dense<0.000000e+00> : vector<8x32xf32>
|
||||
%0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32>
|
||||
%1 = vector.mask %mask {
|
||||
vector.contract {indexing_maps = [#map0, #map0, #map1],
|
||||
iterator_types = ["parallel", "parallel", "reduction"],
|
||||
kind = #vector.kind<add>
|
||||
} %0, %arg1, %cst : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
|
||||
} : vector<8x32x16xi1> -> vector<8x32xf32>
|
||||
return %1 : vector<8x32xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Same as above, but with a scalable dim.
|
||||
|
||||
#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
#map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
|
||||
// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
|
||||
// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
|
||||
// CHECK-LABEL: contract_broadcast_masked_scalable
|
||||
// CHECK-SAME: %[[ARG0:.*]]: vector<[32]x16xf32>,
|
||||
// CHECK-SAME: %[[ARG1:.*]]: vector<8x[32]x16xf32>,
|
||||
// CHECK-SAME: %[[MASK:.*]]: vector<8x[32]x16xi1>) -> vector<8x32xf32> {
|
||||
// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<8x32xf32>
|
||||
// CHECK: %[[R:.*]] = vector.mask %[[MASK]] {
|
||||
// CHECK-SAME: vector.contract {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]],
|
||||
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"],
|
||||
// CHECK-SAME: kind = #vector.kind<add>}
|
||||
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[C0]] : vector<[32]x16xf32>, vector<8x[32]x16xf32> into vector<8x32xf32>
|
||||
// CHECK-SAME } : vector<8x[32]x16xi1> -> vector<8x32xf32>
|
||||
// CHECK: return %[[R]] : vector<8x32xf32>
|
||||
func.func @contract_broadcast_masked_scalable(
|
||||
%arg0: vector<[32]x16xf32>, %arg1: vector<8x[32]x16xf32>, %mask: vector<8x[32]x16xi1>) -> vector<8x32xf32> {
|
||||
%cst = arith.constant dense<0.000000e+00> : vector<8x32xf32>
|
||||
%0 = vector.broadcast %arg0 : vector<[32]x16xf32> to vector<8x[32]x16xf32>
|
||||
%1 = vector.mask %mask {
|
||||
vector.contract {indexing_maps = [#map0, #map0, #map1],
|
||||
iterator_types = ["parallel", "parallel", "reduction"],
|
||||
kind = #vector.kind<add>
|
||||
} %0, %arg1, %cst : vector<8x[32]x16xf32>, vector<8x[32]x16xf32> into vector<8x32xf32>
|
||||
} : vector<8x[32]x16xi1> -> vector<8x32xf32>
|
||||
return %1 : vector<8x32xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test that CombineContractBroadcast is able to combine a broadcast that
|
||||
// creates a unit dim that is consumed by a reduction iterator, dropping that
|
||||
// reduction iterator, as long as there is another reduction iterator left.
|
||||
@@ -95,14 +179,14 @@ func.func @contract_broadcast(
|
||||
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
|
||||
#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
|
||||
|
||||
// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
|
||||
// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
|
||||
// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
|
||||
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
|
||||
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
|
||||
// CHECK-LABEL: contract_broadcast_unit_dim_reduction
|
||||
// CHECK-SAME: (%[[ARG0:.+]]: vector<8x4xi32>, %[[ARG1:.+]]: vector<8x4xi32>, %[[ARG2:.+]]: vector<8x8xi32>)
|
||||
// CHECK: vector.contract
|
||||
// CHECK-SAME: indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]]
|
||||
// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
|
||||
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
|
||||
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<8x4xi32>, vector<8x4xi32> into vector<8x8xi32>
|
||||
func.func @contract_broadcast_unit_dim_reduction(%arg0 : vector<8x4xi32>, %arg1 : vector<8x4xi32>, %arg2 : vector<8x8xi32>) -> vector<8x8xi32> {
|
||||
@@ -116,6 +200,72 @@ func.func @contract_broadcast_unit_dim_reduction(%arg0 : vector<8x4xi32>, %arg1
|
||||
return %result : vector<8x8xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Same as above, but with a mask.
|
||||
|
||||
#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
|
||||
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
|
||||
#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
|
||||
|
||||
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
|
||||
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
|
||||
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
|
||||
// CHECK-LABEL: contract_broadcast_unit_dim_reduction_masked
|
||||
// CHECK-SAME: (%[[ARG0:.+]]: vector<8x4xi32>, %[[ARG1:.+]]: vector<8x4xi32>, %[[ARG2:.+]]: vector<8x8xi32>, %[[MASK:.+]]: vector<1x8x8x4xi1>)
|
||||
// CHECK: %[[MASK_SC:.*]] = vector.shape_cast %[[MASK]] : vector<1x8x8x4xi1> to vector<8x8x4xi1>
|
||||
// CHECK: %[[R:.*]] = vector.mask %[[MASK_SC]] {
|
||||
// CHECK-SAME: vector.contract
|
||||
// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
|
||||
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
|
||||
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<8x4xi32>, vector<8x4xi32> into vector<8x8xi32>
|
||||
func.func @contract_broadcast_unit_dim_reduction_masked(%arg0 : vector<8x4xi32>, %arg1 : vector<8x4xi32>, %arg2 : vector<8x8xi32>, %mask: vector<1x8x8x4xi1>) -> vector<8x8xi32> {
|
||||
%0 = vector.broadcast %arg0 : vector<8x4xi32> to vector<1x8x4xi32>
|
||||
%1 = vector.broadcast %arg1 : vector<8x4xi32> to vector<1x8x4xi32>
|
||||
%result = vector.mask %mask {
|
||||
vector.contract {
|
||||
indexing_maps = [#map0, #map1, #map2],
|
||||
iterator_types = ["reduction", "parallel", "parallel", "reduction"],
|
||||
kind = #vector.kind<add>
|
||||
} %0, %1, %arg2 : vector<1x8x4xi32>, vector<1x8x4xi32> into vector<8x8xi32>
|
||||
} : vector<1x8x8x4xi1> -> vector<8x8xi32>
|
||||
return %result : vector<8x8xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Same as above, but with a scalable dim.
|
||||
|
||||
#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
|
||||
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
|
||||
#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
|
||||
|
||||
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
|
||||
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
|
||||
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
|
||||
// CHECK-LABEL: contract_broadcast_unit_dim_reduction_masked_scalable
|
||||
// CHECK-SAME: (%[[ARG0:.+]]: vector<8x4xi32>, %[[ARG1:.+]]: vector<[8]x4xi32>, %[[ARG2:.+]]: vector<8x[8]xi32>, %[[MASK:.+]]: vector<1x8x[8]x4xi1>)
|
||||
// CHECK: %[[MASK_SC:.*]] = vector.shape_cast %[[MASK]] : vector<1x8x[8]x4xi1> to vector<8x[8]x4xi1>
|
||||
// CHECK: %[[R:.*]] = vector.mask %[[MASK_SC]] {
|
||||
// CHECK-SAME: vector.contract
|
||||
// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
|
||||
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
|
||||
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<8x4xi32>, vector<[8]x4xi32> into vector<8x[8]xi32>
|
||||
func.func @contract_broadcast_unit_dim_reduction_masked_scalable(%arg0 : vector<8x4xi32>, %arg1 : vector<[8]x4xi32>, %arg2 : vector<8x[8]xi32>, %mask: vector<1x8x[8]x4xi1>) -> vector<8x[8]xi32> {
|
||||
%0 = vector.broadcast %arg0 : vector<8x4xi32> to vector<1x8x4xi32>
|
||||
%1 = vector.broadcast %arg1 : vector<[8]x4xi32> to vector<1x[8]x4xi32>
|
||||
%result = vector.mask %mask {
|
||||
vector.contract {
|
||||
indexing_maps = [#map0, #map1, #map2],
|
||||
iterator_types = ["reduction", "parallel", "parallel", "reduction"],
|
||||
kind = #vector.kind<add>
|
||||
} %0, %1, %arg2 : vector<1x8x4xi32>, vector<1x[8]x4xi32> into vector<8x[8]xi32>
|
||||
} : vector<1x8x[8]x4xi1> -> vector<8x[8]xi32>
|
||||
return %result : vector<8x[8]xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// Test that CombineContractBroadcast will not combine a broadcast that creates
|
||||
// a non-unit dim that is consumed by a reduction iterator.
|
||||
@@ -127,16 +277,16 @@ func.func @contract_broadcast_unit_dim_reduction(%arg0 : vector<8x4xi32>, %arg1
|
||||
#map1 = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
|
||||
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d2)>
|
||||
|
||||
// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d0, d3)>
|
||||
// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
|
||||
// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d2)>
|
||||
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d0, d3)>
|
||||
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
|
||||
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d2)>
|
||||
|
||||
// CHECK-LABEL: contract_broadcast_non_unit_dim_reduction_with_permutation
|
||||
// CHECK-SAME: (%[[ARG0:.+]]: vector<8x4xi32>, %[[ARG1:.+]]: vector<8x4xi32>, %[[ARG2:.+]]: vector<8x8xi32>)
|
||||
// CHECK: %[[BROADCAST0:.+]] = vector.broadcast %[[ARG0]] : vector<8x4xi32> to vector<2x8x4xi32>
|
||||
// CHECK: %[[BROADCAST1:.+]] = vector.broadcast %[[ARG1]] : vector<8x4xi32> to vector<2x8x4xi32>
|
||||
// CHECK: vector.contract
|
||||
// CHECK-SAME: indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]]
|
||||
// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
|
||||
// CHECK-SAME: iterator_types = ["parallel", "reduction", "parallel", "reduction"]
|
||||
// CHECK-SAME: %[[BROADCAST0]], %[[BROADCAST1]], %[[ARG2]] : vector<2x8x4xi32>, vector<2x8x4xi32> into vector<8x8xi32>
|
||||
func.func @contract_broadcast_non_unit_dim_reduction_with_permutation(%arg0 : vector<8x4xi32>, %arg1 : vector<8x4xi32>, %arg2 : vector<8x8xi32>) -> vector<8x8xi32> {
|
||||
@@ -159,16 +309,16 @@ func.func @contract_broadcast_non_unit_dim_reduction_with_permutation(%arg0 : ve
|
||||
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
|
||||
#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
|
||||
|
||||
// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
|
||||
// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
|
||||
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
|
||||
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
|
||||
|
||||
// CHECK-LABEL: contract_broadcast_unit_dim_reduction_as_only_reduction
|
||||
// CHECK-SAME: (%[[ARG0:.+]]: vector<8xi32>, %[[ARG1:.+]]: vector<8xi32>, %[[ARG2:.+]]: vector<8x8xi32>)
|
||||
// CHECK: %[[BROADCAST0:.+]] = vector.broadcast %[[ARG0]] : vector<8xi32> to vector<1x8xi32>
|
||||
// CHECK: %[[BROADCAST1:.+]] = vector.broadcast %[[ARG1]] : vector<8xi32> to vector<1x8xi32>
|
||||
// CHECK: vector.contract
|
||||
// CHECK-SAME: indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]]
|
||||
// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
|
||||
// CHECK-SAME: iterator_types = ["reduction", "parallel", "parallel"]
|
||||
// CHECK-SAME: %[[BROADCAST0]], %[[BROADCAST1]], %[[ARG2]] : vector<1x8xi32>, vector<1x8xi32> into vector<8x8xi32>
|
||||
func.func @contract_broadcast_unit_dim_reduction_as_only_reduction(%arg0 : vector<8xi32>, %arg1 : vector<8xi32>, %arg2 : vector<8x8xi32>) -> vector<8x8xi32> {
|
||||
@@ -191,15 +341,15 @@ func.func @contract_broadcast_unit_dim_reduction_as_only_reduction(%arg0 : vecto
|
||||
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
#map2 = affine_map<(d0, d1, d2) -> (d1)>
|
||||
|
||||
// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
|
||||
// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d1)>
|
||||
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
|
||||
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d1)>
|
||||
|
||||
// CHECK-LABEL: contract_broadcast_dimension_would_go_unused_in_lhs_rhs
|
||||
// CHECK-SAME: (%[[ARG0:.+]]: vector<1x2xi32>, %[[ARG1:.+]]: vector<2xi32>, %[[ARG2:.+]]: vector<1xi32>)
|
||||
// CHECK: %[[BROADCAST1:.+]] = vector.broadcast %[[ARG1]] : vector<2xi32> to vector<1x1x2xi32>
|
||||
// CHECK: vector.contract
|
||||
// CHECK-SAME: indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]]
|
||||
// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
|
||||
// CHECK-SAME: iterator_types = ["reduction", "parallel", "reduction"]
|
||||
// CHECK-SAME: %[[ARG0]], %[[BROADCAST1]], %[[ARG2]] : vector<1x2xi32>, vector<1x1x2xi32> into vector<1xi32>
|
||||
|
||||
@@ -230,7 +380,7 @@ func.func @contract_broadcast_dimension_would_go_unused_in_lhs_rhs(%arg0 : vecto
|
||||
// CHECK-SAME: (%[[ARG0:.+]]: vector<1xf32>, %[[ARG1:.+]]: vector<1xf32>, %[[ARG2:.+]]: vector<1xf32>)
|
||||
// CHECK: %[[BROADCAST1:.+]] = vector.broadcast %[[ARG1]] : vector<1xf32> to vector<1x1xf32>
|
||||
// CHECK: vector.contract
|
||||
// CHECK-SAME: indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]]
|
||||
// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
|
||||
// CHECK-SAME: iterator_types = ["parallel", "reduction"]
|
||||
// CHECK-SAME: %[[ARG0]], %[[BROADCAST1]], %[[ARG2]] : vector<1xf32>, vector<1x1xf32> into vector<1xf32>
|
||||
|
||||
@@ -247,6 +397,10 @@ func.func @contract_broadcast_would_have_no_reduction_dim_pair(%arg0 : vector<1x
|
||||
|
||||
// -----
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
// [Pattern: CombineContractResultTranspose]
|
||||
//-----------------------------------------------------------------------------
|
||||
|
||||
// CHECK-DAG: #[[$LHS_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>
|
||||
// CHECK-DAG: #[[$RHS_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
|
||||
// CHECK-DAG: #[[$ACC_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>
|
||||
|
||||
Reference in New Issue
Block a user