Files
clang-p2996/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
Daniil Dudkin 641124a9b9 [mlir][spirv] Add conversions for Arith's maxnumf and minnumf (#66696)
This patch is part of a larger initiative aimed at fixing floating-point
`max` and `min` operations in MLIR:
https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671.

In this commit, we add conversion patterns for the newly introduced
operations `arith.minnumf` and `arith.maxnumf`. When converting to
`spirv.CL`, there is no need to insert additional guards to propagate
non-NaN values when one of the arguments is NaN because `CL` ops do
exactly the same. However, `GL` ops have undefined behavior when one of
the arguments is NaN, so we should insert additional guards to enforce
the semantics of Arith's ops.

This patch addresses the 1.5 task of the mentioned RFC.
2023-09-19 22:49:48 +03:00

1258 lines
49 KiB
C++

//===- ArithToSPIRV.cpp - Arithmetic to SPIRV dialect conversion -----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
#include "../SPIRVCommon/Pattern.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/MathExtras.h"
#include <cassert>
#include <memory>
namespace mlir {
#define GEN_PASS_DEF_CONVERTARITHTOSPIRV
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir
#define DEBUG_TYPE "arith-to-spirv-pattern"
using namespace mlir;
//===----------------------------------------------------------------------===//
// Conversion Helpers
//===----------------------------------------------------------------------===//
/// Converts the given `srcAttr` into a boolean attribute if it holds an
/// integral value. Returns null attribute if conversion fails.
static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) {
if (auto boolAttr = dyn_cast<BoolAttr>(srcAttr))
return boolAttr;
if (auto intAttr = dyn_cast<IntegerAttr>(srcAttr))
return builder.getBoolAttr(intAttr.getValue().getBoolValue());
return {};
}
/// Converts the given `srcAttr` to a new attribute of the given `dstType`.
/// Returns null attribute if conversion fails.
static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType,
Builder builder) {
// If the source number uses less active bits than the target bitwidth, then
// it should be safe to convert.
if (srcAttr.getValue().isIntN(dstType.getWidth()))
return builder.getIntegerAttr(dstType, srcAttr.getInt());
// XXX: Try again by interpreting the source number as a signed value.
// Although integers in the standard dialect are signless, they can represent
// a signed number. It's the operation decides how to interpret. This is
// dangerous, but it seems there is no good way of handling this if we still
// want to change the bitwidth. Emit a message at least.
if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) {
auto dstAttr = builder.getIntegerAttr(dstType, srcAttr.getInt());
LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' converted to '"
<< dstAttr << "' for type '" << dstType << "'\n");
return dstAttr;
}
LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr
<< "' illegal: cannot fit into target type '"
<< dstType << "'\n");
return {};
}
/// Converts the given `srcAttr` to a new attribute of the given `dstType`.
/// Returns null attribute if `dstType` is not 32-bit or conversion fails.
static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
Builder builder) {
// Only support converting to float for now.
if (!dstType.isF32())
return FloatAttr();
// Try to convert the source floating-point number to single precision.
APFloat dstVal = srcAttr.getValue();
bool losesInfo = false;
APFloat::opStatus status =
dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo);
if (status != APFloat::opOK || losesInfo) {
LLVM_DEBUG(llvm::dbgs()
<< srcAttr << " illegal: cannot fit into converted type '"
<< dstType << "'\n");
return FloatAttr();
}
return builder.getF32FloatAttr(dstVal.convertToFloat());
}
/// Returns true if the given `type` is a boolean scalar or vector type.
static bool isBoolScalarOrVector(Type type) {
assert(type && "Not a valid type");
if (type.isInteger(1))
return true;
if (auto vecType = dyn_cast<VectorType>(type))
return vecType.getElementType().isInteger(1);
return false;
}
/// Creates a scalar/vector integer constant.
static Value getScalarOrVectorConstInt(Type type, uint64_t value,
OpBuilder &builder, Location loc) {
if (auto vectorType = dyn_cast<VectorType>(type)) {
Attribute element = IntegerAttr::get(vectorType.getElementType(), value);
auto attr = SplatElementsAttr::get(vectorType, element);
return builder.create<spirv::ConstantOp>(loc, vectorType, attr);
}
if (auto intType = dyn_cast<IntegerType>(type))
return builder.create<spirv::ConstantOp>(
loc, type, builder.getIntegerAttr(type, value));
return nullptr;
}
/// Returns true if scalar/vector type `a` and `b` have the same number of
/// bitwidth.
static bool hasSameBitwidth(Type a, Type b) {
auto getNumBitwidth = [](Type type) {
unsigned bw = 0;
if (type.isIntOrFloat())
bw = type.getIntOrFloatBitWidth();
else if (auto vecType = dyn_cast<VectorType>(type))
bw = vecType.getElementTypeBitWidth() * vecType.getNumElements();
return bw;
};
unsigned aBW = getNumBitwidth(a);
unsigned bBW = getNumBitwidth(b);
return aBW != 0 && bBW != 0 && aBW == bBW;
}
/// Returns a source type conversion failure for `srcType` and operation `op`.
static LogicalResult
getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op,
Type srcType) {
return rewriter.notifyMatchFailure(
op->getLoc(),
llvm::formatv("failed to convert source type '{0}'", srcType));
}
/// Returns a source type conversion failure for the result type of `op`.
static LogicalResult
getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op) {
assert(op->getNumResults() == 1);
return getTypeConversionFailure(rewriter, op, op->getResultTypes().front());
}
namespace {
//===----------------------------------------------------------------------===//
// ConstantOp
//===----------------------------------------------------------------------===//
/// Converts composite arith.constant operation to spirv.Constant.
struct ConstantCompositeOpPattern final
: public OpConversionPattern<arith::ConstantOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto srcType = dyn_cast<ShapedType>(constOp.getType());
if (!srcType || srcType.getNumElements() == 1)
return failure();
// arith.constant should only have vector or tenor types.
assert((isa<VectorType, RankedTensorType>(srcType)));
Type dstType = getTypeConverter()->convertType(srcType);
if (!dstType)
return failure();
auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
if (!dstElementsAttr)
return failure();
ShapedType dstAttrType = dstElementsAttr.getType();
// If the composite type has more than one dimensions, perform
// linearization.
if (srcType.getRank() > 1) {
if (isa<RankedTensorType>(srcType)) {
dstAttrType = RankedTensorType::get(srcType.getNumElements(),
srcType.getElementType());
dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
} else {
// TODO: add support for large vectors.
return failure();
}
}
Type srcElemType = srcType.getElementType();
Type dstElemType;
// Tensor types are converted to SPIR-V array types; vector types are
// converted to SPIR-V vector/array types.
if (auto arrayType = dyn_cast<spirv::ArrayType>(dstType))
dstElemType = arrayType.getElementType();
else
dstElemType = cast<VectorType>(dstType).getElementType();
// If the source and destination element types are different, perform
// attribute conversion.
if (srcElemType != dstElemType) {
SmallVector<Attribute, 8> elements;
if (isa<FloatType>(srcElemType)) {
for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
FloatAttr dstAttr =
convertFloatAttr(srcAttr, cast<FloatType>(dstElemType), rewriter);
if (!dstAttr)
return failure();
elements.push_back(dstAttr);
}
} else if (srcElemType.isInteger(1)) {
return failure();
} else {
for (IntegerAttr srcAttr : dstElementsAttr.getValues<IntegerAttr>()) {
IntegerAttr dstAttr = convertIntegerAttr(
srcAttr, cast<IntegerType>(dstElemType), rewriter);
if (!dstAttr)
return failure();
elements.push_back(dstAttr);
}
}
// Unfortunately, we cannot use dialect-specific types for element
// attributes; element attributes only works with builtin types. So we
// need to prepare another converted builtin types for the destination
// elements attribute.
if (isa<RankedTensorType>(dstAttrType))
dstAttrType =
RankedTensorType::get(dstAttrType.getShape(), dstElemType);
else
dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
}
rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType,
dstElementsAttr);
return success();
}
};
/// Converts scalar arith.constant operation to spirv.Constant.
struct ConstantScalarOpPattern final
: public OpConversionPattern<arith::ConstantOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type srcType = constOp.getType();
if (auto shapedType = dyn_cast<ShapedType>(srcType)) {
if (shapedType.getNumElements() != 1)
return failure();
srcType = shapedType.getElementType();
}
if (!srcType.isIntOrIndexOrFloat())
return failure();
Attribute cstAttr = constOp.getValue();
if (auto elementsAttr = dyn_cast<DenseElementsAttr>(cstAttr))
cstAttr = elementsAttr.getSplatValue<Attribute>();
Type dstType = getTypeConverter()->convertType(srcType);
if (!dstType)
return failure();
// Floating-point types.
if (isa<FloatType>(srcType)) {
auto srcAttr = cast<FloatAttr>(cstAttr);
auto dstAttr = srcAttr;
// Floating-point types not supported in the target environment are all
// converted to float type.
if (srcType != dstType) {
dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstType), rewriter);
if (!dstAttr)
return failure();
}
rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
return success();
}
// Bool type.
if (srcType.isInteger(1)) {
// arith.constant can use 0/1 instead of true/false for i1 values. We need
// to handle that here.
auto dstAttr = convertBoolAttr(cstAttr, rewriter);
if (!dstAttr)
return failure();
rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
return success();
}
// IndexType or IntegerType. Index values are converted to 32-bit integer
// values when converting to SPIR-V.
auto srcAttr = cast<IntegerAttr>(cstAttr);
IntegerAttr dstAttr =
convertIntegerAttr(srcAttr, cast<IntegerType>(dstType), rewriter);
if (!dstAttr)
return failure();
rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
return success();
}
};
//===----------------------------------------------------------------------===//
// RemSIOp
//===----------------------------------------------------------------------===//
/// Returns signed remainder for `lhs` and `rhs` and lets the result follow
/// the sign of `signOperand`.
///
/// Note that this is needed for Vulkan. Per the Vulkan's SPIR-V environment
/// spec, "for the OpSRem and OpSMod instructions, if either operand is negative
/// the result is undefined." So we cannot directly use spirv.SRem/spirv.SMod
/// if either operand can be negative. Emulate it via spirv.UMod.
template <typename SignedAbsOp>
static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs,
Value signOperand, OpBuilder &builder) {
assert(lhs.getType() == rhs.getType());
assert(lhs == signOperand || rhs == signOperand);
Type type = lhs.getType();
// Calculate the remainder with spirv.UMod.
Value lhsAbs = builder.create<SignedAbsOp>(loc, type, lhs);
Value rhsAbs = builder.create<SignedAbsOp>(loc, type, rhs);
Value abs = builder.create<spirv::UModOp>(loc, lhsAbs, rhsAbs);
// Fix the sign.
Value isPositive;
if (lhs == signOperand)
isPositive = builder.create<spirv::IEqualOp>(loc, lhs, lhsAbs);
else
isPositive = builder.create<spirv::IEqualOp>(loc, rhs, rhsAbs);
Value absNegate = builder.create<spirv::SNegateOp>(loc, type, abs);
return builder.create<spirv::SelectOp>(loc, type, isPositive, abs, absNegate);
}
/// Converts arith.remsi to GLSL SPIR-V ops.
///
/// This cannot be merged into the template unary/binary pattern due to Vulkan
/// restrictions over spirv.SRem and spirv.SMod.
struct RemSIOpGLPattern final : public OpConversionPattern<arith::RemSIOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value result = emulateSignedRemainder<spirv::CLSAbsOp>(
op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
adaptor.getOperands()[0], rewriter);
rewriter.replaceOp(op, result);
return success();
}
};
/// Converts arith.remsi to OpenCL SPIR-V ops.
struct RemSIOpCLPattern final : public OpConversionPattern<arith::RemSIOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value result = emulateSignedRemainder<spirv::GLSAbsOp>(
op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
adaptor.getOperands()[0], rewriter);
rewriter.replaceOp(op, result);
return success();
}
};
//===----------------------------------------------------------------------===//
// BitwiseOp
//===----------------------------------------------------------------------===//
/// Converts bitwise operations to SPIR-V operations. This is a special pattern
/// other than the BinaryOpPatternPattern because if the operands are boolean
/// values, SPIR-V uses different operations (`SPIRVLogicalOp`). For
/// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`.
template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp>
struct BitwiseOpPattern final : public OpConversionPattern<Op> {
using OpConversionPattern<Op>::OpConversionPattern;
LogicalResult
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
assert(adaptor.getOperands().size() == 2);
Type dstType = this->getTypeConverter()->convertType(op.getType());
if (!dstType)
return getTypeConversionFailure(rewriter, op);
if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) {
rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(
op, dstType, adaptor.getOperands());
} else {
rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(
op, dstType, adaptor.getOperands());
}
return success();
}
};
//===----------------------------------------------------------------------===//
// XOrIOp
//===----------------------------------------------------------------------===//
/// Converts arith.xori to SPIR-V operations.
struct XOrIOpLogicalPattern final : public OpConversionPattern<arith::XOrIOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
assert(adaptor.getOperands().size() == 2);
if (isBoolScalarOrVector(adaptor.getOperands().front().getType()))
return failure();
Type dstType = getTypeConverter()->convertType(op.getType());
if (!dstType)
return getTypeConversionFailure(rewriter, op);
rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(op, dstType,
adaptor.getOperands());
return success();
}
};
/// Converts arith.xori to SPIR-V operations if the type of source is i1 or
/// vector of i1.
struct XOrIOpBooleanPattern final : public OpConversionPattern<arith::XOrIOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
assert(adaptor.getOperands().size() == 2);
if (!isBoolScalarOrVector(adaptor.getOperands().front().getType()))
return failure();
Type dstType = getTypeConverter()->convertType(op.getType());
if (!dstType)
return getTypeConversionFailure(rewriter, op);
rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(
op, dstType, adaptor.getOperands());
return success();
}
};
//===----------------------------------------------------------------------===//
// UIToFPOp
//===----------------------------------------------------------------------===//
/// Converts arith.uitofp to spirv.Select if the type of source is i1 or vector
/// of i1.
struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type srcType = adaptor.getOperands().front().getType();
if (!isBoolScalarOrVector(srcType))
return failure();
Type dstType = getTypeConverter()->convertType(op.getType());
if (!dstType)
return getTypeConversionFailure(rewriter, op);
Location loc = op.getLoc();
Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
rewriter.replaceOpWithNewOp<spirv::SelectOp>(
op, dstType, adaptor.getOperands().front(), one, zero);
return success();
}
};
//===----------------------------------------------------------------------===//
// ExtSIOp
//===----------------------------------------------------------------------===//
/// Converts arith.extsi to spirv.Select if the type of source is i1 or vector
/// of i1.
struct ExtSII1Pattern final : public OpConversionPattern<arith::ExtSIOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value operand = adaptor.getIn();
if (!isBoolScalarOrVector(operand.getType()))
return failure();
Location loc = op.getLoc();
Type dstType = getTypeConverter()->convertType(op.getType());
if (!dstType)
return getTypeConversionFailure(rewriter, op);
Value allOnes;
if (auto intTy = dyn_cast<IntegerType>(dstType)) {
unsigned componentBitwidth = intTy.getWidth();
allOnes = rewriter.create<spirv::ConstantOp>(
loc, intTy,
rewriter.getIntegerAttr(intTy, APInt::getAllOnes(componentBitwidth)));
} else if (auto vectorTy = dyn_cast<VectorType>(dstType)) {
unsigned componentBitwidth = vectorTy.getElementTypeBitWidth();
allOnes = rewriter.create<spirv::ConstantOp>(
loc, vectorTy,
SplatElementsAttr::get(vectorTy,
APInt::getAllOnes(componentBitwidth)));
} else {
return rewriter.notifyMatchFailure(
loc, llvm::formatv("unhandled type: {0}", dstType));
}
Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, operand, allOnes,
zero);
return success();
}
};
/// Converts arith.extsi to spirv.Select if the type of source is neither i1 nor
/// vector of i1.
struct ExtSIPattern final : public OpConversionPattern<arith::ExtSIOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type srcType = adaptor.getIn().getType();
if (isBoolScalarOrVector(srcType))
return failure();
Type dstType = getTypeConverter()->convertType(op.getType());
if (!dstType)
return getTypeConversionFailure(rewriter, op);
if (dstType == srcType) {
// We can have the same source and destination type due to type emulation.
// Perform bit shifting to make sure we have the proper leading set bits.
unsigned srcBW =
getElementTypeOrSelf(op.getIn().getType()).getIntOrFloatBitWidth();
unsigned dstBW =
getElementTypeOrSelf(op.getType()).getIntOrFloatBitWidth();
assert(srcBW < dstBW);
Value shiftSize = getScalarOrVectorConstInt(dstType, dstBW - srcBW,
rewriter, op.getLoc());
// First shift left to sequeeze out all leading bits beyond the original
// bitwidth. Here we need to use the original source and result type's
// bitwidth.
auto shiftLOp = rewriter.create<spirv::ShiftLeftLogicalOp>(
op.getLoc(), dstType, adaptor.getIn(), shiftSize);
// Then we perform arithmetic right shift to make sure we have the right
// sign bits for negative values.
rewriter.replaceOpWithNewOp<spirv::ShiftRightArithmeticOp>(
op, dstType, shiftLOp, shiftSize);
} else {
rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, dstType,
adaptor.getOperands());
}
return success();
}
};
//===----------------------------------------------------------------------===//
// ExtUIOp
//===----------------------------------------------------------------------===//
/// Converts arith.extui to spirv.Select if the type of source is i1 or vector
/// of i1.
struct ExtUII1Pattern final : public OpConversionPattern<arith::ExtUIOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type srcType = adaptor.getOperands().front().getType();
if (!isBoolScalarOrVector(srcType))
return failure();
Type dstType = getTypeConverter()->convertType(op.getType());
if (!dstType)
return getTypeConversionFailure(rewriter, op);
Location loc = op.getLoc();
Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
rewriter.replaceOpWithNewOp<spirv::SelectOp>(
op, dstType, adaptor.getOperands().front(), one, zero);
return success();
}
};
/// Converts arith.extui for cases where the type of source is neither i1 nor
/// vector of i1.
struct ExtUIPattern final : public OpConversionPattern<arith::ExtUIOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type srcType = adaptor.getIn().getType();
if (isBoolScalarOrVector(srcType))
return failure();
Type dstType = getTypeConverter()->convertType(op.getType());
if (!dstType)
return getTypeConversionFailure(rewriter, op);
if (dstType == srcType) {
// We can have the same source and destination type due to type emulation.
// Perform bit masking to make sure we don't pollute downstream consumers
// with unwanted bits. Here we need to use the original source type's
// bitwidth.
unsigned bitwidth =
getElementTypeOrSelf(op.getIn().getType()).getIntOrFloatBitWidth();
Value mask = getScalarOrVectorConstInt(
dstType, llvm::maskTrailingOnes<uint64_t>(bitwidth), rewriter,
op.getLoc());
rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
adaptor.getIn(), mask);
} else {
rewriter.replaceOpWithNewOp<spirv::UConvertOp>(op, dstType,
adaptor.getOperands());
}
return success();
}
};
//===----------------------------------------------------------------------===//
// TruncIOp
//===----------------------------------------------------------------------===//
/// Converts arith.trunci to spirv.Select if the type of result is i1 or vector
/// of i1.
struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type dstType = getTypeConverter()->convertType(op.getType());
if (!dstType)
return getTypeConversionFailure(rewriter, op);
if (!isBoolScalarOrVector(dstType))
return failure();
Location loc = op.getLoc();
auto srcType = adaptor.getOperands().front().getType();
// Check if (x & 1) == 1.
Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter);
Value maskedSrc = rewriter.create<spirv::BitwiseAndOp>(
loc, srcType, adaptor.getOperands()[0], mask);
Value isOne = rewriter.create<spirv::IEqualOp>(loc, maskedSrc, mask);
Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isOne, one, zero);
return success();
}
};
/// Converts arith.trunci for cases where the type of result is neither i1
/// nor vector of i1.
struct TruncIPattern final : public OpConversionPattern<arith::TruncIOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type srcType = adaptor.getIn().getType();
Type dstType = getTypeConverter()->convertType(op.getType());
if (!dstType)
return getTypeConversionFailure(rewriter, op);
if (isBoolScalarOrVector(dstType))
return failure();
if (dstType == srcType) {
// We can have the same source and destination type due to type emulation.
// Perform bit masking to make sure we don't pollute downstream consumers
// with unwanted bits. Here we need to use the original result type's
// bitwidth.
unsigned bw = getElementTypeOrSelf(op.getType()).getIntOrFloatBitWidth();
Value mask = getScalarOrVectorConstInt(
dstType, llvm::maskTrailingOnes<uint64_t>(bw), rewriter, op.getLoc());
rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
adaptor.getIn(), mask);
} else {
// Given this is truncation, either SConvertOp or UConvertOp works.
rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, dstType,
adaptor.getOperands());
}
return success();
}
};
//===----------------------------------------------------------------------===//
// TypeCastingOp
//===----------------------------------------------------------------------===//
/// Converts type-casting standard operations to SPIR-V operations.
template <typename Op, typename SPIRVOp>
struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
using OpConversionPattern<Op>::OpConversionPattern;
LogicalResult
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
assert(adaptor.getOperands().size() == 1);
Type srcType = adaptor.getOperands().front().getType();
Type dstType = this->getTypeConverter()->convertType(op.getType());
if (!dstType)
return getTypeConversionFailure(rewriter, op);
if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType))
return failure();
if (dstType == srcType) {
// Due to type conversion, we are seeing the same source and target type.
// Then we can just erase this operation by forwarding its operand.
rewriter.replaceOp(op, adaptor.getOperands().front());
} else {
rewriter.template replaceOpWithNewOp<SPIRVOp>(op, dstType,
adaptor.getOperands());
}
return success();
}
};
//===----------------------------------------------------------------------===//
// CmpIOp
//===----------------------------------------------------------------------===//
/// Converts integer compare operation on i1 type operands to SPIR-V ops.
class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type srcType = op.getLhs().getType();
if (!isBoolScalarOrVector(srcType))
return failure();
Type dstType = getTypeConverter()->convertType(srcType);
if (!dstType)
return getTypeConversionFailure(rewriter, op, srcType);
switch (op.getPredicate()) {
case arith::CmpIPredicate::eq: {
rewriter.replaceOpWithNewOp<spirv::LogicalEqualOp>(op, adaptor.getLhs(),
adaptor.getRhs());
return success();
}
case arith::CmpIPredicate::ne: {
rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(
op, adaptor.getLhs(), adaptor.getRhs());
return success();
}
case arith::CmpIPredicate::uge:
case arith::CmpIPredicate::ugt:
case arith::CmpIPredicate::ule:
case arith::CmpIPredicate::ult: {
// There are no direct corresponding instructions in SPIR-V for such
// cases. Extend them to 32-bit and do comparision then.
Type type = rewriter.getI32Type();
if (auto vectorType = dyn_cast<VectorType>(dstType))
type = VectorType::get(vectorType.getShape(), type);
Value extLhs =
rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getLhs());
Value extRhs =
rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getRhs());
rewriter.replaceOpWithNewOp<arith::CmpIOp>(op, op.getPredicate(), extLhs,
extRhs);
return success();
}
default:
break;
}
return failure();
}
};
/// Converts integer compare operation to SPIR-V ops.
class CmpIOpPattern final : public OpConversionPattern<arith::CmpIOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type srcType = op.getLhs().getType();
if (isBoolScalarOrVector(srcType))
return failure();
Type dstType = getTypeConverter()->convertType(srcType);
if (!dstType)
return getTypeConversionFailure(rewriter, op, srcType);
switch (op.getPredicate()) {
#define DISPATCH(cmpPredicate, spirvOp) \
case cmpPredicate: \
if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() && \
!getElementTypeOrSelf(srcType).isIndex() && srcType != dstType && \
!hasSameBitwidth(srcType, dstType)) { \
return op.emitError( \
"bitwidth emulation is not implemented yet on unsigned op"); \
} \
rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
adaptor.getRhs()); \
return success();
DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp);
DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp);
DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp);
DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp);
DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp);
DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp);
DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp);
DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp);
DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp);
#undef DISPATCH
}
return failure();
}
};
//===----------------------------------------------------------------------===//
// CmpFOpPattern
//===----------------------------------------------------------------------===//
/// Converts floating-point comparison operations to SPIR-V ops.
class CmpFOpPattern final : public OpConversionPattern<arith::CmpFOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
switch (op.getPredicate()) {
#define DISPATCH(cmpPredicate, spirvOp) \
case cmpPredicate: \
rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
adaptor.getRhs()); \
return success();
// Ordered.
DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp);
DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp);
DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
// Unordered.
DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp);
DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp);
DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
#undef DISPATCH
default:
break;
}
return failure();
}
};
/// Converts floating point NaN check to SPIR-V ops. This pattern requires
/// Kernel capability.
class CmpFOpNanKernelPattern final : public OpConversionPattern<arith::CmpFOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (op.getPredicate() == arith::CmpFPredicate::ORD) {
rewriter.replaceOpWithNewOp<spirv::OrderedOp>(op, adaptor.getLhs(),
adaptor.getRhs());
return success();
}
if (op.getPredicate() == arith::CmpFPredicate::UNO) {
rewriter.replaceOpWithNewOp<spirv::UnorderedOp>(op, adaptor.getLhs(),
adaptor.getRhs());
return success();
}
return failure();
}
};
/// Converts floating point NaN check to SPIR-V ops. This pattern does not
/// require additional capability.
class CmpFOpNanNonePattern final : public OpConversionPattern<arith::CmpFOp> {
public:
using OpConversionPattern<arith::CmpFOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (op.getPredicate() != arith::CmpFPredicate::ORD &&
op.getPredicate() != arith::CmpFPredicate::UNO)
return failure();
Location loc = op.getLoc();
auto *converter = getTypeConverter<SPIRVTypeConverter>();
Value replace;
if (converter->getOptions().enableFastMathMode) {
if (op.getPredicate() == arith::CmpFPredicate::ORD) {
// Ordered comparsion checks if neither operand is NaN.
replace = spirv::ConstantOp::getOne(op.getType(), loc, rewriter);
} else {
// Unordered comparsion checks if either operand is NaN.
replace = spirv::ConstantOp::getZero(op.getType(), loc, rewriter);
}
} else {
Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
replace = rewriter.create<spirv::LogicalOrOp>(loc, lhsIsNan, rhsIsNan);
if (op.getPredicate() == arith::CmpFPredicate::ORD)
replace = rewriter.create<spirv::LogicalNotOp>(loc, replace);
}
rewriter.replaceOp(op, replace);
return success();
}
};
//===----------------------------------------------------------------------===//
// AddUIExtendedOp
//===----------------------------------------------------------------------===//
/// Converts arith.addui_extended to spirv.IAddCarry.
class AddUIExtendedOpPattern final
: public OpConversionPattern<arith::AddUIExtendedOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type dstElemTy = adaptor.getLhs().getType();
Location loc = op->getLoc();
Value result = rewriter.create<spirv::IAddCarryOp>(loc, adaptor.getLhs(),
adaptor.getRhs());
Value sumResult = rewriter.create<spirv::CompositeExtractOp>(
loc, result, llvm::ArrayRef(0));
Value carryValue = rewriter.create<spirv::CompositeExtractOp>(
loc, result, llvm::ArrayRef(1));
// Convert the carry value to boolean.
Value one = spirv::ConstantOp::getOne(dstElemTy, loc, rewriter);
Value carryResult = rewriter.create<spirv::IEqualOp>(loc, carryValue, one);
rewriter.replaceOp(op, {sumResult, carryResult});
return success();
}
};
//===----------------------------------------------------------------------===//
// MulIExtendedOp
//===----------------------------------------------------------------------===//
/// Converts arith.mul*i_extended to spirv.*MulExtended.
template <typename ArithMulOp, typename SPIRVMulOp>
class MulIExtendedOpPattern final : public OpConversionPattern<ArithMulOp> {
public:
using OpConversionPattern<ArithMulOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(ArithMulOp op, typename ArithMulOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
Value result =
rewriter.create<SPIRVMulOp>(loc, adaptor.getLhs(), adaptor.getRhs());
Value low = rewriter.create<spirv::CompositeExtractOp>(loc, result,
llvm::ArrayRef(0));
Value high = rewriter.create<spirv::CompositeExtractOp>(loc, result,
llvm::ArrayRef(1));
rewriter.replaceOp(op, {low, high});
return success();
}
};
//===----------------------------------------------------------------------===//
// SelectOp
//===----------------------------------------------------------------------===//
/// Converts arith.select to spirv.Select.
class SelectOpPattern final : public OpConversionPattern<arith::SelectOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, adaptor.getCondition(),
adaptor.getTrueValue(),
adaptor.getFalseValue());
return success();
}
};
//===----------------------------------------------------------------------===//
// MinimumFOp, MaximumFOp
//===----------------------------------------------------------------------===//
/// Converts arith.maximumf/minimumf to spirv.GL.FMax/FMin or
/// spirv.CL.fmax/fmin.
template <typename Op, typename SPIRVOp>
class MinimumMaximumFOpPattern final : public OpConversionPattern<Op> {
public:
using OpConversionPattern<Op>::OpConversionPattern;
LogicalResult
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto *converter = this->template getTypeConverter<SPIRVTypeConverter>();
Type dstType = converter->convertType(op.getType());
if (!dstType)
return getTypeConversionFailure(rewriter, op);
// arith.maximumf/minimumf:
// "if one of the arguments is NaN, then the result is also NaN."
// spirv.GL.FMax/FMin
// "which operand is the result is undefined if one of the operands
// is a NaN."
// spirv.CL.fmax/fmin:
// "If one argument is a NaN, Fmin returns the other argument."
Location loc = op.getLoc();
Value spirvOp =
rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands());
if (converter->getOptions().enableFastMathMode) {
rewriter.replaceOp(op, spirvOp);
return success();
}
Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
Value select1 = rewriter.create<spirv::SelectOp>(loc, dstType, lhsIsNan,
adaptor.getLhs(), spirvOp);
Value select2 = rewriter.create<spirv::SelectOp>(loc, dstType, rhsIsNan,
adaptor.getRhs(), select1);
rewriter.replaceOp(op, select2);
return success();
}
};
//===----------------------------------------------------------------------===//
// MinNumFOp, MaxNumFOp
//===----------------------------------------------------------------------===//
/// Converts arith.maxnumf/minnumf to spirv.GL.FMax/FMin or
/// spirv.CL.fmax/fmin.
template <typename Op, typename SPIRVOp>
class MinNumMaxNumFOpPattern final : public OpConversionPattern<Op> {
template <typename TargetOp>
constexpr bool shouldInsertNanGuards() const {
return llvm::is_one_of<TargetOp, spirv::GLFMaxOp, spirv::GLFMinOp>::value;
}
public:
using OpConversionPattern<Op>::OpConversionPattern;
LogicalResult
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto *converter = this->template getTypeConverter<SPIRVTypeConverter>();
Type dstType = converter->convertType(op.getType());
if (!dstType)
return getTypeConversionFailure(rewriter, op);
// arith.maxnumf/minnumf:
// "If one of the arguments is NaN, then the result is the other
// argument."
// spirv.GL.FMax/FMin
// "which operand is the result is undefined if one of the operands
// is a NaN."
// spirv.CL.fmax/fmin:
// "If one argument is a NaN, Fmin returns the other argument."
Location loc = op.getLoc();
Value spirvOp =
rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands());
if (!shouldInsertNanGuards<SPIRVOp>() ||
converter->getOptions().enableFastMathMode) {
rewriter.replaceOp(op, spirvOp);
return success();
}
Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
Value select1 = rewriter.create<spirv::SelectOp>(loc, dstType, lhsIsNan,
adaptor.getRhs(), spirvOp);
Value select2 = rewriter.create<spirv::SelectOp>(loc, dstType, rhsIsNan,
adaptor.getLhs(), select1);
rewriter.replaceOp(op, select2);
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
// Pattern Population
//===----------------------------------------------------------------------===//
void mlir::arith::populateArithToSPIRVPatterns(
SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
// clang-format off
patterns.add<
ConstantCompositeOpPattern,
ConstantScalarOpPattern,
spirv::ElementwiseOpPattern<arith::AddIOp, spirv::IAddOp>,
spirv::ElementwiseOpPattern<arith::SubIOp, spirv::ISubOp>,
spirv::ElementwiseOpPattern<arith::MulIOp, spirv::IMulOp>,
spirv::ElementwiseOpPattern<arith::DivUIOp, spirv::UDivOp>,
spirv::ElementwiseOpPattern<arith::DivSIOp, spirv::SDivOp>,
spirv::ElementwiseOpPattern<arith::RemUIOp, spirv::UModOp>,
RemSIOpGLPattern, RemSIOpCLPattern,
BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
XOrIOpLogicalPattern, XOrIOpBooleanPattern,
spirv::ElementwiseOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>,
spirv::ElementwiseOpPattern<arith::ShRUIOp, spirv::ShiftRightLogicalOp>,
spirv::ElementwiseOpPattern<arith::ShRSIOp, spirv::ShiftRightArithmeticOp>,
spirv::ElementwiseOpPattern<arith::NegFOp, spirv::FNegateOp>,
spirv::ElementwiseOpPattern<arith::AddFOp, spirv::FAddOp>,
spirv::ElementwiseOpPattern<arith::SubFOp, spirv::FSubOp>,
spirv::ElementwiseOpPattern<arith::MulFOp, spirv::FMulOp>,
spirv::ElementwiseOpPattern<arith::DivFOp, spirv::FDivOp>,
spirv::ElementwiseOpPattern<arith::RemFOp, spirv::FRemOp>,
ExtUIPattern, ExtUII1Pattern,
ExtSIPattern, ExtSII1Pattern,
TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>,
TruncIPattern, TruncII1Pattern,
TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>,
TypeCastingOpPattern<arith::UIToFPOp, spirv::ConvertUToFOp>, UIToFPI1Pattern,
TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
TypeCastingOpPattern<arith::FPToUIOp, spirv::ConvertFToUOp>,
TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
TypeCastingOpPattern<arith::IndexCastUIOp, spirv::UConvertOp>,
TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
CmpIOpBooleanPattern, CmpIOpPattern,
CmpFOpNanNonePattern, CmpFOpPattern,
AddUIExtendedOpPattern,
MulIExtendedOpPattern<arith::MulSIExtendedOp, spirv::SMulExtendedOp>,
MulIExtendedOpPattern<arith::MulUIExtendedOp, spirv::UMulExtendedOp>,
SelectOpPattern,
MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::GLFMaxOp>,
MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::GLFMinOp>,
MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::GLFMaxOp>,
MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::GLFMinOp>,
spirv::ElementwiseOpPattern<arith::MaxSIOp, spirv::GLSMaxOp>,
spirv::ElementwiseOpPattern<arith::MaxUIOp, spirv::GLUMaxOp>,
spirv::ElementwiseOpPattern<arith::MinSIOp, spirv::GLSMinOp>,
spirv::ElementwiseOpPattern<arith::MinUIOp, spirv::GLUMinOp>,
MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::CLFMaxOp>,
MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::CLFMinOp>,
MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::CLFMaxOp>,
MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::CLFMinOp>,
spirv::ElementwiseOpPattern<arith::MaxSIOp, spirv::CLSMaxOp>,
spirv::ElementwiseOpPattern<arith::MaxUIOp, spirv::CLUMaxOp>,
spirv::ElementwiseOpPattern<arith::MinSIOp, spirv::CLSMinOp>,
spirv::ElementwiseOpPattern<arith::MinUIOp, spirv::CLUMinOp>
>(typeConverter, patterns.getContext());
// clang-format on
// Give CmpFOpNanKernelPattern a higher benefit so it can prevail when Kernel
// capability is available.
patterns.add<CmpFOpNanKernelPattern>(typeConverter, patterns.getContext(),
/*benefit=*/2);
}
//===----------------------------------------------------------------------===//
// Pass Definition
//===----------------------------------------------------------------------===//
namespace {
struct ConvertArithToSPIRVPass
: public impl::ConvertArithToSPIRVBase<ConvertArithToSPIRVPass> {
void runOnOperation() override {
Operation *op = getOperation();
spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
std::unique_ptr<SPIRVConversionTarget> target =
SPIRVConversionTarget::get(targetAttr);
SPIRVConversionOptions options;
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
options.enableFastMathMode = this->enableFastMath;
SPIRVTypeConverter typeConverter(targetAttr, options);
// Use UnrealizedConversionCast as the bridge so that we don't need to pull
// in patterns for other dialects.
target->addLegalOp<UnrealizedConversionCastOp>();
// Fail hard when there are any remaining 'arith' ops.
target->addIllegalDialect<arith::ArithDialect>();
RewritePatternSet patterns(&getContext());
arith::populateArithToSPIRVPatterns(typeConverter, patterns);
if (failed(applyPartialConversion(op, *target, std::move(patterns))))
signalPassFailure();
}
};
} // namespace
std::unique_ptr<OperationPass<>> mlir::arith::createConvertArithToSPIRVPass() {
return std::make_unique<ConvertArithToSPIRVPass>();
}