//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements lowering patterns from vector.contract to // arm_neon.intr.smmla // //===--- #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" #include "mlir/Dialect/ArmNeon/Transforms.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #define DEBUG_TYPE "lower-contract-to-arm-neon" using namespace mlir; using namespace mlir::arm_neon; namespace { /// Return the shaped type with new element type. static Type matchContainerType(Type element, Type container) { if (auto shapedTy = dyn_cast(container)) { return shapedTy.clone(element); } return element; } /// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile /// any vector.contract into multiple smmla instructions with unrolling so long /// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM /// = 1 (either explicitly or inferred if LHS has only dimK) If no unrolling is /// necessary, a single smmla instruction is emitted. class LowerContractionToSMMLAPattern : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::ContractionOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); // Infer tile sizes from operands. For vecmat, LHS may only have 1 dim. // Note: RHS is not transposed. mlir::VectorType lhsType = op.getLhsType(); mlir::VectorType rhsType = op.getRhsType(); // Avoid 0-D vectors and 1-D rhs: if (!lhsType.hasRank() || !rhsType.hasRank() || rhsType.getRank() < 2) return failure(); auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0); auto dimN = rhsType.getDimSize(0); auto dimK = rhsType.getDimSize(1); bool isVecmat = dimM == 1 ? true : false; if (lhsType.getDimSize(lhsType.getRank() - 1) != rhsType.getDimSize(rhsType.getRank() - 1)) { return failure(); // dimK mismatch } // Unrolling patterns can handle any [2, 2, 8] shaped multiple of inputs for // tiling. if ((dimM % 2 != 0 && !isVecmat) || dimN % 2 != 0 || dimK % 8 != 0) { return failure(); } // Check iterator types for contract. All iterators except inner-most // dimension must be parallel. auto iteratorTypes = op.getIteratorTypesArray(); if (iteratorTypes.size() > 3 || iteratorTypes[iteratorTypes.size() - 1] != vector::IteratorType::reduction) { return failure(); } if (llvm::any_of(ArrayRef(iteratorTypes).drop_back(1), [](vector::IteratorType iteratorType) { return iteratorType != vector::IteratorType::parallel; })) { return failure(); } // Check two extsi inputs Rhs Lhs for contract. arith::ExtSIOp origLhsExtOp = dyn_cast_or_null(op.getLhs().getDefiningOp()); arith::ExtSIOp origRhsExtOp = dyn_cast_or_null(op.getRhs().getDefiningOp()); if (!origLhsExtOp || !origRhsExtOp) { return failure(); } // Match any iX to i32 for X<8 then turn into an i8 output. Feed into // following neon instruction. Check inputs for extsi are <=i8 Value extsiLhs; Value extsiRhs; if (auto lhsExtInType = origLhsExtOp.getIn().getType().dyn_cast()) { if (lhsExtInType.getElementTypeBitWidth() <= 8) { Type targetLhsExtTy = matchContainerType(rewriter.getI8Type(), lhsExtInType); extsiLhs = rewriter.createOrFold(loc, targetLhsExtTy, origLhsExtOp.getIn()); } } if (auto rhsExtInType = origRhsExtOp.getIn().getType().dyn_cast()) { if (rhsExtInType.getElementTypeBitWidth() <= 8) { Type targetRhsExtTy = matchContainerType(rewriter.getI8Type(), rhsExtInType); extsiRhs = rewriter.createOrFold(loc, targetRhsExtTy, origRhsExtOp.getIn()); } } if (!extsiLhs || !extsiRhs) { return failure(); } // Initial accumulator for the final result. This is the un-tiled result if // tiling is done. Value result = rewriter.create( loc, op.getResultType(), rewriter.getZeroAttr(op.getResultType())); SmallVector unrolledSize = *op.getShapeForUnroll(); SmallVector smmlaShape{2, 8}; SmallVector loopOrder{0, 1}; if (unrolledSize.size() == 3) { smmlaShape.insert(smmlaShape.begin(), isVecmat ? 1 : 2); loopOrder.push_back(2); } for (SmallVector offsets : StaticTileOffsetRange(unrolledSize, smmlaShape, loopOrder)) { // Helper to compute the new shape of each operand and extract the slice. auto extractOperand = [&](Value operand, AffineMap permutationMap, ArrayRef operandOffsets) { SmallVector operandShape = applyPermutationMap(permutationMap, ArrayRef(smmlaShape)); SmallVector operandStrides(operandOffsets.size(), 1); return rewriter.createOrFold( loc, operand, operandOffsets, operandShape, operandStrides); }; // Extract tiled lhs, rhs, and acc AffineMap lhsPermutationMap = op.getIndexingMapsArray()[0]; SmallVector lhsOffsets = applyPermutationMap(lhsPermutationMap, ArrayRef(offsets)); Value tiledLhs = extractOperand(extsiLhs, lhsPermutationMap, lhsOffsets); AffineMap rhsPermutationMap = op.getIndexingMapsArray()[1]; SmallVector rhsOffsets = applyPermutationMap(rhsPermutationMap, ArrayRef(offsets)); Value tiledRhs = extractOperand(extsiRhs, rhsPermutationMap, rhsOffsets); AffineMap accPermutationMap = op.getIndexingMapsArray()[2]; SmallVector accOffsets = applyPermutationMap(accPermutationMap, ArrayRef(offsets)); Value tiledAcc = extractOperand(op.getAcc(), accPermutationMap, accOffsets); auto inputElementType = tiledLhs.getType().cast().getElementType(); auto accElementType = tiledAcc.getType().cast().getElementType(); auto inputExpandedType = VectorType::get({2, 8}, inputElementType); auto outputExpandedType = VectorType::get({2, 2}, accElementType); // With vecmat, tiled LHS and ACC will contain only one of 2 necessary // rows along dimM. Expand their shapes to match the smmla op. if (isVecmat) { auto expandForSMMLA = [&](Value tiledOperand, VectorType expandedTypeType) { auto emptyOperand = rewriter.create( loc, expandedTypeType, rewriter.getZeroAttr(expandedTypeType)); SmallVector offsets( emptyOperand.getType().cast().getRank(), 0); SmallVector strides( tiledOperand.getType().cast().getRank(), 1); return rewriter.createOrFold( loc, tiledOperand, emptyOperand, offsets, strides); }; tiledLhs = expandForSMMLA(tiledLhs, inputExpandedType); tiledAcc = expandForSMMLA(tiledAcc, outputExpandedType); } // Collapse tiled operands to 1D vectors required by smmla intrinsic auto collapsedInputType = VectorType::get(inputExpandedType.getNumElements(), inputElementType); auto collapsedLhs = rewriter.createOrFold( tiledLhs.getLoc(), collapsedInputType, tiledLhs); auto collapsedRhs = rewriter.createOrFold( tiledRhs.getLoc(), collapsedInputType, tiledRhs); auto collapsedOutputType = VectorType::get(outputExpandedType.getNumElements(), accElementType); auto collapsedRes = rewriter.createOrFold( tiledAcc.getLoc(), collapsedOutputType, tiledAcc); // Insert contract op auto smmlaOp = rewriter.createOrFold( op.getLoc(), collapsedRes.getType(), collapsedRes, collapsedLhs, collapsedRhs); // Reshape output back to 2D Value tiledRes = rewriter.createOrFold( smmlaOp.getLoc(), tiledAcc.getType(), smmlaOp); // With vecmat, only one row of tiled ACC can be inserted inot file result if (isVecmat) { tiledRes = rewriter.createOrFold(loc, tiledRes, 0); } // Insert the tiled result back into the non tiled result of the // contract op. SmallVector strides( tiledRes.getType().cast().getRank(), 1); result = rewriter.createOrFold( loc, tiledRes, result, accOffsets, strides); } rewriter.replaceOp(op, result); return success(); } }; } // namespace void mlir::arm_neon::populateLowerContractionToSMMLAPatternPatterns( RewritePatternSet &patterns) { MLIRContext *context = patterns.getContext(); patterns.add(context, /*benefit=*/1); }