Patterns in `LowerContractionToSMMLAPattern` are designed to handle vector-to-matrix multiplication but not matrix-to-vector. This leads to the following error when processing `rhs` with rank < 2: ``` iree-compile: /usr/local/google/home/kooljblack/code/iree-build/llvm-project/tools/mlir/include/mlir/IR/BuiltinTypeInterfaces.h.inc:268: int64_t mlir::detail::ShapedTypeTrait<mlir::VectorType>::getDimSize(unsigned int) const [ConcreteType = mlir::VectorType]: Assertion `idx < getRank() && "invalid index for shaped type"' failed. ``` Updates to explicitly check the rhs rank and fail cases that cannot process.
234 lines
9.9 KiB
C++
234 lines
9.9 KiB
C++
//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements lowering patterns from vector.contract to
|
|
// arm_neon.intr.smmla
|
|
//
|
|
//===---
|
|
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
|
|
#include "mlir/Dialect/ArmNeon/Transforms.h"
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
|
#include "mlir/Dialect/Utils/IndexingUtils.h"
|
|
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
|
#include "mlir/IR/AffineMap.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/Support/LogicalResult.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
|
|
#define DEBUG_TYPE "lower-contract-to-arm-neon"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::arm_neon;
|
|
|
|
namespace {
|
|
|
|
/// Return the shaped type with new element type.
|
|
static Type matchContainerType(Type element, Type container) {
|
|
if (auto shapedTy = dyn_cast<ShapedType>(container)) {
|
|
return shapedTy.clone(element);
|
|
}
|
|
return element;
|
|
}
|
|
|
|
/// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile
|
|
/// any vector.contract into multiple smmla instructions with unrolling so long
|
|
/// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM
|
|
/// = 1 (either explicitly or inferred if LHS has only dimK) If no unrolling is
|
|
/// necessary, a single smmla instruction is emitted.
|
|
class LowerContractionToSMMLAPattern
|
|
: public OpRewritePattern<vector::ContractionOp> {
|
|
public:
|
|
using OpRewritePattern::OpRewritePattern;
|
|
LogicalResult matchAndRewrite(vector::ContractionOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
Location loc = op.getLoc();
|
|
// Infer tile sizes from operands. For vecmat, LHS may only have 1 dim.
|
|
// Note: RHS is not transposed.
|
|
mlir::VectorType lhsType = op.getLhsType();
|
|
mlir::VectorType rhsType = op.getRhsType();
|
|
// Avoid 0-D vectors and 1-D rhs:
|
|
if (!lhsType.hasRank() || !rhsType.hasRank() || rhsType.getRank() < 2)
|
|
return failure();
|
|
auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0);
|
|
auto dimN = rhsType.getDimSize(0);
|
|
auto dimK = rhsType.getDimSize(1);
|
|
bool isVecmat = dimM == 1 ? true : false;
|
|
if (lhsType.getDimSize(lhsType.getRank() - 1) !=
|
|
rhsType.getDimSize(rhsType.getRank() - 1)) {
|
|
return failure(); // dimK mismatch
|
|
}
|
|
// Unrolling patterns can handle any [2, 2, 8] shaped multiple of inputs for
|
|
// tiling.
|
|
if ((dimM % 2 != 0 && !isVecmat) || dimN % 2 != 0 || dimK % 8 != 0) {
|
|
return failure();
|
|
}
|
|
|
|
// Check iterator types for contract. All iterators except inner-most
|
|
// dimension must be parallel.
|
|
auto iteratorTypes = op.getIteratorTypesArray();
|
|
if (iteratorTypes.size() > 3 || iteratorTypes[iteratorTypes.size() - 1] !=
|
|
vector::IteratorType::reduction) {
|
|
return failure();
|
|
}
|
|
if (llvm::any_of(ArrayRef<vector::IteratorType>(iteratorTypes).drop_back(1),
|
|
[](vector::IteratorType iteratorType) {
|
|
return iteratorType != vector::IteratorType::parallel;
|
|
})) {
|
|
return failure();
|
|
}
|
|
|
|
// Check two extsi inputs Rhs Lhs for contract.
|
|
arith::ExtSIOp origLhsExtOp =
|
|
dyn_cast_or_null<arith::ExtSIOp>(op.getLhs().getDefiningOp());
|
|
arith::ExtSIOp origRhsExtOp =
|
|
dyn_cast_or_null<arith::ExtSIOp>(op.getRhs().getDefiningOp());
|
|
if (!origLhsExtOp || !origRhsExtOp) {
|
|
return failure();
|
|
}
|
|
|
|
// Match any iX to i32 for X<8 then turn into an i8 output. Feed into
|
|
// following neon instruction. Check inputs for extsi are <=i8
|
|
Value extsiLhs;
|
|
Value extsiRhs;
|
|
if (auto lhsExtInType =
|
|
origLhsExtOp.getIn().getType().dyn_cast<mlir::VectorType>()) {
|
|
if (lhsExtInType.getElementTypeBitWidth() <= 8) {
|
|
Type targetLhsExtTy =
|
|
matchContainerType(rewriter.getI8Type(), lhsExtInType);
|
|
extsiLhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetLhsExtTy,
|
|
origLhsExtOp.getIn());
|
|
}
|
|
}
|
|
if (auto rhsExtInType =
|
|
origRhsExtOp.getIn().getType().dyn_cast<mlir::VectorType>()) {
|
|
if (rhsExtInType.getElementTypeBitWidth() <= 8) {
|
|
Type targetRhsExtTy =
|
|
matchContainerType(rewriter.getI8Type(), rhsExtInType);
|
|
extsiRhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetRhsExtTy,
|
|
origRhsExtOp.getIn());
|
|
}
|
|
}
|
|
|
|
if (!extsiLhs || !extsiRhs) {
|
|
return failure();
|
|
}
|
|
|
|
// Initial accumulator for the final result. This is the un-tiled result if
|
|
// tiling is done.
|
|
Value result = rewriter.create<arith::ConstantOp>(
|
|
loc, op.getResultType(), rewriter.getZeroAttr(op.getResultType()));
|
|
|
|
SmallVector<int64_t> unrolledSize = *op.getShapeForUnroll();
|
|
SmallVector<int64_t> smmlaShape{2, 8};
|
|
SmallVector<int64_t> loopOrder{0, 1};
|
|
if (unrolledSize.size() == 3) {
|
|
smmlaShape.insert(smmlaShape.begin(), isVecmat ? 1 : 2);
|
|
loopOrder.push_back(2);
|
|
}
|
|
for (SmallVector<int64_t> offsets :
|
|
StaticTileOffsetRange(unrolledSize, smmlaShape, loopOrder)) {
|
|
// Helper to compute the new shape of each operand and extract the slice.
|
|
auto extractOperand = [&](Value operand, AffineMap permutationMap,
|
|
ArrayRef<int64_t> operandOffsets) {
|
|
SmallVector<int64_t> operandShape =
|
|
applyPermutationMap(permutationMap, ArrayRef<int64_t>(smmlaShape));
|
|
SmallVector<int64_t> operandStrides(operandOffsets.size(), 1);
|
|
return rewriter.createOrFold<vector::ExtractStridedSliceOp>(
|
|
loc, operand, operandOffsets, operandShape, operandStrides);
|
|
};
|
|
|
|
// Extract tiled lhs, rhs, and acc
|
|
AffineMap lhsPermutationMap = op.getIndexingMapsArray()[0];
|
|
SmallVector<int64_t> lhsOffsets =
|
|
applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
|
|
Value tiledLhs = extractOperand(extsiLhs, lhsPermutationMap, lhsOffsets);
|
|
AffineMap rhsPermutationMap = op.getIndexingMapsArray()[1];
|
|
SmallVector<int64_t> rhsOffsets =
|
|
applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
|
|
Value tiledRhs = extractOperand(extsiRhs, rhsPermutationMap, rhsOffsets);
|
|
AffineMap accPermutationMap = op.getIndexingMapsArray()[2];
|
|
SmallVector<int64_t> accOffsets =
|
|
applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
|
|
Value tiledAcc =
|
|
extractOperand(op.getAcc(), accPermutationMap, accOffsets);
|
|
|
|
auto inputElementType =
|
|
tiledLhs.getType().cast<ShapedType>().getElementType();
|
|
auto accElementType =
|
|
tiledAcc.getType().cast<ShapedType>().getElementType();
|
|
auto inputExpandedType = VectorType::get({2, 8}, inputElementType);
|
|
auto outputExpandedType = VectorType::get({2, 2}, accElementType);
|
|
|
|
// With vecmat, tiled LHS and ACC will contain only one of 2 necessary
|
|
// rows along dimM. Expand their shapes to match the smmla op.
|
|
if (isVecmat) {
|
|
auto expandForSMMLA = [&](Value tiledOperand,
|
|
VectorType expandedTypeType) {
|
|
auto emptyOperand = rewriter.create<arith::ConstantOp>(
|
|
loc, expandedTypeType, rewriter.getZeroAttr(expandedTypeType));
|
|
SmallVector<int64_t> offsets(
|
|
emptyOperand.getType().cast<ShapedType>().getRank(), 0);
|
|
SmallVector<int64_t> strides(
|
|
tiledOperand.getType().cast<ShapedType>().getRank(), 1);
|
|
return rewriter.createOrFold<vector::InsertStridedSliceOp>(
|
|
loc, tiledOperand, emptyOperand, offsets, strides);
|
|
};
|
|
tiledLhs = expandForSMMLA(tiledLhs, inputExpandedType);
|
|
tiledAcc = expandForSMMLA(tiledAcc, outputExpandedType);
|
|
}
|
|
|
|
// Collapse tiled operands to 1D vectors required by smmla intrinsic
|
|
auto collapsedInputType =
|
|
VectorType::get(inputExpandedType.getNumElements(), inputElementType);
|
|
auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>(
|
|
tiledLhs.getLoc(), collapsedInputType, tiledLhs);
|
|
auto collapsedRhs = rewriter.createOrFold<vector::ShapeCastOp>(
|
|
tiledRhs.getLoc(), collapsedInputType, tiledRhs);
|
|
auto collapsedOutputType =
|
|
VectorType::get(outputExpandedType.getNumElements(), accElementType);
|
|
auto collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>(
|
|
tiledAcc.getLoc(), collapsedOutputType, tiledAcc);
|
|
|
|
// Insert contract op
|
|
auto smmlaOp = rewriter.createOrFold<arm_neon::SmmlaOp>(
|
|
op.getLoc(), collapsedRes.getType(), collapsedRes, collapsedLhs,
|
|
collapsedRhs);
|
|
|
|
// Reshape output back to 2D
|
|
Value tiledRes = rewriter.createOrFold<vector::ShapeCastOp>(
|
|
smmlaOp.getLoc(), tiledAcc.getType(), smmlaOp);
|
|
|
|
// With vecmat, only one row of tiled ACC can be inserted inot file result
|
|
if (isVecmat) {
|
|
tiledRes = rewriter.createOrFold<vector::ExtractOp>(loc, tiledRes, 0);
|
|
}
|
|
|
|
// Insert the tiled result back into the non tiled result of the
|
|
// contract op.
|
|
SmallVector<int64_t> strides(
|
|
tiledRes.getType().cast<ShapedType>().getRank(), 1);
|
|
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
|
|
loc, tiledRes, result, accOffsets, strides);
|
|
}
|
|
|
|
rewriter.replaceOp(op, result);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void mlir::arm_neon::populateLowerContractionToSMMLAPatternPatterns(
|
|
RewritePatternSet &patterns) {
|
|
MLIRContext *context = patterns.getContext();
|
|
patterns.add<LowerContractionToSMMLAPattern>(context, /*benefit=*/1);
|
|
}
|