Files
clang-p2996/mlir/lib/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.cpp
Alex Zinenko ba87f99168 [mlir] make vector to llvm conversion truly partial
Historically, the Vector to LLVM dialect conversion subsumed the Standard to
LLVM dialect conversion patterns. This was necessary because the conversion
infrastructure did not have sufficient support for reconciling type
conversions. This support is now available. Only keep the patterns related to
the Vector dialect in the Vector to LLVM conversion and require type casts
operations to be inserted if necessary. These casts will be removed by
following conversions if possible. Update integration tests to also run the
Standard to LLVM conversion.

There is a significant amount of test churn, which is due to (a) unnecessarily
strict tests in VectorToLLVM and (b) many patterns actually targeting Standard
dialect ops instead of LLVM dialect ops leading to tests actually exercising a
Vector->Standard->LLVM conversion. This churn is a good illustration of the
reason to make the conversion partial: now the tests only check the code in the
Vector to LLVM conversion and will not be randomly broken by changes in
Standard to LLVM conversion.

Arguably, it may be possible to extract Vector to Standard patterns into a
separate pass, but given the ongoing splitting of the Standard dialect, such
pass will be short-lived and will require further refactoring.

Depends On D95626

Reviewed By: nicolasvasilache, aartbik

Differential Revision: https://reviews.llvm.org/D95685
2021-02-04 11:33:24 +01:00

118 lines
4.1 KiB
C++

//===- ArmSVEToLLVM.cpp - Convert ArmSVE to the LLVM dialect --------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
using namespace mlir;
using namespace mlir::arm_sve;
using namespace mlir::vector;
using SdotOpLowering =
OneToOneConvertToLLVMPattern<SdotOp, LLVM::aarch64_arm_sve_sdot>;
using SmmlaOpLowering =
OneToOneConvertToLLVMPattern<SmmlaOp, LLVM::aarch64_arm_sve_smmla>;
using UdotOpLowering =
OneToOneConvertToLLVMPattern<UdotOp, LLVM::aarch64_arm_sve_udot>;
using UmmlaOpLowering =
OneToOneConvertToLLVMPattern<UmmlaOp, LLVM::aarch64_arm_sve_ummla>;
using VectorScaleOpLowering =
OneToOneConvertToLLVMPattern<VectorScaleOp, LLVM::vector_scale>;
// Extract an LLVM IR type from the LLVM IR dialect type.
static Type unwrap(Type type) {
if (!type)
return nullptr;
auto *mlirContext = type.getContext();
if (!LLVM::isCompatibleType(type))
emitError(UnknownLoc::get(mlirContext),
"conversion resulted in a non-LLVM type");
return type;
}
static Optional<Type>
convertScalableVectorTypeToLLVM(ScalableVectorType svType,
LLVMTypeConverter &converter) {
auto elementType = unwrap(converter.convertType(svType.getElementType()));
if (!elementType)
return {};
auto sVectorType =
LLVM::LLVMScalableVectorType::get(elementType, svType.getShape().back());
return sVectorType;
}
template <typename OpTy>
class ForwardOperands : public OpConversionPattern<OpTy> {
using OpConversionPattern<OpTy>::OpConversionPattern;
LogicalResult
matchAndRewrite(OpTy op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
if (ValueRange(operands).getTypes() == op->getOperands().getTypes())
return rewriter.notifyMatchFailure(op, "operand types already match");
rewriter.updateRootInPlace(op, [&]() { op->setOperands(operands); });
return success();
}
};
class ReturnOpTypeConversion : public OpConversionPattern<ReturnOp> {
public:
using OpConversionPattern<ReturnOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(ReturnOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
rewriter.updateRootInPlace(op, [&]() { op->setOperands(operands); });
return success();
}
};
static Optional<Value> addUnrealizedCast(OpBuilder &builder,
ScalableVectorType svType,
ValueRange inputs, Location loc) {
if (inputs.size() != 1 ||
!inputs[0].getType().isa<LLVM::LLVMScalableVectorType>())
return Value();
return builder.create<UnrealizedConversionCastOp>(loc, svType, inputs)
.getResult(0);
}
/// Populate the given list with patterns that convert from ArmSVE to LLVM.
void mlir::populateArmSVEToLLVMConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
converter.addConversion([&converter](ScalableVectorType svType) {
return convertScalableVectorTypeToLLVM(svType, converter);
});
converter.addSourceMaterialization(addUnrealizedCast);
// clang-format off
patterns.insert<ForwardOperands<CallOp>,
ForwardOperands<CallIndirectOp>,
ForwardOperands<ReturnOp>>(converter,
&converter.getContext());
patterns.insert<SdotOpLowering,
SmmlaOpLowering,
UdotOpLowering,
UmmlaOpLowering,
VectorScaleOpLowering>(converter);
// clang-format on
}