//===- LegalizeForLLVMExport.cpp - Prepare ArmSVE for LLVM translation ----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" #include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/PatternMatch.h" using namespace mlir; using namespace mlir::arm_sve; template class ForwardOperands : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter) const final { if (adaptor.getOperands().getTypes() == op->getOperands().getTypes()) return rewriter.notifyMatchFailure(op, "operand types already match"); rewriter.modifyOpInPlace(op, [&]() { op->setOperands(adaptor.getOperands()); }); return success(); } }; using SdotOpLowering = OneToOneConvertToLLVMPattern; using SmmlaOpLowering = OneToOneConvertToLLVMPattern; using UdotOpLowering = OneToOneConvertToLLVMPattern; using UmmlaOpLowering = OneToOneConvertToLLVMPattern; using ScalableMaskedAddIOpLowering = OneToOneConvertToLLVMPattern; using ScalableMaskedAddFOpLowering = OneToOneConvertToLLVMPattern; using ScalableMaskedSubIOpLowering = OneToOneConvertToLLVMPattern; using ScalableMaskedSubFOpLowering = OneToOneConvertToLLVMPattern; using ScalableMaskedMulIOpLowering = OneToOneConvertToLLVMPattern; using ScalableMaskedMulFOpLowering = OneToOneConvertToLLVMPattern; using ScalableMaskedSDivIOpLowering = OneToOneConvertToLLVMPattern; using ScalableMaskedUDivIOpLowering = OneToOneConvertToLLVMPattern; using ScalableMaskedDivFOpLowering = OneToOneConvertToLLVMPattern; namespace { /// Unrolls a conversion to/from equivalent vector types, to allow using a /// conversion intrinsic that only supports 1-D vector types. /// /// Example: /// ``` /// %result = arm_sve.convert_to_svbool %source : vector<2x[4]xi1> /// ``` /// is rewritten into: /// ``` /// %cst = arith.constant dense : vector<2x[16]xi1> /// %1 = vector.extract %source[0] : vector<[4]xi1> from vector<2x[4]xi1> /// %2 = "arm_sve.intr.convert.to.svbool"(%1) /// : (vector<[4]xi1>) -> vector<[16]xi1> /// %3 = vector.insert %2, %cst[0] : vector<[16]xi1> into vector<2x[16]xi1> /// %4 = vector.extract %source[1] : vector<[4]xi1> from vector<2x[4]xi1> /// %5 = "arm_sve.intr.convert.to.svbool"(%4) /// : (vector<[4]xi1>) -> vector<[16]xi1> /// %result = vector.insert %5, %3[1] : vector<[16]xi1> into vector<2x[16]xi1> /// ``` template struct SvboolConversionOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(Op convertOp, typename Op::Adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = convertOp.getLoc(); auto source = convertOp.getSource(); VectorType sourceType = source.getType(); VectorType resultType = convertOp.getResult().getType(); Value result = rewriter.create( loc, resultType, rewriter.getZeroAttr(resultType)); // We want to iterate over the input vector in steps of the trailing // dimension. So this creates tile shape where all leading dimensions are 1, // and the trailing dimension step is the size of the dimension. SmallVector tileShape(sourceType.getRank(), 1); tileShape.back() = sourceType.getShape().back(); // Iterate over all scalable mask/predicate slices of the source vector. for (SmallVector index : StaticTileOffsetRange(sourceType.getShape(), tileShape)) { auto extractOrInsertPosition = ArrayRef(index).drop_back(); auto sourceVector = rewriter.create( loc, source, extractOrInsertPosition); VectorType convertedType = VectorType::Builder(llvm::cast(sourceVector.getType())) .setDim(0, resultType.getShape().back()); auto convertedVector = rewriter.create(loc, TypeRange{convertedType}, sourceVector); result = rewriter.create(loc, convertedVector, result, extractOrInsertPosition); } rewriter.replaceOp(convertOp, result); return success(); } }; using ConvertToSvboolOpLowering = SvboolConversionOpLowering; using ConvertFromSvboolOpLowering = SvboolConversionOpLowering; using ZipX2OpLowering = OneToOneConvertToLLVMPattern; using ZipX4OpLowering = OneToOneConvertToLLVMPattern; /// Lower `arm_sve.psel` to LLVM intrinsics. This is almost a 1-to-1 conversion /// but first input (P1) and result predicates need conversion to/from svbool. struct PselOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(PselOp pselOp, PselOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto svboolType = VectorType::get(16, rewriter.getI1Type(), true); auto loc = pselOp.getLoc(); auto svboolP1 = rewriter.create(loc, svboolType, adaptor.getP1()); auto indexI32 = rewriter.create( loc, rewriter.getI32Type(), pselOp.getIndex()); auto pselIntr = rewriter.create(loc, svboolType, svboolP1, pselOp.getP2(), indexI32); rewriter.replaceOpWithNewOp( pselOp, adaptor.getP1().getType(), pselIntr); return success(); } }; /// Converts `vector.create_mask` ops that match the size of an SVE predicate /// to the `whilelt` intrinsic. This produces more canonical codegen than the /// generic LLVM lowering, see https://github.com/llvm/llvm-project/issues/81840 /// for more details. Note that we can't use (the more general) active.lane.mask /// as its semantics don't neatly map on to `vector.create_mask`, as it does an /// unsigned comparison (whereas `create_mask` is signed), and is UB/posion if /// `n` is zero (whereas `create_mask` just returns an all-false mask). struct CreateMaskOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(vector::CreateMaskOp createMaskOp, vector::CreateMaskOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto maskType = createMaskOp.getVectorType(); if (maskType.getRank() != 1 || !maskType.isScalable()) return rewriter.notifyMatchFailure(createMaskOp, "not 1-D and scalable"); // TODO: Support masks which are multiples of SVE predicates. auto maskBaseSize = maskType.getDimSize(0); if (maskBaseSize < 2 || maskBaseSize > 16 || !llvm::isPowerOf2_32(uint32_t(maskBaseSize))) return rewriter.notifyMatchFailure(createMaskOp, "not SVE predicate-sized"); auto loc = createMaskOp.getLoc(); auto zero = rewriter.create(loc, rewriter.getI64Type()); rewriter.replaceOpWithNewOp(createMaskOp, maskType, zero, adaptor.getOperands()[0]); return success(); } }; } // namespace /// Populate the given list with patterns that convert from ArmSVE to LLVM. void mlir::populateArmSVELegalizeForLLVMExportPatterns( const LLVMTypeConverter &converter, RewritePatternSet &patterns) { // Populate conversion patterns // clang-format off patterns.add, ForwardOperands, ForwardOperands>(converter, &converter.getContext()); patterns.add(converter); // Add vector.create_mask conversion with a high benefit as it produces much // nicer code than the generic lowering. patterns.add(converter, /*benefit=*/4096); // clang-format on } void mlir::configureArmSVELegalizeForExportTarget( LLVMConversionTarget &target) { // clang-format off target.addLegalOp(); target.addIllegalOp(); // clang-format on }