This patch has shift ops conversion implementation. In SPIR-V dialect, `Shift` and `Base` may have different bit width. On the contrary, in LLVM dialect both `Base` and `Shift` have to be of the same bit width. This leads to the following cases: - if `Base` has the same bit width as `Shift`, the conversion is straightforward. - if `Base` has a greater bit width than `Shift`, shift is sign/zero extended first. Then the extended value is passed to the shift. - otherwise the conversion is considered to be illegal. Differential Revision: https://reviews.llvm.org/D81546
208 lines
8.6 KiB
C++
208 lines
8.6 KiB
C++
//===- ConvertSPIRVToLLVM.cpp - SPIR-V dialect to LLVM dialect conversion -===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements patterns to convert SPIR-V dialect to LLVM dialect.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h"
|
|
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
|
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
|
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
|
|
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
|
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
|
#include "mlir/IR/Module.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/Support/LogicalResult.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
|
|
using namespace mlir;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Utility functions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Returns true if the given type is an unsigned integer or vector type
|
|
static bool isUnsignedIntegerOrVector(Type type) {
|
|
if (type.isUnsignedInteger())
|
|
return true;
|
|
if (auto vecType = type.dyn_cast<VectorType>())
|
|
return vecType.getElementType().isUnsignedInteger();
|
|
return false;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Operation conversion
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
|
|
/// Converts SPIR-V operations that have straightforward LLVM equivalent
|
|
/// into LLVM dialect operations.
|
|
template <typename SPIRVOp, typename LLVMOp>
|
|
class DirectConversionPattern : public SPIRVToLLVMConversion<SPIRVOp> {
|
|
public:
|
|
using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto dstType = this->typeConverter.convertType(operation.getType());
|
|
if (!dstType)
|
|
return failure();
|
|
rewriter.template replaceOpWithNewOp<LLVMOp>(operation, dstType, operands);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Converts SPIR-V floating-point comparisons to llvm.fcmp "predicate"
|
|
template <typename SPIRVOp, LLVM::FCmpPredicate predicate>
|
|
class FComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
|
|
public:
|
|
using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
auto dstType = this->typeConverter.convertType(operation.getType());
|
|
if (!dstType)
|
|
return failure();
|
|
|
|
rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
|
|
operation, dstType,
|
|
rewriter.getI64IntegerAttr(static_cast<int64_t>(predicate)),
|
|
operation.operand1(), operation.operand2());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Converts SPIR-V integer comparisons to llvm.icmp "predicate"
|
|
template <typename SPIRVOp, LLVM::ICmpPredicate predicate>
|
|
class IComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
|
|
public:
|
|
using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
auto dstType = this->typeConverter.convertType(operation.getType());
|
|
if (!dstType)
|
|
return failure();
|
|
|
|
rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>(
|
|
operation, dstType,
|
|
rewriter.getI64IntegerAttr(static_cast<int64_t>(predicate)),
|
|
operation.operand1(), operation.operand2());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Converts SPIR-V shift ops to LLVM shift ops. Since LLVM dialect
|
|
/// puts a restriction on `Shift` and `Base` to have the same bit width,
|
|
/// `Shift` is zero or sign extended to match this specification. Cases when
|
|
/// `Shift` bit width > `Base` bit width are considered to be illegal.
|
|
template <typename SPIRVOp, typename LLVMOp>
|
|
class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> {
|
|
public:
|
|
using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
auto dstType = this->typeConverter.convertType(operation.getType());
|
|
if (!dstType)
|
|
return failure();
|
|
|
|
Type op1Type = operation.operand1().getType();
|
|
Type op2Type = operation.operand2().getType();
|
|
|
|
if (op1Type == op2Type) {
|
|
rewriter.template replaceOpWithNewOp<LLVMOp>(operation, dstType,
|
|
operands);
|
|
return success();
|
|
}
|
|
|
|
Location loc = operation.getLoc();
|
|
Value extended;
|
|
if (isUnsignedIntegerOrVector(op2Type)) {
|
|
extended = rewriter.template create<LLVM::ZExtOp>(loc, dstType,
|
|
operation.operand2());
|
|
} else {
|
|
extended = rewriter.template create<LLVM::SExtOp>(loc, dstType,
|
|
operation.operand2());
|
|
}
|
|
Value result = rewriter.template create<LLVMOp>(
|
|
loc, dstType, operation.operand1(), extended);
|
|
rewriter.replaceOp(operation, result);
|
|
return success();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Pattern population
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void mlir::populateSPIRVToLLVMConversionPatterns(
|
|
MLIRContext *context, LLVMTypeConverter &typeConverter,
|
|
OwningRewritePatternList &patterns) {
|
|
patterns.insert<
|
|
// Arithmetic ops
|
|
DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>,
|
|
DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>,
|
|
DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>,
|
|
DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>,
|
|
DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>,
|
|
DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>,
|
|
DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>,
|
|
DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>,
|
|
DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>,
|
|
DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>,
|
|
DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>,
|
|
|
|
// Bitwise ops
|
|
DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>,
|
|
DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>,
|
|
DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>,
|
|
|
|
// Comparison ops
|
|
IComparePattern<spirv::IEqualOp, LLVM::ICmpPredicate::eq>,
|
|
IComparePattern<spirv::INotEqualOp, LLVM::ICmpPredicate::ne>,
|
|
FComparePattern<spirv::FOrdEqualOp, LLVM::FCmpPredicate::oeq>,
|
|
FComparePattern<spirv::FOrdGreaterThanOp, LLVM::FCmpPredicate::ogt>,
|
|
FComparePattern<spirv::FOrdGreaterThanEqualOp, LLVM::FCmpPredicate::oge>,
|
|
FComparePattern<spirv::FOrdLessThanEqualOp, LLVM::FCmpPredicate::ole>,
|
|
FComparePattern<spirv::FOrdLessThanOp, LLVM::FCmpPredicate::olt>,
|
|
FComparePattern<spirv::FOrdNotEqualOp, LLVM::FCmpPredicate::one>,
|
|
FComparePattern<spirv::FUnordEqualOp, LLVM::FCmpPredicate::ueq>,
|
|
FComparePattern<spirv::FUnordGreaterThanOp, LLVM::FCmpPredicate::ugt>,
|
|
FComparePattern<spirv::FUnordGreaterThanEqualOp,
|
|
LLVM::FCmpPredicate::uge>,
|
|
FComparePattern<spirv::FUnordLessThanEqualOp, LLVM::FCmpPredicate::ule>,
|
|
FComparePattern<spirv::FUnordLessThanOp, LLVM::FCmpPredicate::ult>,
|
|
FComparePattern<spirv::FUnordNotEqualOp, LLVM::FCmpPredicate::une>,
|
|
IComparePattern<spirv::SGreaterThanOp, LLVM::ICmpPredicate::sgt>,
|
|
IComparePattern<spirv::SGreaterThanEqualOp, LLVM::ICmpPredicate::sge>,
|
|
IComparePattern<spirv::SLessThanEqualOp, LLVM::ICmpPredicate::sle>,
|
|
IComparePattern<spirv::SLessThanOp, LLVM::ICmpPredicate::slt>,
|
|
IComparePattern<spirv::UGreaterThanOp, LLVM::ICmpPredicate::ugt>,
|
|
IComparePattern<spirv::UGreaterThanEqualOp, LLVM::ICmpPredicate::uge>,
|
|
IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>,
|
|
IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>,
|
|
|
|
// Shift ops
|
|
ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>,
|
|
ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>,
|
|
ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>>(context,
|
|
typeConverter);
|
|
}
|