Files
clang-p2996/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
River Riddle b54c724be0 [mlir:OpConversionPattern] Add overloads for taking an Adaptor instead of ArrayRef
This has been a TODO for a long time, and it brings about many advantages (namely nice accessors, and less fragile code). The existing overloads that accept ArrayRef are now treated as deprecated and will be removed in a followup (after a small grace period). Most of the upstream MLIR usages have been fixed by this commit, the rest will be handled in a followup.

Differential Revision: https://reviews.llvm.org/D110293
2021-09-24 17:51:41 +00:00

326 lines
13 KiB
C++

//===- LegalizeForLLVMExport.cpp - Prepare ArmSVE for LLVM translation ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
#include "mlir/Dialect/ArmSVE/Transforms.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
using namespace mlir;
using namespace mlir::arm_sve;
// Extract an LLVM IR type from the LLVM IR dialect type.
static Type unwrap(Type type) {
if (!type)
return nullptr;
auto *mlirContext = type.getContext();
if (!LLVM::isCompatibleType(type))
emitError(UnknownLoc::get(mlirContext),
"conversion resulted in a non-LLVM type");
return type;
}
static Optional<Type>
convertScalableVectorTypeToLLVM(ScalableVectorType svType,
LLVMTypeConverter &converter) {
auto elementType = unwrap(converter.convertType(svType.getElementType()));
if (!elementType)
return {};
auto sVectorType =
LLVM::LLVMScalableVectorType::get(elementType, svType.getShape().back());
return sVectorType;
}
template <typename OpTy>
class ForwardOperands : public OpConversionPattern<OpTy> {
using OpConversionPattern<OpTy>::OpConversionPattern;
LogicalResult
matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
if (adaptor.getOperands().getTypes() == op->getOperands().getTypes())
return rewriter.notifyMatchFailure(op, "operand types already match");
rewriter.updateRootInPlace(
op, [&]() { op->setOperands(adaptor.getOperands()); });
return success();
}
};
class ReturnOpTypeConversion : public OpConversionPattern<ReturnOp> {
public:
using OpConversionPattern<ReturnOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
rewriter.updateRootInPlace(
op, [&]() { op->setOperands(adaptor.getOperands()); });
return success();
}
};
static Optional<Value> addUnrealizedCast(OpBuilder &builder,
ScalableVectorType svType,
ValueRange inputs, Location loc) {
if (inputs.size() != 1 ||
!inputs[0].getType().isa<LLVM::LLVMScalableVectorType>())
return Value();
return builder.create<UnrealizedConversionCastOp>(loc, svType, inputs)
.getResult(0);
}
using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>;
using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>;
using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>;
using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
using VectorScaleOpLowering =
OneToOneConvertToLLVMPattern<VectorScaleOp, VectorScaleIntrOp>;
using ScalableMaskedAddIOpLowering =
OneToOneConvertToLLVMPattern<ScalableMaskedAddIOp,
ScalableMaskedAddIIntrOp>;
using ScalableMaskedAddFOpLowering =
OneToOneConvertToLLVMPattern<ScalableMaskedAddFOp,
ScalableMaskedAddFIntrOp>;
using ScalableMaskedSubIOpLowering =
OneToOneConvertToLLVMPattern<ScalableMaskedSubIOp,
ScalableMaskedSubIIntrOp>;
using ScalableMaskedSubFOpLowering =
OneToOneConvertToLLVMPattern<ScalableMaskedSubFOp,
ScalableMaskedSubFIntrOp>;
using ScalableMaskedMulIOpLowering =
OneToOneConvertToLLVMPattern<ScalableMaskedMulIOp,
ScalableMaskedMulIIntrOp>;
using ScalableMaskedMulFOpLowering =
OneToOneConvertToLLVMPattern<ScalableMaskedMulFOp,
ScalableMaskedMulFIntrOp>;
using ScalableMaskedSDivIOpLowering =
OneToOneConvertToLLVMPattern<ScalableMaskedSDivIOp,
ScalableMaskedSDivIIntrOp>;
using ScalableMaskedUDivIOpLowering =
OneToOneConvertToLLVMPattern<ScalableMaskedUDivIOp,
ScalableMaskedUDivIIntrOp>;
using ScalableMaskedDivFOpLowering =
OneToOneConvertToLLVMPattern<ScalableMaskedDivFOp,
ScalableMaskedDivFIntrOp>;
// Load operation is lowered to code that obtains a pointer to the indexed
// element and loads from it.
struct ScalableLoadOpLowering : public ConvertOpToLLVMPattern<ScalableLoadOp> {
using ConvertOpToLLVMPattern<ScalableLoadOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(ScalableLoadOp loadOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto type = loadOp.getMemRefType();
if (!isConvertibleAndHasIdentityMaps(type))
return failure();
LLVMTypeConverter converter(loadOp.getContext());
auto resultType = loadOp.result().getType();
LLVM::LLVMPointerType llvmDataTypePtr;
if (resultType.isa<VectorType>()) {
llvmDataTypePtr =
LLVM::LLVMPointerType::get(resultType.cast<VectorType>());
} else if (resultType.isa<ScalableVectorType>()) {
llvmDataTypePtr = LLVM::LLVMPointerType::get(
convertScalableVectorTypeToLLVM(resultType.cast<ScalableVectorType>(),
converter)
.getValue());
}
Value dataPtr = getStridedElementPtr(loadOp.getLoc(), type, adaptor.base(),
adaptor.index(), rewriter);
Value bitCastedPtr = rewriter.create<LLVM::BitcastOp>(
loadOp.getLoc(), llvmDataTypePtr, dataPtr);
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, bitCastedPtr);
return success();
}
};
// Store operation is lowered to code that obtains a pointer to the indexed
// element, and stores the given value to it.
struct ScalableStoreOpLowering
: public ConvertOpToLLVMPattern<ScalableStoreOp> {
using ConvertOpToLLVMPattern<ScalableStoreOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(ScalableStoreOp storeOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto type = storeOp.getMemRefType();
if (!isConvertibleAndHasIdentityMaps(type))
return failure();
LLVMTypeConverter converter(storeOp.getContext());
auto resultType = storeOp.value().getType();
LLVM::LLVMPointerType llvmDataTypePtr;
if (resultType.isa<VectorType>()) {
llvmDataTypePtr =
LLVM::LLVMPointerType::get(resultType.cast<VectorType>());
} else if (resultType.isa<ScalableVectorType>()) {
llvmDataTypePtr = LLVM::LLVMPointerType::get(
convertScalableVectorTypeToLLVM(resultType.cast<ScalableVectorType>(),
converter)
.getValue());
}
Value dataPtr = getStridedElementPtr(storeOp.getLoc(), type, adaptor.base(),
adaptor.index(), rewriter);
Value bitCastedPtr = rewriter.create<LLVM::BitcastOp>(
storeOp.getLoc(), llvmDataTypePtr, dataPtr);
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.value(),
bitCastedPtr);
return success();
}
};
static void
populateBasicSVEArithmeticExportPatterns(LLVMTypeConverter &converter,
OwningRewritePatternList &patterns) {
// clang-format off
patterns.add<OneToOneConvertToLLVMPattern<ScalableAddIOp, LLVM::AddOp>,
OneToOneConvertToLLVMPattern<ScalableAddFOp, LLVM::FAddOp>,
OneToOneConvertToLLVMPattern<ScalableSubIOp, LLVM::SubOp>,
OneToOneConvertToLLVMPattern<ScalableSubFOp, LLVM::FSubOp>,
OneToOneConvertToLLVMPattern<ScalableMulIOp, LLVM::MulOp>,
OneToOneConvertToLLVMPattern<ScalableMulFOp, LLVM::FMulOp>,
OneToOneConvertToLLVMPattern<ScalableSDivIOp, LLVM::SDivOp>,
OneToOneConvertToLLVMPattern<ScalableUDivIOp, LLVM::UDivOp>,
OneToOneConvertToLLVMPattern<ScalableDivFOp, LLVM::FDivOp>
>(converter);
// clang-format on
}
static void
configureBasicSVEArithmeticLegalizations(LLVMConversionTarget &target) {
// clang-format off
target.addIllegalOp<ScalableAddIOp,
ScalableAddFOp,
ScalableSubIOp,
ScalableSubFOp,
ScalableMulIOp,
ScalableMulFOp,
ScalableSDivIOp,
ScalableUDivIOp,
ScalableDivFOp>();
// clang-format on
}
static void
populateSVEMaskGenerationExportPatterns(LLVMTypeConverter &converter,
OwningRewritePatternList &patterns) {
// clang-format off
patterns.add<OneToOneConvertToLLVMPattern<ScalableCmpFOp, LLVM::FCmpOp>,
OneToOneConvertToLLVMPattern<ScalableCmpIOp, LLVM::ICmpOp>
>(converter);
// clang-format on
}
static void
configureSVEMaskGenerationLegalizations(LLVMConversionTarget &target) {
// clang-format off
target.addIllegalOp<ScalableCmpFOp,
ScalableCmpIOp>();
// clang-format on
}
/// Populate the given list with patterns that convert from ArmSVE to LLVM.
void mlir::populateArmSVELegalizeForLLVMExportPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
// Populate conversion patterns
// Remove any ArmSVE-specific types from function signatures and results.
populateFuncOpTypeConversionPattern(patterns, converter);
converter.addConversion([&converter](ScalableVectorType svType) {
return convertScalableVectorTypeToLLVM(svType, converter);
});
converter.addSourceMaterialization(addUnrealizedCast);
// clang-format off
patterns.add<ForwardOperands<CallOp>,
ForwardOperands<CallIndirectOp>,
ForwardOperands<ReturnOp>>(converter,
&converter.getContext());
patterns.add<SdotOpLowering,
SmmlaOpLowering,
UdotOpLowering,
UmmlaOpLowering,
VectorScaleOpLowering,
ScalableMaskedAddIOpLowering,
ScalableMaskedAddFOpLowering,
ScalableMaskedSubIOpLowering,
ScalableMaskedSubFOpLowering,
ScalableMaskedMulIOpLowering,
ScalableMaskedMulFOpLowering,
ScalableMaskedSDivIOpLowering,
ScalableMaskedUDivIOpLowering,
ScalableMaskedDivFOpLowering>(converter);
patterns.add<ScalableLoadOpLowering,
ScalableStoreOpLowering>(converter);
// clang-format on
populateBasicSVEArithmeticExportPatterns(converter, patterns);
populateSVEMaskGenerationExportPatterns(converter, patterns);
}
void mlir::configureArmSVELegalizeForExportTarget(
LLVMConversionTarget &target) {
// clang-format off
target.addLegalOp<SdotIntrOp,
SmmlaIntrOp,
UdotIntrOp,
UmmlaIntrOp,
VectorScaleIntrOp,
ScalableMaskedAddIIntrOp,
ScalableMaskedAddFIntrOp,
ScalableMaskedSubIIntrOp,
ScalableMaskedSubFIntrOp,
ScalableMaskedMulIIntrOp,
ScalableMaskedMulFIntrOp,
ScalableMaskedSDivIIntrOp,
ScalableMaskedUDivIIntrOp,
ScalableMaskedDivFIntrOp>();
target.addIllegalOp<SdotOp,
SmmlaOp,
UdotOp,
UmmlaOp,
VectorScaleOp,
ScalableMaskedAddIOp,
ScalableMaskedAddFOp,
ScalableMaskedSubIOp,
ScalableMaskedSubFOp,
ScalableMaskedMulIOp,
ScalableMaskedMulFOp,
ScalableMaskedSDivIOp,
ScalableMaskedUDivIOp,
ScalableMaskedDivFOp,
ScalableLoadOp,
ScalableStoreOp>();
// clang-format on
auto hasScalableVectorType = [](TypeRange types) {
for (Type type : types)
if (type.isa<arm_sve::ScalableVectorType>())
return true;
return false;
};
target.addDynamicallyLegalOp<FuncOp>([hasScalableVectorType](FuncOp op) {
return !hasScalableVectorType(op.getType().getInputs()) &&
!hasScalableVectorType(op.getType().getResults());
});
target.addDynamicallyLegalOp<CallOp, CallIndirectOp, ReturnOp>(
[hasScalableVectorType](Operation *op) {
return !hasScalableVectorType(op->getOperandTypes()) &&
!hasScalableVectorType(op->getResultTypes());
});
configureBasicSVEArithmeticLegalizations(target);
configureSVEMaskGenerationLegalizations(target);
}