Files
clang-p2996/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
George Mitenkov cf2b4d5cb6 [MLIR][SPIRVToLLVM] Implemented shift conversion pattern
This patch has shift ops conversion implementation. In SPIR-V dialect,
`Shift` and `Base` may have different bit width. On the contrary,
in LLVM dialect both `Base` and `Shift` have to be of the same bit width.
This leads to the following cases:
- if `Base` has the same bit width as `Shift`, the conversion is
  straightforward.
- if `Base` has a greater bit width than `Shift`, shift is sign/zero
  extended first. Then the extended value is passed to the shift.
- otherwise the conversion is considered to be illegal.

Differential Revision: https://reviews.llvm.org/D81546
2020-06-12 19:04:30 -04:00

208 lines
8.6 KiB
C++

//===- ConvertSPIRVToLLVM.cpp - SPIR-V dialect to LLVM dialect conversion -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements patterns to convert SPIR-V dialect to LLVM dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
// Utility functions
//===----------------------------------------------------------------------===//
/// Returns true if the given type is an unsigned integer or vector type
static bool isUnsignedIntegerOrVector(Type type) {
if (type.isUnsignedInteger())
return true;
if (auto vecType = type.dyn_cast<VectorType>())
return vecType.getElementType().isUnsignedInteger();
return false;
}
//===----------------------------------------------------------------------===//
// Operation conversion
//===----------------------------------------------------------------------===//
namespace {
/// Converts SPIR-V operations that have straightforward LLVM equivalent
/// into LLVM dialect operations.
template <typename SPIRVOp, typename LLVMOp>
class DirectConversionPattern : public SPIRVToLLVMConversion<SPIRVOp> {
public:
using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
LogicalResult
matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto dstType = this->typeConverter.convertType(operation.getType());
if (!dstType)
return failure();
rewriter.template replaceOpWithNewOp<LLVMOp>(operation, dstType, operands);
return success();
}
};
/// Converts SPIR-V floating-point comparisons to llvm.fcmp "predicate"
template <typename SPIRVOp, LLVM::FCmpPredicate predicate>
class FComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
public:
using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
LogicalResult
matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto dstType = this->typeConverter.convertType(operation.getType());
if (!dstType)
return failure();
rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
operation, dstType,
rewriter.getI64IntegerAttr(static_cast<int64_t>(predicate)),
operation.operand1(), operation.operand2());
return success();
}
};
/// Converts SPIR-V integer comparisons to llvm.icmp "predicate"
template <typename SPIRVOp, LLVM::ICmpPredicate predicate>
class IComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
public:
using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
LogicalResult
matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto dstType = this->typeConverter.convertType(operation.getType());
if (!dstType)
return failure();
rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>(
operation, dstType,
rewriter.getI64IntegerAttr(static_cast<int64_t>(predicate)),
operation.operand1(), operation.operand2());
return success();
}
};
/// Converts SPIR-V shift ops to LLVM shift ops. Since LLVM dialect
/// puts a restriction on `Shift` and `Base` to have the same bit width,
/// `Shift` is zero or sign extended to match this specification. Cases when
/// `Shift` bit width > `Base` bit width are considered to be illegal.
template <typename SPIRVOp, typename LLVMOp>
class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> {
public:
using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
LogicalResult
matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto dstType = this->typeConverter.convertType(operation.getType());
if (!dstType)
return failure();
Type op1Type = operation.operand1().getType();
Type op2Type = operation.operand2().getType();
if (op1Type == op2Type) {
rewriter.template replaceOpWithNewOp<LLVMOp>(operation, dstType,
operands);
return success();
}
Location loc = operation.getLoc();
Value extended;
if (isUnsignedIntegerOrVector(op2Type)) {
extended = rewriter.template create<LLVM::ZExtOp>(loc, dstType,
operation.operand2());
} else {
extended = rewriter.template create<LLVM::SExtOp>(loc, dstType,
operation.operand2());
}
Value result = rewriter.template create<LLVMOp>(
loc, dstType, operation.operand1(), extended);
rewriter.replaceOp(operation, result);
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
// Pattern population
//===----------------------------------------------------------------------===//
void mlir::populateSPIRVToLLVMConversionPatterns(
MLIRContext *context, LLVMTypeConverter &typeConverter,
OwningRewritePatternList &patterns) {
patterns.insert<
// Arithmetic ops
DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>,
DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>,
DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>,
DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>,
DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>,
DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>,
DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>,
DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>,
DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>,
DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>,
DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>,
// Bitwise ops
DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>,
DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>,
DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>,
// Comparison ops
IComparePattern<spirv::IEqualOp, LLVM::ICmpPredicate::eq>,
IComparePattern<spirv::INotEqualOp, LLVM::ICmpPredicate::ne>,
FComparePattern<spirv::FOrdEqualOp, LLVM::FCmpPredicate::oeq>,
FComparePattern<spirv::FOrdGreaterThanOp, LLVM::FCmpPredicate::ogt>,
FComparePattern<spirv::FOrdGreaterThanEqualOp, LLVM::FCmpPredicate::oge>,
FComparePattern<spirv::FOrdLessThanEqualOp, LLVM::FCmpPredicate::ole>,
FComparePattern<spirv::FOrdLessThanOp, LLVM::FCmpPredicate::olt>,
FComparePattern<spirv::FOrdNotEqualOp, LLVM::FCmpPredicate::one>,
FComparePattern<spirv::FUnordEqualOp, LLVM::FCmpPredicate::ueq>,
FComparePattern<spirv::FUnordGreaterThanOp, LLVM::FCmpPredicate::ugt>,
FComparePattern<spirv::FUnordGreaterThanEqualOp,
LLVM::FCmpPredicate::uge>,
FComparePattern<spirv::FUnordLessThanEqualOp, LLVM::FCmpPredicate::ule>,
FComparePattern<spirv::FUnordLessThanOp, LLVM::FCmpPredicate::ult>,
FComparePattern<spirv::FUnordNotEqualOp, LLVM::FCmpPredicate::une>,
IComparePattern<spirv::SGreaterThanOp, LLVM::ICmpPredicate::sgt>,
IComparePattern<spirv::SGreaterThanEqualOp, LLVM::ICmpPredicate::sge>,
IComparePattern<spirv::SLessThanEqualOp, LLVM::ICmpPredicate::sle>,
IComparePattern<spirv::SLessThanOp, LLVM::ICmpPredicate::slt>,
IComparePattern<spirv::UGreaterThanOp, LLVM::ICmpPredicate::ugt>,
IComparePattern<spirv::UGreaterThanEqualOp, LLVM::ICmpPredicate::uge>,
IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>,
IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>,
// Shift ops
ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>,
ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>,
ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>>(context,
typeConverter);
}