//===- LegalizeForLLVMExport.cpp - Prepare ArmSVE for LLVM translation ----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/ArmSVE/ArmSVEDialect.h" #include "mlir/Dialect/ArmSVE/Transforms.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/PatternMatch.h" using namespace mlir; using namespace mlir::arm_sve; // Extract an LLVM IR type from the LLVM IR dialect type. static Type unwrap(Type type) { if (!type) return nullptr; auto *mlirContext = type.getContext(); if (!LLVM::isCompatibleType(type)) emitError(UnknownLoc::get(mlirContext), "conversion resulted in a non-LLVM type"); return type; } static Optional convertScalableVectorTypeToLLVM(ScalableVectorType svType, LLVMTypeConverter &converter) { auto elementType = unwrap(converter.convertType(svType.getElementType())); if (!elementType) return {}; auto sVectorType = LLVM::LLVMScalableVectorType::get(elementType, svType.getShape().back()); return sVectorType; } template class ForwardOperands : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(OpTy op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { if (ValueRange(operands).getTypes() == op->getOperands().getTypes()) return rewriter.notifyMatchFailure(op, "operand types already match"); rewriter.updateRootInPlace(op, [&]() { op->setOperands(operands); }); return success(); } }; class ReturnOpTypeConversion : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ReturnOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { rewriter.updateRootInPlace(op, [&]() { op->setOperands(operands); }); return success(); } }; static Optional addUnrealizedCast(OpBuilder &builder, ScalableVectorType svType, ValueRange inputs, Location loc) { if (inputs.size() != 1 || !inputs[0].getType().isa()) return Value(); return builder.create(loc, svType, inputs) .getResult(0); } using SdotOpLowering = OneToOneConvertToLLVMPattern; using SmmlaOpLowering = OneToOneConvertToLLVMPattern; using UdotOpLowering = OneToOneConvertToLLVMPattern; using UmmlaOpLowering = OneToOneConvertToLLVMPattern; using VectorScaleOpLowering = OneToOneConvertToLLVMPattern; using ScalableMaskedAddIOpLowering = OneToOneConvertToLLVMPattern; using ScalableMaskedAddFOpLowering = OneToOneConvertToLLVMPattern; using ScalableMaskedSubIOpLowering = OneToOneConvertToLLVMPattern; using ScalableMaskedSubFOpLowering = OneToOneConvertToLLVMPattern; using ScalableMaskedMulIOpLowering = OneToOneConvertToLLVMPattern; using ScalableMaskedMulFOpLowering = OneToOneConvertToLLVMPattern; using ScalableMaskedSDivIOpLowering = OneToOneConvertToLLVMPattern; using ScalableMaskedUDivIOpLowering = OneToOneConvertToLLVMPattern; using ScalableMaskedDivFOpLowering = OneToOneConvertToLLVMPattern; // Load operation is lowered to code that obtains a pointer to the indexed // element and loads from it. struct ScalableLoadOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(ScalableLoadOp loadOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto type = loadOp.getMemRefType(); if (!isConvertibleAndHasIdentityMaps(type)) return failure(); ScalableLoadOp::Adaptor transformed(operands); LLVMTypeConverter converter(loadOp.getContext()); auto resultType = loadOp.result().getType(); LLVM::LLVMPointerType llvmDataTypePtr; if (resultType.isa()) { llvmDataTypePtr = LLVM::LLVMPointerType::get(resultType.cast()); } else if (resultType.isa()) { llvmDataTypePtr = LLVM::LLVMPointerType::get( convertScalableVectorTypeToLLVM(resultType.cast(), converter) .getValue()); } Value dataPtr = getStridedElementPtr(loadOp.getLoc(), type, transformed.base(), transformed.index(), rewriter); Value bitCastedPtr = rewriter.create( loadOp.getLoc(), llvmDataTypePtr, dataPtr); rewriter.replaceOpWithNewOp(loadOp, bitCastedPtr); return success(); } }; // Store operation is lowered to code that obtains a pointer to the indexed // element, and stores the given value to it. struct ScalableStoreOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(ScalableStoreOp storeOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto type = storeOp.getMemRefType(); if (!isConvertibleAndHasIdentityMaps(type)) return failure(); ScalableStoreOp::Adaptor transformed(operands); LLVMTypeConverter converter(storeOp.getContext()); auto resultType = storeOp.value().getType(); LLVM::LLVMPointerType llvmDataTypePtr; if (resultType.isa()) { llvmDataTypePtr = LLVM::LLVMPointerType::get(resultType.cast()); } else if (resultType.isa()) { llvmDataTypePtr = LLVM::LLVMPointerType::get( convertScalableVectorTypeToLLVM(resultType.cast(), converter) .getValue()); } Value dataPtr = getStridedElementPtr(storeOp.getLoc(), type, transformed.base(), transformed.index(), rewriter); Value bitCastedPtr = rewriter.create( storeOp.getLoc(), llvmDataTypePtr, dataPtr); rewriter.replaceOpWithNewOp(storeOp, transformed.value(), bitCastedPtr); return success(); } }; static void populateBasicSVEArithmeticExportPatterns(LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { // clang-format off patterns.add, OneToOneConvertToLLVMPattern, OneToOneConvertToLLVMPattern, OneToOneConvertToLLVMPattern, OneToOneConvertToLLVMPattern, OneToOneConvertToLLVMPattern, OneToOneConvertToLLVMPattern, OneToOneConvertToLLVMPattern, OneToOneConvertToLLVMPattern >(converter); // clang-format on } static void configureBasicSVEArithmeticLegalizations(LLVMConversionTarget &target) { // clang-format off target.addIllegalOp(); // clang-format on } static void populateSVEMaskGenerationExportPatterns(LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { // clang-format off patterns.add, OneToOneConvertToLLVMPattern >(converter); // clang-format on } static void configureSVEMaskGenerationLegalizations(LLVMConversionTarget &target) { // clang-format off target.addIllegalOp(); // clang-format on } /// Populate the given list with patterns that convert from ArmSVE to LLVM. void mlir::populateArmSVELegalizeForLLVMExportPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { // Populate conversion patterns // Remove any ArmSVE-specific types from function signatures and results. populateFuncOpTypeConversionPattern(patterns, converter); converter.addConversion([&converter](ScalableVectorType svType) { return convertScalableVectorTypeToLLVM(svType, converter); }); converter.addSourceMaterialization(addUnrealizedCast); // clang-format off patterns.add, ForwardOperands, ForwardOperands>(converter, &converter.getContext()); patterns.add(converter); patterns.add(converter); // clang-format on populateBasicSVEArithmeticExportPatterns(converter, patterns); populateSVEMaskGenerationExportPatterns(converter, patterns); } void mlir::configureArmSVELegalizeForExportTarget( LLVMConversionTarget &target) { // clang-format off target.addLegalOp(); target.addIllegalOp(); // clang-format on auto hasScalableVectorType = [](TypeRange types) { for (Type type : types) if (type.isa()) return true; return false; }; target.addDynamicallyLegalOp([hasScalableVectorType](FuncOp op) { return !hasScalableVectorType(op.getType().getInputs()) && !hasScalableVectorType(op.getType().getResults()); }); target.addDynamicallyLegalOp( [hasScalableVectorType](Operation *op) { return !hasScalableVectorType(op->getOperandTypes()) && !hasScalableVectorType(op->getResultTypes()); }); configureBasicSVEArithmeticLegalizations(target); configureSVEMaskGenerationLegalizations(target); }