Historically, the Vector to LLVM dialect conversion subsumed the Standard to LLVM dialect conversion patterns. This was necessary because the conversion infrastructure did not have sufficient support for reconciling type conversions. This support is now available. Only keep the patterns related to the Vector dialect in the Vector to LLVM conversion and require type casts operations to be inserted if necessary. These casts will be removed by following conversions if possible. Update integration tests to also run the Standard to LLVM conversion. There is a significant amount of test churn, which is due to (a) unnecessarily strict tests in VectorToLLVM and (b) many patterns actually targeting Standard dialect ops instead of LLVM dialect ops leading to tests actually exercising a Vector->Standard->LLVM conversion. This churn is a good illustration of the reason to make the conversion partial: now the tests only check the code in the Vector to LLVM conversion and will not be randomly broken by changes in Standard to LLVM conversion. Arguably, it may be possible to extract Vector to Standard patterns into a separate pass, but given the ongoing splitting of the Standard dialect, such pass will be short-lived and will require further refactoring. Depends On D95626 Reviewed By: nicolasvasilache, aartbik Differential Revision: https://reviews.llvm.org/D95685
118 lines
4.1 KiB
C++
118 lines
4.1 KiB
C++
//===- ArmSVEToLLVM.cpp - Convert ArmSVE to the LLVM dialect --------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.h"
|
|
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
|
#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
|
|
#include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h"
|
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
|
#include "mlir/Dialect/Vector/VectorOps.h"
|
|
#include "mlir/IR/BuiltinOps.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::arm_sve;
|
|
using namespace mlir::vector;
|
|
|
|
using SdotOpLowering =
|
|
OneToOneConvertToLLVMPattern<SdotOp, LLVM::aarch64_arm_sve_sdot>;
|
|
|
|
using SmmlaOpLowering =
|
|
OneToOneConvertToLLVMPattern<SmmlaOp, LLVM::aarch64_arm_sve_smmla>;
|
|
|
|
using UdotOpLowering =
|
|
OneToOneConvertToLLVMPattern<UdotOp, LLVM::aarch64_arm_sve_udot>;
|
|
|
|
using UmmlaOpLowering =
|
|
OneToOneConvertToLLVMPattern<UmmlaOp, LLVM::aarch64_arm_sve_ummla>;
|
|
|
|
using VectorScaleOpLowering =
|
|
OneToOneConvertToLLVMPattern<VectorScaleOp, LLVM::vector_scale>;
|
|
|
|
// Extract an LLVM IR type from the LLVM IR dialect type.
|
|
static Type unwrap(Type type) {
|
|
if (!type)
|
|
return nullptr;
|
|
auto *mlirContext = type.getContext();
|
|
if (!LLVM::isCompatibleType(type))
|
|
emitError(UnknownLoc::get(mlirContext),
|
|
"conversion resulted in a non-LLVM type");
|
|
return type;
|
|
}
|
|
|
|
static Optional<Type>
|
|
convertScalableVectorTypeToLLVM(ScalableVectorType svType,
|
|
LLVMTypeConverter &converter) {
|
|
auto elementType = unwrap(converter.convertType(svType.getElementType()));
|
|
if (!elementType)
|
|
return {};
|
|
|
|
auto sVectorType =
|
|
LLVM::LLVMScalableVectorType::get(elementType, svType.getShape().back());
|
|
return sVectorType;
|
|
}
|
|
|
|
template <typename OpTy>
|
|
class ForwardOperands : public OpConversionPattern<OpTy> {
|
|
using OpConversionPattern<OpTy>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(OpTy op, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const final {
|
|
if (ValueRange(operands).getTypes() == op->getOperands().getTypes())
|
|
return rewriter.notifyMatchFailure(op, "operand types already match");
|
|
|
|
rewriter.updateRootInPlace(op, [&]() { op->setOperands(operands); });
|
|
return success();
|
|
}
|
|
};
|
|
|
|
class ReturnOpTypeConversion : public OpConversionPattern<ReturnOp> {
|
|
public:
|
|
using OpConversionPattern<ReturnOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(ReturnOp op, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const final {
|
|
rewriter.updateRootInPlace(op, [&]() { op->setOperands(operands); });
|
|
return success();
|
|
}
|
|
};
|
|
|
|
static Optional<Value> addUnrealizedCast(OpBuilder &builder,
|
|
ScalableVectorType svType,
|
|
ValueRange inputs, Location loc) {
|
|
if (inputs.size() != 1 ||
|
|
!inputs[0].getType().isa<LLVM::LLVMScalableVectorType>())
|
|
return Value();
|
|
return builder.create<UnrealizedConversionCastOp>(loc, svType, inputs)
|
|
.getResult(0);
|
|
}
|
|
|
|
/// Populate the given list with patterns that convert from ArmSVE to LLVM.
|
|
void mlir::populateArmSVEToLLVMConversionPatterns(
|
|
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
|
|
converter.addConversion([&converter](ScalableVectorType svType) {
|
|
return convertScalableVectorTypeToLLVM(svType, converter);
|
|
});
|
|
converter.addSourceMaterialization(addUnrealizedCast);
|
|
|
|
// clang-format off
|
|
patterns.insert<ForwardOperands<CallOp>,
|
|
ForwardOperands<CallIndirectOp>,
|
|
ForwardOperands<ReturnOp>>(converter,
|
|
&converter.getContext());
|
|
patterns.insert<SdotOpLowering,
|
|
SmmlaOpLowering,
|
|
UdotOpLowering,
|
|
UmmlaOpLowering,
|
|
VectorScaleOpLowering>(converter);
|
|
// clang-format on
|
|
}
|